# Anti-Malarial Demand Forecasting Analysis

This notebook tests the Airflow DAG submission by performing a simple demand forecasting analysis for anti-malarial drugs. It:
- Reads sample data from a MinIO bucket.
- Processes data using PySpark.
- Writes results to an Apache Iceberg table.
- Queries the Iceberg table using Trino.

In [None]:
# Parameters injected by Papermill
minio_bucket = 'demand-data'  # Default value, overridden by DAG

In [None]:
# Import required libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg, lit
from minio import Minio
import trino
import pandas as pd
import os

# Initialize SparkSession with Iceberg configurations
spark = SparkSession.builder \
    .appName('DemandForecasting') \
    .config('spark.master', 'spark://spark-master:7077') \
    .config('spark.jars.packages', 'org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.6.1,org.apache.kafka:kafka-clients:3.6.2') \
    .config('spark.sql.extensions', 'org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions') \
    .config('spark.sql.catalog.iceberg', 'org.apache.iceberg.spark.SparkCatalog') \
    .config('spark.sql.catalog.iceberg.type', 'rest') \
    .config('spark.sql.catalog.iceberg.uri', 'http://iceberg-rest:8181') \
    .config('spark.sql.catalog.iceberg.warehouse', 's3a://iceberg-warehouse/') \
    .config('spark.hadoop.fs.s3a.endpoint', 'http://minio:9000') \
    .config('spark.hadoop.fs.s3a.access.key', 'minioadmin') \
    .config('spark.hadoop.fs.s3a.secret.key', 'minioadmin') \
    .config('spark.hadoop.fs.s3a.path.style.access', 'true') \
    .config('spark.hadoop.fs.s3a.impl', 'org.apache.hadoop.fs.s3a.S3AFileSystem') \
    .getOrCreate()

# Initialize MinIO client
minio_client = Minio(
    'minio:9000',
    access_key='minioadmin',
    secret_key='minioadmin',
    secure=False
)

## Step 1: Create Sample Data in MinIO

Create a sample CSV file with anti-malarial drug demand data and upload it to the specified MinIO bucket.

In [None]:
# Create sample data
sample_data = pd.DataFrame({
    'region': ['East Africa', 'West Africa', 'Southern Africa', 'East Africa', 'West Africa'],
    'drug': ['Artemether', 'Artemether', 'Lumefantrine', 'Lumefantrine', 'Artemether'],
    'year': [2023, 2023, 2023, 2024, 2024],
    'demand_units': [1000, 1500, 800, 1200, 1700]
})

# Save to temporary CSV
sample_csv_path = '/tmp/sample_demand_data.csv'
sample_data.to_csv(sample_csv_path, index=False)

# Create bucket if it doesn't exist
if not minio_client.bucket_exists(minio_bucket):
    minio_client.make_bucket(minio_bucket)

# Upload sample data to MinIO
minio_client.fput_object(
    minio_bucket,
    'sample_demand_data.csv',
    sample_csv_path
)
print(f'Uploaded sample data to {minio_bucket}/sample_demand_data.csv')

# Clean up temporary file
os.remove(sample_csv_path)

## Step 2: Read and Process Data with Spark

Read the CSV file from MinIO, calculate average demand per region and drug, and create a simple forecast.

In [None]:
# Read CSV from MinIO
input_path = f's3a://{minio_bucket}/sample_demand_data.csv'
df = spark.read.option('header', 'true').csv(input_path)

# Calculate average demand per region and drug
avg_demand = df.groupBy('region', 'drug').agg(avg('demand_units').alias('avg_demand_units'))

# Create forecast (e.g., increase by 10% for next year)
forecast_df = avg_demand.withColumn('forecast_year', lit(2025)) \
                       .withColumn('forecast_demand_units', col('avg_demand_units') * 1.10)

forecast_df.show()

## Step 3: Write Results to Iceberg Table

Write the forecast data to an Iceberg table named `demand_table`.

In [None]:
# Write to Iceberg table
forecast_df.write.mode('overwrite').saveAsTable('iceberg.demand_table')
print('Wrote forecast data to iceberg.demand_table')

## Step 4: Verify with Trino Query

Query the Iceberg table using Trino to ensure data was written correctly.

In [None]:
# Connect to Trino
conn = trino.dbapi.connect(
    host='trino-coordinator',
    port=8084,
    user='trino',
    catalog='iceberg',
    schema='default'
)

# Execute query
cursor = conn.cursor()
cursor.execute('SELECT * FROM iceberg.demand_table LIMIT 5')
results = cursor.fetchall()

# Display results
print('Trino query results:')
for row in results:
    print(row)

# Close connection
cursor.close()
conn.close()

## Step 5: Cleanup

Stop the Spark session.

In [None]:
# Stop Spark session
spark.stop()
print('Spark session stopped')