In [7]:
import gzip, json, torch
from torch.utils.data import Dataset
from rdkit import Chem
from rdkit.Chem import AllChem

GEN0 = AllChem.GetMorganGenerator(radius=0, fpSize=2048)
GEN1 = AllChem.GetMorganGenerator(radius=1, fpSize=2048)
GEN2 = AllChem.GetMorganGenerator(radius=2, fpSize=2048)

def fp6144_from_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    fp0 = torch.tensor(list(GEN0.GetFingerprint(mol)), dtype=torch.float32)
    fp1 = torch.tensor(list(GEN1.GetFingerprint(mol)), dtype=torch.float32)
    fp2 = torch.tensor(list(GEN2.GetFingerprint(mol)), dtype=torch.float32)
    return torch.cat([fp0, fp1, fp2])        # shape (6144,)

class JSONLPotencyDataset(Dataset):
    """
    Loads the new *.jsonl.gz split files created in the parsing notebook.
    Each line contains {"smiles": ..., "label_vector": [...] }.
    """
    def __init__(self, path_to_jsonl_gz):
        with gzip.open(path_to_jsonl_gz, "rt") as f:
            self.records = [(rec["smiles"], rec["label_vector"])
                            for rec in (json.loads(l) for l in f)]

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        smiles, label_vec = self.records[idx]
        fp = fp6144_from_smiles(smiles)

        # fall‑through if RDKit fails on this SMILES
        while fp is None:
            idx = (idx + 1) % len(self.records)
            smiles, label_vec = self.records[idx]
            fp = fp6144_from_smiles(smiles)

        labels = torch.tensor(label_vec, dtype=torch.long)   # (60,)
        return fp, labels




In [8]:
import torch.nn as nn
import torch.nn.functional as F

class MultiLineMLP(nn.Module):
    def __init__(self,
                 input_dim=6144,
                 hidden_dim=768,     # ↑ from 256
                 num_lines=60,
                 num_classes=6):
        super().__init__()
        self.bn_in = nn.BatchNorm1d(input_dim)           # NEW

        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.6),
        )

        # one big classifier → (batch, 60, 6)
        self.classifier = nn.Linear(hidden_dim, num_lines * num_classes)

    def forward(self, x):                                # x: (batch, 6144)
        x = self.bn_in(x)
        feat = self.shared(x)                            # (batch, 768)
        logits = self.classifier(feat)                   # (batch, 360)
        return logits.view(-1, 60, 6)                    # (batch, 60, 6)




In [9]:
import torch, json, numpy as np
from torch.utils.data import DataLoader
import torch.optim as optim
from rdkit import RDLogger

# Silence RDKit chatter
RDLogger.DisableLog("rdApp.*")

train_dataset = JSONLPotencyDataset("train.jsonl.gz")
val_dataset   = JSONLPotencyDataset("val.jsonl.gz")

train_loader  = DataLoader(train_dataset, batch_size=64,
                           shuffle=True, num_workers=4, pin_memory=True)
val_loader    = DataLoader(val_dataset,   batch_size=64,
                           num_workers=4, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MultiLineMLP().to(device)   # the new BN + shared‑head version


In [10]:
# TRAINING cell (class weights, full accuracy, scheduler, checkpoint)
from tqdm.auto import tqdm
import torch.nn.functional as F

# ---------- 1.  class‑imbalance weights  ----------------------------------
hist = torch.zeros(6)
for _, labels in train_loader:
    mask = labels != -1
    for c in range(6):
        hist[c] += ((labels == c) & mask).sum()

weights = 1.0 / (hist + 1e-6)
weights = (weights / weights.sum()) * 6
weights = weights.to(torch.float32).to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=-1, weight=weights)

# ---------- 2.  full‑set validation accuracy ------------------------------
def full_val_accuracy(model, loader, device):
    correct = np.zeros(60); total = np.zeros(60)
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            p    = model(x).argmax(2)
            m    = y != -1
            correct += ((p == y) & m).sum(0).cpu().numpy()
            total   += m.sum(0).cpu().numpy()
    accs = [c / t if t > 0 else None for c, t in zip(correct, total)]
    return float(np.nanmean(accs)), accs

# ---------- 3.  optimiser, scheduler, loop --------------------------------
optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=2)

best_acc = 0.0
num_epochs = 40
for epoch in range(1, num_epochs + 1):
    # TRAIN
    model.train()
    epoch_loss = 0.0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch}", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss   = criterion(logits.view(-1, 6), y.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        epoch_loss += loss.item()

    # VALIDATE
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            val_loss += criterion(
                model(x).view(-1, 6), y.view(-1)
            ).item()

    scheduler.step(val_loss)
    avg_acc, _ = full_val_accuracy(model, val_loader, device)
    print(f"Epoch {epoch:02}/{num_epochs} | "
          f"Train Loss: {epoch_loss:.1f} | Val Loss: {val_loss:.1f} | "
          f"Avg Val Acc: {avg_acc:.4f} | "
          f"LR: {optimizer.param_groups[0]['lr']:.1e}")

    if avg_acc > best_acc:
        best_acc = avg_acc
        torch.save(model.state_dict(), "best_model.pt")
        print(f"  ✓ saved new best ({best_acc:.4f})")

print(f"Done. Best Avg Val Acc = {best_acc:.4f}")


                                                                                

Epoch 01/40 | Train Loss: 1058.4 | Val Loss: 153.6 | Avg Val Acc: 0.4358 | LR: 5.0e-04
  ✓ saved new best (0.4358)


                                                                                

Epoch 02/40 | Train Loss: 861.9 | Val Loss: 178.1 | Avg Val Acc: 0.4139 | LR: 5.0e-04


                                                                                

Epoch 03/40 | Train Loss: 768.5 | Val Loss: 171.1 | Avg Val Acc: 0.4650 | LR: 5.0e-04
  ✓ saved new best (0.4650)


                                                                                

Epoch 04/40 | Train Loss: 704.2 | Val Loss: 207.4 | Avg Val Acc: 0.4658 | LR: 2.5e-04
  ✓ saved new best (0.4658)


                                                                                

Epoch 05/40 | Train Loss: 623.3 | Val Loss: 202.4 | Avg Val Acc: 0.4676 | LR: 2.5e-04
  ✓ saved new best (0.4676)


                                                                                

Epoch 06/40 | Train Loss: 584.8 | Val Loss: 206.7 | Avg Val Acc: 0.4653 | LR: 2.5e-04


                                                                                

Epoch 07/40 | Train Loss: 561.0 | Val Loss: 214.3 | Avg Val Acc: 0.4779 | LR: 1.3e-04
  ✓ saved new best (0.4779)


                                                                                

Epoch 08/40 | Train Loss: 522.4 | Val Loss: 220.3 | Avg Val Acc: 0.4859 | LR: 1.3e-04
  ✓ saved new best (0.4859)


                                                                                

Epoch 09/40 | Train Loss: 506.1 | Val Loss: 220.8 | Avg Val Acc: 0.4825 | LR: 1.3e-04


                                                                                

Epoch 10/40 | Train Loss: 497.7 | Val Loss: 221.1 | Avg Val Acc: 0.4891 | LR: 6.3e-05
  ✓ saved new best (0.4891)


                                                                                

Epoch 11/40 | Train Loss: 476.5 | Val Loss: 222.9 | Avg Val Acc: 0.4908 | LR: 6.3e-05
  ✓ saved new best (0.4908)


                                                                                

Epoch 12/40 | Train Loss: 469.1 | Val Loss: 223.5 | Avg Val Acc: 0.4885 | LR: 6.3e-05


                                                                                

Epoch 13/40 | Train Loss: 463.8 | Val Loss: 227.9 | Avg Val Acc: 0.4893 | LR: 3.1e-05


                                                                                

Epoch 14/40 | Train Loss: 457.9 | Val Loss: 226.5 | Avg Val Acc: 0.4904 | LR: 3.1e-05


                                                                                

Epoch 15/40 | Train Loss: 452.1 | Val Loss: 235.3 | Avg Val Acc: 0.4858 | LR: 3.1e-05


                                                                                

Epoch 16/40 | Train Loss: 450.3 | Val Loss: 224.0 | Avg Val Acc: 0.4924 | LR: 1.6e-05
  ✓ saved new best (0.4924)


                                                                                

Epoch 17/40 | Train Loss: 446.7 | Val Loss: 218.2 | Avg Val Acc: 0.4928 | LR: 1.6e-05
  ✓ saved new best (0.4928)


                                                                                

Epoch 18/40 | Train Loss: 444.1 | Val Loss: 226.8 | Avg Val Acc: 0.5004 | LR: 1.6e-05
  ✓ saved new best (0.5004)


                                                                                

Epoch 19/40 | Train Loss: 443.1 | Val Loss: 227.5 | Avg Val Acc: 0.4936 | LR: 7.8e-06


                                                                                

Epoch 20/40 | Train Loss: 440.3 | Val Loss: 237.7 | Avg Val Acc: 0.4945 | LR: 7.8e-06


                                                                                

Epoch 21/40 | Train Loss: 440.9 | Val Loss: 233.6 | Avg Val Acc: 0.4957 | LR: 7.8e-06


                                                                                

Epoch 22/40 | Train Loss: 439.4 | Val Loss: 231.9 | Avg Val Acc: 0.4986 | LR: 3.9e-06


                                                                                

Epoch 23/40 | Train Loss: 436.0 | Val Loss: 229.3 | Avg Val Acc: 0.4946 | LR: 3.9e-06


                                                                                

Epoch 24/40 | Train Loss: 437.8 | Val Loss: 231.6 | Avg Val Acc: 0.4950 | LR: 3.9e-06


                                                                                

Epoch 25/40 | Train Loss: 438.3 | Val Loss: 236.5 | Avg Val Acc: 0.4987 | LR: 2.0e-06


                                                                                

Epoch 26/40 | Train Loss: 436.2 | Val Loss: 227.5 | Avg Val Acc: 0.4862 | LR: 2.0e-06


                                                                                

Epoch 27/40 | Train Loss: 436.0 | Val Loss: 222.5 | Avg Val Acc: 0.4970 | LR: 2.0e-06


                                                                                

Epoch 28/40 | Train Loss: 436.9 | Val Loss: 222.2 | Avg Val Acc: 0.4934 | LR: 9.8e-07


                                                                                

Epoch 29/40 | Train Loss: 436.7 | Val Loss: 226.5 | Avg Val Acc: 0.4927 | LR: 9.8e-07


                                                                                

Epoch 30/40 | Train Loss: 436.9 | Val Loss: 221.3 | Avg Val Acc: 0.4962 | LR: 9.8e-07


                                                                                

Epoch 31/40 | Train Loss: 435.1 | Val Loss: 221.8 | Avg Val Acc: 0.5009 | LR: 4.9e-07
  ✓ saved new best (0.5009)


                                                                                

Epoch 32/40 | Train Loss: 437.5 | Val Loss: 226.9 | Avg Val Acc: 0.4997 | LR: 4.9e-07


                                                                                

Epoch 33/40 | Train Loss: 436.0 | Val Loss: 237.5 | Avg Val Acc: 0.4902 | LR: 4.9e-07


                                                                                

Epoch 34/40 | Train Loss: 434.3 | Val Loss: 238.0 | Avg Val Acc: 0.4957 | LR: 2.4e-07


                                                                                

Epoch 35/40 | Train Loss: 435.6 | Val Loss: 228.2 | Avg Val Acc: 0.4929 | LR: 2.4e-07


                                                                                

Epoch 36/40 | Train Loss: 436.3 | Val Loss: 240.6 | Avg Val Acc: 0.4962 | LR: 2.4e-07


                                                                                

Epoch 37/40 | Train Loss: 435.8 | Val Loss: 223.1 | Avg Val Acc: 0.5047 | LR: 1.2e-07
  ✓ saved new best (0.5047)


                                                                                

Epoch 38/40 | Train Loss: 436.1 | Val Loss: 227.3 | Avg Val Acc: 0.4970 | LR: 1.2e-07


                                                                                

Epoch 39/40 | Train Loss: 437.0 | Val Loss: 240.2 | Avg Val Acc: 0.4975 | LR: 1.2e-07


                                                                                

Epoch 40/40 | Train Loss: 436.0 | Val Loss: 240.7 | Avg Val Acc: 0.4953 | LR: 6.1e-08
Done. Best Avg Val Acc = 0.5047


In [None]:
# OLD Training
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from rdkit import RDLogger

# Hide all RDKit warnings (keep “error” if you still want fatal messages)
RDLogger.DisableLog("rdApp.*")     # most common choice
# RDLogger.DisableLog("rdApp.*")         # silence everything

def accuracy_by_cell_line(logits, labels):
    """
    logits: (batch, 60, 6)
    labels: (batch, 60)
    """
    preds = logits.argmax(dim=2)          # (batch, 60)
    correct = ((preds == labels) & (labels != -1))
    total   = (labels != -1)

    correct_by_line = correct.sum(dim=0).cpu().numpy()   # length 60
    total_by_line   = total.sum(dim=0).cpu().numpy()

    accuracies = [
        (c / t) if t > 0 else None
        for c, t in zip(correct_by_line, total_by_line)
    ]
    return accuracies



# Load datasets
train_dataset = JSONLPotencyDataset("train.jsonl.gz")
val_dataset   = JSONLPotencyDataset("val.jsonl.gz")


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)

# Initialize model, loss, optimizer
model = MultiLineMLP()
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.Adam(model.parameters(),
                       lr=5e-4,                # ↓ from 1e‑3
                       weight_decay=1e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=2,      
)


criterion = nn.CrossEntropyLoss(ignore_index=-1)

num_epochs = 10
for epoch in range(num_epochs):
    # -------------------- TRAIN --------------------
    model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        logits = model(inputs)                   # (B, 60, 6)
        loss = criterion(
            logits.view(-1, 6),                  # (B*60, 6)
            labels.view(-1)                      # (B*60,)
        )

        if torch.isnan(loss):
            raise RuntimeError(f"NaN detected at epoch {epoch}")

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

        running_loss += loss.item()

    # ------------------ VALIDATION -----------------
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits = model(inputs)
            val_loss += criterion(
                logits.view(-1, 6),
                labels.view(-1)
            ).item()

    scheduler.step(val_loss)

    # accuracy over the **last batch** (quick view)
    val_accs = accuracy_by_cell_line(logits, labels)
    avg_val_acc = (
        sum(a for a in val_accs if a is not None) /
        len([a for a in val_accs if a is not None])
    )

    print(f"Epoch {epoch+1}/{num_epochs} | "
          f"Train Loss: {running_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"Avg Val Acc: {avg_val_acc:.4f} | "
          f"LR: {optimizer.param_groups[0]['lr']:.2e}")


In [19]:
import os, json

sample_path = os.path.join("train", os.listdir("train")[0])
with open(sample_path) as f:
    rec = json.load(f)

print("File:", sample_path)
print("Keys:", rec.keys())
print("SMILES:", rec.get("SMILES"))
print("Label:", rec.get("label"))
print("Label type:", type(rec.get("label")))
if isinstance(rec["label"], list):
    print("Label length:", len(rec["label"]))


File: train/787283.json
Keys: dict_keys(['NSC', 'SMILES', 'mol_concentration', 'cancer_type', 'potency', 'label'])
SMILES: CCC1=Nc2cc(\C=C\C(=O)NO)ccc2C(=O)N1CCc3ccc(OC)cc3
Label: 2
Label type: <class 'int'>
