In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
from multimodal_contrastive.data.dataset import TestDataset, CustomSubset
from multimodal_contrastive.networks.models import MultiTask_FP_PL
from multimodal_contrastive.analysis.utils import make_eval_data_loader
import torch

## Load Model

In [2]:
ckpt = '/home/mila/s/stephen.lu/gfn_gene/res/mmc/models/puma_assay_epoch=139.ckpt'

if torch.cuda.is_available() == False:
    device = torch.device('cpu')
else:
    device = None

model = MultiTask_FP_PL.load_from_checkpoint(ckpt, map_location=device)
model.eval()

MultiTask_FP_PL(
  (loss): MultiTaskLoss()
  (encoder): FP_MLP(
    (mlp): MultiLayerPerceptron(
      (module): Sequential(
        (0): Linear(in_features=2048, out_features=64, bias=True)
        (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Dropout(p=0.0, inplace=False)
        (4): Linear(in_features=64, out_features=64, bias=True)
        (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): Dropout(p=0.0, inplace=False)
        (8): Linear(in_features=64, out_features=37, bias=True)
      )
    )
  )
)

## Load Data

In [20]:
# load dataset and test set smiles
data_dir = '/home/mila/s/stephen.lu/scratch/mmc/datasets/'
dataset = TestDataset(data_dir + 'assay_matrix_discrete_37_assays_canonical.csv')
target_assay_smi = np.load(data_dir + 'assay_targets_smi.npy')

In [21]:
# Get the rows of the dataset corresponding to the test set smiles
target_idx = [np.where(np.array(dataset.ids)==_)[0][0] for _ in target_assay_smi if _ in set(dataset.ids)]
print(target_idx)

target_dataset = CustomSubset(dataset, [i for i in range(len(dataset)) if i in target_idx])
# target_dataset = CustomSubset(dataset, np.arange(len(dataset)))
print(len(target_dataset))

[36, 821, 1648, 2040, 6470, 8362, 9552, 13210]
8


In [24]:
# measure the performance using accuracy, precision, recall, f1
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# get predictions for the test set
loader = make_eval_data_loader(target_dataset.dataset, batch_size=len(target_dataset))

for batch in loader:
    y = batch['labels']
    y_hat = model(batch)

    idx_to_keep = ~torch.isnan(y)
    y = y[idx_to_keep].cpu().numpy()
    y_hat = y_hat[0][idx_to_keep].detach().cpu().numpy()

    # choose the best threshold depending on f1 score
    thresholds = np.linspace(0, 1, 100)
    best_f1 = 0

    for threshold in thresholds:
        y_pred = y_hat > threshold
        f1 = f1_score(y, y_pred)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    y_pred = y_hat > best_threshold

    # measure the performance using accuracy, precision, recall, f1
    print('Accuracy:', accuracy_score(y, y_pred))
    print('Precision:', precision_score(y, y_pred))
    print('Recall:', recall_score(y, y_pred))
    print('F1:', f1_score(y, y_pred))

    break

tensor([[False, False,  True, False, False, False, False, False,  True, False,
         False, False, False, False, False, False, False, False, False, False,
          True, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False],
        [False, False,  True, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False,  True,  True, False,
          True, False,  True, False,  True, False, False],
        [False, False, False, False, False, False, False,  True, False, False,
         False, False, False, False, False, False, False,  True, False,  True,
         False, False, False, False, False, False, False, False, False,  True,
         False, False, False,  True, False, False, False],
        [False, False,  True,  True, False, False, False, False,  True, False,
         False, False, False, Fal