In [None]:
import pandas as pd

from chebai.result.classification import (
    evaluate_model,
    load_results_from_buffer,
    print_metrics,
)
from chebai.models.electra import Electra
from chebai.preprocessing.datasets.chebi import ChEBIOver50, ChEBIOver100
import os
import tqdm
import torch
import pickle

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
checkpoint_name = "best_epoch=99_val_loss=0.0096_val_macro-f1=0.5358_val_micro-f1=0.8968"
checkpoint_path = os.path.join("logs", f"{checkpoint_name}.ckpt")
kind = "test"  # replace with "train" / "validation" to run on train / validation sets
buffer_dir = os.path.join("results_buffer", checkpoint_name, kind)
# make sure to use the same data module and model class that were used during training
data_module = ChEBIOver50(
    chebi_version=227, 
)
model_class = Electra

In [None]:
# evaluates model, stores results in buffer_dir
model = model_class.load_from_checkpoint(checkpoint_path)
preds, labels = evaluate_model(
        model,
        data_module,
        buffer_dir=buffer_dir,
        filename=data_module.processed_file_names_dict[kind],
        batch_size=10,
    )

In [None]:
# load data from buffer_dir
load_results_from_buffer(buffer_dir, device=DEVICE)
with open(os.path.join(data_module.raw_dir, "classes.txt"), "r") as f:
    classes = [line.strip() for line in f.readlines()]

In [None]:
# output relevant metrics
print_metrics(
    preds,
    labels.to(torch.int),
    DEVICE,
    classes=classes,
    markdown_output=False,
    top_k=10,
)