# 2. MAAI Model Development Walkthrough

This notebook provides a step-by-step guide to preprocessing the data, building, training, and evaluating the MAAI model interactively.

### Step 1: Imports and Setup

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf
import sys
import os

# Add root directory to path to import our custom modules
sys.path.append(os.path.abspath(os.path.join('..')))

from data_preprocessing import config
from model_development.maai_model import build_maai_model
from model_development.utils import get_feature_sets, create_sequences, plot_training_history, plot_roc_curve, plot_pr_curve

### Step 2: Load Preprocessed Data
This assumes you have already run `python data_preprocessing/run_preprocessing.py`.

In [None]:
data_path = os.path.join(config.PROCESSED_DATA_PATH, "processed_feature_matrix.parquet")
df = pd.read_parquet(data_path)
print(f"Loaded data with shape: {df.shape}")
df.head()

### Step 3: Scale Features and Create Sequences
We will now scale the continuous features and reshape the flat dataframe into 3D sequences (`samples`, `timesteps`, `features`) suitable for LSTMs.

In [None]:
# Get feature name lists
vitals_cols, labs_cols, meds_cols = get_feature_sets(config)

# Scale continuous features
feature_cols_to_scale = vitals_cols + labs_cols
scaler = StandardScaler()
df[feature_cols_to_scale] = scaler.fit_transform(df[feature_cols_to_scale])
print("Features scaled.")

In [None]:
# Create sequences
(X_vitals, X_labs, X_meds), y = create_sequences(
    df, vitals_cols, labs_cols, meds_cols, config.TARGET_VARIABLE, config.LOOKBACK_WINDOW_HOURS
)

print(f"Shape of Vitals sequences: {X_vitals.shape}")
print(f"Shape of Labs sequences: {X_labs.shape}")
print(f"Shape of Meds sequences: {X_meds.shape}")
print(f"Shape of labels: {y.shape}")

### Step 4: Split Data into Training and Test Sets

In [None]:
indices = np.arange(X_vitals.shape[0])
train_indices, test_indices = train_test_split(
    indices, 
    test_size=config.TEST_SPLIT_SIZE, 
    random_state=config.RANDOM_STATE, 
    stratify=y # Important for imbalanced datasets
)

X_train_v, X_test_v = X_vitals[train_indices], X_vitals[test_indices]
X_train_l, X_test_l = X_labs[train_indices], X_labs[test_indices]
X_train_m, X_test_m = X_meds[train_indices], X_meds[test_indices]
y_train, y_test = y[train_indices], y[test_indices]

print(f"Training set size: {len(y_train)}")
print(f"Test set size: {len(y_test)}")

### Step 5: Build and Train the MAAI Model

In [None]:
model = build_maai_model(
    n_features_vitals=X_train_v.shape[2],
    n_features_labs=X_train_l.shape[2],
    n_features_meds=X_train_m.shape[2],
    lookback_window=config.LOOKBACK_WINDOW_HOURS
)
model.summary()

In [None]:
# Calculate class weights to handle imbalance
neg, pos = np.bincount(y_train)
total = neg + pos
class_weight = {0: (1 / neg) * (total / 2.0), 1: (1 / pos) * (total / 2.0)}
print(f"Class weights: {class_weight}")

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_auprc', mode='max', patience=10, restore_best_weights=True)
]

history = model.fit(
    [X_train_v, X_train_l, X_train_m], y_train,
    validation_split=config.VALIDATION_SPLIT_SIZE / (1 - config.TEST_SPLIT_SIZE), # Adjust validation split for the training set size
    epochs=100,
    batch_size=256,
    class_weight=class_weight,
    callbacks=callbacks,
    verbose=1
)

### Step 6: Evaluate the Model

In [None]:
plot_training_history(history, os.path.join(config.RESULTS_PATH, 'notebook_training_history.png'))

In [None]:
y_pred_proba = model.predict([X_test_v, X_test_l, X_test_m]).ravel()

plot_roc_curve(y_test, y_pred_proba, os.path.join(config.RESULTS_PATH, 'notebook_roc_curve.png'))
plot_pr_curve(y_test, y_pred_proba, os.path.join(config.RESULTS_PATH, 'notebook_pr_curve.png'))

In [None]:
y_pred_class = (y_pred_proba > 0.5).astype(int)
print("\nClassification Report (Threshold = 0.5):")
print(classification_report(y_test, y_pred_class))

In [None]:
cm = confusion_matrix(y_test, y_pred_class)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['No ABX', 'Start ABX'], yticklabels=['No ABX', 'Start ABX'])
plt.title('Confusion Matrix')
plt.ylabel('Actual Label')
plt.xlabel('Predicted Label')
plt.show()