<a href="https://colab.research.google.com/github/2403A51L33/PfDS-PROJECT/blob/main/DEEP%20LEARNING%20ALGORITHMS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os, re, json, math, random, argparse, warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from collections import Counter

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

try:
    import shap
    SHAP_OK = True
except Exception:
    SHAP_OK = False

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

def get_args():
    p = argparse.ArgumentParser()
    p.add_argument("--csv", type=str, default="/mnt/data/realistic_drug_labels_side_effects.csv")
    p.add_argument("--model", type=str, default="fusion",
                   choices=["mlp_tab","text_cnn","bilstm_attn","transformer","tab_transformer","fusion"])
    p.add_argument("--epochs", type=int, default=5)
    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--max_vocab", type=int, default=30000)
    p.add_argument("--max_len", type=int, default=256)
    p.add_argument("--embed_dim", type=int, default=128)
    p.add_argument("--hidden", type=int, default=256)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--weight_decay", type=float, default=1e-4)
    p.add_argument("--dropout", type=float, default=0.2)
    p.add_argument("--text_encoder", type=str, default="transformer",
                   choices=["transformer","bilstm_attn","text_cnn"])  # used by fusion
    p.add_argument("--no_shap", action="store_true", help="Skip SHAP even if available")
    return p.parse_args([]) if "_file_" not in globals() else p.parse_args()

ARGS = get_args()
OUTDIR = "./dl_outputs"; os.makedirs(OUTDIR, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not os.path.exists(ARGS.csv):
    alt = "./realistic_drug_labels_side_effects.csv"
    if os.path.exists(alt):
        ARGS.csv = alt
    else:
        raise FileNotFoundError(f"CSV not found at {ARGS.csv} or {alt}")

df = pd.read_csv(ARGS.csv)

TEXT_COLS = [c for c in ["indications","side_effects","contraindications","warnings"] if c in df.columns]
NUM_COLS  = [c for c in ["dosage_mg","price_usd","approval_year"] if c in df.columns]
CAT_COLS  = [c for c in ["drug_class","administration_route","approval_status","manufacturer"] if c in df.columns]

for c in TEXT_COLS: df[c] = df[c].fillna("")
for c in NUM_COLS:  df[c] = df[c].astype(float)
for c in CAT_COLS:  df[c] = df[c].astype(str).fillna("UNK")

def _to_num(v):
    try: return float(v)
    except: return {"low":0,"mild":0,"moderate":1,"medium":1,"high":2,"severe":2}.get(str(v).lower().strip(), np.nan)

y_raw = df["side_effect_severity"]
if y_raw.dtype.kind in "ifu":
    q = np.quantile(y_raw, [0.33,0.66]);
    y = y_raw.apply(lambda v: 0 if v<=q[0] else (1 if v<=q[1] else 2)).astype(int).values
else:
    tmp = y_raw.apply(_to_num)
    if tmp.isna().mean()<0.5:
        q = np.quantile(tmp.fillna(tmp.median()), [0.33,0.66])
        y = tmp.fillna(tmp.median()).apply(lambda v: 0 if v<=q[0] else (1 if v<=q[1] else 2)).astype(int).values
    else:
        cats = {k:i for i,k in enumerate(sorted(y_raw.astype(str).unique()))}
        y = y_raw.astype(str).map(cats).values % 3  # crude
NUM_CLASSES = 3
CLASS_NAMES = ["low","moderate","high"]

def simple_tokenize(s):
    s = re.sub(r"[^A-Za-z0-9\s\-\_/\.]", " ", s.lower())
    return [t for t in s.split() if t]

full_text = (df[TEXT_COLS].apply(lambda r: " ".join(map(str, r.values)), axis=1)
             if TEXT_COLS else pd.Series([""]*len(df)))
tokens = [simple_tokenize(t) for t in full_text]

freq = Counter([w for ts in tokens for w in ts])
most = [w for w,_ in freq.most_common(ARGS.max_vocab-2)]
itos = ["<pad>","<unk>"] + most
stoi = {w:i for i,w in enumerate(itos)}

def encode(ts, max_len=ARGS.max_len):
    ids = [stoi.get(w,1) for w in ts][:max_len]
    if len(ids)<max_len: ids += [0]*(max_len-len(ids))
    return np.array(ids, dtype=np.int64)

X_text_ids = np.vstack([encode(ts) for ts in tokens]) if TEXT_COLS else np.zeros((len(df), ARGS.max_len), dtype=np.int64)

scaler = StandardScaler()
X_num = scaler.fit_transform(df[NUM_COLS].values) if NUM_COLS else np.zeros((len(df),0),dtype=np.float32)

cat_maps = []
cat_card = []
X_cat_ids = []
for c in CAT_COLS:
    vals = df[c].astype(str).values
    uniq = ["<unk>"] + sorted(list(set(vals)))
    m = {u:i for i,u in enumerate(uniq)}
    cat_maps.append(m)
    cat_card.append(len(uniq))
    X_cat_ids.append(np.array([m.get(v,0) for v in vals], dtype=np.int64))
X_cat_ids = np.stack(X_cat_ids, axis=1) if CAT_COLS else np.zeros((len(df),0), dtype=np.int64)

X_train_idx, X_test_idx = train_test_split(np.arange(len(df)), test_size=0.2, random_state=SEED, stratify=y)
def split_arr(arr): return arr[X_train_idx], arr[X_test_idx]
Xtr_text, Xte_text = split_arr(X_text_ids)
Xtr_num,  Xte_num  = split_arr(X_num)
Xtr_cat,  Xte_cat  = split_arr(X_cat_ids)
y_train,  y_test   = y[X_train_idx], y[X_test_idx]

class_counts = np.bincount(y_train, minlength=NUM_CLASSES)
weights = (class_counts.sum() / (class_counts + 1e-6))
weights = weights / weights.mean()
CLASS_WEIGHTS = torch.tensor(weights, dtype=torch.float32, device=DEVICE)

class DS(torch.utils.data.Dataset):
    def __init__(self, text_ids, num, cat, y):
        self.text_ids = text_ids; self.num = num; self.cat = cat; self.y = y
    def __len__(self): return len(self.y)
    def __getitem__(self, i):
        return (torch.tensor(self.text_ids[i],dtype=torch.long),
                torch.tensor(self.num[i],dtype=torch.float32),
                torch.tensor(self.cat[i],dtype=torch.long),
                torch.tensor(self.y[i],dtype=torch.long))
tr_loader = torch.utils.data.DataLoader(DS(Xtr_text,Xtr_num,Xtr_cat,y_train), batch_size=ARGS.batch_size, shuffle=True)
te_loader = torch.utils.data.DataLoader(DS(Xte_text,Xte_num,Xte_cat,y_test),  batch_size=ARGS.batch_size, shuffle=False)

class TabularMLP(nn.Module):
    def __init__(self, num_dim, cat_cards, emb_dim=32, hidden=256, dropout=0.2):
        super().__init__()
        self.cat_embs = nn.ModuleList([nn.Embedding(c, min(emb_dim, max(4,int(round(c**0.25))))) for c in cat_cards])
        cat_out = sum(e.embedding_dim for e in self.cat_embs)
        in_dim = num_dim + cat_out
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout)
        )
        self.out_dim = hidden
    def forward(self, x_num, x_cat):
        if len(self.cat_embs):
            embs = [emb(x_cat[:,i]) for i,emb in enumerate(self.cat_embs)]
            cat_e = torch.cat(embs, dim=1)
        else:
            cat_e = torch.zeros(x_num.size(0),0,device=x_num.device)
        x = torch.cat([x_num, cat_e], dim=1)
        return self.mlp(x)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0,max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0,d_model,2).float() * (-math.log(10000.0)/d_model))
        pe[:,0::2] = torch.sin(pos*div); pe[:,1::2] = torch.cos(pos*div)
        self.register_buffer('pe', pe.unsqueeze(0)) # (1,L,D)
    def forward(self, x):
        return x + self.pe[:,:x.size(1)]

class TextCNN(nn.Module):
    def __init__(self, vocab, d_model=128, num_classes=3, dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(vocab, d_model, padding_idx=0)
        self.convs = nn.ModuleList([nn.Conv1d(d_model, d_model, k) for k in [3,4,5]])
        self.fc = nn.Sequential(nn.Linear(d_model*3, d_model), nn.ReLU(), nn.Dropout(dropout))
        self.head = nn.Linear(d_model, num_classes)
    def forward(self, x_ids):
        x = self.emb(x_ids)
        x = x.transpose(1,2)
        feats = [F.max_pool1d(F.relu(conv(x)), kernel_size=conv(x).size(-1)).squeeze(-1) for conv in self.convs]
        h = torch.cat(feats, dim=1)
        h = self.fc(h)
        return h, self.head(h)

class BiLSTMAttn(nn.Module):
    def __init__(self, vocab, d_model=128, hidden=128, num_classes=3, dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(vocab, d_model, padding_idx=0)
        self.lstm = nn.LSTM(d_model, hidden, num_layers=1, batch_first=True, bidirectional=True)
        self.attn = nn.Linear(2*hidden, 1)
        self.fc   = nn.Sequential(nn.Linear(2*hidden, hidden), nn.ReLU(), nn.Dropout(dropout))
        self.head = nn.Linear(hidden, num_classes)
    def forward(self, x_ids):
        x = self.emb(x_ids)
        h, _ = self.lstm(x)
        a = torch.softmax(self.attn(h).squeeze(-1), dim=1)
        ctx = (h * a.unsqueeze(-1)).sum(1)
        z = self.fc(ctx)
        return z, self.head(z)

class SimpleTransformer(nn.Module):
    def __init__(self, vocab, d_model=128, nhead=4, num_layers=2, num_classes=3, dropout=0.2, max_len=2048):
        super().__init__()
        self.emb = nn.Embedding(vocab, d_model, padding_idx=0)
        self.pos = PositionalEncoding(d_model, max_len)
        layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=4*d_model, dropout=dropout, batch_first=True)
        self.enc = nn.TransformerEncoder(layer, num_layers=num_layers)
        self.fc  = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Dropout(dropout))
        self.head= nn.Linear(d_model, num_classes)
    def forward(self, x_ids):
        x = self.emb(x_ids)
        x = self.pos(x)
        mask = (x_ids==0)
        h = self.enc(x, src_key_padding_mask=mask)
        pooled = (h.masked_fill(mask.unsqueeze(-1), 0).sum(1) /
                  (~mask).sum(1).clamp(min=1).unsqueeze(-1))
        z = self.fc(pooled)
        return z, self.head(z)

class TabTransformer(nn.Module):
    def __init__(self, cat_cards, num_dim, d_model=128, nhead=4, num_layers=2, num_classes=3, dropout=0.2):
        super().__init__()
        self.embs = nn.ModuleList([nn.Embedding(c, d_model) for c in cat_cards])
        layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=4*d_model, dropout=dropout, batch_first=True)
        self.enc = nn.TransformerEncoder(layer, num_layers=num_layers)
        self.mlp = nn.Sequential(
            nn.Linear(d_model*len(cat_cards)+num_dim, 256), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(256, 256), nn.ReLU(), nn.Dropout(dropout))
        self.head = nn.Linear(256, num_classes)
    def forward(self, x_num, x_cat):
        if len(self.embs):
            toks = torch.stack([emb(x_cat[:,i]) for i,emb in enumerate(self.embs)], dim=1) # B,C,D
            h = self.enc(toks)                                      # B,C,D
            h = h.reshape(h.size(0), -1)                           # B, C*D
        else:
            h = torch.zeros(x_num.size(0),0,device=x_num.device)
        z = torch.cat([h, x_num], dim=1)
        z = self.mlp(z)
        return z, self.head(z)

class FusionNet(nn.Module):
    def __init__(self, vocab, num_dim, cat_cards, text_encoder="transformer", d_model=128, hidden=256, num_classes=3, dropout=0.2):
        super().__init__()
        if text_encoder == "transformer":
            self.text = SimpleTransformer(vocab, d_model=d_model, nhead=4, num_layers=2, num_classes=num_classes, dropout=dropout)
        elif text_encoder == "bilstm_attn":
            self.text = BiLSTMAttn(vocab, d_model=d_model, hidden=d_model, num_classes=num_classes, dropout=dropout)
        else:
            self.text = TextCNN(vocab, d_model=d_model, num_classes=num_classes, dropout=dropout)
        self.tab = TabularMLP(num_dim=num_dim, cat_cards=cat_cards, emb_dim=32, hidden=hidden, dropout=dropout)
        self.fuse = nn.Sequential(
            nn.Linear(d_model + self.tab.out_dim, hidden), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout)
        )
        self.head = nn.Linear(hidden, num_classes)
    def forward(self, x_ids, x_num, x_cat):
        tz, _ = self.text(x_ids)
        sz = self.tab(x_num, x_cat)
        z = torch.cat([tz, sz], dim=1)
        z = self.fuse(z)
        return z, self.head(z)

VOCAB_SIZE = len(itos)
NUM_DIM = Xtr_num.shape[1]
CAT_CARDS = cat_card

def make_model(name):
    if name=="mlp_tab":
        class MLPWrap(nn.Module):
            def __init__(self):
                super().__init__()
                self.base = TabularMLP(NUM_DIM, CAT_CARDS, emb_dim=32, hidden=ARGS.hidden, dropout=ARGS.dropout)
                self.head=nn.Linear(self.base.out_dim, NUM_CLASSES)
            def forward(self, ids, num, cat):
                z = self.base(num, cat)
                return z, self.head(z)
        return MLPWrap()
    if name=="text_cnn":      return TextCNN(VOCAB_SIZE, d_model=ARGS.embed_dim, num_classes=NUM_CLASSES, dropout=ARGS.dropout)
    if name=="bilstm_attn":   return BiLSTMAttn(VOCAB_SIZE, d_model=ARGS.embed_dim, hidden=ARGS.embed_dim, num_classes=NUM_CLASSES, dropout=ARGS.dropout)
    if name=="transformer":   return SimpleTransformer(VOCAB_SIZE, d_model=ARGS.embed_dim, nhead=4, num_layers=2, num_classes=NUM_CLASSES, dropout=ARGS.dropout, max_len=ARGS.max_len)
    if name=="tab_transformer":
        class TabWrap(nn.Module):
            def __init__(self):
                super().__init__()
                self.base = TabTransformer(CAT_CARDS, NUM_DIM, d_model=ARGS.embed_dim, nhead=4, num_layers=2, num_classes=NUM_CLASSES, dropout=ARGS.dropout)
                self.head = nn.Linear(self.base.mlp[3].out_features, NUM_CLASSES)
            def forward(self, ids, num, cat):
                z, logits = self.base(num, cat)
                return z, logits
        return TabWrap()
    if name=="fusion":
        class FusWrap(nn.Module):
            def __init__(self):
                super().__init__()
                self.base = FusionNet(VOCAB_SIZE, NUM_DIM, CAT_CARDS, text_encoder=ARGS.text_encoder, d_model=ARGS.embed_dim, hidden=ARGS.hidden, num_classes=NUM_CLASSES, dropout=ARGS.dropout)
            def forward(self, ids, num, cat):
                return self.base(ids, num, cat)
        return FusWrap()
    raise ValueError("Unknown model")

model = make_model(ARGS.model).to(DEVICE)
optim_ = optim.AdamW(model.parameters(), lr=ARGS.lr, weight_decay=ARGS.weight_decay)
criterion = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS)

def step_batch(batch, train=True):
    ids, num, cat, yy = [b.to(DEVICE) for b in batch]
    if ARGS.model in ["text_cnn","bilstm_attn","transformer"]:
        z, logits = model(ids)
    elif ARGS.model in ["tab_transformer","mlp_tab"]:
        z, logits = model(ids, num, cat)
    else:  # fusion
        z, logits = model(ids, num, cat)
    loss = criterion(logits, yy)
    if train:
        optim_.zero_grad(); loss.backward(); optim_.step()
    return loss.item(), logits.detach(), yy.detach()

def run_epoch(loader, train=True):
    model.train(mode=train)
    losses = []; all_logits=[]; all_y=[]
    for batch in loader:
        l, lo, yy = step_batch(batch, train=train)
        losses.append(l)
        all_logits.append(lo.cpu()); all_y.append(yy.cpu())
    logits = torch.cat(all_logits); ytrue = torch.cat(all_y)
    ypred = logits.argmax(1).numpy()
    report = classification_report(ytrue, ypred, target_names=CLASS_NAMES, output_dict=True, zero_division=0)
    cm = confusion_matrix(ytrue, ypred, labels=list(range(NUM_CLASSES))).tolist()
    return float(np.mean(losses)), report, cm, logits.numpy(), ytrue.numpy()

best_f1 = -1.0
for ep in range(1, ARGS.epochs+1):
    tr_loss, tr_rep, tr_cm, _, _ = run_epoch(tr_loader, train=True)
    te_loss, te_rep, te_cm, te_logits, te_y = run_epoch(te_loader, train=False)
    macro_f1 = te_rep["macro avg"]["f1-score"]
    print(f"Epoch {ep:02d} | train loss {tr_loss:.4f} | test loss {te_loss:.4f} | macroF1 {macro_f1:.3f}")
    if macro_f1 > best_f1:
        best_f1 = macro_f1
        torch.save(model.state_dict(), os.path.join(OUTDIR, f"model_{ARGS.model}.pt"))
        with open(os.path.join(OUTDIR, f"metrics_{ARGS.model}.json"), "w") as f:
            json.dump({"train":tr_rep,"test":te_rep,"confusion_matrix_test":te_cm}, f, indent=2)
        np.save(os.path.join(OUTDIR, f"probs_{ARGS.model}.npy"), F.softmax(torch.tensor(te_logits),dim=1).numpy())

def compute_saliency_text(ids_batch):
    ids = torch.tensor(ids_batch[:64], dtype=torch.long, device=DEVICE)
    if ARGS.model in ["text_cnn","bilstm_attn","transformer"]:
        emb_layer = model.emb if hasattr(model, "emb") else model.text.emb
        emb = emb_layer(ids); emb.retain_grad()
        if ARGS.model=="text_cnn":
            x = emb.transpose(1,2)
            feats = [F.relu(conv(x)) for conv in model.convs]
            pooled = [F.max_pool1d(f, f.size(-1)).squeeze(-1) for f in feats]
            h = torch.cat(pooled, dim=1)
            z = model.fc(h); logits = model.head(z)
        elif ARGS.model=="bilstm_attn":
            h,_ = model.lstm(emb); a = torch.softmax(model.attn(h).squeeze(-1), dim=1); ctx=(h*a.unsqueeze(-1)).sum(1)
            z = model.fc(ctx); logits = model.head(z)
        else:
            x = model.pos(emb); mask = (ids==0); h = model.enc(x, src_key_padding_mask=mask)
            pooled = (h.masked_fill(mask.unsqueeze(-1), 0).sum(1) / (~mask).sum(1).clamp(min=1).unsqueeze(-1))
            z = model.fc(pooled); logits = model.head(z)
        probs = F.softmax(logits, dim=1)
        top = probs.max(1).values.sum()
        top.backward()
        sal = (emb.grad*emb).abs().sum(-1).detach().cpu().numpy()  # B,L
        return sal
    return None

def compute_saliency_tab(num_batch):
    num = torch.tensor(num_batch[:128], dtype=torch.float32, device=DEVICE, requires_grad=True)
    ids_dummy = torch.zeros((num.size(0), ARGS.max_len), dtype=torch.long, device=DEVICE)
    cat_dummy = torch.zeros((num.size(0), Xtr_cat.shape[1]), dtype=torch.long, device=DEVICE)
    if ARGS.model in ["mlp_tab","tab_transformer"]:
        z, logits = model(ids_dummy, num, cat_dummy)
    else:
        z, logits = model(ids_dummy, num, cat_dummy)
    probs = F.softmax(logits, dim=1)
    top = probs.max(1).values.sum()
    top.backward()
    sal = (num.grad * num).abs().detach().cpu().numpy()
    return sal

try:
    ids_batch = Xte_text[:64]
    num_batch = Xte_num[:128] if NUM_DIM>0 else None
    s_text = compute_saliency_text(ids_batch) if TEXT_COLS else None
    s_tab  = compute_saliency_tab(num_batch) if (ARGS.model in ["mlp_tab","tab_transformer","fusion"] and NUM_DIM>0) else None
    np.savez(os.path.join(OUTDIR, f"saliency_{ARGS.model}.npz"), text=s_text, tab=s_tab)
except Exception as e:
    print("Saliency error:", e)

if SHAP_OK and not ARGS.no_shap:
    try:
        model.eval()
        def pred_fn(X_concat):
            B = X_concat.shape[0]
            ids = torch.tensor(Xte_text[:B], dtype=torch.long, device=DEVICE)
            num = torch.tensor(Xte_num[:B], dtype=torch.float32, device=DEVICE)
            cat = torch.tensor(Xte_cat[:B], dtype=torch.long, device=DEVICE)
            with torch.no_grad():
                if ARGS.model in ["text_cnn","bilstm_attn","transformer"]:
                    _, logits = model(ids)
                elif ARGS.model in ["tab_transformer","mlp_tab"]:
                    _, logits = model(ids, num, cat)
                else:
                    _, logits = model(ids, num, cat)
                return F.softmax(logits, dim=1).detach().cpu().numpy()
        bg = np.zeros((30, 10))  # dummy background; we rely on model internals for actual inputs
        expl = shap.KernelExplainer(pred_fn, bg)
        sample = np.zeros((20, 10))
        vals = expl.shap_values(sample, nsamples=100)
        np.save(os.path.join(OUTDIR, f"shap_values_{ARGS.model}.npy"), vals, allow_pickle=True)
    except Exception as e:
        print("SHAP error:", e)

print("Done. Artifacts in:", OUTDIR)