In [None]:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from argparse import ArgumentParser
import random
import torch
from torch.nn import init
import pandas as pd
from sklearn.metrics import roc_auc_score,roc_curve, auc
import torch.utils.data as data
from PIL import Image
import os
import os.path
import scipy.io
import csv
import torchvision
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from imblearn.under_sampling import RandomUnderSampler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score


In [None]:

def load_hla_dataframe(fname):
    df = pd.read_table(fname, skiprows=[0], sep=',', names=['CDR3', 'Antigen', 'HLA', 'binder', 'MHC'])
    return df

def compute_metrics1(y_labels, y_preds, cutoff=0.5, digit=3):
    # If a valid threshold is not specified, use Youden's J statistic to choose the optimal threshold
    if (cutoff <= 0) | (cutoff >= 1):
        fpr, tpr, threshold = roc_curve(y_labels, y_preds)
        cutoff = sorted(list(zip(np.abs(tpr - fpr), threshold)), key=lambda i: i[0], reverse=True)[0][1]
        pass
    
    y_pred_labels = np.array([1. if p >= cutoff else 0. for p in y_preds])
    Accuracy = accuracy_score(y_labels, y_pred_labels) # Accuracy
    # Precision = precision_score(y_labels, y_pred_labels) # Precision
    Recall = recall_score(y_labels, y_pred_labels) # Recall (Sensitivity)
    Specificity = recall_score(1 - y_labels, 1 - y_pred_labels) # Specificity
    F1 = f1_score(y_labels, y_pred_labels) # F1 score
    AUC = roc_auc_score(y_labels, y_preds) # AUC
    
    return {
        'AUC':         np.round(AUC,         digit),
        'Accuracy':    np.round(Accuracy,    digit),
        # 'Precision':   np.round(Precision,   digit),
        'Sensitivity': np.round(Recall,      digit),
        'Specificity': np.round(Specificity, digit),
        'Threshold':   np.round(cutoff,      digit),
        'F1':          np.round(F1,          digit),
    }


# Set LightGBM parameters
params = {
    'objective': 'binary',
    'metric': 'binary_error',
    'boosting_type': 'gbdt',
    'learning_rate': 0.03,
    'num_leaves': 51,
    'feature_fraction': 0.62,
    'bagging_fraction': 1,
    'bagging_freq': 6,
    'verbose': 0,
    'max_depth': 20,  # Set maximum tree depth
}

### TCR-pMHC train and test

In [None]:
def train_and_test(df, data_1, data_2, params):

    labels = list(df.binder)
    labels = np.array(labels)
    labels = labels.reshape(-1, 1)

    CDR3_flattened = data_2.reshape(data_2.shape[0], -1)
    features = np.concatenate((data_1, CDR3_flattened), axis=1)

    # Split dataset
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=342)

    train_data = lgb.Dataset(X_train, label=y_train)
    # test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)

    # Train LightGBM model
    bst = lgb.train(
        params,
        train_data,
        num_boost_round=400,
        # valid_sets=[test_data],
        # callbacks=[lgb.early_stopping(stopping_rounds=30)]
    )

    # Make predictions on the test set
    test_predictions = bst.predict(X_test, num_iteration=bst.best_iteration)
    test_predicted_labels = np.where(test_predictions > 0.5, 1, 0)

    print(compute_metrics1(y_test, test_predictions))


In [None]:
################### Load Data and Train Model ###################

data_1 = torch.load("PMHC_encodeing/PMHC_result/pmhc_large_data_encoding.pt")
data_2 = torch.load("TCR_encodeing/TCR_result/TCR_encoding/TCR_large_encoding.pt")
df = load_hla_dataframe('PMHC_encodeing/data/large_data.csv')

train_and_test(df, data_1, data_2, params)


## Generalization ability test

In [None]:
########### Define Training and Testing Functions ####################

def train(df, data_1, data_2, params):
    labels = list(df.binder)
    # print(len(labels))
    labels = np.array(labels)
    labels = labels.reshape(-1, 1)

    # Flatten and merge data
    CDR3_flattened = data_2.reshape(data_2.shape[0], -1)
    features = np.concatenate((data_1, CDR3_flattened), axis=1)

    # Random under-sampling of data
    nm = RandomUnderSampler(random_state=42)
    X_res, y_res = nm.fit_resample(features, labels)

    # Create LightGBM dataset
    train_data = lgb.Dataset(X_res, label=y_res)

    bst = lgb.train(
        params,
        train_data,
        num_boost_round=400,
        # valid_sets=[test_data],
    )
    print('trian done !!!!')
    
    return bst

def test(df_test, data_test_1, data_test_2, bst):
    labels_test = list(df_test.binder)
    # print(len(labels_test))
    labels_test = np.array(labels_test)
    labels_test = labels_test.reshape(-1, 1)

    # Flatten and merge test data
    test_CDR3_flattened = data_test_2.reshape(data_test_2.shape[0], -1)
    test_features = np.concatenate((data_test_1, test_CDR3_flattened), axis=1)

    # Make predictions using the test set
    test_predictions = bst.predict(test_features, num_iteration=bst.best_iteration)
    test_predicted_labels = np.where(test_predictions > 0.5, 1, 0)

    # Compute and print metrics
    print(compute_metrics1(labels_test, test_predictions))


In [None]:
############### Load Training Dataset and Train Model ######################

numpy_array = np.load('PMHC_encodeing/PMHC_result/epiTCR_data_encoding/pmhc_train_encoding.npy')
data_1 = torch.from_numpy(numpy_array)
data_2 = torch.load("TCR_encodeing/TCR_result/TCR_encoding/TCR_train.pt")
df = load_hla_dataframe('PMHC_encodeing/data/epiTCR_data/train.csv')

bst = train(df, data_1, data_2, params)


In [None]:
############# Load Testing Dataset and Test Model ################

data_test_1 = torch.load("PMHC_encodeing/PMHC_result/epiTCR_data_encoding/pmhc_test01_encoding.pt")
data_test_2 = torch.load("TCR_encodeing/TCR_result/TCR_encoding/TCR_test01.pt")
df_test = load_hla_dataframe('PMHC_encodeing/data/epiTCR_data/test01.csv')

test(df_test, data_test_1, data_test_2, bst)
