This script evaluates two models trained on the datasets $ChEBI_{v200}^{854}$ and $ChEBI_{v148}^{709}$.

In [12]:
from chebai.preprocessing.datasets.chebi import ChEBIOver100
from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.models.electra import Electra
from chebai.models.base import ChebaiBaseNet

from torchmetrics.classification import MultilabelF1Score
import numpy as np
from chebai.result import pretraining as eval_pre
from chebai.preprocessing.datasets.pubchem import PubChemDeepSMILES
from chebai.result.base import ResultProcessor
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import chebai.models.electra as electra
from chebai.loss.pretraining import ElectraPreLoss
import torch
import tqdm

DEVICE = "cpu"

In [13]:
model_path_v200 = os.path.join("models", "electra_c100_bce_unweighted.ckpt")
model_path_v148 = os.path.join("models", "electra_c100_bce_unweighted_v148.ckpt")

model_v200 = Electra.load_from_checkpoint(model_path_v200).to(DEVICE)
model_v148 = Electra.load_from_checkpoint(model_path_v148).to(DEVICE)

data_module_v200 = ChEBIOver100(chebi_version=200)
data_module_v148 = ChEBIOver100(chebi_version=200, chebi_version_train=148)

In [14]:
classes_file_v200 = "classes.txt"
classes_file_v148 = f"classes_v148.txt"
with open(os.path.join(data_module_v200.raw_dir, classes_file_v200), "r") as file:
    v200_classes = file.readlines()
with open(os.path.join(data_module_v148.raw_dir, classes_file_v148), "r") as file:
    v148_classes = file.readlines()

In [15]:
# get list of classes that appear in v200 and v148
common_classes = []
for v200_class in v200_classes:
    if v200_class in v148_classes:
        common_classes.append(v200_class)
# get filter if a class in v200/v148 is a common class
common_classes_mask_v200 = torch.tensor([[c in common_classes for c in v200_classes]])
common_classes_mask_v148 = torch.tensor([[c in common_classes for c in v148_classes]])

In [16]:
print(f"Number of classes in ChEBI_v148: {len(v148_classes)}")
print(f"Number of classes in ChEBI_v200: {len(v200_classes)}")
print(f"Number of classes in both versions: {len(common_classes)}")

Number of classes in ChEBI_v148: 709
Number of classes in ChEBI_v200: 854
Number of classes in both versions: 701


In [12]:
def evaluate_model(
    model: ChebaiBaseNet,
    data_module: XYBaseDataModule,
    common_classes_mask=None,
    test_file=None,
):
    collate = data_module.reader.COLLATER()
    if test_file is None:
        test_file = data_module.processed_file_names_dict["test"]
    data_path = os.path.join(data_module.processed_dir, test_file)
    data_list = torch.load(data_path)
    preds_list = []
    labels_list = []

    for row in tqdm.tqdm(data_list):
        processable_data = model._process_batch(collate([row]), 0)
        model_output = model(processable_data)
        preds, labels = model._get_prediction_and_labels(
            processable_data, processable_data["labels"], model_output
        )
        if common_classes_mask is not None:
            preds = preds[common_classes_mask]
            labels = labels[common_classes_mask]
            preds_list.append(preds.unsqueeze(0))
            labels_list.append(labels.unsqueeze(0))
        else:
            preds_list.append(preds)
            labels_list.append(labels)

    test_preds = torch.cat(preds_list)
    test_labels = torch.cat(labels_list)
    print(test_preds.shape)
    print(test_labels.shape)
    f1_macro = MultilabelF1Score(test_preds.shape[1], average="macro")
    f1_micro = MultilabelF1Score(test_preds.shape[1], average="micro")
    print(
        f"Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}"
    )
    print(
        f"Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}"
    )

In [11]:
evaluate_model(model_v200, data_module_v200)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16486/16486 [07:24<00:00, 37.09it/s]


torch.Size([16486, 854])
torch.Size([16486, 854])
Macro-F1 on test set with 854 classes: 0.603181
Micro-F1 on test set with 854 classes: 0.902437


In [40]:
evaluate_model(model_v200, data_module_v200, common_classes_mask_v200)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16486/16486 [07:21<00:00, 37.36it/s]


torch.Size([16486, 701])
torch.Size([16486, 701])
Macro-F1 on test set with 701 classes: 0.623063
Micro-F1 on test set with 701 classes: 0.905059


In [41]:
evaluate_model(model_v148, data_module_v148)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16486/16486 [05:31<00:00, 49.69it/s]


torch.Size([16486, 709])
torch.Size([16486, 709])
Macro-F1 on test set with 709 classes: 0.513283
Micro-F1 on test set with 709 classes: 0.854591


In [42]:
evaluate_model(model_v148, data_module_v148, common_classes_mask_v148)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16486/16486 [05:16<00:00, 52.07it/s]


torch.Size([16486, 701])
torch.Size([16486, 701])
Macro-F1 on test set with 701 classes: 0.519968
Micro-F1 on test set with 701 classes: 0.855442
