In [None]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from catboost import CatBoostClassifier

import pickle

from tqdm.contrib.concurrent import thread_map

In [None]:
def load_dataset(folder_path: str):

    data_frames = {}
    corrupt_files = {}

    def proc(csv):
        try:
            df = pd.read_csv(csv)
            data_frames[csv.stem] = df
        except:
            corrupt_files[csv.stem] = True


    thread_map(proc, Path(folder_path).glob("*.csv"))

    combined_df = pd.concat(data_frames.values(), ignore_index=True)

    combined_df = combined_df[(combined_df["aa"].notna()) & (combined_df["dssp"].notna())]
    combined_df['dssp3'] = combined_df['dssp'].map({
        'H': 'H',
        'G': 'H',
        'I': 'H',
        'P': 'H',
        'B': 'E',
        'E': 'E',
        '.': '.',
        'T': '.',
        'S': '.'
    })

    return combined_df

In [None]:
dfs = {
    'alpha': "data/ca-features",
    'beta': "data/cb-features",
    'com': "data/com-features"
}

In [None]:
def preproc_catboost(X, y):

    X_processed = X.copy()
    X_processed['aa'] = X_processed['aa'].astype('category')

    y_processed = y.astype('category')

    return X_processed, y_processed

def train_catboost(X_train, y_train):

    X_train, y_train = preproc_catboost(X_train, y_train)

    model = CatBoostClassifier(
        loss_function='MultiClass',
        cat_features=['aa'],
        eval_metric='Accuracy'
    )

    model.grid_search(
        X=X_train,
        y=y_train,
        param_grid={
            'iterations': [100, 500, 1000],
            'depth': [4, 7, 10],
            'learning_rate': [0.1, 0.5, 0.75]
        },
        cv=3,
        verbose=False
    )

    return model

def test_catboost(X_test, y_test, model):
    X_test, y_test = preproc_catboost(X_test, y_test)

    y_pred = model.predict(X_test)

    cm = confusion_matrix(y_test, y_pred)

    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=model.classes_, yticklabels=model.classes_)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix - CatBoost')

    return classification_report(y_test, y_pred)

In [None]:
for data_name, df_path in dfs.items():

    df = load_dataset(df_path)

    dist_cols = [col for col in df.columns if col.startswith('dist_')]
    angle_cols = [col for col in df.columns if col.startswith('angle_') or col.startswith('dihedral_')]
    neighbor_cols = [col for col in df.columns if col.startswith('neighbor_')]

    X = df[angle_cols + dist_cols + neighbor_cols +["aa"]]
    y = df["dssp3"]

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )

    print(f"Training on dataset {data_name}")

    model_key = f"catboost_{data_name}"
    outname = f"models/{model_key}_model.pkl"
    model = train_catboost(X_train, y_train)

    pickle.dump(model, open(outname, "wb"))

In [None]:
# Test models on other data
for data_name, df_path in dfs.items():

    df = load_dataset(df_path)

    dist_cols = [col for col in df.columns if col.startswith('dist_')]
    angle_cols = [col for col in df.columns if col.startswith('angle_') or col.startswith('dihedral_')]
    neighbor_cols = [col for col in df.columns if col.startswith('neighbor_')]

    for model_file in Path("models").glob("*.pkl"):

        model = pickle.load(open(model_file, "rb"))

        X = df[angle_cols + dist_cols + neighbor_cols +["aa"]]
        y = df["dssp3"]

        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42, stratify=y
        )

        model_name = "_".join(model_file.stem.split("_")[:-2])
        report = test_catboost(X_test, y_test, model)

        eval_name = f"{model_file.stem.replace("_model", "")}_eval_{data_name}"

        report_name = f"reports/{eval_name}_report.txt"

        with open(report_name, "w") as f:
            f.write(report)

        plt.savefig(f"reports/{eval_name}.png")
        plt.clf()