In [1]:
# 1. Imports and Setup
import gzip, json, torch, numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn, torch.optim as optim
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 2. Morgan Fingerprint Helper
GEN0 = AllChem.GetMorganGenerator(radius=0, fpSize=2048)
GEN1 = AllChem.GetMorganGenerator(radius=1, fpSize=2048)
GEN2 = AllChem.GetMorganGenerator(radius=2, fpSize=2048)

def fp6144_from_smiles(s):
    m = Chem.MolFromSmiles(s)
    if m is None: return None
    f0 = torch.tensor(list(GEN0.GetFingerprint(m)),dtype=torch.float32)
    f1 = torch.tensor(list(GEN1.GetFingerprint(m)),dtype=torch.float32)
    f2 = torch.tensor(list(GEN2.GetFingerprint(m)),dtype=torch.float32)
    return torch.cat([f0,f1,f2])  # (6144,)


In [3]:
# 3. Dataset Class
class JSONLPotencyDataset(Dataset):
    def __init__(self, path):
        self.records = []
        with gzip.open(path,'rt') as f:
            for line in f:
                r = json.loads(line)
                self.records.append((r['smiles'],r['label_vector']))
    def __len__(self): return len(self.records)
    def __getitem__(self,i):
        s, lbl = self.records[i]
        fp = fp6144_from_smiles(s)
        while fp is None:
            i = (i+1)%len(self)
            s, lbl = self.records[i]
            fp = fp6144_from_smiles(s)
        return fp, torch.tensor(lbl,dtype=torch.long)


In [5]:
# 4. Model Definition
class MultiLineMLP(nn.Module):
    def __init__(self, input_dim=6144, hidden_dims=[1024,1024,512,512,256],
                 num_lines=60, num_classes=6, p_drop=0.3):
        super().__init__()
        self.bn = nn.BatchNorm1d(input_dim)
        layers=[]
        dims=[input_dim]+hidden_dims
        for a,b in zip(dims,dims[1:]):
            layers += [nn.Linear(a,b), nn.BatchNorm1d(b), nn.ReLU(), nn.Dropout(p_drop)]
        self.shared = nn.Sequential(*layers)
        self.classifier = nn.Linear(hidden_dims[-1], num_lines*num_classes)
    def forward(self,x):
        x=self.bn(x); f=self.shared(x)
        out=self.classifier(f)
        return out.view(out.size(0),-1,6)

# 4b. Training Options
# You can choose model depth, sampling strategy, and whether to use weighted loss.
MODEL_CONFIGS = {
    '5layer': [1024, 1024, 512, 512, 256]
}
MODEL_CONFIG = MODEL_CONFIGS['5layer']  # always 5-layer
SAMPLING_METHOD = 'weighted_sampler'  # options: 'none', 'resampled', 'weighted_sampler'
WEIGHTED_LOSS   = True     # True to use class weights in loss, False otherwise


In [14]:
# 5. Configuration
from rdkit import RDLogger

# Silence RDKit chatter
RDLogger.DisableLog("rdApp.*")

TRAIN_JSON = 'train.jsonl.gz'
VAL_JSON   = 'val.jsonl.gz'
BATCH      = 64
LR         = 5e-4
WD         = 1e-4
EPOCHS     = 20
DROP       = 0.3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 6. Instantiate Dataset and DataLoader with Sampling
train_ds = JSONLPotencyDataset(TRAIN_JSON)
val_ds   = JSONLPotencyDataset(VAL_JSON)

if SAMPLING_METHOD == 'resampled':
    # compute sample weights inversely proportional to class freq
    all_labels = torch.cat([lbl for _,lbl in DataLoader(train_ds, batch_size=BATCH)])
    mask = all_labels != -1
    freq = torch.bincount(all_labels[mask], minlength=6).float()
    inv = 1.0 / (freq + 1e-6)
    sample_weights = []
    for _, lbl in train_ds:
        w = inv[lbl[lbl>=0]].mean().item()
        sample_weights.append(w)
    sampler = torch.utils.data.WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
    train_loader = DataLoader(train_ds, batch_size=BATCH, sampler=sampler)
elif SAMPLING_METHOD == 'weighted_sampler':
    # use sqrt-inverse weighting
    all_labels = torch.cat([lbl for _,lbl in DataLoader(train_ds, batch_size=BATCH)])
    mask = all_labels != -1
    freq = torch.bincount(all_labels[mask], minlength=6).float()
    inv_sqrt = 1.0 / torch.sqrt(freq + 1e-6)
    inv_sqrt /= inv_sqrt.mean()
    sample_weights = []
    for _, lbl in train_ds:
        w = inv_sqrt[lbl[lbl>=0]].mean().item()
        sample_weights.append(w)
    sampler = torch.utils.data.WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
    train_loader = DataLoader(train_ds, batch_size=BATCH, sampler=sampler)
else:
    train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True)

val_loader = DataLoader(val_ds, batch_size=BATCH)

# 7. Loss & Optimizer Setup
# instantiate model
dims = MODEL_CONFIG
model = MultiLineMLP(hidden_dims=dims, p_drop=DROP).to(device)

# criterion
if WEIGHTED_LOSS:
    # compute class weights from training loader
    all_labels = torch.cat([lbl for _,lbl in DataLoader(train_ds, batch_size=BATCH)])
    mask = all_labels != -1
    freq = torch.bincount(all_labels[mask], minlength=6).float()
    class_weights = (1.0 / (freq + 1e-6))
    class_weights = (class_weights / class_weights.sum()) * 6.0
    criterion = nn.CrossEntropyLoss(ignore_index=-1, weight=class_weights.to(device))
else:
    criterion = nn.CrossEntropyLoss(ignore_index=-1)

optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)

# 8. Quick Test
test_x, test_y = next(iter(train_loader))
print('Test batch loaded:', test_x.shape, test_y.shape)


Test batch loaded: torch.Size([64, 6144]) torch.Size([64, 60])


In [None]:
best_acc = 0.0
for epoch in range(1,EPOCHS+1):
    model.train(); tloss=0.0
    for x,y in tqdm(train_loader,desc=f'Epoch {epoch}'):
        x,y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1,6), y.view(-1))
        loss.backward(); optimizer.step()
        tloss += loss.item()
    avg_loss = tloss/len(train_loader)
    model.eval(); correct=0; total=0
    with torch.no_grad():
        for vx,vy in val_loader:
            vx,vy = vx.to(device), vy.to(device)
            p = model(vx).argmax(2); m = (vy!=-1)
            correct += ((p==vy)&m).sum().item()
            total   += m.sum().item()
    val_acc = correct/total if total>0 else 0.0
    print(f'Epoch {epoch} | Loss {avg_loss:.4f} | Val Acc {val_acc:.4f}')
    if val_acc>best_acc:
        best_acc=val_acc
        ckpt = f"best_{SAMPLING_METHOD}_{'weighted' if WEIGHTED_LOSS else 'unweighted'}.pt"
        torch.save(model.state_dict(),ckpt)
        print(f'  ✓ New best: {best_acc:.4f} saved to {ckpt}')
print(f'Training complete. Best Val Acc = {best_acc:.4f}')

Epoch 1: 100%|███████████████████████████████████████████████| 730/730 [02:37<00:00,  4.62it/s]


Epoch 1 | Loss 0.6284 | Val Acc 0.4178
  ✓ New best: 0.4178 saved to best_weighted_sampler_weighted.pt


Epoch 2: 100%|███████████████████████████████████████████████| 730/730 [02:35<00:00,  4.69it/s]


Epoch 2 | Loss 0.6119 | Val Acc 0.4547
  ✓ New best: 0.4547 saved to best_weighted_sampler_weighted.pt


Epoch 3: 100%|███████████████████████████████████████████████| 730/730 [02:35<00:00,  4.68it/s]


Epoch 3 | Loss 0.6179 | Val Acc 0.4387


Epoch 4: 100%|███████████████████████████████████████████████| 730/730 [02:34<00:00,  4.72it/s]


Epoch 4 | Loss 0.6107 | Val Acc 0.4003


Epoch 5: 100%|███████████████████████████████████████████████| 730/730 [02:38<00:00,  4.60it/s]


Epoch 5 | Loss 0.6114 | Val Acc 0.4288


Epoch 6: 100%|███████████████████████████████████████████████| 730/730 [02:35<00:00,  4.68it/s]


Epoch 6 | Loss 0.6123 | Val Acc 0.4067


Epoch 7: 100%|███████████████████████████████████████████████| 730/730 [02:37<00:00,  4.64it/s]


Epoch 7 | Loss 0.6089 | Val Acc 0.4217


Epoch 8: 100%|███████████████████████████████████████████████| 730/730 [02:37<00:00,  4.62it/s]


Epoch 8 | Loss 0.6100 | Val Acc 0.4319


Epoch 9: 100%|███████████████████████████████████████████████| 730/730 [02:36<00:00,  4.67it/s]


Epoch 9 | Loss 0.6019 | Val Acc 0.4422


Epoch 10: 100%|██████████████████████████████████████████████| 730/730 [02:36<00:00,  4.67it/s]


Epoch 10 | Loss 0.5949 | Val Acc 0.4074


Epoch 11: 100%|██████████████████████████████████████████████| 730/730 [02:36<00:00,  4.67it/s]


Epoch 11 | Loss 0.5981 | Val Acc 0.4237


Epoch 12: 100%|██████████████████████████████████████████████| 730/730 [02:38<00:00,  4.61it/s]


Epoch 12 | Loss 0.5990 | Val Acc 0.4328


Epoch 13: 100%|██████████████████████████████████████████████| 730/730 [02:35<00:00,  4.68it/s]


Epoch 13 | Loss 0.6062 | Val Acc 0.4140


Epoch 14: 100%|██████████████████████████████████████████████| 730/730 [02:36<00:00,  4.66it/s]


Epoch 14 | Loss 0.5909 | Val Acc 0.4317


Epoch 15:  42%|███████████████████▎                          | 307/730 [01:05<01:29,  4.75it/s]

In [11]:
def evaluate(loader):
    model.eval(); correct=0; total=0
    with torch.no_grad():
        for x,y in loader:
            x,y=x.to(device),y.to(device)
            preds=model(x).argmax(2)
            mask=y!=-1
            correct += ((preds==y)&mask).sum().item()
            total   += mask.sum().item()
    return correct/total

acc=evaluate(val_loader)
print('Validation accuracy:',acc)

Validation accuracy: 0.3871216230529495
