# Linear Probing Using Extracted Features with EdgePred

In [1]:
import numpy as np
from torch import nn
from torch.optim import Adam
from benchmol.dataloader.feat_dataset import TrainValTestFromCSVFactory
from benchmol.trainer import Trainer
from benchmol.model_pools.base_utils import get_predictor

In [2]:
# Data
feat_pkl_path = "../datasets/toys/toxcast/processed/EdgePred.pkl"  # features extracted by EdgePred
csv_path = "../datasets/toys/toxcast/processed/toxcast_processed_ac.csv"
n_feat, num_tasks = 300, 617
device = "cpu"  # or cuda:0

# Take the feature modality and classification task as an example
modality = "feature"
task_type = "classification"

# define Model with arch3 MLP for n tasks
model = get_predictor(arch="arch3", in_features=n_feat, num_tasks=num_tasks)

# define Dataset
factory = TrainValTestFromCSVFactory(
    csv_path, feat_pkl_path, task_type, y_column="label", split_column="scaffold_split",
    batch_size=8, num_workers=8, pin_memory=True
)
train_loader, valid_loader, test_loader = factory.get_dataloader(split="train"), factory.get_dataloader(split="valid"), factory.get_dataloader(split="test")

# define Trainer
trainer = Trainer(
    model, modality, train_loader, valid_loader, test_loader, task_type,
    criterion=nn.BCEWithLogitsLoss(reduction="none"), optimizer=Adam(model.parameters(), lr=0.001, weight_decay=1e-5), 
    label_empty=-1, device=device
)

# training and evaluation
results = trainer.train(
    num_epochs=10, eval_metric="ROCAUC", valid_select="max", min_value=-np.inf, dataset="toxcast", 
    save_finetune_ckpt=True, save_dir="./experiments/"
)

# Output model results
print("results: {}\n".format(results))

[train] dataset: toxcast; epoch: 0 total loss: 0.570: 100%|██████████| 858/858 [00:05<00:00, 150.27it/s]
[eval on train set] dataset: toxcast; epoch: 0 total loss: 0.567: 100%|██████████| 858/858 [00:01<00:00, 487.82it/s]
[eval on valid set] dataset: toxcast; epoch: 0 total loss: 0.476: 100%|██████████| 108/108 [00:00<00:00, 135.02it/s]
[eval on test set] dataset: toxcast; epoch: 0:   0%|          | 0/108 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]


[eval on test set] dataset: toxcast; epoch: 0 total loss: 0.494: 100%|██████████| 108/108 [00:00<00:00, 141.05it/s]
  0%|          | 0/858 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]
{'dataset': 'toxcast', 'epoch': 0, 'Loss': 0.5666267377235394, 'metric': 'ROCAUC', 'train': 0.6702432400322601, 'valid': 0.5492058675004362, 'test': 0.524549900865828}
model has been saved as ./experiments//valid_best.pth


[train] dataset: toxcast; epoch: 1 total loss: 0.392: 100%|██████████| 858/858 [00:02<00:00, 325.37it/s]
[eval on train set] dataset: toxcast; epoch: 1 total loss: 0.356: 100%|██████████| 858/858 [00:02<00:00, 335.96it/s]
[eval on valid set] dataset: toxcast; epoch: 1 total loss: 0.372: 100%|██████████| 108/108 [00:00<00:00, 152.81it/s]
[eval on test set] dataset: toxcast; epoch: 1:   0%|          | 0/108 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]


[eval on test set] dataset: toxcast; epoch: 1 total loss: 0.398: 100%|██████████| 108/108 [00:00<00:00, 149.21it/s]
  0%|          | 0/858 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]
{'dataset': 'toxcast', 'epoch': 1, 'Loss': 0.35581817271270394, 'metric': 'ROCAUC', 'train': 0.6860053472650769, 'valid': 0.5580567861350569, 'test': 0.5330715364122453}
model has been saved as ./experiments//valid_best.pth


[train] dataset: toxcast; epoch: 2 total loss: 0.324: 100%|██████████| 858/858 [00:02<00:00, 317.86it/s]
[eval on train set] dataset: toxcast; epoch: 2 total loss: 0.274: 100%|██████████| 858/858 [00:02<00:00, 354.43it/s]
[eval on valid set] dataset: toxcast; epoch: 2 total loss: 0.323: 100%|██████████| 108/108 [00:01<00:00, 93.26it/s]
[eval on test set] dataset: toxcast; epoch: 2:   0%|          | 0/108 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]


[eval on test set] dataset: toxcast; epoch: 2 total loss: 0.353: 100%|██████████| 108/108 [00:00<00:00, 135.72it/s]
  0%|          | 0/858 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]
{'dataset': 'toxcast', 'epoch': 2, 'Loss': 0.27417207375550884, 'metric': 'ROCAUC', 'train': 0.6959876115554002, 'valid': 0.5598498818270679, 'test': 0.5385805732850703}
model has been saved as ./experiments//valid_best.pth


[train] dataset: toxcast; epoch: 3 total loss: 0.272: 100%|██████████| 858/858 [00:02<00:00, 331.96it/s]
[eval on train set] dataset: toxcast; epoch: 3 total loss: 0.270: 100%|██████████| 858/858 [00:02<00:00, 346.52it/s]
[eval on valid set] dataset: toxcast; epoch: 3 total loss: 0.299: 100%|██████████| 108/108 [00:00<00:00, 114.42it/s]
[eval on test set] dataset: toxcast; epoch: 3:   0%|          | 0/108 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]


[eval on test set] dataset: toxcast; epoch: 3 total loss: 0.330: 100%|██████████| 108/108 [00:00<00:00, 118.66it/s]
  0%|          | 0/858 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]
{'dataset': 'toxcast', 'epoch': 3, 'Loss': 0.27049370967980585, 'metric': 'ROCAUC', 'train': 0.7039584773521083, 'valid': 0.5623088786568775, 'test': 0.5380729343636014}
model has been saved as ./experiments//valid_best.pth


[train] dataset: toxcast; epoch: 4 total loss: 0.264: 100%|██████████| 858/858 [00:02<00:00, 330.73it/s]
[eval on train set] dataset: toxcast; epoch: 4 total loss: 0.384: 100%|██████████| 858/858 [00:02<00:00, 344.78it/s]
[eval on valid set] dataset: toxcast; epoch: 4 total loss: 0.286: 100%|██████████| 108/108 [00:00<00:00, 130.21it/s]
[eval on test set] dataset: toxcast; epoch: 4:   0%|          | 0/108 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]


[eval on test set] dataset: toxcast; epoch: 4 total loss: 0.319: 100%|██████████| 108/108 [00:00<00:00, 126.16it/s]
  0%|          | 0/858 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]
{'dataset': 'toxcast', 'epoch': 4, 'Loss': 0.3840466123638731, 'metric': 'ROCAUC', 'train': 0.7118933845293426, 'valid': 0.563530446798586, 'test': 0.5403093097607599}
model has been saved as ./experiments//valid_best.pth


[train] dataset: toxcast; epoch: 5 total loss: 0.265: 100%|██████████| 858/858 [00:02<00:00, 390.01it/s]
[eval on train set] dataset: toxcast; epoch: 5 total loss: 0.245: 100%|██████████| 858/858 [00:02<00:00, 368.98it/s]
[eval on valid set] dataset: toxcast; epoch: 5 total loss: 0.279: 100%|██████████| 108/108 [00:00<00:00, 128.73it/s]
[eval on test set] dataset: toxcast; epoch: 5:   0%|          | 0/108 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]


[eval on test set] dataset: toxcast; epoch: 5 total loss: 0.313: 100%|██████████| 108/108 [00:00<00:00, 129.77it/s]
  0%|          | 0/858 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]
{'dataset': 'toxcast', 'epoch': 5, 'Loss': 0.24518851967124672, 'metric': 'ROCAUC', 'train': 0.7200329220929907, 'valid': 0.5674412154189835, 'test': 0.5405272610319071}
model has been saved as ./experiments//valid_best.pth


[train] dataset: toxcast; epoch: 6 total loss: 0.254: 100%|██████████| 858/858 [00:02<00:00, 328.39it/s]
[eval on train set] dataset: toxcast; epoch: 6 total loss: 0.242: 100%|██████████| 858/858 [00:02<00:00, 313.42it/s]
[eval on valid set] dataset: toxcast; epoch: 6 total loss: 0.276: 100%|██████████| 108/108 [00:05<00:00, 19.38it/s]
[eval on test set] dataset: toxcast; epoch: 6:   0%|          | 0/108 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]


[eval on test set] dataset: toxcast; epoch: 6 total loss: 0.309: 100%|██████████| 108/108 [00:05<00:00, 19.21it/s]
  0%|          | 0/858 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]
{'dataset': 'toxcast', 'epoch': 6, 'Loss': 0.24181629172016134, 'metric': 'ROCAUC', 'train': 0.7258804422994105, 'valid': 0.5694439250077379, 'test': 0.5415806791683212}
model has been saved as ./experiments//valid_best.pth


[train] dataset: toxcast; epoch: 7 total loss: 0.252: 100%|██████████| 858/858 [00:20<00:00, 42.10it/s] 
[eval on train set] dataset: toxcast; epoch: 7 total loss: 0.264: 100%|██████████| 858/858 [00:14<00:00, 60.40it/s] 
[eval on valid set] dataset: toxcast; epoch: 7 total loss: 0.273: 100%|██████████| 108/108 [00:03<00:00, 30.00it/s]
[eval on test set] dataset: toxcast; epoch: 7:   0%|          | 0/108 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]


[eval on test set] dataset: toxcast; epoch: 7 total loss: 0.308: 100%|██████████| 108/108 [00:04<00:00, 22.41it/s]
  0%|          | 0/858 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]
{'dataset': 'toxcast', 'epoch': 7, 'Loss': 0.2635254759888549, 'metric': 'ROCAUC', 'train': 0.7299309522798388, 'valid': 0.5692742668894935, 'test': 0.5436785123890283}


[train] dataset: toxcast; epoch: 8 total loss: 0.252: 100%|██████████| 858/858 [00:30<00:00, 27.77it/s] 
[eval on train set] dataset: toxcast; epoch: 8 total loss: 0.244: 100%|██████████| 858/858 [00:12<00:00, 70.28it/s] 
[eval on valid set] dataset: toxcast; epoch: 8 total loss: 0.272: 100%|██████████| 108/108 [00:05<00:00, 20.92it/s]
[eval on test set] dataset: toxcast; epoch: 8:   0%|          | 0/108 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]


[eval on test set] dataset: toxcast; epoch: 8 total loss: 0.307: 100%|██████████| 108/108 [00:04<00:00, 22.81it/s]
  0%|          | 0/858 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]
{'dataset': 'toxcast', 'epoch': 8, 'Loss': 0.24352465365038609, 'metric': 'ROCAUC', 'train': 0.7342290161684275, 'valid': 0.5683418261234248, 'test': 0.5444607985955275}


[train] dataset: toxcast; epoch: 9 total loss: 0.240: 100%|██████████| 858/858 [02:13<00:00,  6.43it/s]
[eval on train set] dataset: toxcast; epoch: 9 total loss: 0.237: 100%|██████████| 858/858 [00:44<00:00, 19.40it/s]
[eval on valid set] dataset: toxcast; epoch: 9 total loss: 0.272: 100%|██████████| 108/108 [00:06<00:00, 15.93it/s]
[eval on test set] dataset: toxcast; epoch: 9:   0%|          | 0/108 [00:00<?, ?it/s]

Some target is missing! Missing ratio: 0.01 [610/617]


[eval on test set] dataset: toxcast; epoch: 9 total loss: 0.306: 100%|██████████| 108/108 [00:07<00:00, 14.10it/s]


Some target is missing! Missing ratio: 0.01 [610/617]
{'dataset': 'toxcast', 'epoch': 9, 'Loss': 0.23722279155171, 'metric': 'ROCAUC', 'train': 0.7375540053961169, 'valid': 0.572109682898607, 'test': 0.5467083229900723}
model has been saved as ./experiments//valid_best.pth
results: {'highest_valid': 0.572109682898607, 'final_train': 0.7375540053961169, 'final_test': 0.5467083229900723, 'highest_train': 0.7375540053961169}

