In [3]:
import gzip, json, torch
import torch.nn as nn
from torch.utils.data import Dataset
from rdkit import Chem
from rdkit.Chem import AllChem
# -- Morgan‑FP helpers (unchanged) -----------------------------------------
GEN0 = AllChem.GetMorganGenerator(radius=0, fpSize=2048)
GEN1 = AllChem.GetMorganGenerator(radius=1, fpSize=2048)
GEN2 = AllChem.GetMorganGenerator(radius=2, fpSize=2048)

def fp6144_from_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    fp0 = torch.tensor(list(GEN0.GetFingerprint(mol)), dtype=torch.float32)
    fp1 = torch.tensor(list(GEN1.GetFingerprint(mol)), dtype=torch.float32)
    fp2 = torch.tensor(list(GEN2.GetFingerprint(mol)), dtype=torch.float32)
    return torch.cat([fp0, fp1, fp2])        # shape (6144,)

# -- NEW dataset class -----------------------------------------------------
class JSONLPotencyDataset(Dataset):
    """
    Loads the new *.jsonl.gz split files created in the parsing notebook.
    Each line contains {"smiles": ..., "label_vector": [...] }.
    """
    def __init__(self, path_to_jsonl_gz):
        with gzip.open(path_to_jsonl_gz, "rt") as f:
            self.records = [(rec["smiles"], rec["label_vector"])
                            for rec in (json.loads(l) for l in f)]

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        smiles, label_vec = self.records[idx]
        fp = fp6144_from_smiles(smiles)

        # fall‑through if RDKit fails on this SMILES
        while fp is None:
            idx = (idx + 1) % len(self.records)
            smiles, label_vec = self.records[idx]
            fp = fp6144_from_smiles(smiles)

        labels = torch.tensor(label_vec, dtype=torch.long)   # (60,)
        return fp, labels

In [4]:
class MultiLineMLP5(nn.Module):
    def __init__(self,
                 input_dim=6144,
                 hidden_dims=[1024, 1024, 512, 512, 256],
                 num_lines=60,
                 num_classes=6,
                 p_drop=0.3):
        super().__init__()

        self.bn_in = nn.BatchNorm1d(input_dim)

        layers = []
        dims = [input_dim] + hidden_dims
        for d_in, d_out in zip(dims[:-1], dims[1:]):
            layers.extend([
                nn.Linear(d_in, d_out),
                nn.BatchNorm1d(d_out),
                nn.ReLU(),
                nn.Dropout(p_drop)
            ])
        self.shared = nn.Sequential(*layers)

        self.classifier = nn.Linear(hidden_dims[-1], num_lines * num_classes)

    def forward(self, x):
        x = self.bn_in(x)
        x = self.shared(x)
        logits = self.classifier(x)
        return logits.view(-1, 60, 6)



In [5]:
import torch, json, numpy as np
from torch.utils.data import DataLoader
import torch.optim as optim
from rdkit import RDLogger

# Silence RDKit chatter
RDLogger.DisableLog("rdApp.*")

# ---------- DATA ----------------------------------------------------------
train_dataset = JSONLPotencyDataset("train.jsonl.gz")
val_dataset   = JSONLPotencyDataset("val.jsonl.gz")

train_loader  = DataLoader(train_dataset, batch_size=64,
                           shuffle=True, num_workers=4, pin_memory=True)
val_loader    = DataLoader(val_dataset,   batch_size=64,
                           num_workers=4, pin_memory=True)

# ---------- MODEL ---------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MultiLineMLP5().to(device)


In [6]:
from tqdm.auto import tqdm
import torch.nn.functional as F

# ---------- 1.  class‑imbalance weights  ----------------------------------
hist = torch.zeros(6)
for _, labels in train_loader:
    mask = labels != -1
    for c in range(6):
        hist[c] += ((labels == c) & mask).sum()

weights = 1.0 / (hist + 1e-6)
weights = (weights / weights.sum()) * 6
weights = weights.to(torch.float32).to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=weights)

# ---------- 2.  full‑set validation accuracy ------------------------------
def full_val_accuracy(model, loader, device):
    correct = np.zeros(60); total = np.zeros(60)
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            p    = model(x).argmax(2)
            m    = y != -1
            correct += ((p == y) & m).sum(0).cpu().numpy()
            total   += m.sum(0).cpu().numpy()
    accs = [c / t if t > 0 else None for c, t in zip(correct, total)]
    return float(np.nanmean(accs)), accs

# ---------- 3.  optimiser, scheduler, loop --------------------------------
optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=2)

best_acc = 0.0
num_epochs = 40
for epoch in range(1, num_epochs + 1):
    # TRAIN
    model.train()
    epoch_loss = 0.0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss   = criterion(logits.view(-1, 6), y.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        epoch_loss += loss.item()

    # VALIDATE
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            val_loss += criterion(
                model(x).view(-1, 6), y.view(-1)
            ).item()

    scheduler.step(val_loss)
    avg_acc, _ = full_val_accuracy(model, val_loader, device)
    print(f"Epoch {epoch:02}/{num_epochs} | "
          f"Train Loss: {epoch_loss:.1f} | Val Loss: {val_loss:.1f} | "
          f"Avg Val Acc: {avg_acc:.4f} | "
          f"LR: {optimizer.param_groups[0]['lr']:.1e}")

    if avg_acc > best_acc:
        best_acc = avg_acc
        torch.save(model.state_dict(), "best_model.pt")
        print(f"  ✓ saved new best ({best_acc:.4f})")

print(f"Done. Best Avg Val Acc = {best_acc:.4f}")


  from .autonotebook import tqdm as notebook_tqdm
                                                                                           

Epoch 01/40 | Train Loss: 1082.8 | Val Loss: 158.7 | Avg Val Acc: 0.3560 | LR: 5.0e-04
  ✓ saved new best (0.3560)


                                                                                           

Epoch 02/40 | Train Loss: 913.2 | Val Loss: 171.2 | Avg Val Acc: 0.3709 | LR: 5.0e-04
  ✓ saved new best (0.3709)


                                                                                           

Epoch 03/40 | Train Loss: 844.2 | Val Loss: 184.7 | Avg Val Acc: 0.3666 | LR: 5.0e-04


                                                                                           

Epoch 04/40 | Train Loss: 806.9 | Val Loss: 173.7 | Avg Val Acc: 0.3578 | LR: 2.5e-04


                                                                                           

Epoch 05/40 | Train Loss: 727.3 | Val Loss: 195.9 | Avg Val Acc: 0.3676 | LR: 2.5e-04


                                                                                           

Epoch 06/40 | Train Loss: 697.2 | Val Loss: 194.4 | Avg Val Acc: 0.3735 | LR: 2.5e-04
  ✓ saved new best (0.3735)


                                                                                           

Epoch 07/40 | Train Loss: 676.0 | Val Loss: 190.0 | Avg Val Acc: 0.3504 | LR: 1.3e-04


                                                                                           

Epoch 08/40 | Train Loss: 630.7 | Val Loss: 198.6 | Avg Val Acc: 0.3663 | LR: 1.3e-04


                                                                                           

Epoch 09/40 | Train Loss: 609.4 | Val Loss: 184.7 | Avg Val Acc: 0.4004 | LR: 1.3e-04
  ✓ saved new best (0.4004)


                                                                                           

Epoch 10/40 | Train Loss: 596.5 | Val Loss: 191.0 | Avg Val Acc: 0.3908 | LR: 6.3e-05


                                                                                           

Epoch 11/40 | Train Loss: 573.1 | Val Loss: 195.9 | Avg Val Acc: 0.3993 | LR: 6.3e-05


                                                                                           

Epoch 12/40 | Train Loss: 563.2 | Val Loss: 206.4 | Avg Val Acc: 0.3773 | LR: 6.3e-05


                                                                                           

Epoch 13/40 | Train Loss: 553.6 | Val Loss: 199.2 | Avg Val Acc: 0.3936 | LR: 3.1e-05


                                                                                           

Epoch 14/40 | Train Loss: 542.3 | Val Loss: 195.8 | Avg Val Acc: 0.3957 | LR: 3.1e-05


                                                                                           

Epoch 15/40 | Train Loss: 538.5 | Val Loss: 204.9 | Avg Val Acc: 0.3939 | LR: 3.1e-05


                                                                                           

Epoch 16/40 | Train Loss: 532.0 | Val Loss: 190.3 | Avg Val Acc: 0.4060 | LR: 1.6e-05
  ✓ saved new best (0.4060)


                                                                                           

Epoch 17/40 | Train Loss: 526.8 | Val Loss: 199.5 | Avg Val Acc: 0.3907 | LR: 1.6e-05


                                                                                           

Epoch 18/40 | Train Loss: 523.8 | Val Loss: 193.4 | Avg Val Acc: 0.4123 | LR: 1.6e-05
  ✓ saved new best (0.4123)


                                                                                           

Epoch 19/40 | Train Loss: 523.5 | Val Loss: 197.1 | Avg Val Acc: 0.4083 | LR: 7.8e-06


                                                                                           

Epoch 20/40 | Train Loss: 519.6 | Val Loss: 202.9 | Avg Val Acc: 0.4047 | LR: 7.8e-06


                                                                                           

Epoch 21/40 | Train Loss: 516.9 | Val Loss: 186.5 | Avg Val Acc: 0.4226 | LR: 7.8e-06
  ✓ saved new best (0.4226)


                                                                                           

Epoch 22/40 | Train Loss: 518.2 | Val Loss: 194.7 | Avg Val Acc: 0.4059 | LR: 3.9e-06


                                                                                           

Epoch 23/40 | Train Loss: 515.1 | Val Loss: 192.1 | Avg Val Acc: 0.4229 | LR: 3.9e-06
  ✓ saved new best (0.4229)


                                                                                           

Epoch 24/40 | Train Loss: 514.6 | Val Loss: 192.8 | Avg Val Acc: 0.4235 | LR: 3.9e-06
  ✓ saved new best (0.4235)


                                                                                           

Epoch 25/40 | Train Loss: 512.2 | Val Loss: 185.2 | Avg Val Acc: 0.4129 | LR: 2.0e-06


                                                                                           

Epoch 26/40 | Train Loss: 513.6 | Val Loss: 199.2 | Avg Val Acc: 0.4135 | LR: 2.0e-06


                                                                                           

Epoch 27/40 | Train Loss: 514.1 | Val Loss: 191.1 | Avg Val Acc: 0.4164 | LR: 2.0e-06


                                                                                           

Epoch 28/40 | Train Loss: 513.3 | Val Loss: 194.2 | Avg Val Acc: 0.4194 | LR: 9.8e-07


                                                                                           

Epoch 29/40 | Train Loss: 511.5 | Val Loss: 198.4 | Avg Val Acc: 0.3885 | LR: 9.8e-07


                                                                                           

Epoch 30/40 | Train Loss: 512.9 | Val Loss: 196.4 | Avg Val Acc: 0.4184 | LR: 9.8e-07


                                                                                           

Epoch 31/40 | Train Loss: 511.9 | Val Loss: 206.3 | Avg Val Acc: 0.4170 | LR: 4.9e-07


                                                                                           

Epoch 32/40 | Train Loss: 511.3 | Val Loss: 189.8 | Avg Val Acc: 0.4167 | LR: 4.9e-07


Exception in thread Thread-136:██████████████████▋       | 603/730 [00:42<00:07, 17.77it/s]
Traceback (most recent call last):
  File "/home/nbilic/miniconda3/envs/Nandos/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/home/nbilic/miniconda3/envs/Nandos/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/home/nbilic/miniconda3/envs/Nandos/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/home/nbilic/miniconda3/envs/Nandos/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 54, in _pin_memory_loop
    do_one_step()
  File "/home/nbilic/miniconda3/envs/Nandos/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 31, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/home/nbilic/miniconda3/envs/Nandos/lib/python3.9/multiprocessing/queues.py", line 122, in get
    return _F

KeyboardInterrupt: 