# Fetch checkpoint models and test

In [1]:
from dotmap import DotMap
from src.tdc_constant import TDC

args = DotMap({
    'device':'cuda:2',
    'pred_layer_depth': 2,
    'output': 'testbench.csv',
})

folder_list = [
    "ckpts/MTL1_Scaled",
    "ckpts/MTL2_DropRatio0.2",
    "ckpts/MTL3_DropRatio0.5",
    "ckpts/MTL4_DropRatio0.7",
]

In [2]:
# Reload modules before run
%load_ext autoreload
%autoreload 2

In [3]:
from src.testbench_code import find_pt_files, extract_tdcList

modelPathList = find_pt_files(folder_list)
print(modelPathList)

['ckpts/MTL1_Scaled/MolCLR_[X, CYP3A4, X, Solubility]_sc-12.13_1617.pt', 'ckpts/MTL1_Scaled/MolCLR_[BBB, X, X, Solubility]_sc-12.06_0842.pt', 'ckpts/MTL1_Scaled/MolCLR_[X, X, Clearance, X]_sc-12.13_1630.pt', 'ckpts/MTL1_Scaled/MolCLR_[BBB, CYP3A4, Clearance, Solubility]-12.06_0025.pt', 'ckpts/MTL1_Scaled/MolCLR_[BBB, X, X, X]_sc-12.06_0034.pt', 'ckpts/MTL1_Scaled/MolCLR_[BBB, CYP3A4, X, X]_sc-12.06_0827.pt', 'ckpts/MTL1_Scaled/MolCLR_[X, CYP3A4, X, X]_sc-12.13_1617.pt', 'ckpts/MTL1_Scaled/MolCLR_[X, X, X, Solubility]_sc-12.13_1631.pt', 'ckpts/MTL1_Scaled/MolCLR_[X, CYP3A4, Clearance, Solubility]_sc-12.06_1445.pt', 'ckpts/MTL1_Scaled/MolCLR_[BBB, X, Clearance, Solubility]_sc-12.06_0801.pt', 'ckpts/MTL1_Scaled/MolCLR_[X, CYP3A4, Clearance, X]_sc-12.13_1617.pt', 'ckpts/MTL1_Scaled/MolCLR_[BBB, CYP3A4, Clearance, X]_sc-12.06_0914.pt', 'ckpts/MTL1_Scaled/MolCLR_[BBB, X, Clearance, X]_sc-12.06_0843.pt', 'ckpts/MTL1_Scaled/MolCLR_[BBB, CYP3A4, X, Solubility]_sc-12.06_0759.pt', 'ckpts/MTL1_Sca

In [4]:
from src.testbench_code import TestbenchHelper

helper = TestbenchHelper(args.device, args.pred_layer_depth)

In [5]:
def eval_task(model, tdc, predIndex, scaled):
    loader = helper.get_testloader(tdc)
    criterion = helper.get_criterion(tdc)
    
    model.eval()  # Set to eval mode
    eval_loss = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(args.device)
            label = batch.y.squeeze()
            pred = model(batch)
            pred = pred[:, predIndex]
            
            if scaled:
                scaler = helper.get_scaler(tdc)
                pred = scaler.inverse_transform(pred.cpu().reshape(-1, 1))
                pred = torch.Tensor(pred.flatten()).to(args.device)
            eval_loss += criterion(pred, label).item()
    avg_eval_loss = eval_loss / len(loader)

    return avg_eval_loss

In [6]:
def eval_accuracy(model, tdc, predIndex):
    loader = helper.get_testloader(tdc)
    criterion = helper.get_criterion(tdc)
    
    preds = []
    ys = []
    with torch.no_grad():
         for batch in loader:
            batch = batch.to(args.device)
            label = batch.y.squeeze()
            pred = model(batch)
            pred = pred[:, predIndex]
            
            preds.append(pred.reshape(-1))
            ys.append(label.reshape(-1))

    preds = torch.cat(preds, dim=0) # flatten into 1 dimension
    ys = torch.cat(ys, dim=0)

    pred_final = torch.nn.Sigmoid()(preds)
    correct = (torch.abs(pred_final - ys) < 0.5).float().sum()
    accuracy = 100 * correct / len(pred_final)
    
    return accuracy

In [None]:
import torch
import pandas as pd

df = pd.DataFrame()
for tdc in TDC.allList:
    df[str(tdc)] = pd.Series(dtype='object')
    if not tdc.isRegression():
        df[f"{tdc}_Acc"] = pd.Series(dtype='object')

# pred = [tdc1, tdc2, ...] with order specified in path
for model_path in modelPathList:
    tdcList, scaled = extract_tdcList(model_path)
    model = helper.get_model(tdcList)
    model.load_state_dict(torch.load(model_path, map_location=args.device))
    model.eval()  # Set to eval mode
    
    for tdc in tdcList:
        predIndex = tdcList.index(tdc)
        test_loss = eval_task(
            model, tdc=tdc, predIndex=predIndex, scaled=scaled and tdc.isRegression()
        )
        df.at[model_path, str(tdc)] = f"{test_loss:.4f}"
        
        if not tdc.isRegression():
            accuracy = eval_accuracy(model, tdc=tdc, predIndex=predIndex)
            df.at[model_path, f"{tdc}_Acc"] = f"{accuracy:.1f}"

df.to_csv(args.output)