In [53]:
import pandas as pd
import pickle
import numpy as np
import joblib
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GroupShuffleSplit
from train_src import merge_embed, filter_train, make_balanced_df, make_lama_df, merge_embed, Metrics, form_Xy


from lightautoml.automl.presets.tabular_presets import TabularAutoML
from lightautoml.tasks import Task


SEED = 42

In [54]:
roles = {
    "target": "label",
    "group": "cluster"
}

task = Task("binary")

In [55]:
df = pd.read_csv("../data/ready_data/train_p2.csv")
df = df.sample(frac=1, random_state=SEED)

In [56]:
test_df = df[df["source"] == "pdb2272"]

In [57]:
train = df[df["source"] == "train_p2"]
train = filter_train(train, test_df)

In [58]:
from sklearn.model_selection import train_test_split

esm_train, ankh_train = train_test_split(train, test_size=0.5, random_state=SEED)

In [59]:
esm_train = merge_embed(esm_train, "../data/embeddings/esm_embeddings/train_p2.pkl")
X_train, y_train, clusters_train = form_Xy(esm_train, clusters="Yes")
esm_train = make_lama_df(X_train, y_train, clusters=clusters_train)

In [60]:
esm_model = TabularAutoML(
    task = task,
    reader_params = {'random_state': SEED})

oof_pred = esm_model.fit_predict(
    esm_train,
    roles = roles)

KeyboardInterrupt: 

In [None]:
valid_prob = oof_pred.data[:, 0]
valid_pred = (valid_prob > 0.5) * 1
metrics = Metrics(y_train, valid_pred, valid_prob, "valid")
valid_metrics = metrics.get_metrics()
valid_metrics

In [28]:
esm_test = merge_embed(test_df, "../data/embeddings/esm_embeddings/pdb2272.pkl")
X_test, y_test = form_Xy(esm_test)
esm_test = make_lama_df(X_test, y_test)

In [30]:
test_pred = esm_model.predict(esm_test)
test_prob = test_pred.data.reshape(-1, )
test_pred = (test_pred.data[:, 0] > 0.5) * 1
metrics = Metrics(y_test, test_pred, test_prob, "esm_pdb2272")
test_metrics = metrics.get_metrics()
test_metrics

Unnamed: 0,accuracy,sensitivity,specificity,precision,AUC,F1,MCC
esm_pdb186,0.790323,0.784946,0.795699,0.793478,0.840791,0.789189,0.580679


In [31]:
ankh_train = merge_embed(ankh_train, "../data/embeddings/ankh_embeddings/train_p2.pkl")
X_train, y_train, clusters_train = form_Xy(ankh_train, clusters="Yes")
ankh_train = make_lama_df(X_train, y_train, clusters=clusters_train)

In [32]:
ankh_model = TabularAutoML(
    task = task,
    reader_params = {'random_state': SEED})

oof_pred = ankh_model.fit_predict(
    ankh_train,
    roles = roles)

Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric period is 5 because AUC is/are not implemented for GPU
Default metric perio

In [33]:
valid_prob = oof_pred.data[:, 0]
valid_pred = (valid_prob > 0.5) * 1
metrics = Metrics(y_train, valid_pred, valid_prob, "valid")
valid_metrics = metrics.get_metrics()
valid_metrics

Unnamed: 0,accuracy,sensitivity,specificity,precision,AUC,F1,MCC
valid,0.956363,0.939745,0.973006,0.972117,0.989274,0.955657,0.913236


In [34]:
ankh_test = merge_embed(test_df, "../data/embeddings/ankh_embeddings/pdb2272.pkl")
X_test, y_test = form_Xy(ankh_test)
ankh_test = make_lama_df(X_test, y_test)

In [36]:
test_pred = ankh_model.predict(ankh_test)
test_prob = test_pred.data.reshape(-1, )
test_pred = (test_pred.data[:, 0] > 0.5) * 1
metrics = Metrics(y_test, test_pred, test_prob, "ankh_pdb2272")
test_metrics = metrics.get_metrics()
test_metrics

Unnamed: 0,accuracy,sensitivity,specificity,precision,AUC,F1,MCC
ankh_pdb186,0.790323,0.763441,0.817204,0.806818,0.863915,0.78453,0.581486


Усреднение

In [39]:
ankh_pred = ankh_model.predict(ankh_test)
esm_pred = esm_model.predict(esm_test)

In [43]:
ankh_pred = ankh_pred.data
esm_pred = esm_pred.data

In [47]:
avg_prob = np.concatenate((ankh_pred, esm_pred), axis=1).mean(axis=1).reshape(-1, 1)

In [50]:
avg_pred = (avg_prob[:, 0] > 0.5) * 1

In [52]:
metrics = Metrics(y_test, avg_pred, avg_prob, "avg_pdb2272")
test_metrics = metrics.get_metrics()
test_metrics

Unnamed: 0,accuracy,sensitivity,specificity,precision,AUC,F1,MCC
avg_pdb186,0.817204,0.817204,0.817204,0.817204,0.864724,0.817204,0.634409
