In [None]:
import gzip, json, torch
from torch.utils.data import Dataset
from rdkit import Chem
from rdkit.Chem import AllChem

GEN0 = AllChem.GetMorganGenerator(radius=0, fpSize=2048)
GEN1 = AllChem.GetMorganGenerator(radius=1, fpSize=2048)
GEN2 = AllChem.GetMorganGenerator(radius=2, fpSize=2048)

def fp6144_from_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    fp0 = torch.tensor(list(GEN0.GetFingerprint(mol)), dtype=torch.float32)
    fp1 = torch.tensor(list(GEN1.GetFingerprint(mol)), dtype=torch.float32)
    fp2 = torch.tensor(list(GEN2.GetFingerprint(mol)), dtype=torch.float32)
    return torch.cat([fp0, fp1, fp2])        # shape (6144,)

class JSONLPotencyDataset(Dataset):
    """
    Loads the new *.jsonl.gz split files created in the parsing notebook.
    Each line contains {"smiles": ..., "label_vector": [...] }.
    """
    def __init__(self, path_to_jsonl_gz):
        with gzip.open(path_to_jsonl_gz, "rt") as f:
            self.records = [(rec["smiles"], rec["label_vector"])
                            for rec in (json.loads(l) for l in f)]

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        smiles, label_vec = self.records[idx]
        fp = fp6144_from_smiles(smiles)

        # fall‑through if RDKit fails on this SMILES
        while fp is None:
            idx = (idx + 1) % len(self.records)
            smiles, label_vec = self.records[idx]
            fp = fp6144_from_smiles(smiles)

        labels = torch.tensor(label_vec, dtype=torch.long)   # (60,)
        return fp, labels




In [None]:
import torch.nn as nn
import torch.nn.functional as F

class MultiLineMLP(nn.Module):
    def __init__(self,
                 input_dim=6144,
                 hidden_dim=768,     # ↑ from 256
                 num_lines=60,
                 num_classes=6):
        super().__init__()
        self.bn_in = nn.BatchNorm1d(input_dim)           # NEW

        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.6),
        )

        # one big classifier → (batch, 60, 6)
        self.classifier = nn.Linear(hidden_dim, num_lines * num_classes)

    def forward(self, x):                                # x: (batch, 6144)
        x = self.bn_in(x)
        feat = self.shared(x)                            # (batch, 768)
        logits = self.classifier(feat)                   # (batch, 360)
        return logits.view(-1, 60, 6)                    # (batch, 60, 6)




In [None]:
import torch, json, numpy as np
from torch.utils.data import DataLoader
import torch.optim as optim
from rdkit import RDLogger

# Silence RDKit chatter
RDLogger.DisableLog("rdApp.*")

train_dataset = JSONLPotencyDataset("train.jsonl.gz")
val_dataset   = JSONLPotencyDataset("val.jsonl.gz")

train_loader  = DataLoader(train_dataset, batch_size=64,
                           shuffle=True, num_workers=4, pin_memory=True)
val_loader    = DataLoader(val_dataset,   batch_size=64,
                           num_workers=4, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MultiLineMLP().to(device)   # the new BN + shared‑head version


In [None]:
# TRAINING cell 
from tqdm.auto import tqdm
import torch.nn.functional as F

# class‑imbalance weights
hist = torch.zeros(6)
for _, labels in train_loader:
    mask = labels != -1
    for c in range(6):
        hist[c] += ((labels == c) & mask).sum()

weights = 1.0 / (hist + 1e-6)
weights = (weights / weights.sum()) * 6
weights = weights.to(torch.float32).to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=weights)

# full‑set validation accuracy 
def full_val_accuracy(model, loader, device):
    correct = np.zeros(60); total = np.zeros(60)
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            p    = model(x).argmax(2)
            m    = y != -1
            correct += ((p == y) & m).sum(0).cpu().numpy()
            total   += m.sum(0).cpu().numpy()
    accs = [c / t if t > 0 else None for c, t in zip(correct, total)]
    return float(np.nanmean(accs)), accs

#  optimiser, scheduler, loop
optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=2)

best_acc = 0.0
num_epochs = 40
for epoch in range(1, num_epochs + 1):
    # TRAIN
    model.train()
    epoch_loss = 0.0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss   = criterion(logits.view(-1, 6), y.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        epoch_loss += loss.item()

    # VALIDATE
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            val_loss += criterion(
                model(x).view(-1, 6), y.view(-1)
            ).item()

    scheduler.step(val_loss)
    avg_acc, _ = full_val_accuracy(model, val_loader, device)
    print(f"Epoch {epoch:02}/{num_epochs} | "
          f"Train Loss: {epoch_loss:.1f} | Val Loss: {val_loss:.1f} | "
          f"Avg Val Acc: {avg_acc:.4f} | "
          f"LR: {optimizer.param_groups[0]['lr']:.1e}")

    if avg_acc > best_acc:
        best_acc = avg_acc
        torch.save(model.state_dict(), "best_model.pt")
        print(f"  ✓ saved new best ({best_acc:.4f})")

print(f"Done. Best Avg Val Acc = {best_acc:.4f}")
