# Downstream adaption with MiniMol

This example shows how MiniMol can featurise the molecules that will then serve as an input to another model trained on a small downstream dataset from TDC ADMET. This allows to transfer the knowledge from the pre-trained MiniMol to another task. 

Before we start, let's make sure that the TDC package is installed in the environment. 

In [1]:
%pip install PyTDC

## Step 1: Getting the data
Next, we will build a predictor for the `HIA Hou` dataset, one of the binary classification benchmarks corresponding to `absorption`-type of problems from TDC ADMET group. We then split the data into training, validation and test set based on molecular scaffolds. 

In [2]:
from tdc.benchmark_group import admet_group

DATASET_NAME = 'hia_hou'

admet = admet_group(path="admet-data/")

mols_test = admet.get(DATASET_NAME)['test']
mols_train, mols_val = admet.get_train_valid_split(benchmark=DATASET_NAME, split_type='scaffold', seed=42)

Found local copy...
generating training, validation splits...
generating training, validation splits...
100%|██████████| 461/461 [00:00<00:00, 3691.96it/s]


In [3]:
print(f"Dataset - {DATASET_NAME}\n")
print(f"Val split ({len(mols_val)} mols): \n{mols_val.head()}\n")
print(f"Test split ({len(mols_test)} mols): \n{mols_test.head()}\n")
print(f"Train split ({len(mols_train)} mols): \n{mols_train.head()}\n")

Dataset - hia_hou

Val split (58 mols): 
                 Drug_ID                                               Drug  Y
0         Atracurium.mol  COc1ccc(C[C@H]2c3cc(OC)c(OC)cc3CC[N@@+]2(C)CCC...  0
1  Succinylsulfathiazole          O=C(O)CCC(=O)Nc1ccc(S(=O)(=O)Nc2nccs2)cc1  0
2            Ticarcillin  CC1(C)S[C@H]2[C@@H](NC(=O)[C@@H](C(=O)O)c3ccsc...  0
3          Raffinose.mol  OC[C@@H]1O[C@@H](OC[C@@H]2O[C@@H](O[C@]3(CO)O[...  0
4          Triamcinolone  C[C@@]12C=CC(=O)C=C1CC[C@@H]1[C@H]3C[C@@H](O)[...  1

Test split (117 mols): 
                Drug_ID                                               Drug  Y
0         Trazodone.mol         O=c1n(CCCN2CCN(c3cccc(Cl)c3)CC2)nc2ccccn12  1
1          Lisuride.mol  CCN(CC)C(=O)N[C@H]1C=C2c3cccc4[nH]cc(c34)C[C@@...  1
2  Methylergonovine.mol  CC[C@H](CO)NC(=O)[C@H]1C=C2c3cccc4[nH]cc(c34)C...  1
3      Methysergide.mol  CC[C@H](CO)NC(=O)[C@H]1C=C2c3cccc4c3c(cn4C)C[C...  1
4       Moclobemide.mol                       O=C(NCCN1CCOCC1)c1ccc(Cl

## Step 2: Generating molecular fingerprints
After spltting the dataset into training, validation and test sets, we will use MiniMol to embed all molecules. The embedding will be added as an extra column in the dataframe returned by TDC.

In [4]:
from minimol import Minimol

featuriser = Minimol()

In [5]:
mols_val['Embedding'] = featuriser(mols_val['Drug'])
mols_test['Embedding'] = featuriser(mols_test['Drug'])
mols_train['Embedding'] = featuriser(mols_train['Drug'])

featurizing_smiles, batch=1:   0%|          | 0/58 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 58/58 [00:00<00:00, 8947.03it/s]
100%|██████████| 1/1 [00:00<00:00,  4.72it/s]


featurizing_smiles, batch=3:   0%|          | 0/39 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 117/117 [00:00<00:00, 16518.57it/s]
100%|██████████| 2/2 [00:00<00:00,  7.86it/s]


featurizing_smiles, batch=13:   0%|          | 0/31 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 403/403 [00:00<00:00, 19929.78it/s]
100%|██████████| 5/5 [00:00<00:00,  7.37it/s]


The model is small, so it took us 7.3 seconds to generate the embeddings for almost 600 molecules. Here is a preview after a new column has been added:

In [6]:
print(mols_train.head())

           Drug_ID                                               Drug  Y  \
0        Guanadrel                      N=C(N)NC[C@@H]1COC2(CCCCC2)O1  1   
1      Cefmetazole  CO[C@@]1(NC(=O)CSCC#N)C(=O)N2C(C(=O)O)=C(CSc3n...  0   
2   Zonisamide.mol                           NS(=O)(=O)Cc1noc2ccccc12  1   
3   Furosemide.mol            NS(=O)(=O)c1cc(Cl)cc(NCc2ccco2)c1C(=O)O  1   
4  Telmisartan.mol  CCCc1nc2c(n1Cc1ccc(-c3ccccc3C(=O)O)cc1)=C[C@H]...  1   

                                           Embedding  
0  [0.24859753, 0.18472305, 0.4028932, 0.22700065...  
1  [0.7069565, 0.41227153, 1.0127053, 2.3176281, ...  
2  [0.19019875, -0.14087728, 0.8896561, 1.2718395...  
3  [0.11933186, 0.38785577, 1.5808605, 1.999807, ...  
4  [0.99853146, 1.1408926, 2.2468193, 1.3438487, ...  


## Step 3: Training a model
Now that the molecules are featurised leverging the representation MiniMol learned during its pre-training, we will set up the training of a simple Multi-Layer Perceptron model on our newely generated embeddings and the labels from the `HIA Hou` dataset. We will use PyTorch.

Let's start by defining a new class for the dataset and then creating the dataloaders for different splits.

In [7]:
from torch.utils.data import DataLoader, Dataset
    
class AdmetDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples['Embedding'].tolist()
        self.targets = [float(target) for target in samples['Y'].tolist()]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = torch.tensor(self.samples[idx])
        target = torch.tensor(self.targets[idx])
        return sample, target

val_loader = DataLoader(AdmetDataset(mols_val), batch_size=128, shuffle=False)
test_loader = DataLoader(AdmetDataset(mols_test), batch_size=128, shuffle=False)
train_loader = DataLoader(AdmetDataset(mols_train), batch_size=32, shuffle=True)

Our model will be a simple 3-layer perceptron with batch normalisation and dropout. Before the last layer the input features will be concatenated together with the output from the previous layer.  

In [8]:
import torch.nn as nn
import torch.nn.functional as F


class TaskHead(nn.Module):
    def __init__(self):
        super(TaskHead, self).__init__()
        self.dense1 = nn.Linear(512, 512)
        self.dense2 = nn.Linear(512, 512)
        self.dense3 = nn.Linear(1024, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        original_x = x

        x = self.dense1(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = F.relu(x)

        x = self.dense2(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = F.relu(x)

        x = torch.cat((x, original_x), dim=1)
        x = self.dense3(x)
        return self.sigmoid(x)

Below we declare the basic hyperparamters together with choosing optimiser, loss function, learning scheduler and weight decay regularisation.

In [9]:
import math
import torch.optim as optim

lr = 0.006
epochs = 25
warmup = 5

loss_fn = nn.BCELoss()

def model_factory():
    model = TaskHead()
    lr_fn = lambda epoch: lr * (1 + math.cos(math.pi * (epoch - warmup) / (epochs - warmup))) / 2
    optimiser = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0001)
    lr_scheduler = optim.lr_scheduler.LambdaLR(optimiser, lr_fn)
    return model, optimiser, lr_scheduler

For evaluation we will use both AUROC and Average Precision metrics. The reported loss would be an average across all samples in the epoch.

In [10]:
import torch
from sklearn.metrics import roc_auc_score, average_precision_score

def evaluate(predictor, dataloader, loss_fn):
    predictor.eval()
    total_loss = 0
    all_probs = []
    all_targets = []

    with torch.no_grad():
        
        for inputs, targets in dataloader:
        
            probs = predictor(inputs).squeeze()

            loss = loss_fn(probs, targets)
            total_loss += loss.item()

            all_probs.extend(probs.tolist())
            all_targets.extend(targets.tolist())

    loss = total_loss / len(dataloader)
    
    return (
        loss,
        roc_auc_score(all_targets, all_probs),
        average_precision_score(all_targets, all_probs)
    )

It's time to define a method for training a model: 

In [11]:
def train_one_epoch(predictor, optimiser, lr_scheduler, loss_fn, epoch):
    predictor.train()        
    train_loss = 0

    lr_scheduler.step(epoch)
    for inputs, targets in train_loader:
        optimiser.zero_grad()
        probs = predictor(inputs).squeeze()
        loss = loss_fn(probs, targets)
        loss.backward()
        optimiser.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)
    val_loss, auroc, avpr = evaluate(predictor, val_loader, loss_fn)
    print(
        f"## Epoch {epoch+1}\t"
        f"train_loss: {train_loss:.4f}\t"
        f"val_loss: {val_loss:.4f}\t"
        f"val_auroc: {auroc:.4f}\t"
        f"val_avpr: {avpr:.4f}"
    )
    return predictor

And now, let's see how good our model gets after training... 🚀

In [17]:
model, optimiser, lr_scheduler = model_factory()

val_loss, val_auroc, val_avpr = evaluate(model, val_loader, loss_fn)
print(
    f"## Epoch 0\t"
    f"train_loss: ------\t"
    f"val_loss: {val_loss:.4f}\t"
    f"val_auroc: {val_auroc:.4f}\t"
    f"val_avpr: {val_avpr:.4f}"
)

for epoch in range(epochs):
    model = train_one_epoch(model, optimiser, lr_scheduler, loss_fn, epoch)

test_loss, test_auroc, test_avpr = evaluate(model, test_loader, loss_fn)
print(
    f"test_loss: {test_loss:.4f}\n"
    f"test_auroc: {test_auroc:.4f}\n"
    f"test_avpr: {test_avpr:.4f}"
)

## Epoch 0	train_loss: ------	val_loss: 0.6198	val_auroc: 0.4712	val_avpr: 0.8686
## Epoch 1	train_loss: 0.5763	val_loss: 0.5906	val_auroc: 0.4887	val_avpr: 0.8722
## Epoch 2	train_loss: 0.4879	val_loss: 0.5132	val_auroc: 0.4536	val_avpr: 0.8507
## Epoch 3	train_loss: 0.4125	val_loss: 0.4404	val_auroc: 0.4787	val_avpr: 0.8584
## Epoch 4	train_loss: 0.3626	val_loss: 0.3910	val_auroc: 0.5439	val_avpr: 0.8717
## Epoch 5	train_loss: 0.3205	val_loss: 0.3605	val_auroc: 0.5815	val_avpr: 0.8935
## Epoch 6	train_loss: 0.2883	val_loss: 0.3375	val_auroc: 0.5990	val_avpr: 0.8963
## Epoch 7	train_loss: 0.2585	val_loss: 0.3202	val_auroc: 0.6241	val_avpr: 0.9073
## Epoch 8	train_loss: 0.2316	val_loss: 0.3096	val_auroc: 0.6516	val_avpr: 0.9183
## Epoch 9	train_loss: 0.2167	val_loss: 0.2989	val_auroc: 0.6642	val_avpr: 0.9222
## Epoch 10	train_loss: 0.2063	val_loss: 0.2903	val_auroc: 0.6842	val_avpr: 0.9309
## Epoch 11	train_loss: 0.1929	val_loss: 0.2828	val_auroc: 0.7118	val_avpr: 0.9422
## Epoch 12	tr

The model trained in just 1.4s, reaching AUROC on the test set of 0.9671. Pretty good!

The result can be further improved. We adapt two techniques:

- Ensembling. Since the training is so fast, fitting a few addtional models is not a big deal. We will train each of the five ensembled models on a different fold generated by the TDC train-val splitting method.

- Rather than choosing the model at the last epoch, we will use best validation loss to decide which one to choose.

Below we implement a method that create a new training and validation dataloader for each fold, and also a mehtod for evaluating using an ensemble rather than a single model.

In [18]:
def dataloader_factory(seed):
    mols_train, mols_val = admet.get_train_valid_split(benchmark=DATASET_NAME, split_type='scaffold', seed=seed)

    mols_val['Embedding'] = featuriser(mols_val['Drug'])
    mols_train['Embedding'] = featuriser(mols_train['Drug'])

    val_loader = DataLoader(AdmetDataset(mols_val), batch_size=128, shuffle=False)
    train_loader = DataLoader(AdmetDataset(mols_train), batch_size=32, shuffle=True)

    return val_loader, train_loader


def evaluate_ensemble(predictors, dataloader, loss_fn):
    total_loss = 0
    all_probs = []
    all_targets = []

    with torch.no_grad():
        
        for inputs, targets in dataloader:
            model_outputs = [predictor(inputs).squeeze() for predictor in predictors]
            averaged_output = torch.mean(torch.stack(model_outputs, dim=0), dim=0)

            loss = loss_fn(averaged_output, targets)
            total_loss += loss.item()

            all_probs.extend(averaged_output.tolist())
            all_targets.extend(targets.tolist())

    loss = total_loss / len(dataloader)
    return loss, roc_auc_score(all_targets, all_probs), average_precision_score(all_targets, all_probs)

Finally, let's see how much better the model can get!

In [13]:
from copy import deepcopy

seeds = [1, 2, 3, 4, 5]

best_models = []
num_folds = 5

for seed in seeds:
    val_loader, train_loader = dataloader_factory(seed)
    model, optimiser, lr_scheduler = model_factory()

    best_epoch = {"model": None, "result": None}
    for epoch in range(epochs):
        model = train_one_epoch(model, optimiser, lr_scheduler, loss_fn, epoch)
        val_loss, auroc, _ = evaluate(model, val_loader, loss_fn)

        if best_epoch['model'] is None:
            best_epoch['model'] = deepcopy(model)
            best_epoch['result'] = auroc
        else:
            best_epoch['model'] = best_epoch['model'] if best_epoch['result'] <= val_loss else deepcopy(model)
            best_epoch['result'] = best_epoch['result'] if best_epoch['result'] <= val_loss else val_loss 

    best_models.append(best_epoch['model'])

test_loss, test_auroc, test_avpr = evaluate_ensemble(best_models, test_loader, loss_fn)
print(
    f"test_loss: {test_loss:.4f}\n"
    f"test_auroc: {test_auroc:.4f}\n"
    f"test_avpr: {test_avpr:.4f}"
)

generating training, validation splits...
100%|██████████| 461/461 [00:00<00:00, 3196.61it/s]


featurizing_smiles, batch=1:   0%|          | 0/58 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 58/58 [00:00<00:00, 15806.99it/s]
100%|██████████| 1/1 [00:00<00:00,  4.36it/s]


featurizing_smiles, batch=13:   0%|          | 0/31 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 403/403 [00:00<00:00, 16046.94it/s]
100%|██████████| 5/5 [00:00<00:00,  7.57it/s]


## Epoch 1	train_loss: 0.5566	val_loss: 0.5515	val_auroc: 0.3081	val_avpr: 0.8305
## Epoch 2	train_loss: 0.4675	val_loss: 0.4924	val_auroc: 0.5826	val_avpr: 0.8775
## Epoch 3	train_loss: 0.4083	val_loss: 0.4274	val_auroc: 0.6947	val_avpr: 0.9211
## Epoch 4	train_loss: 0.3606	val_loss: 0.3802	val_auroc: 0.7535	val_avpr: 0.9460
## Epoch 5	train_loss: 0.3154	val_loss: 0.3475	val_auroc: 0.7675	val_avpr: 0.9505
## Epoch 6	train_loss: 0.2842	val_loss: 0.3227	val_auroc: 0.8151	val_avpr: 0.9658
## Epoch 7	train_loss: 0.2555	val_loss: 0.3013	val_auroc: 0.8459	val_avpr: 0.9737
## Epoch 8	train_loss: 0.2338	val_loss: 0.2869	val_auroc: 0.8543	val_avpr: 0.9754
## Epoch 9	train_loss: 0.2137	val_loss: 0.2747	val_auroc: 0.8711	val_avpr: 0.9793
## Epoch 10	train_loss: 0.2000	val_loss: 0.2675	val_auroc: 0.8796	val_avpr: 0.9810
## Epoch 11	train_loss: 0.1892	val_loss: 0.2595	val_auroc: 0.8824	val_avpr: 0.9816
## Epoch 12	train_loss: 0.1900	val_loss: 0.2528	val_auroc: 0.8964	val_avpr: 0.9844
## Epoch 13	t

generating training, validation splits...


## Epoch 22	train_loss: 0.1294	val_loss: 0.2317	val_auroc: 0.9132	val_avpr: 0.9873
## Epoch 23	train_loss: 0.1304	val_loss: 0.2315	val_auroc: 0.9132	val_avpr: 0.9873
## Epoch 24	train_loss: 0.1335	val_loss: 0.2314	val_auroc: 0.9132	val_avpr: 0.9873
## Epoch 25	train_loss: 0.1281	val_loss: 0.2307	val_auroc: 0.9132	val_avpr: 0.9873


100%|██████████| 461/461 [00:00<00:00, 3069.32it/s]


featurizing_smiles, batch=1:   0%|          | 0/58 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 58/58 [00:00<00:00, 17746.54it/s]
100%|██████████| 1/1 [00:00<00:00,  4.08it/s]


featurizing_smiles, batch=13:   0%|          | 0/31 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 403/403 [00:00<00:00, 19709.71it/s]
100%|██████████| 5/5 [00:00<00:00,  6.62it/s]


## Epoch 1	train_loss: 0.4981	val_loss: 0.5801	val_auroc: 0.4528	val_avpr: 0.9011
## Epoch 2	train_loss: 0.4269	val_loss: 0.4856	val_auroc: 0.6491	val_avpr: 0.9385
## Epoch 3	train_loss: 0.3687	val_loss: 0.3989	val_auroc: 0.7660	val_avpr: 0.9677
## Epoch 4	train_loss: 0.3297	val_loss: 0.3509	val_auroc: 0.8113	val_avpr: 0.9759
## Epoch 5	train_loss: 0.2891	val_loss: 0.3164	val_auroc: 0.8226	val_avpr: 0.9780
## Epoch 6	train_loss: 0.2658	val_loss: 0.2861	val_auroc: 0.8491	val_avpr: 0.9819
## Epoch 7	train_loss: 0.2421	val_loss: 0.2669	val_auroc: 0.8755	val_avpr: 0.9858
## Epoch 8	train_loss: 0.2309	val_loss: 0.2479	val_auroc: 0.8830	val_avpr: 0.9868
## Epoch 9	train_loss: 0.2134	val_loss: 0.2383	val_auroc: 0.8792	val_avpr: 0.9866
## Epoch 10	train_loss: 0.1947	val_loss: 0.2265	val_auroc: 0.8906	val_avpr: 0.9878
## Epoch 11	train_loss: 0.1809	val_loss: 0.2145	val_auroc: 0.8981	val_avpr: 0.9889
## Epoch 12	train_loss: 0.1672	val_loss: 0.2115	val_auroc: 0.9132	val_avpr: 0.9909
## Epoch 13	t

generating training, validation splits...


## Epoch 24	train_loss: 0.1306	val_loss: 0.1802	val_auroc: 0.9434	val_avpr: 0.9944
## Epoch 25	train_loss: 0.1301	val_loss: 0.1778	val_auroc: 0.9434	val_avpr: 0.9943


100%|██████████| 461/461 [00:00<00:00, 3789.72it/s]


featurizing_smiles, batch=1:   0%|          | 0/58 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 58/58 [00:00<00:00, 14143.58it/s]
100%|██████████| 1/1 [00:00<00:00,  5.68it/s]


featurizing_smiles, batch=13:   0%|          | 0/31 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 403/403 [00:00<00:00, 16329.24it/s]
100%|██████████| 5/5 [00:00<00:00,  5.46it/s]


## Epoch 1	train_loss: 0.5888	val_loss: 0.5955	val_auroc: 0.3109	val_avpr: 0.8351
## Epoch 2	train_loss: 0.5061	val_loss: 0.5120	val_auroc: 0.6603	val_avpr: 0.9407
## Epoch 3	train_loss: 0.4270	val_loss: 0.4331	val_auroc: 0.8333	val_avpr: 0.9703
## Epoch 4	train_loss: 0.3892	val_loss: 0.3694	val_auroc: 0.8622	val_avpr: 0.9780
## Epoch 5	train_loss: 0.3384	val_loss: 0.3416	val_auroc: 0.8782	val_avpr: 0.9813
## Epoch 6	train_loss: 0.3086	val_loss: 0.3090	val_auroc: 0.9071	val_avpr: 0.9875
## Epoch 7	train_loss: 0.2751	val_loss: 0.2914	val_auroc: 0.9071	val_avpr: 0.9873
## Epoch 8	train_loss: 0.2534	val_loss: 0.2720	val_auroc: 0.9263	val_avpr: 0.9906
## Epoch 9	train_loss: 0.2409	val_loss: 0.2577	val_auroc: 0.9359	val_avpr: 0.9919
## Epoch 10	train_loss: 0.2148	val_loss: 0.2464	val_auroc: 0.9455	val_avpr: 0.9933
## Epoch 11	train_loss: 0.2104	val_loss: 0.2390	val_auroc: 0.9487	val_avpr: 0.9938
## Epoch 12	train_loss: 0.1975	val_loss: 0.2297	val_auroc: 0.9583	val_avpr: 0.9951
## Epoch 13	t

generating training, validation splits...


## Epoch 24	train_loss: 0.1524	val_loss: 0.1998	val_auroc: 0.9712	val_avpr: 0.9966
## Epoch 25	train_loss: 0.1449	val_loss: 0.1999	val_auroc: 0.9712	val_avpr: 0.9966


100%|██████████| 461/461 [00:00<00:00, 3792.48it/s]


featurizing_smiles, batch=1:   0%|          | 0/58 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 58/58 [00:00<00:00, 15222.43it/s]
100%|██████████| 1/1 [00:00<00:00,  5.83it/s]


featurizing_smiles, batch=13:   0%|          | 0/31 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 403/403 [00:00<00:00, 14800.75it/s]
100%|██████████| 5/5 [00:00<00:00,  8.25it/s]


## Epoch 1	train_loss: 0.5502	val_loss: 0.5522	val_auroc: 0.6065	val_avpr: 0.9563
## Epoch 2	train_loss: 0.4692	val_loss: 0.4540	val_auroc: 0.7593	val_avpr: 0.9737
## Epoch 3	train_loss: 0.4013	val_loss: 0.3752	val_auroc: 0.7917	val_avpr: 0.9728
## Epoch 4	train_loss: 0.3501	val_loss: 0.3212	val_auroc: 0.8102	val_avpr: 0.9769
## Epoch 5	train_loss: 0.3182	val_loss: 0.2918	val_auroc: 0.8148	val_avpr: 0.9780
## Epoch 6	train_loss: 0.2813	val_loss: 0.2605	val_auroc: 0.8565	val_avpr: 0.9857
## Epoch 7	train_loss: 0.2536	val_loss: 0.2479	val_auroc: 0.8565	val_avpr: 0.9860
## Epoch 8	train_loss: 0.2370	val_loss: 0.2274	val_auroc: 0.8657	val_avpr: 0.9873
## Epoch 9	train_loss: 0.2189	val_loss: 0.2134	val_auroc: 0.8750	val_avpr: 0.9883
## Epoch 10	train_loss: 0.2053	val_loss: 0.2036	val_auroc: 0.8704	val_avpr: 0.9877
## Epoch 11	train_loss: 0.1918	val_loss: 0.1961	val_auroc: 0.8889	val_avpr: 0.9900
## Epoch 12	train_loss: 0.1851	val_loss: 0.1913	val_auroc: 0.8981	val_avpr: 0.9911
## Epoch 13	t

generating training, validation splits...


## Epoch 25	train_loss: 0.1377	val_loss: 0.1738	val_auroc: 0.9074	val_avpr: 0.9921


100%|██████████| 461/461 [00:00<00:00, 3762.20it/s]


featurizing_smiles, batch=2:   0%|          | 0/32 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 64/64 [00:00<00:00, 14105.17it/s]
100%|██████████| 1/1 [00:00<00:00,  4.85it/s]


featurizing_smiles, batch=13:   0%|          | 0/30 [00:00<?, ?it/s]

Casting to FP32: 100%|██████████| 397/397 [00:00<00:00, 17897.21it/s]
100%|██████████| 4/4 [00:00<00:00,  5.44it/s]


## Epoch 1	train_loss: 0.5709	val_loss: 0.5547	val_auroc: 0.5263	val_avpr: 0.9043
## Epoch 2	train_loss: 0.4856	val_loss: 0.4772	val_auroc: 0.6992	val_avpr: 0.9433
## Epoch 3	train_loss: 0.4175	val_loss: 0.4074	val_auroc: 0.7419	val_avpr: 0.9431
## Epoch 4	train_loss: 0.3627	val_loss: 0.3590	val_auroc: 0.7494	val_avpr: 0.9491
## Epoch 5	train_loss: 0.3160	val_loss: 0.3317	val_auroc: 0.7694	val_avpr: 0.9572
## Epoch 6	train_loss: 0.2816	val_loss: 0.3095	val_auroc: 0.7694	val_avpr: 0.9571
## Epoch 7	train_loss: 0.2555	val_loss: 0.2954	val_auroc: 0.7719	val_avpr: 0.9590
## Epoch 8	train_loss: 0.2365	val_loss: 0.2838	val_auroc: 0.7870	val_avpr: 0.9633
## Epoch 9	train_loss: 0.2165	val_loss: 0.2749	val_auroc: 0.8095	val_avpr: 0.9688
## Epoch 10	train_loss: 0.2049	val_loss: 0.2682	val_auroc: 0.8246	val_avpr: 0.9724
## Epoch 11	train_loss: 0.1894	val_loss: 0.2615	val_auroc: 0.8271	val_avpr: 0.9727
## Epoch 12	train_loss: 0.1848	val_loss: 0.2569	val_auroc: 0.8371	val_avpr: 0.9748
## Epoch 13	t

As we can see, in less than 15s our performance improved from 0.9671 to 0.9831 AUROC on the test set. 

The results could be further improved by doing a hyperparameter sweep for the task-specific head. Considering low cost of fitting these tiny models, it would still be very fast. The results reported in the paper are slighly better than what we show here, thanks to sweeping that we performed for each task, getting us a few extra percent points of boost.  