In [None]:
!pip install pytorch-lightning transformers torch

In [None]:
# full_pipeline.py
import os
import random
import joblib

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, Subset
import pytorch_lightning as pl

from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModel

# ─── 1. LOAD & PREPARE TABULAR DATA ──────────────────────────────────────────────

DATA_PATH = "/content/cleaned_lab_data_withNewTests.csv"  # your CSV
df = pd.read_csv(DATA_PATH)

# 1a. Label
y = df["ICD_Label"].values

# 1b. Lab continuous + missing flags
lab_cont_features = [
    'ALT (SGPT)','AST (SGOT)','Bilirubin','Albumin','Platelet Count',
    'Total Cholesterol','BP Systolic','BP Diastolic','Troponin','Ejection Fraction',
    'HbA1c','Fasting Glucose','Postprandial Glucose','Triglycerides','Insulin Level',
    'WBC Count','Fever','Hematocrit','Hemoglobin','Ferritin','CRP','Widal_Test',
    'Chest_Xray_Result','Age','Height_cm'
]
lab_missing_flags = [c for c in df.columns if c.endswith('_missing')]
lab_cont_features += lab_missing_flags

# Standardize lab continuous
scaler = StandardScaler()
df[lab_cont_features] = scaler.fit_transform(df[lab_cont_features])
joblib.dump(scaler, "lab_scaler.pkl")

# 1c. Encode categorical metadata
cat_features = ["Gender","Occupation","Region"]
ord_enc = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1)
df[cat_features] = ord_enc.fit_transform(df[cat_features])
joblib.dump(ord_enc, "cat_encoder.pkl")

# ─── 2. COMPUTE BERT EMBEDDINGS FOR TEXT ─────────────────────────────────────────

MODEL_NAME = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert = AutoModel.from_pretrained(MODEL_NAME).eval().to(device)


def embed_texts(texts, batch_size=16, max_length=128):
    all_embs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        enc = tokenizer(batch, padding=True, truncation=True,
                        max_length=max_length, return_tensors="pt").to(device)
        with torch.no_grad():
            out = bert(**enc)
        pooled = out.last_hidden_state.mean(dim=1).cpu().numpy()
        all_embs.append(pooled)
    return np.vstack(all_embs)


sym_emb = embed_texts(df["Symptoms"].tolist())
med_emb = embed_texts(df["Medications"].tolist())
exp_emb = embed_texts(df["Exposure History"].tolist())

conv_cont_emb = np.concatenate([sym_emb, med_emb, exp_emb], axis=1)

# ─── 3. DATASET & DATALOADERS ──────────────────────────────────────────────────

class HealthcareDataset(Dataset):
    def __init__(self, lab_cont, lab_cat, conv_cont, labels):
        self.lab_cont = torch.tensor(lab_cont, dtype=torch.float32)
        self.lab_cat  = torch.tensor(lab_cat,  dtype=torch.long)
        self.conv_cont= torch.tensor(conv_cont, dtype=torch.float32)
        self.labels   = torch.tensor(labels,   dtype=torch.long)

    def __len__(self): return len(self.labels)
    def __getitem__(self, idx):
        return {
            "lab_cont":   self.lab_cont[idx],
            "lab_cat":    self.lab_cat[idx],
            "conv_cont":  self.conv_cont[idx],
            "label":      self.labels[idx]
        }

# Prepare inputs
lab_cont_array = df[lab_cont_features].values
lab_cat_array  = df[cat_features].values
labels_array   = df["ICD_Label"].values

# Stratified split of indices
idx = np.arange(len(df))
train_idx, temp_idx = train_test_split(idx, test_size=0.2, stratify=y, random_state=42)
val_idx, test_idx  = train_test_split(temp_idx, test_size=0.5, stratify=y[temp_idx], random_state=42)

train_ds = HealthcareDataset(lab_cont_array[train_idx],
                             lab_cat_array[train_idx],
                             conv_cont_emb[train_idx],
                             labels_array[train_idx])
val_ds   = HealthcareDataset(lab_cont_array[val_idx],
                             lab_cat_array[val_idx],
                             conv_cont_emb[val_idx],
                             labels_array[val_idx])
test_ds  = HealthcareDataset(lab_cont_array[test_idx],
                             lab_cat_array[test_idx],
                             conv_cont_emb[test_idx],
                             labels_array[test_idx])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=4)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False, num_workers=4)

# ─── 4. MODEL DEFINITION ───────────────────────────────────────────────────────

class ResidualBlock(nn.Module):
    def __init__(self, in_f, out_f, dropout=0.2):
        super().__init__()
        self.fc1 = nn.Linear(in_f, out_f)
        self.relu= nn.ReLU()
        self.fc2 = nn.Linear(out_f, out_f)
        self.drop= nn.Dropout(dropout)
        self.shortcut = nn.Linear(in_f, out_f) if in_f!=out_f else nn.Identity()

    def forward(self, x):
        out = self.drop(self.relu(self.fc1(x)))
        out = self.fc2(out)
        return self.relu(out + self.shortcut(x))

class DualEncoder(pl.LightningModule):
    def __init__(self, lab_cont_dim, lab_cat_dims, text_emb_dim, embedding_dim, num_classes, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()

        # Lab continuous path
        self.lab_enc = nn.Sequential(
            ResidualBlock(lab_cont_dim, 64),
            ResidualBlock(64, 64)
        ) if lab_cont_dim>0 else None

        # Lab categorical embeddings
        self.lab_cat_embs = nn.ModuleList([
            nn.Embedding(dim+1, embedding_dim) for dim in lab_cat_dims
        ])

        # Text path is already embedded: just a linear proj
        self.text_proj = nn.Linear(text_emb_dim, 64)

        # Fusion classifier
        total_dim = (64 if self.lab_enc else 0) + \
                    embedding_dim*len(lab_cat_dims) + \
                    64
        self.classifier = nn.Sequential(
            nn.Linear(total_dim, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, lab_cont, lab_cat, conv_cont):
        parts = []
        if self.lab_enc: parts.append(self.lab_enc(lab_cont))
        if self.lab_cat_embs:
            parts += [emb(lab_cat[:,i]) for i,emb in enumerate(self.lab_cat_embs)]
        text_feat = self.text_proj(conv_cont)
        parts.append(text_feat)
        x = torch.cat(parts, dim=1)
        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        logits = self(batch["lab_cont"], batch["lab_cat"], batch["conv_cont"])
        loss   = self.loss_fn(logits, batch["label"])
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        logits = self(batch["lab_cont"], batch["lab_cat"], batch["conv_cont"])
        loss   = self.loss_fn(logits, batch["label"])
        preds  = torch.argmax(logits, dim=1)
        acc    = (preds==batch["label"]).float().mean()
        self.log_dict({"val_loss": loss, "val_acc":acc}, prog_bar=True)

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        sch = {"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                   opt, mode="min", factor=0.5, patience=2),
               "monitor":"val_loss"}
        return {"optimizer":opt, "lr_scheduler":sch}


# ─── 5. TRAINING ────────────────────────────────────────────────────────────────

# dims
lab_cat_dims  = [len(ord_enc.categories_[i]) for i,_ in enumerate(cat_features)]
text_emb_dim  = conv_cont_emb.shape[1]
num_classes   = len(np.unique(y))

model = DualEncoder(
    lab_cont_dim = len(lab_cont_features),
    lab_cat_dims = lab_cat_dims,
    text_emb_dim = text_emb_dim,
    embedding_dim= 16,
    num_classes  = num_classes
)

trainer = pl.Trainer(max_epochs=15, accelerator="auto")
trainer.fit(model, train_loader, val_loader)

# Optionally evaluate on test set:
# trainer.test(model, test_loader)


In [None]:
# Save the trained model
import torch
torch.save(model.state_dict(), "dual_encoder_model_withNewTests.pth")
print("Model saved successfully.")

In [None]:
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
import numpy as np

def evaluate_model(model, dataloader, class_names=None):
    model.eval()
    device = next(model.parameters()).device

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            lab_cont = batch["lab_cont"].to(device)
            lab_cat  = batch["lab_cat"].to(device)
            conv_cont = batch["conv_cont"].to(device)
            labels = batch["label"].to(device)

            logits = model(lab_cont, lab_cat, conv_cont)
            preds = torch.argmax(logits, dim=1)

            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Generate classification report
    if class_names is None:
        class_names = [f"ICD-{i}" for i in sorted(all_labels.unique().tolist())]

    print("📊 Classification Report:")
    report_dict = classification_report(
        all_labels, all_preds, target_names=class_names, digits=3, output_dict=True
    )
    print(classification_report(all_labels, all_preds, target_names=class_names, digits=3))

    # Plot precision, recall, F1-score
    labels_clean = class_names
    precision = [report_dict[label]['precision'] for label in labels_clean]
    recall = [report_dict[label]['recall'] for label in labels_clean]
    f1 = [report_dict[label]['f1-score'] for label in labels_clean]

    x = np.arange(len(labels_clean))
    width = 0.25

    fig, ax = plt.subplots(figsize=(14, 6))
    ax.bar(x - width, precision, width, label='Precision', color='tab:blue')
    ax.bar(x, recall, width, label='Recall', color='tab:orange')
    ax.bar(x + width, f1, width, label='F1-score', color='tab:green')

    ax.set_ylabel('Score')
    ax.set_title('Per-Class Evaluation Metrics')
    ax.set_xticks(x)
    ax.set_xticklabels(labels_clean, rotation=45, ha='right')
    ax.set_ylim([0, 1.1])
    ax.legend()
    ax.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()
# Example label names
class_names = [f"ICD-{i}" for i in range(num_classes)]  # or provide actual label names
evaluate_model(model, val_loader, class_names)
