<a href="https://colab.research.google.com/github/2403A51L33/PfDS-PROJECT/blob/main/REINFORCEMENT%20LEARNING%20ALGORITHMS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import os
import random
import json
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

try:
    import shap
    SHAP_AVAILABLE = True
except Exception:
    SHAP_AVAILABLE = False

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

OUTDIR = "./rl_outputs"
os.makedirs(OUTDIR, exist_ok=True)

DATA_PATH = "/content/realistic_drug_labels_side_effects.csv"

if not os.path.exists(DATA_PATH):
    alt = "./realistic_drug_labels_side_effects.csv"
    if os.path.exists(alt):
        DATA_PATH = alt
    else:
        raise FileNotFoundError(f"Dataset not found at {DATA_PATH} or {alt}")

df = pd.read_csv(DATA_PATH)

def to_three_bins(x):
    try:
        val = float(x)
        return val
    except:
        s = str(x).strip().lower()
        mapping = {"low":0, "mild":0, "moderate":1, "medium":1, "high":2, "severe":2}
        return mapping.get(s, 1)

if pd.api.types.is_numeric_dtype(df["side_effect_severity"]):
    q = np.quantile(df["side_effect_severity"], [0.33, 0.66])
    def bin_numeric(v):
        if v <= q[0]: return 0
        if v <= q[1]: return 1
        return 2
    y = df["side_effect_severity"].apply(bin_numeric).astype(int).values
else:
    approx = df["side_effect_severity"].apply(to_three_bins).astype(float)
    q = np.quantile(approx, [0.33, 0.66])
    def bin_numeric(v):
        if v <= q[0]: return 0
        if v <= q[1]: return 1
        return 2
    y = approx.apply(bin_numeric).astype(int).values

num_classes = 3
class_names = ["low", "moderate", "high"]

text_cols = ["indications", "side_effects", "contraindications", "warnings"]
num_cols = ["dosage_mg", "price_usd", "approval_year"]
cat_cols = ["drug_class", "administration_route", "approval_status", "manufacturer"]

text_cols = [c for c in text_cols if c in df.columns]
num_cols = [c for c in num_cols if c in df.columns]
cat_cols = [c for c in cat_cols if c in df.columns]

for c in text_cols:
    df[c] = df[c].fillna("")

X_text = df[text_cols].apply(lambda r: " ".join([str(v) for v in r.values]), axis=1)

tfidf = TfidfVectorizer(max_features=5000, ngram_range=(1,2), min_df=2)
X_text_mat = tfidf.fit_transform(X_text)

X_num = df[num_cols].fillna(df[num_cols].median()) if num_cols else pd.DataFrame(index=df.index)
X_cat = df[cat_cols].fillna("UNK") if cat_cols else pd.DataFrame(index=df.index)

if not X_num.empty:
    scaler = StandardScaler()
    X_num_scaled = scaler.fit_transform(X_num.values)
else:
    X_num_scaled = np.zeros((len(df), 0))

if not X_cat.empty:
    X_cat_dummies = pd.get_dummies(X_cat, drop_first=True, dtype=np.float32)
else:
    X_cat_dummies = pd.DataFrame(index=df.index)

from scipy import sparse
X_other = np.hstack([X_num_scaled, X_cat_dummies.values]) if X_cat_dummies.shape[1] > 0 else X_num_scaled
X = sparse.hstack([X_text_mat, sparse.csr_matrix(X_other)], format="csr")

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=SEED, stratify=y
)

R = np.array([
    [ 1.0,      -0.2,   -0.5],
    [-0.2,       1.0,   -0.2],
    [-2.0,      -0.5,    1.2],
])

def batch_csr_to_torch(X_csr):
    X_coo = X_csr.tocoo()
    indices = torch.tensor(np.vstack((X_coo.row, X_coo.col)), dtype=torch.long)
    values = torch.tensor(X_coo.data, dtype=torch.float32)
    shape = torch.Size(X_coo.shape)
    return torch.sparse_coo_tensor(indices, values, shape).to_dense()

def sample_minibatch(X_csr, y_arr, batch_size=64):
    idx = np.random.randint(0, X_csr.shape[0], size=batch_size)
    Xb = X_csr[idx]
    yb = y_arr[idx]
    return batch_csr_to_torch(Xb), torch.tensor(yb, dtype=torch.long)

def compute_reward(y_true, a):
    return R[y_true, a]

def evaluate_policy(pred, y_true):
    cm = confusion_matrix(y_true, pred, labels=[0,1,2])
    report = classification_report(y_true, pred, target_names=class_names, output_dict=True)
    return cm, report

class EpsilonGreedyBandit:
    def __init__(self, n_actions=3, eps=0.1):
        self.nA = n_actions
        self.eps = eps
        self.counts = np.zeros(n_actions, dtype=int)
        self.values = np.zeros(n_actions, dtype=float)

    def select(self):
        if np.random.rand() < self.eps:
            return np.random.randint(self.nA)
        return int(np.argmax(self.values))

    def update(self, a, r):
        self.counts[a] += 1
        n = self.counts[a]
        self.values[a] += (r - self.values[a]) / n

class UCB1Bandit:
    def __init__(self, n_actions=3):
        self.nA = n_actions
        self.counts = np.zeros(n_actions, dtype=int)
        self.values = np.zeros(n_actions, dtype=float)
        self.t = 0

    def select(self):
        self.t += 1
        for a in range(self.nA):
            if self.counts[a] == 0:
                return a
        ucb = self.values + np.sqrt(2*np.log(self.t)/self.counts)
        return int(np.argmax(ucb))

    def update(self, a, r):
        self.counts[a] += 1
        n = self.counts[a]
        self.values[a] += (r - self.values[a]) / n

class ThompsonBandit:
    def __init__(self, n_actions=3):
        self.nA = n_actions
        self.mu = np.zeros(n_actions)
        self.lambda_prec = np.ones(n_actions)  # precision
        self.tau = 1.0

    def select(self):
        samples = np.random.normal(self.mu, 1.0/np.sqrt(self.lambda_prec))
        return int(np.argmax(samples))

    def update(self, a, r):
        self.lambda_prec[a] += self.tau
        self.mu[a] = (self.mu[a]*(self.lambda_prec[a]-self.tau) + r) / self.lambda_prec[a]

class LinUCB:
    def __init__(self, d, n_actions=3, alpha=1.0, l2=1.0):
        self.nA = n_actions
        self.alpha = alpha
        self.A = [l2 * np.eye(d) for _ in range(n_actions)]
        self.b = [np.zeros((d,)) for _ in range(n_actions)]

    def select(self, x):
        p = np.zeros(self.nA)
        for a in range(self.nA):
            A_inv = np.linalg.inv(self.A[a])
            theta = A_inv @ self.b[a]
            p[a] = theta @ x + self.alpha * np.sqrt(x @ A_inv @ x)
        return int(np.argmax(p))

    def update(self, x, a, r):
        self.A[a] += np.outer(x, x)
        self.b[a] += r * x

class LogisticTS:
    def __init__(self, d, n_actions=3, l2=1.0):
        self.nA = n_actions
        self.d = d
        self.l2 = l2
        self.W = np.zeros((n_actions, d))

    def _sigmoid(self, z):
        return 1/(1+np.exp(-z))

    def select(self, x):
        noise = np.random.normal(0, 0.1, size=self.W.shape)
        W_s = self.W + noise
        logits = W_s @ x
        return int(np.argmax(logits))

    def update(self, x, a, r):
        y = 1 if r > 0 else 0
        z = self.W[a] @ x
        p = self._sigmoid(z)
        grad = (y - p) * x - self.l2 * self.W[a]
        self.W[a] += 0.05 * grad

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )
    def forward(self, x):
        return self.net(x)

class DQNAgent:
    def __init__(self, in_dim, n_actions, gamma=0.99, lr=1e-3, eps=0.1):
        self.q = MLP(in_dim, n_actions, hidden=256)
        self.target = MLP(in_dim, n_actions, hidden=256)
        self.target.load_state_dict(self.q.state_dict())
        self.gamma = gamma
        self.optim = optim.Adam(self.q.parameters(), lr=lr)
        self.eps = eps
        self.nA = n_actions
        self.losses = []

    def act(self, x):
        if np.random.rand() < self.eps:
            return np.random.randint(self.nA)
        with torch.no_grad():
            q = self.q(x)
            return int(torch.argmax(q, dim=-1).item())

    def update(self, x, a, r, xn, done):
        q = self.q(x)[0, a]
        with torch.no_grad():
            qn = self.target(xn).max(dim=-1).values
            y = r + (0 if done else self.gamma * qn)
        loss = F.mse_loss(q, y)
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        self.losses.append(loss.item())

    def soft_update(self, tau=0.01):
        for t, s in zip(self.target.parameters(), self.q.parameters()):
            t.data.copy_((1 - tau) * t.data + tau * s.data)

class PolicyNet(nn.Module):
    def __init__(self, in_dim, n_actions, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )
    def forward(self, x):
        return F.log_softmax(self.net(x), dim=-1)

class ValueNet(nn.Module):
    def __init__(self, in_dim, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )
    def forward(self, x):
        return self.net(x)

class REINFORCEAgent:
    def __init__(self, in_dim, n_actions, lr=1e-3):
        self.policy = PolicyNet(in_dim, n_actions)
        self.optim = optim.Adam(self.policy.parameters(), lr=lr)

    def act(self, x):
        with torch.no_grad():
            logp = self.policy(x)
            p = torch.exp(logp)
            a = torch.multinomial(p, num_samples=1)
            return int(a.item())

    def update(self, x, a, G):
        logp = self.policy(x)[0, a]
        loss = -logp * G
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

class A2CAgent:
    def __init__(self, in_dim, n_actions, lr=1e-3, gamma=0.99):
        self.policy = PolicyNet(in_dim, n_actions)
        self.value = ValueNet(in_dim)
        self.op = optim.Adam(list(self.policy.parameters()) + list(self.value.parameters()), lr=lr)
        self.gamma = gamma

    def act(self, x):
        with torch.no_grad():
            logp = self.policy(x)
            p = torch.exp(logp)
            a = torch.multinomial(p, num_samples=1)
            return int(a.item())

    def update(self, x, a, r, xn, done):
        V = self.value(x)
        with torch.no_grad():
            Vn = self.value(xn)
            target = r + (0 if done else self.gamma * Vn)
            adv = target - V
        logp = self.policy(x)[0, a]
        actor_loss = -logp * adv.detach()
        critic_loss = F.mse_loss(V, target)
        loss = actor_loss + critic_loss
        self.op.zero_grad()
        loss.backward()
        self.op.step()

class BehaviorCloning:
    def __init__(self, in_dim, n_actions, lr=1e-3):
        self.net = MLP(in_dim, n_actions)
        self.op = optim.Adam(self.net.parameters(), lr=lr)

    def fit(self, X_t, y_t, steps=500, bs=64):
        N = X_t.shape[0]
        X_dense = batch_csr_to_torch(X_t)
        y_t = torch.tensor(y_t, dtype=torch.long)
        for _ in range(steps):
            idx = np.random.randint(0, N, size=bs)
            xb = X_dense[idx]
            yb = y_t[idx]
            logits = self.net(xb)
            loss = F.cross_entropy(logits, yb)
            self.op.zero_grad()
            loss.backward()
            self.op.step()
        return self

    def predict(self, X_csr):
        with torch.no_grad():
            Xd = batch_csr_to_torch(X_csr)
            logits = self.net(Xd)
            return logits.argmax(dim=-1).cpu().numpy()

class CQLlite(DQNAgent):
    def update(self, x, a, r, xn, done):
        q_all = self.q(x)
        q = q_all[0, a]
        with torch.no_grad():
            qn = self.target(xn).max(dim=-1).values
            y = r + (0 if done else self.gamma * qn)
        cql_penalty = 1e-3 * (q_all.pow(2).mean())
        loss = F.mse_loss(q, y) + cql_penalty
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

def train_non_contextual_bandits():
    y_tr = y_train
    bandits = {
        "eps_greedy": EpsilonGreedyBandit(n_actions=num_classes, eps=0.1),
        "ucb1": UCB1Bandit(n_actions=num_classes),
        "thompson": ThompsonBandit(n_actions=num_classes),
    }
    steps = min(5000, len(y_tr)*2)
    results = {}
    for name, agent in bandits.items():
        for _ in range(steps):
            a = agent.select()
            yt = int(y_tr[np.random.randint(0, len(y_tr))])
            r = compute_reward(yt, a)
            agent.update(a, r)
        best_action = int(np.argmax(agent.values if hasattr(agent, "values") else agent.mu))
        pred = np.full_like(y_test, best_action)
        cm, rep = evaluate_policy(pred, y_test)
        results[name] = {"confusion_matrix": cm.tolist(), "report": rep}
    return results

def train_contextual_bandits():
    Xtr_dense = batch_csr_to_torch(X_train)
    Xte_dense = batch_csr_to_torch(X_test)
    d = Xtr_dense.shape[1]

    linucb = LinUCB(d, n_actions=num_classes, alpha=0.5, l2=1.0)
    logts = LogisticTS(d, n_actions=num_classes, l2=1.0)

    T = min(2000, Xtr_dense.shape[0]*2)
    for _ in range(T):
        i = np.random.randint(0, Xtr_dense.shape[0])
        x = Xtr_dense[i].numpy()
        ytrue = int(y_train[i])
        a1 = linucb.select(x); r1 = compute_reward(ytrue, a1); linucb.update(x, a1, r1)
        a2 = logts.select(x);  r2 = compute_reward(ytrue, a2); logts.update(x, a2, r2)

    def eval_agent(agent, denseX):
        preds = []
        for i in range(denseX.shape[0]):
            x = denseX[i].numpy()
            preds.append(agent.select(x))
        return np.array(preds)

    pred_linucb = eval_agent(linucb, Xte_dense)
    pred_logts = eval_agent(logts, Xte_dense)

    results = {}
    for name, pred in [("linucb", pred_linucb), ("logistic_ts", pred_logts)]:
        cm, rep = evaluate_policy(pred, y_test)
        results[name] = {"confusion_matrix": cm.tolist(), "report": rep}
    return results

def train_deep_rl():
    Xtr_dense = batch_csr_to_torch(X_train)
    Xte_dense = batch_csr_to_torch(X_test)
    in_dim = Xtr_dense.shape[1]

    # DQN
    dqn = DQNAgent(in_dim, num_classes, lr=1e-3, eps=0.1)
    for _ in range(min(2000, Xtr_dense.shape[0]*2)):
        i = np.random.randint(0, Xtr_dense.shape[0])
        x = Xtr_dense[i].unsqueeze(0)
        ytrue = int(y_train[i])
        a = dqn.act(x)
        r = torch.tensor([compute_reward(ytrue, a)], dtype=torch.float32)
        xn = Xtr_dense[np.random.randint(0, Xtr_dense.shape[0])].unsqueeze(0)
        dqn.update(x, a, r, xn, done=True)
        dqn.soft_update(0.01)
    with torch.no_grad():
        q_logits = dqn.q(Xte_dense)
        pred_dqn = q_logits.argmax(dim=-1).cpu().numpy()
    cm_dqn, rep_dqn = evaluate_policy(pred_dqn, y_test)

    rei = REINFORCEAgent(in_dim, num_classes, lr=1e-3)
    for _ in range(min(2000, Xtr_dense.shape[0]*2)):
        i = np.random.randint(0, Xtr_dense.shape[0])
        x = Xtr_dense[i].unsqueeze(0)
        ytrue = int(y_train[i])
        a = rei.act(x)
        r = compute_reward(ytrue, a)
        G = torch.tensor([r], dtype=torch.float32)
        rei.update(x, a, G)
    with torch.no_grad():
        logp = rei.policy(Xte_dense)
        pred_rei = torch.exp(logp).argmax(dim=-1).cpu().numpy()
    cm_rei, rep_rei = evaluate_policy(pred_rei, y_test)

    a2c = A2CAgent(in_dim, num_classes, lr=1e-3)
    for _ in range(min(2000, Xtr_dense.shape[0]*2)):
        i = np.random.randint(0, Xtr_dense.shape[0])
        x = Xtr_dense[i].unsqueeze(0)
        ytrue = int(y_train[i])
        a = a2c.act(x)
        r = torch.tensor([compute_reward(ytrue, a)], dtype=torch.float32)
        xn = Xtr_dense[np.random.randint(0, Xtr_dense.shape[0])].unsqueeze(0)
        a2c.update(x, a, r, xn, done=True)
    with torch.no_grad():
        logp = a2c.policy(Xte_dense)
        pred_a2c = torch.exp(logp).argmax(dim=-1).cpu().numpy()
    cm_a2c, rep_a2c = evaluate_policy(pred_a2c, y_test)

    results = {
        "dqn": {"confusion_matrix": cm_dqn.tolist(), "report": rep_dqn},
        "reinforce": {"confusion_matrix": cm_rei.tolist(), "report": rep_rei},
        "a2c": {"confusion_matrix": cm_a2c.tolist(), "report": rep_a2c},
    }
    return results, rei.policy, Xtr_dense, Xte_dense

def train_offline_rl():
    Xtr_dense = batch_csr_to_torch(X_train)
    Xte_dense = batch_csr_to_torch(X_test)
    in_dim = Xtr_dense.shape[1]

    bc = BehaviorCloning(in_dim, num_classes, lr=1e-3)
    bc.fit(X_train, y_train, steps=750, bs=64)
    pred_bc = bc.predict(X_test)
    cm_bc, rep_bc = evaluate_policy(pred_bc, y_test)

    cql = CQLlite(in_dim, num_classes, lr=1e-3, eps=0.1)
    for _ in range(min(2000, Xtr_dense.shape[0]*2)):
        i = np.random.randint(0, Xtr_dense.shape[0])
        x = Xtr_dense[i].unsqueeze(0)
        ytrue = int(y_train[i])
        a = ytrue
        r = torch.tensor([compute_reward(ytrue, a)], dtype=torch.float32)
        xn = Xtr_dense[np.random.randint(0, Xtr_dense.shape[0])].unsqueeze(0)
        cql.update(x, a, r, xn, done=True)
        cql.soft_update(0.01)

    with torch.no_grad():
        logits = cql.q(Xte_dense)
        pred_cql = logits.argmax(dim=-1).cpu().numpy()

    cm_cql, rep_cql = evaluate_policy(pred_cql, y_test)

    results = {
        "behavior_cloning": {"confusion_matrix": cm_bc.tolist(), "report": rep_bc},
        "cql_lite": {"confusion_matrix": cm_cql.tolist(), "report": rep_cql},
    }
    return results

def explain_policy_with_shap(policy_net, X_background, X_explain, topK=10):
    if not SHAP_AVAILABLE:
        return None, None
    def f_predict(x_np):
        x_t = torch.tensor(x_np, dtype=torch.float32)
        with torch.no_grad():
            logp = policy_net(x_t)
            return torch.exp(logp).numpy()  # probabilities
    bg_idx = np.random.choice(X_background.shape[0], size=min(50, X_background.shape[0]), replace=False)
    bg = X_background[bg_idx].numpy()
    explainer = shap.KernelExplainer(f_predict, bg, link="identity")
    ex_idx = np.random.choice(X_explain.shape[0], size=min(25, X_explain.shape[0]), replace=False)
    xex = X_explain[ex_idx].numpy()
    shap_values = explainer.shap_values(xex, nsamples=100)
    # Save SHAP arrays
    np.save(os.path.join(OUTDIR, "shap_values.npy"), shap_values, allow_pickle=True)
    np.save(os.path.join(OUTDIR, "shap_examples.npy"), xex, allow_pickle=True)
    return shap_values, xex

def distill_policy_to_tree(policy_net, X_csr, y_true, max_depth=4):
    Xd = batch_csr_to_torch(X_csr).numpy()
    with torch.no_grad():
        policy_probs = torch.exp(policy_net(torch.tensor(Xd, dtype=torch.float32))).numpy()
    pseudo_labels = np.argmax(policy_probs, axis=1)
    tree = DecisionTreeClassifier(max_depth=max_depth, random_state=SEED)
    tree.fit(Xd, pseudo_labels)
    pred = tree.predict(batch_csr_to_torch(X_test).numpy())
    cm, rep = evaluate_policy(pred, y_test)
    with open(os.path.join(OUTDIR, "surrogate_tree_metrics.json"), "w") as f:
        json.dump({"confusion_matrix": cm.tolist(), "report": rep}, f, indent=2)
    rules = export_text(tree, feature_names=[f"f{i}" for i in range(Xd.shape[1])])
    with open(os.path.join(OUTDIR, "surrogate_tree_rules.txt"), "w") as f:
        f.write(rules)
    return tree, rules

if __name__ == "__main__":
    summary = {}

    summary["non_contextual_bandits"] = train_non_contextual_bandits()

    summary["contextual_bandits"] = train_contextual_bandits()

    deep_results, policy_for_xai, Xtr_d, Xte_d = train_deep_rl()
    summary["deep_rl"] = deep_results

    summary["offline_rl"] = train_offline_rl()

    try:
        shap_vals, shap_x = explain_policy_with_shap(policy_for_xai, Xtr_d, Xte_d)
        summary["shap_saved"] = SHAP_AVAILABLE
    except Exception as e:
        summary["shap_error"] = str(e)

    try:
        tree, rules = distill_policy_to_tree(policy_for_xai, X_train, y_train)
        summary["surrogate_tree"] = "ok"
    except Exception as e:
        summary["surrogate_tree_error"] = str(e)

    with open(os.path.join(OUTDIR, "RUN_SUMMARY.json"), "w") as f:
        json.dump(summary, f, indent=2)

    print("Done. Artifacts in:", OUTDIR)

  0%|          | 0/25 [00:00<?, ?it/s]

Done. Artifacts in: ./rl_outputs
