In [1]:
#!/usr/bin/env python3
import gzip, json, random
from collections import Counter

def resample_jsonl(input_path, output_path, seed=42):
    random.seed(seed)
    records = []
    with gzip.open(input_path, "rt") as f:
        for line in f:
            records.append(json.loads(line))
    N = len(records)

    total_counts = Counter()
    for rec in records:
        for lbl in rec["label_vector"]:
            if lbl != -1:
                total_counts[lbl] += 1

    total_pairs = sum(total_counts.values())
    class_weights = {
        c: (total_pairs / 6) / total_counts[c]
        for c in range(6)
    }

    sample_weights = []
    for rec in records:
        w = [
            class_weights[lbl]
            for lbl in rec["label_vector"]
            if lbl != -1
        ]
        sample_weights.append(sum(w) / len(w))

    # Normalize to sum to 1
    tot = sum(sample_weights)
    sample_weights = [w / tot for w in sample_weights]

 
    indices = random.choices(range(N), weights=sample_weights, k=N)

    # --- 6. Write out the new balanced JSONL ---
    with gzip.open(output_path, "wt") as out_f:
        for i in indices:
            out_f.write(json.dumps(records[i]) + "\n")
    print(f"Wrote {len(indices)} records to {output_path}")

resample_jsonl("train.jsonl.gz", "train_resampled.jsonl.gz")


Wrote 46663 records to train_resampled.jsonl.gz


In [7]:
#MODIFICATION OF CELL BELOW DELETE NEXT CELL IF WORKS
import gzip, json, torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from rdkit import RDLogger
from rdkit import Chem
from rdkit.Chem import AllChem
# ---------- Silence RDKit chatter ----------
RDLogger.DisableLog("rdApp.*")


class JSONLPotencyDataset(Dataset):
    def __init__(self, path_to_jsonl_gz):
        self.records = []
        with gzip.open(path_to_jsonl_gz, "rt") as f:
            for line in f:
                rec = json.loads(line)
                self.records.append((rec["smiles"], rec["label_vector"]))
    def __len__(self):
        return len(self.records)
    def __getitem__(self, idx):
        smiles, label_vec = self.records[idx]
        fp = fp6144_from_smiles(smiles)
        while fp is None:  # skip invalid SMILES
            idx = (idx + 1) % len(self.records)
            smiles, label_vec = self.records[idx]
            fp = fp6144_from_smiles(smiles)
        return fp, torch.tensor(label_vec, dtype=torch.long)

GEN0 = AllChem.GetMorganGenerator(radius=0, fpSize=2048)
GEN1 = AllChem.GetMorganGenerator(radius=1, fpSize=2048)
GEN2 = AllChem.GetMorganGenerator(radius=2, fpSize=2048)

def fp6144_from_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    fp0 = torch.tensor(list(GEN0.GetFingerprint(mol)), dtype=torch.float32)
    fp1 = torch.tensor(list(GEN1.GetFingerprint(mol)), dtype=torch.float32)
    fp2 = torch.tensor(list(GEN2.GetFingerprint(mol)), dtype=torch.float32)
    return torch.cat([fp0, fp1, fp2])  # (6144,)

def full_val_accuracy(model, loader, device):
    model.eval()
    correct = np.zeros(60); total = np.zeros(60)
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(2)
            mask = y != -1
            correct += ((pred == y) & mask).sum(0).cpu().numpy()
            total   += mask.sum(0).cpu().numpy()
    accs = [c / t if t>0 else 0.0 for c, t in zip(correct, total)]
    return float(np.nanmean(accs)), accs

class MultiLineMLP5(nn.Module):
    def __init__(self,
                 input_dim=6144,
                 hidden_dims=[1024, 1024, 512, 512, 256],
                 num_lines=60,
                 num_classes=6,
                 p_drop=0.3):
        super().__init__()
        self.bn_in = nn.BatchNorm1d(input_dim)
        layers = []
        dims = [input_dim] + hidden_dims
        for d_in, d_out in zip(dims[:-1], dims[1:]):
            layers += [
                nn.Linear(d_in, d_out),
                nn.BatchNorm1d(d_out),
                nn.ReLU(),
                nn.Dropout(p_drop),
            ]
        self.shared = nn.Sequential(*layers)
        self.classifier = nn.Linear(hidden_dims[-1], num_lines * num_classes)
    def forward(self, x):
        x = self.bn_in(x)
        feat = self.shared(x)
        logits = self.classifier(feat)
        return logits.view(-1, 60, 6)

train_ds = JSONLPotencyDataset("train.jsonl.gz")
val_ds   = JSONLPotencyDataset("val.jsonl.gz")
test_ds  = JSONLPotencyDataset("test.jsonl.gz")

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True,  num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=4, pin_memory=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = MultiLineMLP5().to(device)

criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)


def full_metrics(model, loader, device):
    """Return per‑line precision, recall, accuracy and support."""
    model.eval()
    tp = np.zeros(60); fp = np.zeros(60); fn = np.zeros(60)
    correct = np.zeros(60); total = np.zeros(60)
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(2)
            mask = y != -1
            for line in range(60):
                m = mask[:, line]
                if m.sum() == 0:
                    continue
                p = pred[m, line].cpu().numpy()
                t = y[m, line].cpu().numpy()
                correct[line] += (p == t).sum()
                total[line]   += len(t)
                for cls in range(6):
                    tp[line] += ((p == cls) & (t == cls)).sum()
                    fp[line] += ((p == cls) & (t != cls)).sum()
                    fn[line] += ((p != cls) & (t == cls)).sum()
    acc  = np.divide(correct, total, out=np.zeros_like(correct), where=total>0)
    prec = np.divide(tp, tp + fp + 1e-9)
    rec  = np.divide(tp, tp + fn + 1e-9)
    return acc, prec, rec, total

history = {"train_loss": [], "val_loss": [], "val_acc": []}
plots_dir = Path("plots"); plots_dir.mkdir(exist_ok=True)

best_acc = 0.0
for epoch in range(1, 41):
    # ----- Train -----
    model.train()
    running_loss = 0.0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1, 6), y.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        running_loss += loss.item()
    train_loss = running_loss / len(train_loader)

    # ----- Validation -----
    model.eval(); val_running = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            val_running += criterion(model(x).view(-1, 6), y.view(-1)).item()
    val_loss = val_running / len(val_loader)
    scheduler.step(val_loss)

    # Accuracy
    acc_avg, _ = full_val_accuracy(model, val_loader, device)

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(acc_avg)

    print(f"Epoch {epoch:02} | TL {train_loss:.3f} | VL {val_loss:.3f} | Acc {acc_avg:.4f} | LR {optimizer.param_groups[0]['lr']:.1e}")

    if acc_avg > best_acc:
        best_acc = acc_avg
        torch.save({"epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "history": history}, "best_resampled.pt")
        print("   ✓ New best accuracy")

# ---------- Curves ----------
plt.figure();
plt.plot(history["train_loss"], label="Train Loss")
plt.plot(history["val_loss"], label="Val Loss")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.tight_layout()
plt.savefig(plots_dir/"loss_curve.png", dpi=120)
plt.close()

plt.figure();
plt.plot(history["val_acc"], label="Val Accuracy")
plt.xlabel("Epoch"); plt.ylabel("Accuracy")
plt.tight_layout(); plt.savefig(plots_dir/"accuracy_curve.png", dpi=120); plt.close()

# ---------- Test‑set metrics (Table 2 + Figure 1) ----------
acc, prec, rec, support = full_metrics(model, test_loader, device)

df = pd.DataFrame({
    "CellLine": np.arange(60),
    "Precision": prec,
    "Recall": rec,
    "Accuracy": acc,
    "Support": support,
})

df.to_csv("table2_metrics.csv", index=False)
df.to_latex("table2.tex", index=False, float_format="%.3f")

plt.figure();
plt.scatter(rec, prec)
for i, (x, y) in enumerate(zip(rec, prec)):
    plt.annotate(str(i), (x, y), fontsize=6, alpha=0.7)
plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title("Precision vs Recall per Cell Line")
plt.tight_layout(); plt.savefig(plots_dir/"precision_recall_scatter.png", dpi=120); plt.close()

print("Finished. Best validation accuracy:", best_acc)
print("Table 2 saved to table2_metrics.csv / table2.tex; curves + Figure 1 saved in", plots_dir)



                                                                                                 

Epoch 01 | TL 1.173 | VL 1.242 | Acc 0.4958 | LR 5.0e-04
   ✓ New best accuracy


                                                                                                 

Epoch 02 | TL 1.021 | VL 1.341 | Acc 0.4910 | LR 5.0e-04


                                                                                                 

Epoch 03 | TL 0.948 | VL 1.347 | Acc 0.4965 | LR 5.0e-04
   ✓ New best accuracy


                                                                                                 

Epoch 04 | TL 0.896 | VL 1.345 | Acc 0.5112 | LR 2.5e-04
   ✓ New best accuracy


                                                                                                 

Epoch 05 | TL 0.800 | VL 1.437 | Acc 0.5145 | LR 2.5e-04
   ✓ New best accuracy


                                                                                                 

Epoch 06 | TL 0.756 | VL 1.397 | Acc 0.5063 | LR 2.5e-04


                                                                                                 

Epoch 07 | TL 0.728 | VL 1.548 | Acc 0.4974 | LR 1.3e-04


                                                                                                 

Epoch 08 | TL 0.674 | VL 1.430 | Acc 0.5185 | LR 1.3e-04
   ✓ New best accuracy


                                                                                                 

Epoch 09 | TL 0.647 | VL 1.553 | Acc 0.5151 | LR 1.3e-04


                                                                                                 

Epoch 10 | TL 0.637 | VL 1.454 | Acc 0.5328 | LR 6.3e-05
   ✓ New best accuracy


                                                                                                 

Epoch 11 | TL 0.609 | VL 1.530 | Acc 0.5284 | LR 6.3e-05


                                                                                                 

Epoch 12 | TL 0.597 | VL 1.518 | Acc 0.5247 | LR 6.3e-05


                                                                                                 

Epoch 13 | TL 0.589 | VL 1.567 | Acc 0.5246 | LR 3.1e-05


                                                                                                 

Epoch 14 | TL 0.578 | VL 1.437 | Acc 0.5338 | LR 3.1e-05
   ✓ New best accuracy


                                                                                                 

Epoch 15 | TL 0.574 | VL 1.480 | Acc 0.5320 | LR 3.1e-05


                                                                                                 

Epoch 16 | TL 0.569 | VL 1.495 | Acc 0.5265 | LR 1.6e-05


                                                                                                 

Epoch 17 | TL 0.566 | VL 1.515 | Acc 0.5291 | LR 1.6e-05


                                                                                                 

Epoch 18 | TL 0.559 | VL 1.525 | Acc 0.5323 | LR 1.6e-05


                                                                                                 

Epoch 19 | TL 0.560 | VL 1.513 | Acc 0.5294 | LR 7.8e-06


                                                                                                 

Epoch 20 | TL 0.556 | VL 1.512 | Acc 0.5370 | LR 7.8e-06
   ✓ New best accuracy


                                                                                                 

Epoch 21 | TL 0.555 | VL 1.471 | Acc 0.5365 | LR 7.8e-06


                                                                                                 

Epoch 22 | TL 0.555 | VL 1.493 | Acc 0.5324 | LR 3.9e-06


                                                                                                 

Epoch 23 | TL 0.552 | VL 1.476 | Acc 0.5357 | LR 3.9e-06


                                                                                                 

Epoch 24 | TL 0.551 | VL 1.540 | Acc 0.5253 | LR 3.9e-06


                                                                                                 

Epoch 25 | TL 0.552 | VL 1.502 | Acc 0.5257 | LR 2.0e-06


                                                                                                 

Epoch 26 | TL 0.550 | VL 1.492 | Acc 0.5326 | LR 2.0e-06


                                                                                                 

Epoch 27 | TL 0.549 | VL 1.490 | Acc 0.5369 | LR 2.0e-06


                                                                                                 

Epoch 28 | TL 0.550 | VL 1.429 | Acc 0.5446 | LR 9.8e-07
   ✓ New best accuracy


                                                                                                 

Epoch 29 | TL 0.548 | VL 1.453 | Acc 0.5397 | LR 9.8e-07


                                                                                                 

Epoch 30 | TL 0.550 | VL 1.398 | Acc 0.5406 | LR 9.8e-07


                                                                                                 

Epoch 31 | TL 0.549 | VL 1.480 | Acc 0.5372 | LR 4.9e-07


                                                                                                 

Epoch 32 | TL 0.551 | VL 1.507 | Acc 0.5349 | LR 4.9e-07


                                                                                                 

Epoch 33 | TL 0.549 | VL 1.489 | Acc 0.5372 | LR 4.9e-07


                                                                                                 

Epoch 34 | TL 0.550 | VL 1.481 | Acc 0.5376 | LR 2.4e-07


                                                                                                 

Epoch 35 | TL 0.549 | VL 1.519 | Acc 0.5360 | LR 2.4e-07


                                                                                                 

Epoch 36 | TL 0.551 | VL 1.477 | Acc 0.5351 | LR 2.4e-07


                                                                                                 

Epoch 37 | TL 0.549 | VL 1.461 | Acc 0.5333 | LR 1.2e-07


                                                                                                 

Epoch 38 | TL 0.549 | VL 1.458 | Acc 0.5371 | LR 1.2e-07


                                                                                                 

Epoch 39 | TL 0.550 | VL 1.561 | Acc 0.5337 | LR 1.2e-07


                                                                                                 

Epoch 40 | TL 0.548 | VL 1.526 | Acc 0.5338 | LR 6.1e-08
Finished. Best validation accuracy: 0.5446336553684783
Table 2 saved to table2_metrics.csv / table2.tex; curves + Figure 1 saved in plots


In [5]:
import gzip, json, torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm.auto import tqdm
import numpy as np
from rdkit import RDLogger

# Silence RDKit chatter
RDLogger.DisableLog("rdApp.*")

# --- Dataset class (unchanged) ---
class JSONLPotencyDataset(Dataset):
    def __init__(self, path_to_jsonl_gz):
        self.records = []
        with gzip.open(path_to_jsonl_gz, "rt") as f:
            for line in f:
                rec = json.loads(line)
                self.records.append((rec["smiles"], rec["label_vector"]))

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        smiles, label_vec = self.records[idx]
        fp = fp6144_from_smiles(smiles)
        while fp is None:
            idx = (idx + 1) % len(self.records)
            smiles, label_vec = self.records[idx]
            fp = fp6144_from_smiles(smiles)
        return fp, torch.tensor(label_vec, dtype=torch.long)

# --- Fingerprint helper (unchanged) ---
from rdkit import Chem
from rdkit.Chem import AllChem
GEN0 = AllChem.GetMorganGenerator(radius=0, fpSize=2048)
GEN1 = AllChem.GetMorganGenerator(radius=1, fpSize=2048)
GEN2 = AllChem.GetMorganGenerator(radius=2, fpSize=2048)

def fp6144_from_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    fp0 = torch.tensor(list(GEN0.GetFingerprint(mol)), dtype=torch.float32)
    fp1 = torch.tensor(list(GEN1.GetFingerprint(mol)), dtype=torch.float32)
    fp2 = torch.tensor(list(GEN2.GetFingerprint(mol)), dtype=torch.float32)
    return torch.cat([fp0, fp1, fp2])  # (6144,)

# --- Model definition (same 5-layer MLP) ---
class MultiLineMLP5(nn.Module):
    def __init__(self,
                 input_dim=6144,
                 hidden_dims=[1024, 1024, 512, 512, 256],
                 num_lines=60,
                 num_classes=6,
                 p_drop=0.3):
        super().__init__()
        self.bn_in = nn.BatchNorm1d(input_dim)
        layers = []
        dims = [input_dim] + hidden_dims
        for d_in, d_out in zip(dims[:-1], dims[1:]):
            layers += [
                nn.Linear(d_in, d_out),
                nn.BatchNorm1d(d_out),
                nn.ReLU(),
                nn.Dropout(p_drop),
            ]
        self.shared = nn.Sequential(*layers)
        self.classifier = nn.Linear(hidden_dims[-1], num_lines * num_classes)

    def forward(self, x):
        x = self.bn_in(x)
        feat = self.shared(x)
        logits = self.classifier(feat)
        return logits.view(-1, 60, 6)
        
train_ds = JSONLPotencyDataset("train_resampled.jsonl.gz")
val_ds   = JSONLPotencyDataset("val.jsonl.gz")
test_ds  = JSONLPotencyDataset("test.jsonl.gz")

train_loader = DataLoader(train_ds, batch_size=64,
                          shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=64,
                          shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=64,
                          shuffle=False, num_workers=4, pin_memory=True)

# --- Training setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = MultiLineMLP5().to(device)

# No class-weights here:
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=2)

def full_val_accuracy(model, loader, device):
    model.eval()
    correct = np.zeros(60); total = np.zeros(60)
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(2)
            mask = y != -1
            correct += ((pred == y) & mask).sum(0).cpu().numpy()
            total   += mask.sum(0).cpu().numpy()
    accs = [c / t if t>0 else 0.0 for c, t in zip(correct, total)]
    return float(np.nanmean(accs)), accs

# --- Training loop ---
best_acc = 0.0
for epoch in range(1, 41):
    model.train()
    train_loss = 0.0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1,6), y.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        train_loss += loss.item()

    # validation
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            val_loss += criterion(
                model(x).view(-1,6), y.view(-1)
            ).item()

    scheduler.step(val_loss)
    avg_acc, _ = full_val_accuracy(model, val_loader, device)
    print(f"Epoch {epoch:02} | Train Loss {train_loss:.1f} "
          f"| Val Loss {val_loss:.1f} | Acc {avg_acc:.4f} "
          f"| LR {optimizer.param_groups[0]['lr']:.1e}")

    if avg_acc > best_acc:
        best_acc = avg_acc
        torch.save(model.state_dict(), "best_resampled.pt")
        print(f"  ✓ New best: {best_acc:.4f}")

print("Training complete. Best Avg Val Acc =", best_acc)


                                                                                           

Epoch 01 | Train Loss 910.9 | Val Loss 168.1 | Acc 0.4055 | LR 5.0e-04
  ✓ New best: 0.4055


                                                                                           

Epoch 02 | Train Loss 729.6 | Val Loss 161.8 | Acc 0.4362 | LR 5.0e-04
  ✓ New best: 0.4362


                                                                                           

Epoch 03 | Train Loss 660.5 | Val Loss 165.4 | Acc 0.4447 | LR 5.0e-04
  ✓ New best: 0.4447


                                                                                           

Epoch 04 | Train Loss 616.2 | Val Loss 154.2 | Acc 0.4580 | LR 5.0e-04
  ✓ New best: 0.4580


                                                                                           

Epoch 05 | Train Loss 581.0 | Val Loss 152.6 | Acc 0.4620 | LR 5.0e-04
  ✓ New best: 0.4620


                                                                                           

Epoch 06 | Train Loss 557.6 | Val Loss 156.5 | Acc 0.4664 | LR 5.0e-04
  ✓ New best: 0.4664


                                                                                           

Epoch 07 | Train Loss 540.2 | Val Loss 149.2 | Acc 0.4782 | LR 5.0e-04
  ✓ New best: 0.4782


                                                                                           

Epoch 08 | Train Loss 524.8 | Val Loss 156.2 | Acc 0.4632 | LR 5.0e-04


                                                                                           

Epoch 09 | Train Loss 514.6 | Val Loss 138.2 | Acc 0.4891 | LR 5.0e-04
  ✓ New best: 0.4891


                                                                                           

Epoch 10 | Train Loss 507.0 | Val Loss 140.5 | Acc 0.4876 | LR 5.0e-04


                                                                                           

Epoch 11 | Train Loss 497.5 | Val Loss 141.1 | Acc 0.4879 | LR 5.0e-04


                                                                                           

Epoch 12 | Train Loss 491.3 | Val Loss 132.2 | Acc 0.5022 | LR 5.0e-04
  ✓ New best: 0.5022


                                                                                           

Epoch 13 | Train Loss 488.1 | Val Loss 133.3 | Acc 0.5010 | LR 5.0e-04


                                                                                           

Epoch 14 | Train Loss 482.4 | Val Loss 134.9 | Acc 0.4934 | LR 5.0e-04


                                                                                           

Epoch 15 | Train Loss 479.6 | Val Loss 131.6 | Acc 0.5056 | LR 5.0e-04
  ✓ New best: 0.5056


                                                                                           

Epoch 16 | Train Loss 472.1 | Val Loss 137.3 | Acc 0.4928 | LR 5.0e-04


                                                                                           

Epoch 17 | Train Loss 472.5 | Val Loss 138.9 | Acc 0.4931 | LR 5.0e-04


                                                                                           

Epoch 18 | Train Loss 469.0 | Val Loss 132.4 | Acc 0.5075 | LR 2.5e-04
  ✓ New best: 0.5075


                                                                                           

Epoch 19 | Train Loss 440.9 | Val Loss 131.5 | Acc 0.5048 | LR 2.5e-04


                                                                                           

Epoch 20 | Train Loss 423.8 | Val Loss 133.0 | Acc 0.5107 | LR 2.5e-04
  ✓ New best: 0.5107


                                                                                           

Epoch 21 | Train Loss 420.5 | Val Loss 128.1 | Acc 0.5172 | LR 2.5e-04
  ✓ New best: 0.5172


                                                                                           

Epoch 22 | Train Loss 419.2 | Val Loss 130.4 | Acc 0.5203 | LR 2.5e-04
  ✓ New best: 0.5203


                                                                                           

Epoch 23 | Train Loss 416.6 | Val Loss 134.7 | Acc 0.5105 | LR 2.5e-04


                                                                                           

Epoch 24 | Train Loss 412.6 | Val Loss 134.2 | Acc 0.5189 | LR 1.3e-04


                                                                                           

Epoch 25 | Train Loss 402.4 | Val Loss 132.3 | Acc 0.5225 | LR 1.3e-04
  ✓ New best: 0.5225


                                                                                           

Epoch 26 | Train Loss 394.3 | Val Loss 132.5 | Acc 0.5192 | LR 1.3e-04


                                                                                           

Epoch 27 | Train Loss 390.4 | Val Loss 127.8 | Acc 0.5269 | LR 1.3e-04
  ✓ New best: 0.5269


                                                                                           

Epoch 28 | Train Loss 389.4 | Val Loss 132.6 | Acc 0.5277 | LR 1.3e-04
  ✓ New best: 0.5277


                                                                                           

Epoch 29 | Train Loss 386.1 | Val Loss 133.4 | Acc 0.5207 | LR 1.3e-04


                                                                                           

Epoch 30 | Train Loss 385.6 | Val Loss 126.2 | Acc 0.5307 | LR 1.3e-04
  ✓ New best: 0.5307


                                                                                           

Epoch 31 | Train Loss 382.7 | Val Loss 128.0 | Acc 0.5269 | LR 1.3e-04


                                                                                           

Epoch 32 | Train Loss 380.5 | Val Loss 125.4 | Acc 0.5366 | LR 1.3e-04
  ✓ New best: 0.5366


                                                                                           

Epoch 33 | Train Loss 378.9 | Val Loss 130.7 | Acc 0.5289 | LR 1.3e-04


                                                                                           

Epoch 34 | Train Loss 378.7 | Val Loss 130.5 | Acc 0.5284 | LR 1.3e-04


                                                                                           

Epoch 35 | Train Loss 377.0 | Val Loss 134.0 | Acc 0.5312 | LR 6.3e-05


                                                                                           

Epoch 36 | Train Loss 372.2 | Val Loss 127.5 | Acc 0.5376 | LR 6.3e-05
  ✓ New best: 0.5376


                                                                                           

Epoch 37 | Train Loss 368.2 | Val Loss 125.1 | Acc 0.5377 | LR 6.3e-05
  ✓ New best: 0.5377


                                                                                           

Epoch 38 | Train Loss 366.5 | Val Loss 127.7 | Acc 0.5340 | LR 6.3e-05


                                                                                           

Epoch 39 | Train Loss 364.9 | Val Loss 129.8 | Acc 0.5335 | LR 6.3e-05


                                                                                           

Epoch 40 | Train Loss 364.6 | Val Loss 127.8 | Acc 0.5364 | LR 3.1e-05
Training complete. Best Avg Val Acc = 0.5377233967144139


In [6]:
# 1) Define the function (one cell)
import torch
import numpy as np

def compute_precision_recall(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=2)
            mask = (y != -1)
            all_preds.append( preds[mask].cpu().numpy().ravel() )
            all_labels.append(y[mask].cpu().numpy().ravel())
    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_labels)
    precision = np.zeros(6); recall = np.zeros(6)
    for c in range(6):
        tp = ((y_pred==c)&(y_true==c)).sum()
        fp = ((y_pred==c)&(y_true!=c)).sum()
        fn = ((y_pred!=c)&(y_true==c)).sum()
        precision[c] = tp/(tp+fp) if tp+fp>0 else np.nan
        recall[c]    = tp/(tp+fn) if tp+fn>0 else np.nan
    return precision, recall


In [7]:
# 2) Call it in a fresh cell once training is done
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("best_model.pt"))  # if you restarted kernel
model.to(device)

precision, recall = compute_precision_recall(model, val_loader, device)
for c in range(6):
    print(f"class {c}:  prec={precision[c]:.3f},  rec={recall[c]:.3f}")


class 0:  prec=0.770,  rec=0.473
class 1:  prec=0.513,  rec=0.406
class 2:  prec=0.494,  rec=0.357
class 3:  prec=0.209,  rec=0.398
class 4:  prec=0.088,  rec=0.469
class 5:  prec=0.063,  rec=0.630


In [8]:
import torch, numpy as np
from collections import defaultdict

def macro_precision_recall(model, loader, device, num_lines=59):
    """
    Returns two 2-D arrays of shape (6 classes, 2 metrics):
        [exact , within-one] for precision and recall.
    """
    model.eval()
    # per-line confusion buckets
    exact = defaultdict(lambda: np.zeros((6, 3), dtype=int))       # TP, FP, FN
    within = defaultdict(lambda: np.zeros((6, 3), dtype=int))      # TPw, FPw, FNw
    
    with torch.no_grad():
        for x, y in loader:                         # y: (B, 60)
            x = x.to(device)
            logits = model(x)[:, :num_lines]        # (B, 59, 6)  drop 60th line
            preds  = logits.argmax(2).cpu().numpy()
            y      = y[:, :num_lines].cpu().numpy()
            
            for line_idx in range(num_lines):
                true_line   = y[:, line_idx]
                pred_line   = preds[:, line_idx]
                mask        = true_line != -1
                yt          = true_line[mask]
                yp          = pred_line[mask]
                if yt.size == 0:                    # line had no labels
                    continue
                
                for c in range(6):
                    tp = np.sum((yp == c) & (yt == c))
                    fp = np.sum((yp == c) & (yt != c))
                    fn = np.sum((yp != c) & (yt == c))
                    exact[line_idx][c] += (tp, fp, fn)
                    
                    # within-one: hit if |pred-true| <= 1
                    tp_w = np.sum((yp == c) & (np.abs(yt - c) <= 1))
                    fp_w = np.sum((yp == c) & (np.abs(yt - c) >  1))
                    fn_w = np.sum((yp != c) & (np.abs(yp - c) <= 1) & (yt == c))
                    within[line_idx][c] += (tp_w, fp_w, fn_w)
    
    # macro-average over cell lines (skip lines with zero support)
    prec = np.zeros((6, 2))
    rec  = np.zeros((6, 2))
    for c in range(6):
        # gather per-line metrics then average
        p_exact, r_exact, p_within, r_within = [], [], [], []
        for line in exact.keys():
            tp, fp, fn         = exact[line][c]
            tpw, fpw, fnw      = within[line][c]
            if tp + fn == 0:    # no true instances of class c in this line
                continue
            p_exact.append(  tp / (tp + fp) if tp+fp>0 else np.nan )
            r_exact.append(  tp / (tp + fn)                 )
            p_within.append( tpw / (tpw + fpw) if tpw+fpw>0 else np.nan )
            r_within.append( tpw / (tpw + fnw)               )
        prec[c, 0] = np.nanmean(p_exact)
        prec[c, 1] = np.nanmean(p_within)
        rec[c, 0]  = np.nanmean(r_exact)
        rec[c, 1]  = np.nanmean(r_within)
    return prec, rec

# -----------------------------------------------------------
# Example usage on your validation set
prec, rec = macro_precision_recall(model, val_loader, device)

headers = ["Exact P", "±1 P", "Exact R", "±1 R"]
print(f"{'Class':>6}  {headers[0]:>8}  {headers[1]:>8}  "
      f"{headers[2]:>8}  {headers[3]:>8}")
for c in range(6):
    print(f"{c:>6}  {prec[c,0]*100:8.1f}%  {prec[c,1]*100:8.1f}%  "
          f"{rec[c,0]*100:8.1f}%  {rec[c,1]*100:8.1f}%")


 Class   Exact P      ±1 P   Exact R      ±1 R
     0      76.3%      94.7%      46.7%      74.2%
     1      51.0%      98.2%      40.3%      72.2%
     2      49.7%      84.0%      35.8%      59.8%
     3      21.1%      66.3%      40.2%      73.3%
     4       8.9%      27.9%      47.3%      77.7%
     5       6.5%      13.2%      61.6%      85.2%


In [13]:
#RECALL MODEL

import numpy as np
from torch.utils.data import WeightedRandomSampler, DataLoader
from rdkit import RDLogger


# 1) Count how often each potency label (0-5) appears anywhere in the training set
class_freq = np.zeros(6, dtype=int)         # one slot per class
for _, lbl_vec in train_ds.records:    # lbl_vec length = 59 (or 60) cell lines
    labels  = np.asarray(lbl_vec, dtype=int)
    present = labels[labels >= 0]           # drop -1 (missing) entries
    class_freq += np.bincount(present, minlength=6)

# 2) Inverse-√ frequency weights, normalised so their mean ≈ 1
inv_sqrt = 1 / np.sqrt(class_freq + 1e-6)
inv_sqrt = inv_sqrt / inv_sqrt.mean()

# 3) Weight each molecule by the *average* weight of the labels it actually has
sample_weights = []
for _, lbl_vec in train_ds.records:
    labels  = np.asarray(lbl_vec, dtype=int)
    present = labels[labels >= 0]
    sample_weights.append(float(inv_sqrt[present].mean()))

# 4) Build a weighted sampler and the DataLoader
sampler = WeightedRandomSampler(sample_weights,
                                num_samples=len(sample_weights),
                                replacement=True)

train_loader = DataLoader(train_ds,
                          batch_size=64,
                          sampler=sampler,
                          num_workers=4,
                          pin_memory=True)

val_loader   = DataLoader(val_ds,   batch_size=64,
                          shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=64,
                          shuffle=False, num_workers=4, pin_memory=True)

# --- Training setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = MultiLineMLP5().to(device)

# No class-weights here:
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=2)

def full_val_accuracy(model, loader, device):
    model.eval()
    correct = np.zeros(60); total = np.zeros(60)
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(2)
            mask = y != -1
            correct += ((pred == y) & mask).sum(0).cpu().numpy()
            total   += mask.sum(0).cpu().numpy()
    accs = [c / t if t>0 else 0.0 for c, t in zip(correct, total)]
    return float(np.nanmean(accs)), accs

# --- Training loop ---
best_acc = 0.0
for epoch in range(1, 41):
    model.train()
    train_loss = 0.0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1,6), y.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        train_loss += loss.item()

    # validation
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            val_loss += criterion(
                model(x).view(-1,6), y.view(-1)
            ).item()

    scheduler.step(val_loss)
    avg_acc, _ = full_val_accuracy(model, val_loader, device)
    print(f"Epoch {epoch:02} | Train Loss {train_loss:.1f} "
          f"| Val Loss {val_loss:.1f} | Acc {avg_acc:.4f} "
          f"| LR {optimizer.param_groups[0]['lr']:.1e}")

    if avg_acc > best_acc:
        best_acc = avg_acc
        torch.save(model.state_dict(), "best_resampled.pt")
        print(f"  ✓ New best: {best_acc:.4f}")

print("Training complete. Best Avg Val Acc =", best_acc)

                                                                                           

Epoch 01 | Train Loss 877.7 | Val Loss 185.0 | Acc 0.3957 | LR 5.0e-04
  ✓ New best: 0.3957


                                                                                           

Epoch 02 | Train Loss 702.8 | Val Loss 186.3 | Acc 0.4179 | LR 5.0e-04
  ✓ New best: 0.4179


                                                                                           

Epoch 03 | Train Loss 639.8 | Val Loss 173.5 | Acc 0.4303 | LR 5.0e-04
  ✓ New best: 0.4303


                                                                                           

Epoch 04 | Train Loss 600.8 | Val Loss 163.0 | Acc 0.4524 | LR 5.0e-04
  ✓ New best: 0.4524


                                                                                           

Epoch 05 | Train Loss 568.9 | Val Loss 158.9 | Acc 0.4540 | LR 5.0e-04
  ✓ New best: 0.4540


                                                                                           

Epoch 06 | Train Loss 549.8 | Val Loss 155.0 | Acc 0.4540 | LR 5.0e-04


                                                                                           

Epoch 07 | Train Loss 534.1 | Val Loss 158.2 | Acc 0.4586 | LR 5.0e-04
  ✓ New best: 0.4586


                                                                                           

Epoch 08 | Train Loss 521.0 | Val Loss 153.8 | Acc 0.4661 | LR 5.0e-04
  ✓ New best: 0.4661


                                                                                           

Epoch 09 | Train Loss 511.2 | Val Loss 151.5 | Acc 0.4789 | LR 5.0e-04
  ✓ New best: 0.4789


                                                                                           

Epoch 10 | Train Loss 502.4 | Val Loss 148.1 | Acc 0.4804 | LR 5.0e-04
  ✓ New best: 0.4804


                                                                                           

Epoch 11 | Train Loss 497.3 | Val Loss 143.7 | Acc 0.4909 | LR 5.0e-04
  ✓ New best: 0.4909


                                                                                           

Epoch 12 | Train Loss 488.5 | Val Loss 141.7 | Acc 0.4861 | LR 5.0e-04


                                                                                           

Epoch 13 | Train Loss 485.8 | Val Loss 153.0 | Acc 0.4720 | LR 5.0e-04


                                                                                           

Epoch 14 | Train Loss 479.6 | Val Loss 140.6 | Acc 0.4907 | LR 5.0e-04


                                                                                           

Epoch 15 | Train Loss 475.4 | Val Loss 135.8 | Acc 0.4947 | LR 5.0e-04
  ✓ New best: 0.4947


                                                                                           

Epoch 16 | Train Loss 470.2 | Val Loss 132.5 | Acc 0.4954 | LR 5.0e-04
  ✓ New best: 0.4954


                                                                                           

Epoch 17 | Train Loss 467.2 | Val Loss 139.8 | Acc 0.4922 | LR 5.0e-04


                                                                                           

Epoch 18 | Train Loss 462.5 | Val Loss 129.7 | Acc 0.5096 | LR 5.0e-04
  ✓ New best: 0.5096


                                                                                           

Epoch 19 | Train Loss 459.1 | Val Loss 136.3 | Acc 0.4916 | LR 5.0e-04


                                                                                           

Epoch 20 | Train Loss 456.8 | Val Loss 134.1 | Acc 0.4983 | LR 5.0e-04


                                                                                           

Epoch 21 | Train Loss 456.8 | Val Loss 139.1 | Acc 0.4962 | LR 2.5e-04


                                                                                           

Epoch 22 | Train Loss 434.4 | Val Loss 129.9 | Acc 0.5193 | LR 2.5e-04
  ✓ New best: 0.5193


                                                                                           

Epoch 23 | Train Loss 416.4 | Val Loss 126.2 | Acc 0.5184 | LR 2.5e-04


                                                                                           

Epoch 24 | Train Loss 411.9 | Val Loss 133.1 | Acc 0.5157 | LR 2.5e-04


                                                                                           

Epoch 25 | Train Loss 408.8 | Val Loss 126.8 | Acc 0.5248 | LR 2.5e-04
  ✓ New best: 0.5248


                                                                                           

Epoch 26 | Train Loss 407.1 | Val Loss 127.7 | Acc 0.5237 | LR 1.3e-04


                                                                                           

Epoch 27 | Train Loss 396.0 | Val Loss 129.3 | Acc 0.5210 | LR 1.3e-04


                                                                                           

Epoch 28 | Train Loss 387.2 | Val Loss 132.5 | Acc 0.5238 | LR 1.3e-04


                                                                                           

Epoch 29 | Train Loss 386.1 | Val Loss 124.6 | Acc 0.5302 | LR 1.3e-04
  ✓ New best: 0.5302


                                                                                           

Epoch 30 | Train Loss 383.6 | Val Loss 123.7 | Acc 0.5381 | LR 1.3e-04
  ✓ New best: 0.5381


                                                                                           

Epoch 31 | Train Loss 381.9 | Val Loss 122.0 | Acc 0.5378 | LR 1.3e-04


                                                                                           

Epoch 32 | Train Loss 379.0 | Val Loss 121.5 | Acc 0.5402 | LR 1.3e-04
  ✓ New best: 0.5402


                                                                                           

Epoch 33 | Train Loss 376.6 | Val Loss 125.4 | Acc 0.5354 | LR 1.3e-04


                                                                                           

Epoch 34 | Train Loss 374.6 | Val Loss 130.0 | Acc 0.5314 | LR 1.3e-04


                                                                                           

Epoch 35 | Train Loss 376.9 | Val Loss 124.0 | Acc 0.5360 | LR 6.3e-05


                                                                                           

Epoch 36 | Train Loss 370.3 | Val Loss 124.2 | Acc 0.5360 | LR 6.3e-05


                                                                                           

Epoch 37 | Train Loss 366.1 | Val Loss 124.7 | Acc 0.5416 | LR 6.3e-05
  ✓ New best: 0.5416


                                                                                           

Epoch 38 | Train Loss 366.2 | Val Loss 124.4 | Acc 0.5384 | LR 3.1e-05


                                                                                           

Epoch 39 | Train Loss 360.8 | Val Loss 124.2 | Acc 0.5397 | LR 3.1e-05


                                                                                           

Epoch 40 | Train Loss 360.5 | Val Loss 126.6 | Acc 0.5398 | LR 3.1e-05
Training complete. Best Avg Val Acc = 0.5416469308812588


In [14]:
def micro_precision_recall(model, loader, device, num_lines=59):
    """
    Micro-averaged precision and recall across all cell lines and samples.
    Returns: (precision_exact, recall_exact), (precision_within1, recall_within1)
    """
    model.eval()
    y_true_all = []
    y_pred_all = []
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            logits = model(x)[:, :num_lines]     # (B, 59, 6)
            preds = logits.argmax(dim=2).cpu().numpy()   # (B, 59)
            y     = y[:, :num_lines].cpu().numpy()        # (B, 59)

            mask = (y != -1)
            y_true_all.append(y[mask])
            y_pred_all.append(preds[mask])

    y_true = np.concatenate(y_true_all)
    y_pred = np.concatenate(y_pred_all)

    precision = []
    recall    = []
    precision1 = []
    recall1    = []

    for c in range(6):
        tp  = np.sum((y_pred == c) & (y_true == c))
        fp  = np.sum((y_pred == c) & (y_true != c))
        fn  = np.sum((y_pred != c) & (y_true == c))

        tp1 = np.sum((y_pred == c) & (np.abs(y_true - c) <= 1))
        fp1 = np.sum((y_pred == c) & (np.abs(y_true - c) >  1))
        fn1 = np.sum((y_pred != c) & (y_true == c) & (np.abs(y_pred - c) <= 1))

        p  = tp  / (tp + fp) if (tp + fp) > 0 else np.nan
        r  = tp  / (tp + fn) if (tp + fn) > 0 else np.nan
        p1 = tp1 / (tp1 + fp1) if (tp1 + fp1) > 0 else np.nan
        r1 = tp1 / (tp1 + fn1) if (tp1 + fn1) > 0 else np.nan

        precision.append(p)
        recall.append(r)
        precision1.append(p1)
        recall1.append(r1)

    return np.array(precision), np.array(recall), np.array(precision1), np.array(recall1)


In [15]:
prec, rec, prec1, rec1 = micro_precision_recall(model, val_loader, device)

headers = ["Exact P", "±1 P", "Exact R", "±1 R"]
print(f"{'Class':>6}  {headers[0]:>8}  {headers[1]:>8}  "
      f"{headers[2]:>8}  {headers[3]:>8}")
for c in range(6):
    print(f"{c:>6}  {prec[c]*100:8.1f}%  {prec1[c]*100:8.1f}%  "
          f"{rec[c]*100:8.1f}%  {rec1[c]*100:8.1f}%")


 Class   Exact P      ±1 P   Exact R      ±1 R
     0      70.2%      91.3%      62.2%      75.7%
     1      48.7%      97.1%      49.9%      68.8%
     2      48.4%      84.1%      51.8%      72.1%
     3      35.4%      76.3%      42.8%      69.3%
     4      30.5%      69.4%      34.8%      66.3%
     5      37.3%      62.1%      44.1%      68.9%


In [9]:
import torch

# 1. Set up model as usual
model = MultiLineMLP5(hidden_dims=[1024,1024,512,512,256], p_drop=0.3).to(device)

# 2. Load the checkpoint
ckpt = torch.load("best_resampled.pt", map_location=device)

# 3. Restore model state
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# 4. (Optional) check epoch / history
print(f"Checkpoint from epoch {ckpt['epoch']}")
if "history" in ckpt:
    print("Validation accuracy at save time:", ckpt["history"]["val_acc"][-1])

# 5. Run a quick accuracy check
def evaluate(loader):
    correct = total = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            p = model(x).argmax(2)
            m = y != -1
            correct += ((p == y) & m).sum().item()
            total   += m.sum().item()
    return correct / total if total else 0.0

val_acc = evaluate(val_loader)
print("Validation accuracy after reload:", val_acc)


Checkpoint from epoch 28
Validation accuracy at save time: 0.5446336553684783
Validation accuracy after reload: 0.5453763468470059
