# Imports

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, roc_curve, auc, classification_report, confusion_matrix
from sklearn.preprocessing import label_binarize
import re
import random
import seaborn as sns
from wordcloud import WordCloud
import matplotlib.pyplot as plt
import os
import time
import copy

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed}")

set_seed(42)

# Config, Load & Preprocess

In [None]:
# Define the 12 categories
CATEGORIES = ['Pathology', 'Anatomy', 'Pharmacology', 'Microbiology', 'Gynaecology & Obstetrics', 'Dental', 'Physiology',
              'Biochemistry', 'Pediatrics', 'Ophthalmology', 'Psychiatry', 'Radiology']


MAX_LENGTH = 64

tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")

try:
    data_path = "data/" 
    
    train_df_raw = pd.read_json(data_path + "train.json", lines=True)
    dev_df_raw = pd.read_json(data_path + "dev.json", lines=True)
    test_df_raw = pd.read_json(data_path + "test.json", lines=True)
    print(f"Loaded raw data: train({len(train_df_raw)}), dev({len(dev_df_raw)}), test({len(test_df_raw)})\")")
except Exception as e:
    print(f"Error loading JSON files: {e}")


def parse_correct_answers(row):

    correct_answers = [0, 0, 0, 0]
    exp_text = str(row['exp']) 

    if pd.notnull(exp_text):

        correct_indices = re.findall(r'[.\(\s]([a-dA-D])[.\)\s]', exp_text)
        correct_mapping = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'A': 0, 'B': 1, 'C': 2, 'D': 3}

        for ans in set(correct_indices): 
            if ans in correct_mapping:
                correct_answers[correct_mapping[ans]] = 1
    return correct_answers

def preprocess_data(df, label_encoder=None, is_train=True):
    df_filtered = df[df['subject_name'].isin(CATEGORIES)].copy()

    if is_train:
        label_encoder = LabelEncoder()
        df_filtered['subject_encoded'] = label_encoder.fit_transform(df_filtered['subject_name'])
    else:
        if label_encoder is None:
            raise ValueError("label_encoder must be provided for dev/test data")
        df_filtered['subject_encoded'] = df_filtered['subject_name'].apply(
            lambda x: label_encoder.transform([x])[0] if x in label_encoder.classes_ else -1
        )
        df_filtered = df_filtered[df_filtered['subject_encoded'] != -1]

    if 'exp' in df_filtered.columns:
        print("Found 'exp' column. Parsing answer labels...")
        df_filtered['answer_labels'] = df_filtered.apply(parse_correct_answers, axis=1)
    else:
        print("Warning: 'exp' column not found. Creating dummy answer labels [0,0,0,0].")
        dummy_labels = [[0, 0, 0, 0] for _ in range(len(df_filtered))]
        df_filtered['answer_labels'] = dummy_labels

    # Create Negation Feature
    negation_words = [' except ', ' not ', ' false ', ' incorrect ']
    df_filtered['negation_feature'] = df_filtered['question'].apply(
        lambda x: 1.0 if any(word in str(x).lower() for word in negation_words) else 0.0
    )

    if is_train:
        return df_filtered, label_encoder
    else:
        return df_filtered

print("Preprocessing data...")

train_df, label_encoder = preprocess_data(train_df_raw, is_train=True)

dev_df = preprocess_data(dev_df_raw, label_encoder=label_encoder, is_train=False)
test_df = preprocess_data(test_df_raw, label_encoder=label_encoder, is_train=False)

print(f"Filtered data: train({len(train_df)}), dev({len(dev_df)}), test({len(test_df)})")
print("Preprocessing complete.")

# Data Visualization

In [None]:
print("Generating data visualizations")

print("Generating Word Cloud...")
plt.figure(figsize=(10, 5))
wordcloud = WordCloud(width=800, height=400, background_color='white').generate(' '.join(train_df['question']))
plt.imshow(wordcloud, interpolation='bilinear')
plt.axis('off')
plt.title('Most Frequent Terms in MedMCQA Questions', fontsize=14)
plt.show()

print("Generating Correct Option Distribution...")
option_counts = train_df['cop'].value_counts()
colors = ['#F05454', '#30475E', '#F9D14B', '#DFFB95']
explode = (0.05, 0.05, 0.05, 0.05)
plt.figure(figsize=(8, 8))
plt.pie(option_counts,
        labels=[f'Option {idx}' for idx in option_counts.index],
        autopct='%1.1f%%',
        startangle=90,
        colors=colors,
        explode=explode,
        shadow=True)
plt.title('Correct Option Distribution', fontsize=14)
plt.axis('equal')
plt.legend(option_counts.index, loc='upper right')
plt.show()

print("Generating Subject Category Distribution...")
subject_counts = train_df['subject_name'].value_counts()

plt.figure(figsize=(15, 8))
bars = plt.bar(subject_counts.index, subject_counts.values, color='#30475E', width=0.6)
plt.xlabel('Subject Name', fontsize=12)
plt.ylabel('Number of Questions', fontsize=12)
plt.title('Frequency of Questions per Subject Category', fontsize=14)
plt.xticks(rotation=75, fontsize=12)
plt.yticks(fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.7)
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval, int(yval), va='bottom', ha='center', fontsize=10)
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 10))
colors_12 = ['#FF6F61', '#6B5B93', '#88B04B', '#F7CAC9', '#92A8D1', '#955251',
             '#B68D40', '#F6BDC0', '#6A0588', '#F1C40F', '#FF9F00', '#00BFFF']
explode_12 = (0.05,) * len(subject_counts)
plt.pie(subject_counts,
        labels=subject_counts.index,
        autopct='%1.1f%%',
        startangle=90,
        colors=colors_12,
        explode=explode_12,
        shadow=True,
        pctdistance=0.85)
plt.title('Percentage Distribution of Medical Subjects', fontsize=14)
plt.axis('equal')
plt.show()

# Dataset Class

In [None]:
class MedicalQADataset(Dataset):

    def __init__(self, dataframe, tokenizer, max_length, is_test=False):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_test = is_test 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        question = str(self.data.iloc[index]['question'])
        negation_feature = torch.tensor(self.data.iloc[index]['negation_feature'], dtype=torch.float)

        encoding = self.tokenizer.encode_plus(
            question,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        item = {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'negation_feature': negation_feature
        }

        if not self.is_test:
            item['subject_labels'] = torch.tensor(self.data.iloc[index]['subject_encoded'], dtype=torch.long)
            item['answer_labels'] = torch.tensor(self.data.iloc[index]['answer_labels'], dtype=torch.float)

        return item

# Model Architecture

In [None]:
class AttentionPooling(nn.Module):

    def __init__(self, hidden_size):
        super(AttentionPooling, self).__init__()
        self.attention_weights = nn.Linear(hidden_size, 1)

    def forward(self, hidden_states, attention_mask):
        attention_scores = self.attention_weights(hidden_states).squeeze(-1)
        if attention_mask is not None:
            attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e9)
        attention_weights = F.softmax(attention_scores, dim=1) 
        pooled_output = torch.sum(hidden_states * attention_weights.unsqueeze(-1), dim=1)
        return pooled_output

class BRT_Cell(nn.Module):

    def __init__(self, bert_hidden_size, num_states, state_dim, num_heads=8):
        super(BRT_Cell, self).__init__()
        self.num_states = num_states
        self.state_dim = state_dim
        
        self.state_vectors = nn.Parameter(torch.randn(num_states, state_dim))
        
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=state_dim, 
            num_heads=num_heads, 
            kdim=bert_hidden_size, 
            vdim=bert_hidden_size, 
            batch_first=True
        )
        
        self.self_attn = nn.MultiheadAttention(
            embed_dim=state_dim, 
            num_heads=num_heads, 
            batch_first=True
        )
        
        combined_dim = state_dim * 2
        self.forget_gate = nn.Linear(combined_dim, state_dim)
        self.input_gate = nn.Linear(combined_dim, state_dim)
        self.candidate_gate = nn.Linear(combined_dim, state_dim)

    def forward(self, hidden_states, attention_mask):

        batch_size = hidden_states.size(0)
        
        states = self.state_vectors.unsqueeze(0).expand(batch_size, -1, -1)
        
        padding_mask = (attention_mask == 0)
        
        cross_attn_out, _ = self.cross_attn(
            query=states, 
            key=hidden_states, 
            value=hidden_states, 
            key_padding_mask=padding_mask
        )
        
        self_attn_out, _ = self.self_attn(
            query=cross_attn_out, 
            key=cross_attn_out, 
            value=cross_attn_out
        )
        
        combined_features = torch.cat((cross_attn_out, self_attn_out), dim=-1)
        
        f = torch.sigmoid(self.forget_gate(combined_features))
        i = torch.sigmoid(self.input_gate(combined_features))
        c = torch.tanh(self.candidate_gate(combined_features))
        
        new_states = (states * f) + (c * i)
        
        pooled_output = new_states.mean(dim=1)
        return pooled_output

class MedicalQAModel(nn.Module):

    def __init__(self, n_classes, gru_hidden_size=128, dropout_rate=0.25, 
                 use_bigru=True, use_negation=True, use_brt=False):
        
        super(MedicalQAModel, self).__init__()
        self.use_bigru = use_bigru
        self.use_negation = use_negation
        self.use_brt = use_brt
        
        self.bert = BertModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
        bert_hidden_size = self.bert.config.hidden_size
        
        if self.use_brt:
            print("--- Initializing model with BRT_Cell ---")
            brt_state_dim = 256
            self.brt_cell = BRT_Cell(
                bert_hidden_size=bert_hidden_size,
                num_states=64,
                state_dim=brt_state_dim
            )
            current_features_size = brt_state_dim
            
        elif self.use_bigru:
            print("--- Initializing model with BiGRU ---")
            self.bigru = nn.GRU(
                input_size=bert_hidden_size,
                hidden_size=gru_hidden_size,
                num_layers=1,
                bidirectional=True,
                batch_first=True
            )
            self.attention_pooling = AttentionPooling(hidden_size=gru_hidden_size * 2)
            current_features_size = gru_hidden_size * 2 # 256
            
        else:
            print("--- Initializing model with Attention Pooling Only ---")
            self.attention_pooling = AttentionPooling(hidden_size=bert_hidden_size)
            current_features_size = bert_hidden_size
            
        classifier_input_size = current_features_size
        if self.use_negation:
            classifier_input_size += 1 
            
        self.dropout = nn.Dropout(dropout_rate)
        
        self.subject_classifier = nn.Linear(classifier_input_size, n_classes)
        self.answer_classifier = nn.Linear(classifier_input_size, 4)

    def forward(self, input_ids, attention_mask, negation_feature):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state 

        if self.use_brt:
            pooled_output = self.brt_cell(hidden_state, attention_mask)
        elif self.use_bigru:
            features, _ = self.bigru(hidden_state)
            pooled_output = self.attention_pooling(features, attention_mask)
        else:
            pooled_output = self.attention_pooling(hidden_state, attention_mask)
        
        pooled_output = self.dropout(pooled_output)
        
        if self.use_negation:
            negation_f = negation_feature.unsqueeze(1)
            combined_output = torch.cat((pooled_output, negation_f), dim=1)
        else:
            combined_output = pooled_output

        subject_logits = self.subject_classifier(combined_output)
        answer_logits = self.answer_classifier(combined_output)

        return subject_logits, answer_logits

# Hyperparameters & DataLoaders

In [None]:
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 2e-5
MAX_LENGTH = 64
FOCAL_ALPHA = 0.25
FOCAL_GAMMA = 2
ANSWER_THRESHOLD = 0.3 # As reported in the paper
NUM_CLASSES = len(CATEGORIES)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


train_dataset = MedicalQADataset(train_df, tokenizer, MAX_LENGTH, is_test=False)
dev_dataset = MedicalQADataset(dev_df, tokenizer, MAX_LENGTH, is_test=False)
test_dataset = MedicalQADataset(test_df, tokenizer, MAX_LENGTH, is_test=False) 

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"DataLoaders created with Batch Size: {BATCH_SIZE}")

# Loss, Optimizer, Scheduler

In [None]:
class FocalLoss(nn.Module):

    def __init__(self, alpha=FOCAL_ALPHA, gamma=FOCAL_GAMMA, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

loss_fn_subject = FocalLoss().to(device)
loss_fn_answers = nn.BCEWithLogitsLoss().to(device) 

model = MedicalQAModel(
    n_classes=NUM_CLASSES,
    use_bigru=True,
    use_negation=True
).to(device)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

# Training & Evaluation Functions

In [None]:
def train_epoch(model, data_loader, loss_fn_subject, loss_fn_answers, optimizer, device, scheduler):

    model.train()
    total_loss = 0
    total_subject_correct = 0
    total_answer_correct = 0
    total_samples = 0

    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        negation_feature = batch['negation_feature'].to(device)
        subject_labels = batch['subject_labels'].to(device)
        answer_labels = batch['answer_labels'].to(device)

        optimizer.zero_grad()

        subject_logits, answer_logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            negation_feature=negation_feature
        )

        loss_subject = loss_fn_subject(subject_logits, subject_labels)
        loss_answers = loss_fn_answers(answer_logits, answer_labels)
        loss = loss_subject + loss_answers #

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

        subject_preds = torch.argmax(subject_logits, dim=1)
        total_subject_correct += (subject_preds == subject_labels).sum().item()

        answer_preds = (torch.sigmoid(answer_logits) > 0.5).float()
        total_answer_correct += (answer_preds == answer_labels).sum().item()

        total_samples += subject_labels.size(0)

    avg_loss = total_loss / len(data_loader)
    subject_accuracy = total_subject_correct / total_samples
    answer_accuracy = total_answer_correct / (total_samples * 4)

    return avg_loss, subject_accuracy, answer_accuracy


def eval_model(model, data_loader, loss_fn_subject, loss_fn_answers, device, threshold=ANSWER_THRESHOLD):

    model.eval()
    total_loss = 0
    total_subject_correct = 0
    total_answer_correct = 0
    total_samples = 0

    all_subject_preds = []
    all_subject_labels = []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            negation_feature = batch['negation_feature'].to(device)
            subject_labels = batch['subject_labels'].to(device)
            answer_labels = batch['answer_labels'].to(device)

            subject_logits, answer_logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                negation_feature=negation_feature
            )

            loss_subject = loss_fn_subject(subject_logits, subject_labels)
            loss_answers = loss_fn_answers(answer_logits, answer_labels)
            loss = loss_subject + loss_answers
            total_loss += loss.item()

            subject_preds = torch.argmax(subject_logits, dim=1)
            total_subject_correct += (subject_preds == subject_labels).sum().item()

            answer_preds = (torch.sigmoid(answer_logits) > threshold).float()
            total_answer_correct += (answer_preds == answer_labels).sum().item()

            total_samples += subject_labels.size(0)

            all_subject_preds.extend(subject_preds.cpu().numpy())
            all_subject_labels.extend(subject_labels.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    subject_accuracy = total_subject_correct / total_samples
    answer_accuracy = total_answer_correct / (total_samples * 4)

    subject_precision = precision_score(all_subject_labels, all_subject_preds, average='weighted', zero_division=0)
    subject_recall = recall_score(all_subject_labels, all_subject_preds, average='weighted', zero_division=0)
    subject_f1 = f1_score(all_subject_labels, all_subject_preds, average='weighted', zero_division=0)

    return avg_loss, subject_accuracy, answer_accuracy, subject_precision, subject_recall, subject_f1

# Define Save Path

In [None]:
BASE_SAVE_PATH = "model_outputs" 

if not os.path.exists(BASE_SAVE_PATH):
    os.makedirs(BASE_SAVE_PATH)
    print(f"Created directory: {BASE_SAVE_PATH}")
else:
    print(f"Directory already exists: {BASE_SAVE_PATH}")

# Main Training Loop

In [None]:
def run_training_experiment(model, optimizer, scheduler, experiment_name):

    print(f"\n{'='*20} STARTING EXPERIMENT: {experiment_name} {'='*20}")

    EXPERIMENT_PATH = os.path.join(BASE_SAVE_PATH, experiment_name)
    if not os.path.exists(EXPERIMENT_PATH):
        os.makedirs(EXPERIMENT_PATH)

    BEST_MODEL_PATH = os.path.join(EXPERIMENT_PATH, 'best_model_state.bin')
    HISTORY_PATH = os.path.join(EXPERIMENT_PATH, 'training_history.csv')

    print(f"Best model will be saved to: {BEST_MODEL_PATH}")

    best_dev_f1 = 0.0
    history = {
        'train_loss': [], 'train_subj_acc': [], 'train_ans_acc': [],
        'dev_loss': [], 'dev_subj_acc': [], 'dev_ans_acc': [], 'dev_subj_f1': []
    }

    patience = 3
    epochs_no_improve = 0

    for epoch in range(EPOCHS):
        start_time = time.time()

        train_loss, train_subj_acc, train_answer_acc = train_epoch(
            model, train_loader, loss_fn_subject, loss_fn_answers, optimizer, device, scheduler
        )

        dev_loss, dev_subj_acc, dev_answer_acc, dev_subj_prec, dev_subj_rec, dev_subj_f1 = eval_model(
            model, dev_loader, loss_fn_subject, loss_fn_answers, device, threshold=ANSWER_THRESHOLD
        )

        end_time = time.time()
        epoch_mins, epoch_secs = divmod(end_time - start_time, 60)

        history['train_loss'].append(train_loss)
        history['train_subj_acc'].append(train_subj_acc)
        history['train_ans_acc'].append(train_answer_acc)
        history['dev_loss'].append(dev_loss)
        history['dev_subj_acc'].append(dev_subj_acc)
        history['dev_ans_acc'].append(dev_answer_acc)
        history['dev_subj_f1'].append(dev_subj_f1)

        print(f"\n--- Epoch {epoch + 1}/{EPOCHS} | Time: {int(epoch_mins)}m {int(epoch_secs)}s ---")
        print(f"\tTrain Loss: {train_loss:.4f} | Train Subj Acc: {train_subj_acc:.2%} | Train Ans Acc: {train_answer_acc:.2%}")
        print(f"\tDev Loss:   {dev_loss:.4f} | Dev Subj Acc:   {dev_subj_acc:.2%} | Dev Ans Acc:   {dev_answer_acc:.2%}")
        print(f"\tDev Subject F1: {dev_subj_f1:.4f} (P: {dev_subj_prec:.4f}, R: {dev_subj_rec:.4f})")

        if dev_subj_f1 > best_dev_f1:
            best_dev_f1 = dev_subj_f1
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            print(f"\t*** New best model saved to Drive with F1: {best_dev_f1:.4f} ***")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            print(f"\tNo improvement in F1 for {epochs_no_improve} epoch(s). Patience: {patience}")

        if epochs_no_improve >= patience:
            print(f"\n--- EARLY STOPPING triggered after {epoch + 1} epochs. ---")
            break

    print(f"\nTraining complete for experiment: {experiment_name}")
    print(f"Best model state (F1: {best_dev_f1:.4f}) saved to: {BEST_MODEL_PATH}")

    try:
        history_df = pd.DataFrame(history)
        history_df.to_csv(HISTORY_PATH, index=False)
        print(f"Training history saved to {HISTORY_PATH}")
    except Exception as e:
        print(f"Error saving training history: {e}")

    return BEST_MODEL_PATH

# Analysis & Visualization Helpers

In [None]:
def display_predictions(df, y_pred_subj, y_pred_ans, y_prob_ans, label_encoder, num_samples=5):

    print("\n" + "="*30)
    print("  QUALITATIVE PREDICTION SAMPLES")
    print("="*30 + "\n")

    df_reset = df.reset_index(drop=True)

    if len(df_reset) > num_samples:
        sample_indices = np.random.choice(len(df_reset), num_samples, replace=False)
    else:
        sample_indices = range(len(df_reset))

    for i in sample_indices:
        row = df_reset.iloc[i]
        print(f"\n--- Sample ID: {row.get('id', 'N/A')} ---")
        print(f"Question: {row['question']}")

        true_subject = row['subject_name']
        pred_subject = label_encoder.inverse_transform([y_pred_subj[i]])[0]
        subj_status = "✓" if true_subject == pred_subject else "✗"
        print(f"Subject: [True: {true_subject}] | [Pred: {pred_subject}] {subj_status}")

        true_answers = row['answer_labels']
        pred_answers = y_pred_ans[i]
        prob_answers = y_prob_ans[i]

        options = ['A', 'B', 'C', 'D']
        for j, opt in enumerate(options):
            opt_text = row[f'op{opt.lower()}']
            status = "✓" if int(true_answers[j]) == int(pred_answers[j]) else "✗"
            pred_mark = f"Selected (Prob: {prob_answers[j]:.3f})" if int(pred_answers[j]) == 1 else "Not Selected"
            true_mark = "Correct" if int(true_answers[j]) == 1 else "Incorrect"
            print(f"  Opt {opt}: {opt_text}")
            print(f"    Prediction: {pred_mark} | Actual: {true_mark} | Status: {status}")

    print("\n" + "-"*80)


def plot_subject_roc(y_true, y_pred, labels, save_path=None):

    print("Generating Subject Classification ROC Curve")

    y_true_bin = label_binarize(y_true, classes=range(len(labels)))
    y_pred_bin = label_binarize(y_pred, classes=range(len(labels)))
    n_classes = y_true_bin.shape[1]

    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    plt.figure(figsize=(12, 9))

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_bin[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        plt.plot(fpr[i], tpr[i], lw=2, label=f'{labels[i]} (AUC = {roc_auc[i]:.2f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title('Subject Classification ROC Curve (Per-Class)', fontsize=14)
    plt.legend(loc="lower right")
    plt.grid(True)

    if save_path:
        plt.savefig(save_path)
        print(f"ROC curve saved to {save_path}")
    plt.show()

# Final Test Set Evaluation Functions

In [None]:
def get_test_predictions_and_labels(model, data_loader, device, threshold=ANSWER_THRESHOLD):

    model.eval()

    all_subject_preds = []
    all_subject_labels = []
    all_answer_preds = []
    all_answer_labels = []
    all_answer_probs = []

    total_inference_time = 0
    total_samples = 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            negation_feature = batch['negation_feature'].to(device)
            subject_labels = batch['subject_labels'].to(device)
            answer_labels = batch['answer_labels'].to(device)

            start_batch_time = time.time()

            subject_logits, answer_logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                negation_feature=negation_feature
            )

            end_batch_time = time.time()
            total_inference_time += (end_batch_time - start_batch_time)
            total_samples += input_ids.size(0)

            subject_preds = torch.argmax(subject_logits, dim=1)
            all_subject_preds.extend(subject_preds.cpu().numpy())
            all_subject_labels.extend(subject_labels.cpu().numpy())

            answer_probs_batch = torch.sigmoid(answer_logits)
            answer_preds = (answer_probs_batch > threshold).float()

            all_answer_preds.extend(answer_preds.cpu().numpy())
            all_answer_labels.extend(answer_labels.cpu().numpy())
            all_answer_probs.extend(answer_probs_batch.cpu().numpy())

    avg_inference_time_ms = (total_inference_time / total_samples) * 1000

    return (
        np.array(all_subject_labels), np.array(all_subject_preds),
        np.array(all_answer_labels), np.array(all_answer_preds),
        np.array(all_answer_probs),
        avg_inference_time_ms
    )


def run_final_evaluation(model, best_model_path, experiment_name, labels, test_df_for_display):

    print(f"\n{'='*20} STARTING FINAL EVALUATION: {experiment_name} {'='*20}")

    EXPERIMENT_PATH = os.path.join(BASE_SAVE_PATH, experiment_name)

    try:
        model.load_state_dict(torch.load(best_model_path))
        print(f"Successfully loaded best model from {best_model_path}")
    except Exception as e:
        print(f"Error loading model state: {e}. Aborting evaluation.")
        return

    print("Running predictions on the test set...")
    y_true_subj, y_pred_subj, y_true_ans, y_pred_ans, y_prob_ans, avg_inference_time_ms = get_test_predictions_and_labels(
        model,
        test_loader,
        device,
        threshold=ANSWER_THRESHOLD
    )

    print("\n" + "="*30)
    print("  FINAL TEST SET METRICS")
    print("="*30 + "\n")

    report_lines = [] 

    report_lines.append("--- Practical Deployment Metrics ---")
    report_lines.append(f"Average Inference Time per Sample: {avg_inference_time_ms:.2f} ms")
    report_lines.append("\n")

    overall_subject_accuracy = np.mean(y_true_subj == y_pred_subj)
    subject_precision = precision_score(y_true_subj, y_pred_subj, average='weighted', zero_division=0)
    subject_recall = recall_score(y_true_subj, y_pred_subj, average='weighted', zero_division=0)
    subject_f1 = f1_score(y_true_subj, y_pred_subj, average='weighted', zero_division=0)

    report_lines.append("--- Overall Subject Classification ---")
    report_lines.append(f"Overall Subject Accuracy: {overall_subject_accuracy:.4f}")
    report_lines.append(f"Overall Subject Precision: {subject_precision:.4f}")
    report_lines.append(f"Overall Subject Recall: {subject_recall:.4f}")
    report_lines.append(f"Overall Subject F1-Score: {subject_f1:.4f}")
    report_lines.append("\n")

    answer_accuracy = np.mean(y_true_ans == y_pred_ans)
    report_lines.append("--- Overall Answer Prediction Metrics ---")
    report_lines.append(f"Overall Answer Accuracy (Label-Based): {answer_accuracy:.4f}")
    report_lines.append("\n")

    report_lines.append("--- Per-Class Subject Classification Report ---")
    class_report = classification_report(y_true_subj, y_pred_subj, target_names=labels, zero_division=0)
    report_lines.append(class_report)
    report_lines.append("\n")

    for line in report_lines:
        print(line)

    report_path = os.path.join(EXPERIMENT_PATH, 'final_metrics_report.txt')
    try:
        with open(report_path, 'w') as f:
            f.write("\n".join(report_lines))
        print(f"\nFinal metrics report saved to {report_path}")
    except Exception as e:
        print(f"Error saving metrics report to Drive: {e}")

    print("\nGenerating Confusion Matrix...")
    try:
        conf_matrix = confusion_matrix(y_true_subj, y_pred_subj)
        plt.figure(figsize=(12, 10))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                    xticklabels=labels, yticklabels=labels)
        plt.title(f'Confusion Matrix - {experiment_name}')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        cm_path = os.path.join(EXPERIMENT_PATH, 'confusion_matrix.png')
        plt.savefig(cm_path)
        plt.show()
        print(f"Confusion matrix saved to {cm_path}")
    except Exception as e:
        print(f"Error generating confusion matrix: {e}")

    try:
        roc_path = os.path.join(EXPERIMENT_PATH, 'roc_curve.png')
        plot_subject_roc(y_true_subj, y_pred_subj, labels=labels, save_path=roc_path)
    except Exception as e:
        print(f"Error generating ROC curve: {e}")

    try:
        display_predictions(
            test_df_for_display, y_pred_subj, y_pred_ans, y_prob_ans, label_encoder, num_samples=5
        )
    except Exception as e:
        print(f"Error displaying qualitative samples: {e}")

# Run Ablation Study Experiments

In [None]:
total_steps = len(train_loader) * EPOCHS

print("--- LAUNCHING EXPERIMENT 1: Base_Model_BioBERT_Attention ---")
model_base = MedicalQAModel(
    n_classes=NUM_CLASSES,
    use_bigru=False,
    use_negation=False,
    use_brt=False
).to(device)

optimizer_base = AdamW(model_base.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler_base = get_linear_schedule_with_warmup(optimizer_base, 0, total_steps)

best_model_path_base = run_training_experiment(
    model_base, optimizer_base, scheduler_base, "Base_Model_BioBERT_Attention"
)
print("--- FINISHED EXPERIMENT 1 ---")


print("\n--- LAUNCHING EXPERIMENT 2: BioBERT_BiGRU_Model_No_Feature ---")
model_bigru = MedicalQAModel(
    n_classes=NUM_CLASSES,
    use_bigru=True,
    use_negation=False,
    use_brt=False
).to(device)

optimizer_bigru = AdamW(model_bigru.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler_bigru = get_linear_schedule_with_warmup(optimizer_bigru, 0, total_steps)

best_model_path_bigru = run_training_experiment(
    model_bigru, optimizer_bigru, scheduler_bigru, "BioBERT_BiGRU_Model_No_Feature"
)
print("--- FINISHED EXPERIMENT 2 ---")


print("\n--- LAUNCHING EXPERIMENT 3: BioBERT_BiGRU_Model_with_Negation ---")
model_bigru_neg = MedicalQAModel(
    n_classes=NUM_CLASSES,
    use_bigru=True,
    use_negation=True,
    use_brt=False
).to(device)

optimizer_bigru_neg = AdamW(model_bigru_neg.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler_bigru_neg = get_linear_schedule_with_warmup(optimizer_bigru_neg, 0, total_steps)

best_model_path_bigru_neg = run_training_experiment(
    model_bigru_neg, optimizer_bigru_neg, scheduler_bigru_neg, "BioBERT_BiGRU_Model_with_Negation"
)
print("--- FINISHED EXPERIMENT 3 ---")


print("\n--- LAUNCHING EXPERIMENT 4: BRT_Model_with_Negation ---")
model_brt_neg = MedicalQAModel(
    n_classes=NUM_CLASSES,
    use_bigru=False,
    use_negation=True,
    use_brt=True 
).to(device)

optimizer_brt_neg = AdamW(model_brt_neg.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler_brt_neg = get_linear_schedule_with_warmup(optimizer_brt_neg, 0, total_steps)

best_model_path_brt_neg = run_training_experiment(
    model_brt_neg, optimizer_brt_neg, scheduler_brt_neg, "BRT_Model_with_Negation"
)
print("--- FINISHED EXPERIMENT 4 ---")


print("\n\n" + "="*30)
print("  ALL EXPERIMENTS COMPLETE. RUNNING FINAL EVALUATIONS ON TEST SET...")
print("="*30 + "\n")


print("\n--- Evaluating Base_Model_BioBERT_Attention ---")
model_base_eval = MedicalQAModel(n_classes=NUM_CLASSES, use_bigru=False, use_negation=False, use_brt=False).to(device)
run_final_evaluation(
    model_base_eval,
    best_model_path_base,
    "Base_Model_BioBERT_Attention",
    labels=label_encoder.classes_,
    test_df_for_display=test_df
)

print("\n--- Evaluating BioBERT-BiGRU_Model_No_Feature ---")
model_bigru_eval = MedicalQAModel(n_classes=NUM_CLASSES, use_bigru=True, use_negation=False, use_brt=False).to(device)
run_final_evaluation(
    model_bigru_eval,
    best_model_path_bigru,
    "BioBERT_BiGRU_Model_No_Feature",
    labels=label_encoder.classes_,
    test_df_for_display=test_df
)

print("\n--- Evaluating BioBERT-BiGRU_Model_with_Negation ---")
model_bigru_neg_eval = MedicalQAModel(n_classes=NUM_CLASSES, use_bigru=True, use_negation=True, use_brt=False).to(device)
run_final_evaluation(
    model_bigru_neg_eval,
    best_model_path_bigru_neg,
    "BioBERT_BiGRU_Model_with_Negation",
    labels=label_encoder.classes_,
    test_df_for_display=test_df
)

print("\n--- Evaluating BRT_Model_with_Negation ---")
model_brt_neg_eval = MedicalQAModel(n_classes=NUM_CLASSES, use_bigru=False, use_negation=True, use_brt=True).to(device)
run_final_evaluation(
    model_brt_neg_eval,
    best_model_path_brt_neg,
    "BRT_Model_with_Negation",
    labels=label_encoder.classes_,
    test_df_for_display=test_df
)

print("\n\n--- FULL REVISION EXPERIMENTS & FINAL EVALUATION COMPLETE ---")