In [36]:
from pathlib import Path
import pandas as pd
from catboost import CatBoostClassifier
import xgboost as xgb
import lightgbm as lgb

def load_data(data_path: Path):
    df_train = pd.read_csv(data_path / "train.csv")
    df_valid = pd.read_csv(data_path / "val.csv")
    df_test = pd.read_csv(data_path / "test.csv")

    X_train = df_train.drop(columns=["label"])
    y_train = df_train["label"]
    X_valid = df_valid.drop(columns=["label"])
    y_valid = df_valid["label"]
    X_test = df_test.drop(columns=["label"])
    y_test = df_test["label"]

    return X_train, y_train, X_valid, y_valid, X_test, y_test

def load_model(model_path: Path):
    try:
        model = CatBoostClassifier()
        model.load_model(model_path)
        return model
    except Exception:
        pass
    
    try:
        model = xgb.XGBClassifier()
        model.load_model(model_path)
        return model
    except Exception:
        pass
    
    try:
        model = lgb.Booster(model_file=model_path)
        return model
    except Exception:
        pass

    return None

In [42]:
from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef
import numpy as np

def evaluate_model(clf, X_test, y_test, metric):
    y_pred = clf.predict(X_test)
    if y_pred.ndim == 1:
        y_pred = (y_pred > 0.5).astype(int)
    else:
        y_pred = np.argmax(y_pred, axis=1)
        
    if metric == "F1":
        return f1_score(y_test, y_pred, average="binary")
    elif metric == "Accuracy":
        return accuracy_score(y_test, y_pred)
    else:
        return matthews_corrcoef(y_test, y_pred)

In [46]:
embeddings_dir = Path("../data/embeddings")

DATASETS = {
    "enhancers": embeddings_dir / "enhancers",
    "promoter_all": embeddings_dir / "promoter_all",
    "splice_sites_all": embeddings_dir / "splice_sites_all",
    "H3K9me3": embeddings_dir / "H3K9me3",
    "H4K20me1": embeddings_dir / "H4K20me1",
}

METRICS = {
    "enhancers": "MCC",
    "promoter_all": "F1",
    "splice_sites_all": "Accuracy",
    "H3K9me3": "MCC",
    "H4K20me1": "MCC",
}

save_model_dir = Path("../models")

In [47]:
import os

def evaluate_models():
    results = {}
    for dataset_name in DATASETS:
        results[dataset_name] = {}
        dataset_path = DATASETS[dataset_name]
        metric = METRICS[dataset_name]
        _, _, _, _, X_test, y_test = load_data(dataset_path)
        models_dir = save_model_dir / dataset_name
        for file in os.listdir(models_dir):
            if file.endswith(".pkl"):
                model_name = file.split(".")[0]
                model_path = models_dir / file
                model = load_model(model_path)
                print(f"Evaluating {model_name} on {dataset_name} with {metric} metric")
                score = evaluate_model(model, X_test, y_test, metric)
                results[dataset_name][model_name] = score
    return results

In [48]:
results = evaluate_models()

Evaluating xgboost on enhancers with MCC metric
Evaluating catboost on enhancers with MCC metric
Evaluating lightgbm on enhancers with MCC metric
Evaluating xgboost on promoter_all with F1 metric
Evaluating catboost on promoter_all with F1 metric
Evaluating lightgbm on promoter_all with F1 metric
Evaluating xgboost on splice_sites_all with Accuracy metric
Evaluating catboost on splice_sites_all with Accuracy metric
Evaluating lightgbm on splice_sites_all with Accuracy metric
Evaluating xgboost on H3K9me3 with MCC metric
Evaluating catboost on H3K9me3 with MCC metric
Evaluating lightgbm on H3K9me3 with MCC metric
Evaluating xgboost on H4K20me1 with MCC metric
Evaluating catboost on H4K20me1 with MCC metric
Evaluating lightgbm on H4K20me1 with MCC metric


In [52]:
from tabulate import tabulate

headers = ['Dataset', 'XGBoost', 'CatBoost', 'LightGBM']
table_data = []

for dataset, models in results.items():
    row = [
        dataset,
        f"{models['xgboost']:.4f}",
        f"{models['catboost']:.4f}",
        f"{models['lightgbm']:.4f}"
    ]
    table_data.append(row)

print(tabulate(table_data, headers=headers, tablefmt='grid'))

+------------------+-----------+------------+------------+
| Dataset          |   XGBoost |   CatBoost |   LightGBM |
| enhancers        |    0.4721 |     0.4758 |     0.4638 |
+------------------+-----------+------------+------------+
| promoter_all     |    0.8508 |     0.85   |     0.8557 |
+------------------+-----------+------------+------------+
| splice_sites_all |    0.4807 |     0.343  |     0.5437 |
+------------------+-----------+------------+------------+
| H3K9me3          |    0.2801 |     0.2896 |     0.2894 |
+------------------+-----------+------------+------------+
| H4K20me1         |    0.5742 |     0.5832 |     0.5843 |
+------------------+-----------+------------+------------+
