In [1]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import OneHotEncoder
from xgboost import XGBClassifier
from catboost import CatBoostClassifier

import pickle

In [2]:
def load_dataset(folder_path: str):

    data_frames = []
    for csv in Path(folder_path).glob("*.csv"):
        df = pd.read_csv(csv)
        data_frames.append(df)

    combined_df = pd.concat(data_frames, ignore_index=True)

    combined_df = combined_df[(combined_df["aa"].notna()) & (combined_df["dssp"].notna())]
    combined_df['dssp3'] = combined_df['dssp'].map({
        'H': 'H',
        'G': 'H',
        'I': 'H',
        'P': 'H',
        'B': 'B',
        'E': 'B',
        '.': '.',
        'T': '.',
        'S': '.'
    })

    return combined_df

In [3]:
df_alpha = load_dataset("data/ca-features")
df_beta = load_dataset("data/cb-features")
df_com = load_dataset("data/com-features")

dfs = {
    'alpha': df_alpha,
    'beta': df_beta,
    'com': df_com,
}

In [4]:
xgb_model = pickle.load(open("models/xgb_model.pkl", "rb"))
rf_model = pickle.load(open("models/rf_model.pkl", "rb"))
cb_model = pickle.load(open("models/cb_model.pkl", "rb"))

models = {
    "rf": rf_model,
    "xgb": xgb_model,
    "catboost": cb_model
}

In [5]:
def preproc_random_forest(X, y):
    aa_enc = OneHotEncoder(sparse_output=False)
    aa_encoded = aa_enc.fit(X[["aa"]]).transform(X[["aa"]])

    aa_encoded_df = pd.DataFrame(aa_encoded, columns=aa_enc.get_feature_names_out(["aa"]))
    X_processed = pd.concat([X.reset_index(drop=True).drop(columns=["aa"]), aa_encoded_df.reset_index(drop=True)], axis=1)

    return X_processed, y

def test_random_forest(X_test, y_test, model):
    X_test, y_test = preproc_random_forest(X_test, y_test)

    y_pred = model.predict(X_test)

    print("Classification Report - Random Forest:")
    print(classification_report(y_test, y_pred))

    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=model.classes_, yticklabels=model.classes_)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix - Random Forest')
    plt.show()

In [6]:
xgboost_label_mapping = {'H': 0, 'B': 1, '.': 2}

def preproc_xgboost(X, y):

    X_processed = X.copy()
    X_processed['aa'] = X_processed['aa'].astype('category')

    y_processed = y.map(xgboost_label_mapping)

    return X_processed, y_processed

def test_xgboost(X_test, y_test, model):
    X_test, y_test = preproc_xgboost(X_test, y_test)

    y_pred = model.predict(X_test)

    keys = list(xgboost_label_mapping.keys())
    labels = list(xgboost_label_mapping.values())

    print("Classification Report - XGBoost:")
    print(classification_report(
        y_test,
        y_pred,
        target_names=keys
        )
    )

    cm = confusion_matrix(
        y_test,
        y_pred
    )
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=keys, yticklabels=keys)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix - XGBoost')
    plt.show()

In [7]:
def preproc_catboost(X, y):

    X_processed = X.copy()
    X_processed['aa'] = X_processed['aa'].astype('category')

    y_processed = y.astype('category')

    return X_processed, y_processed

def test_catboost(X_test, y_test, model):
    X_test, y_test = preproc_catboost(X_test, y_test)

    y_pred = model.predict(X_test)

    print("Classification Report - CatBoost:")
    print(classification_report(y_test, y_pred))

    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=model.classes_, yticklabels=model.classes_)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix - CatBoost')
    plt.show()

In [None]:
results = []

for df in dfs:
    for model in models:

        
