In [29]:
import json
from pathlib import Path
import os

import numpy as np
import torch
from tqdm import tqdm

from conplex_dti import featurizer
from conplex_dti.dataset import datamodules
from conplex_dti.model import architectures


In [30]:
histories = {file.split(".")[0]: json.load(open(os.path.join("../histories", file))) for file in os.listdir("../histories")}

model_architectures = {
    "baseline": architectures.SimpleCoembedding,
    "baseline_davis": architectures.SimpleCoembedding,
    "CNN": architectures.CNN,
    "cross_attention": architectures.CrossAttentionCoembedding,
    "duo_davis": architectures.DuoLayerPerceptron,
    "duo_layer": architectures.DuoLayerPerceptron,
    "large_layer": architectures.LargeDuoLayerPerceptron,
    "quintuple": architectures.QuintupleLayerPerceptron,
    "residual": architectures.ResidualCoembedding,
    "small_duo": architectures.SmallDuoLayerPerceptron,
    "triple_layer": architectures.TripleLayerPerceptron
}

In [31]:
device = torch.device("cpu")

models = {}
for name, arc in model_architectures.items():
    model = arc()
    model.load_state_dict(torch.load(f"../models/{name}.pt", map_location=device))
    model = model.eval()
    model = model.to(device)
    models[name] = model

In [32]:
task_dir = datamodules.get_task_dir("DAVIS", database_root=Path("../datasets"))

datamodule = datamodules.DTIDataModule(
    task_dir,
    drug_featurizer=featurizer.get_featurizer("MorganFeaturizer", save_dir=task_dir),
    target_featurizer=featurizer.get_featurizer("ProtBertFeaturizer", save_dir=task_dir),
    device=device,
    batch_size=32,
    shuffle=True,
    num_workers=0
)
datamodule.prepare_data()
datamodule.setup()
test_loader = datamodule.test_dataloader()

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Drug and target featurizers already exist
Morgan: 100%|██████████| 68/68 [00:00<00:00, 1430.33it/s]
ProtBert: 100%|████████

In [33]:
evaluation = {}
for name, model in models.items():
    preds, labels = [], []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Evaluating {name}"):
            drug, target, label = batch
            drug = drug.to(device)
            target = target.to(device)
            
            pred = model(drug, target)
            
            preds.extend(pred.cpu())
            labels.extend(label)
    
    evaluation[name] = {
        "preds": [float(p) for p in preds],
        "labels": [float(l) for l in labels]
    }

Evaluating baseline: 100%|██████████| 188/188 [00:00<00:00, 196.38it/s]
Evaluating baseline_davis: 100%|██████████| 188/188 [00:00<00:00, 218.69it/s]
Evaluating CNN: 100%|██████████| 188/188 [00:10<00:00, 17.11it/s]
Evaluating cross_attention: 100%|██████████| 188/188 [00:01<00:00, 150.88it/s]
Evaluating duo_davis: 100%|██████████| 188/188 [00:00<00:00, 251.81it/s]
Evaluating duo_layer: 100%|██████████| 188/188 [00:00<00:00, 235.83it/s]
Evaluating large_layer: 100%|██████████| 188/188 [00:01<00:00, 141.03it/s]
Evaluating quintuple: 100%|██████████| 188/188 [00:01<00:00, 119.77it/s]
Evaluating residual: 100%|██████████| 188/188 [00:07<00:00, 26.28it/s]
Evaluating small_duo: 100%|██████████| 188/188 [00:01<00:00, 174.40it/s]
Evaluating triple_layer: 100%|██████████| 188/188 [00:01<00:00, 139.84it/s]


In [35]:
for name in model_architectures.keys():
    histories[name]["evaluation"] = evaluation[name]
    
    with open(f"../histories/{name}.json", "w") as f:
        json.dump(histories[name], f)