In [None]:
import gzip, json, torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm.auto import tqdm
import numpy as np
from rdkit import RDLogger

# Silence RDKit chatter
RDLogger.DisableLog("rdApp.*")


class JSONLPotencyDataset(Dataset):
    def __init__(self, path_to_jsonl_gz):
        self.records = []
        with gzip.open(path_to_jsonl_gz, "rt") as f:
            for line in f:
                rec = json.loads(line)
                self.records.append((rec["smiles"], rec["label_vector"]))

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        smiles, label_vec = self.records[idx]
        fp = fp6144_from_smiles(smiles)
        while fp is None:
            idx = (idx + 1) % len(self.records)
            smiles, label_vec = self.records[idx]
            fp = fp6144_from_smiles(smiles)
        return fp, torch.tensor(label_vec, dtype=torch.long)

# Fingerprint helper (NOT NEW fingerprint method)
from rdkit import Chem
from rdkit.Chem import AllChem
GEN0 = AllChem.GetMorganGenerator(radius=0, fpSize=2048)
GEN1 = AllChem.GetMorganGenerator(radius=1, fpSize=2048)
GEN2 = AllChem.GetMorganGenerator(radius=2, fpSize=2048)

def fp6144_from_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    fp0 = torch.tensor(list(GEN0.GetFingerprint(mol)), dtype=torch.float32)
    fp1 = torch.tensor(list(GEN1.GetFingerprint(mol)), dtype=torch.float32)
    fp2 = torch.tensor(list(GEN2.GetFingerprint(mol)), dtype=torch.float32)
    return torch.cat([fp0, fp1, fp2])  # (6144,)

# Model definition 
class MultiLineMLP5(nn.Module):
    def __init__(self,
                 input_dim=6144,
                 hidden_dims=[1024, 1024, 512, 512, 256],
                 num_lines=60,
                 num_classes=6,
                 p_drop=0.3):
        super().__init__()
        self.bn_in = nn.BatchNorm1d(input_dim)
        layers = []
        dims = [input_dim] + hidden_dims
        for d_in, d_out in zip(dims[:-1], dims[1:]):
            layers += [
                nn.Linear(d_in, d_out),
                nn.BatchNorm1d(d_out),
                nn.ReLU(),
                nn.Dropout(p_drop),
            ]
        self.shared = nn.Sequential(*layers)
        self.classifier = nn.Linear(hidden_dims[-1], num_lines * num_classes)

    def forward(self, x):
        x = self.bn_in(x)
        feat = self.shared(x)
        logits = self.classifier(feat)
        return logits.view(-1, 60, 6)
        
train_ds = JSONLPotencyDataset("train_resampled.jsonl.gz")
val_ds   = JSONLPotencyDataset("val.jsonl.gz")
test_ds  = JSONLPotencyDataset("test.jsonl.gz")

train_loader = DataLoader(train_ds, batch_size=64,
                          shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=64,
                          shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=64,
                          shuffle=False, num_workers=4, pin_memory=True)

# --- Training setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = MultiLineMLP5().to(device)

# No class-weights here:
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=2)

def full_val_accuracy(model, loader, device):
    model.eval()
    correct = np.zeros(60); total = np.zeros(60)
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(2)
            mask = y != -1
            correct += ((pred == y) & mask).sum(0).cpu().numpy()
            total   += mask.sum(0).cpu().numpy()
    accs = [c / t if t>0 else 0.0 for c, t in zip(correct, total)]
    return float(np.nanmean(accs)), accs

# --- Training loop ---
best_acc = 0.0
for epoch in range(1, 41):
    model.train()
    train_loss = 0.0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1,6), y.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        train_loss += loss.item()

    # validation
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            val_loss += criterion(
                model(x).view(-1,6), y.view(-1)
            ).item()

    scheduler.step(val_loss)
    avg_acc, _ = full_val_accuracy(model, val_loader, device)
    print(f"Epoch {epoch:02} | Train Loss {train_loss:.1f} "
          f"| Val Loss {val_loss:.1f} | Acc {avg_acc:.4f} "
          f"| LR {optimizer.param_groups[0]['lr']:.1e}")

    if avg_acc > best_acc:
        best_acc = avg_acc
        torch.save(model.state_dict(), "best_resampled.pt")
        print(f"  ✓ New best: {best_acc:.4f}")

print("Training complete. Best Avg Val Acc =", best_acc)
http://localhost:9993/home/nbilic/miniconda3/envs/Nandos/lib/python3.9/site-packages/tqdm/auto.py#line=20

In [None]:
# 1) Define the function (one cell)
import torch
import numpy as np

def compute_precision_recall(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=2)
            mask = (y != -1)
            all_preds.append( preds[mask].cpu().numpy().ravel() )
            all_labels.append(y[mask].cpu().numpy().ravel())
    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_labels)
    precision = np.zeros(6); recall = np.zeros(6)
    for c in range(6):
        tp = ((y_pred==c)&(y_true==c)).sum()
        fp = ((y_pred==c)&(y_true!=c)).sum()
        fn = ((y_pred!=c)&(y_true==c)).sum()
        precision[c] = tp/(tp+fp) if tp+fp>0 else np.nan
        recall[c]    = tp/(tp+fn) if tp+fn>0 else np.nan
    return precision, recall


In [None]:
# 2) Call it in a fresh cell once training is done
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("best_resampled.pt"))  # pick model
model.to(device)

precision, recall = compute_precision_recall(model, val_loader, device)
for c in range(6):
    print(f"class {c}:  prec={precision[c]:.3f},  rec={recall[c]:.3f}")


In [None]:
import torch, numpy as np
from collections import defaultdict

def macro_precision_recall(model, loader, device, num_lines=59):
    """
    Returns two 2-D arrays of shape (6 classes, 2 metrics):
        [exact , within-one] for precision and recall.
    """
    model.eval()
    # per-line confusion buckets
    exact = defaultdict(lambda: np.zeros((6, 3), dtype=int))       # TP, FP, FN
    within = defaultdict(lambda: np.zeros((6, 3), dtype=int))      # TPw, FPw, FNw
    
    with torch.no_grad():
        for x, y in loader:                         # y: (B, 60)
            x = x.to(device)
            logits = model(x)[:, :num_lines]        # (B, 59, 6)  drop 60th line
            preds  = logits.argmax(2).cpu().numpy()
            y      = y[:, :num_lines].cpu().numpy()
            
            for line_idx in range(num_lines):
                true_line   = y[:, line_idx]
                pred_line   = preds[:, line_idx]
                mask        = true_line != -1
                yt          = true_line[mask]
                yp          = pred_line[mask]
                if yt.size == 0:                    # line had no labels
                    continue
                
                for c in range(6):
                    tp = np.sum((yp == c) & (yt == c))
                    fp = np.sum((yp == c) & (yt != c))
                    fn = np.sum((yp != c) & (yt == c))
                    exact[line_idx][c] += (tp, fp, fn)
                    
                    # within-one: hit if |pred-true| <= 1
                    tp_w = np.sum((yp == c) & (np.abs(yt - c) <= 1))
                    fp_w = np.sum((yp == c) & (np.abs(yt - c) >  1))
                    fn_w = np.sum((yp != c) & (np.abs(yp - c) <= 1) & (yt == c))
                    within[line_idx][c] += (tp_w, fp_w, fn_w)
    
    # macro-average over cell lines (skip lines with zero support)
    prec = np.zeros((6, 2))
    rec  = np.zeros((6, 2))
    for c in range(6):
        # gather per-line metrics then average
        p_exact, r_exact, p_within, r_within = [], [], [], []
        for line in exact.keys():
            tp, fp, fn         = exact[line][c]
            tpw, fpw, fnw      = within[line][c]
            if tp + fn == 0:    # no true instances of class c in this line
                continue
            p_exact.append(  tp / (tp + fp) if tp+fp>0 else np.nan )
            r_exact.append(  tp / (tp + fn)                 )
            p_within.append( tpw / (tpw + fpw) if tpw+fpw>0 else np.nan )
            r_within.append( tpw / (tpw + fnw)               )
        prec[c, 0] = np.nanmean(p_exact)
        prec[c, 1] = np.nanmean(p_within)
        rec[c, 0]  = np.nanmean(r_exact)
        rec[c, 1]  = np.nanmean(r_within)
    return prec, rec

# -----------------------------------------------------------
# Example usage on your validation set
prec, rec = macro_precision_recall(model, val_loader, device)

headers = ["Exact P", "±1 P", "Exact R", "±1 R"]
print(f"{'Class':>6}  {headers[0]:>8}  {headers[1]:>8}  "
      f"{headers[2]:>8}  {headers[3]:>8}")
for c in range(6):
    print(f"{c:>6}  {prec[c,0]*100:8.1f}%  {prec[c,1]*100:8.1f}%  "
          f"{rec[c,0]*100:8.1f}%  {rec[c,1]*100:8.1f}%")


In [None]:
def micro_precision_recall(model, loader, device, num_lines=59):
    """
    Micro-averaged precision and recall across all cell lines and samples.
    Returns: (precision_exact, recall_exact), (precision_within1, recall_within1)
    """
    model.eval()
    y_true_all = []
    y_pred_all = []
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            logits = model(x)[:, :num_lines]     # (B, 59, 6)
            preds = logits.argmax(dim=2).cpu().numpy()   # (B, 59)
            y     = y[:, :num_lines].cpu().numpy()        # (B, 59)

            mask = (y != -1)
            y_true_all.append(y[mask])
            y_pred_all.append(preds[mask])

    y_true = np.concatenate(y_true_all)
    y_pred = np.concatenate(y_pred_all)

    precision = []
    recall    = []
    precision1 = []
    recall1    = []

    for c in range(6):
        tp  = np.sum((y_pred == c) & (y_true == c))
        fp  = np.sum((y_pred == c) & (y_true != c))
        fn  = np.sum((y_pred != c) & (y_true == c))

        tp1 = np.sum((y_pred == c) & (np.abs(y_true - c) <= 1))
        fp1 = np.sum((y_pred == c) & (np.abs(y_true - c) >  1))
        fn1 = np.sum((y_pred != c) & (y_true == c) & (np.abs(y_pred - c) <= 1))

        p  = tp  / (tp + fp) if (tp + fp) > 0 else np.nan
        r  = tp  / (tp + fn) if (tp + fn) > 0 else np.nan
        p1 = tp1 / (tp1 + fp1) if (tp1 + fp1) > 0 else np.nan
        r1 = tp1 / (tp1 + fn1) if (tp1 + fn1) > 0 else np.nan

        precision.append(p)
        recall.append(r)
        precision1.append(p1)
        recall1.append(r1)

    return np.array(precision), np.array(recall), np.array(precision1), np.array(recall1)


In [None]:
prec, rec, prec1, rec1 = micro_precision_recall(model, val_loader, device)

headers = ["Exact P", "±1 P", "Exact R", "±1 R"]
print(f"{'Class':>6}  {headers[0]:>8}  {headers[1]:>8}  "
      f"{headers[2]:>8}  {headers[3]:>8}")
for c in range(6):
    print(f"{c:>6}  {prec[c]*100:8.1f}%  {prec1[c]*100:8.1f}%  "
          f"{rec[c]*100:8.1f}%  {rec1[c]*100:8.1f}%")
