In [None]:
import sys
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts


from sklearn.metrics import roc_auc_score, recall_score, precision_score, accuracy_score, f1_score
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve
from sklearn.metrics import auc

import numpy as np
import pandas as pd

from tqdm import tqdm
import os
import json
from IPython import display  # needed to plot training statistics
import time
import matplotlib.pyplot as plt
from seaborn import heatmap

import pylab as pl
%matplotlib inline

# Data prep

In [None]:
main_path = os.path.abspath('')

pheno = "diab"

# paths
## input paths

vcf_path = f"./data/ext_prs.90k.{pheno}.vcf"
ordered_target_path = f"./data/phenotype.{pheno}.ordered"
ordered_covariates_path = f"./data/cov.{pheno}.ordered"

## output paths
target_output_path = os.path.join(main_path, "data", f"target_{pheno}.csv")
transposed_feature_matrix_path = os.path.join(
    main_path, "data", f"feature_matrix_{pheno}.csv")
feature_cov_path = os.path.join(
    main_path, "data", f"feature_cov_matrix_{pheno}.csv")
feature_cov_hla_path = os.path.join(
    main_path, "data", f"feature_cov_hla_matrix_{pheno}.csv")

# Model (MLP)

In [None]:
params = {'run': 409,
 'overall_epochs': 200,
 'lr': 0.001,
 'run_id': 1684699473,
 'bn_momentum': 0.9,
 'first_dropout': 0.9,
 'other_dropouts': 0.9,
 'lin1_output': 100,
 'lin2_output': 50,
 'lin3_output': 10,
 'test_ROC_AUC': 0.8387424608828513,
 'test_recall': 0.6956521739130435,
 'test_precision': 0.022408963585434174,
 'confusion_matrix': {'TP': 12523.0, 'TN': 48.0, 'FP': 2094.0, 'FN': 21.0},
 'test_accuracy': 0.8559852921149393,
 'test_PR_AUC': 0.04604731954223416,
 'F1-score': 0.04341926729986431,
 'best_epochs': 94}

In [None]:
# model info
bn_momentum = params["bn_momentum"]
first_dropout = params["first_dropout"]
other_dropouts = params["other_dropouts"]
lin1_output = params["lin1_output"]
lin2_output = params["lin2_output"]
lin3_output = params["lin3_output"]


batch_size = 4096
learning_rate = params["lr"]
epochs = 200

In [None]:
# stopper
min_delta = 0.08
patience = 15

## Dataset and DataLoader

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'

In [None]:
device

In [None]:
df = pd.read_csv(target_output_path, header=None)
print(df.shape)
from sklearn.model_selection import train_test_split

train, test = train_test_split(df, test_size=0.3, stratify=df, random_state=5)
train_index = train.index
test_index = test.index

In [None]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [None]:
class PRS_Dataset(Dataset):

    """ 
    Loads features and tatget. take = n -> take *first* n entries for dataset (if train = True);
    take *last* n entries for dataset (if train = False) 
    """

    def __init__(self, x_path, y_path, take, train):

        y = np.loadtxt(y_path, delimiter=',', dtype=np.float32) # тут увеличил размер
        print(y.shape)
        x = np.loadtxt(x_path, delimiter=',', dtype=np.float32)
        
        if train:
            x = x[train_index]
            y = y[train_index]


        else:
            x = x[test_index]
            y = y[test_index]

        self.x_data = torch.from_numpy(x).to(torch.float32)
        self.y_data = torch.from_numpy(y).to(torch.float32)
        self.y_data = self.y_data.unsqueeze(1)
        print(f"x_data {self.x_data.shape}")
        print(f"y_data {self.y_data.shape}")

    # support indexing such that dataset[i] can be used to get i-th sample
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    # we can call len(dataset) to return the size
    def __len__(self):
        return len(self.y_data)

In [None]:
mode = "gen_cov_hla"

if mode == "gen":
    feature_path = transposed_feature_matrix_path
elif mode == "cov":
    feature_path = ordered_covariates_path
elif mode == "gen_cov":
    feature_path = feature_cov_path
elif mode == "gen_cov_hla":
    feature_path = feature_cov_hla_path

In [None]:
train_dataset = PRS_Dataset(feature_path, target_output_path,
                            take=0.7, train=True)
test_dataset = PRS_Dataset(feature_path, target_output_path,
                           take=0.3,  train=False)


train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=2)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=2)

In [None]:
example = iter(test_loader)
x_batch, y_batch = next(example)
input_size = x_batch.shape[1]
print(x_batch.shape)

## Model

In [None]:
class P1(nn.Module):
    def __init__(self, input_size, bn_momentum=0.9, first_dropout=0.9, other_dropouts=0.9,
                 lin1_output=1000, lin2_output=250, lin3_output=50):
        super(P1, self).__init__()
        self.lin1 = nn.Linear(input_size, lin1_output)
        self.bn1 = nn.BatchNorm1d(lin1_output, momentum=bn_momentum)
        self.lin2 = nn.Linear(lin1_output, lin2_output)
        self.bn2 = nn.BatchNorm1d(lin2_output, momentum=bn_momentum)

        self.lin3 = nn.Linear(lin2_output, lin3_output)
        self.bn3 = nn.BatchNorm1d(lin3_output, momentum=bn_momentum)

        self.lin4 = nn.Linear(lin3_output, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        self.dropout_first = nn.Dropout(p=first_dropout)
        self.dropout_other = nn.Dropout(p=other_dropouts)

    def forward(self, X):
        X = self.lin1(X)
        X = self.relu(X)
        X = self.bn1(X)
        X = self.dropout_first(X)

        X = self.lin2(X)
        X = self.relu(X)
        X = self.bn2(X)
        X = self.dropout_other(X)

        X = self.lin3(X)
        X = self.relu(X)
        X = self.bn3(X)

        X = self.dropout_other(X)

        X = self.lin4(X)
        X = self.sigmoid(X)

        return X

In [None]:
model = P1(input_size, bn_momentum=bn_momentum,
               first_dropout=first_dropout, other_dropouts=other_dropouts,
               lin1_output=lin1_output, lin2_output=lin2_output,
               lin3_output=lin3_output).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCELoss()
scheduler = CosineAnnealingWarmRestarts(optimizer,
                                            T_0=10,  # Number of iterations for the first restart
                                            T_mult=1,  # A factor increases TiTi after a restart
                                            eta_min=1e-4)  # Minimum learning rate

## Training loop

In [None]:
def plot_stats(loss_history, auc_history):
    """Plot loss and ROC AUC in jupyter notebook"""

    fig, (ax1, ax2) = pl.subplots(1, 2, figsize=(10, 5))

    # loss
    training_loss = loss_history['train']
    test_loss = loss_history['test']

    epoch_count = range(1, len(training_loss) + 1)

    ax1.plot(epoch_count, training_loss, '-r')
    ax1.plot(epoch_count, test_loss, '-b')
    ax1.legend(['Training loss', 'Test loss'])
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')

    # auc
    training_auc = auc_history['train']
    test_auc = auc_history['test']

    ax2.plot(epoch_count, training_auc, '-r')
    ax2.plot(epoch_count, test_auc, '-b')
    ax2.legend(['Training ROC AUC', 'Test ROC AUC'])
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('ROC AUC')

    fig.tight_layout()

    display.clear_output(wait=True)
    display.display(pl.gcf())
    #time.sleep(1.0)

In [None]:
def plot_curves(cutoff, curves, current_params):
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 10))

    # PR curve
    ax1.plot(curves['PR']['recall'], curves['PR']['precision'])
    ax1.title.set_text('Precision-Recall Curve')
    ax1.set_ylabel('Precision')
    ax1.set_xlabel('Recall')

    # ROC curve
    ax2.plot(curves['ROC']['false_positive_rate'], curves['ROC']
             ['true_positive_rate'], label='AUC = %0.2f' % current_params['test_ROC_AUC'])
    ax2.title.set_text('ROC Curve')
    ax2.set_ylabel('True Positive Rate')
    ax2.set_xlabel('False Positive Rate')
    ax2.legend(loc = 'lower right')

    # confusion matrix
    conf_matrix = np.array(([current_params['confusion_matrix']['TP'],
                                     current_params['confusion_matrix']['FP']],
                                   [current_params['confusion_matrix']['FN'],
                                   current_params['confusion_matrix']['TN']]))
    
    

    ax4 = heatmap(conf_matrix, annot=True, fmt=".1f")
    ax4.set(xlabel="Predicted Label", ylabel="True Label")
    ax4.title.set_text('Confusion matrix')

    
    # model info
    text = f"Cutoff: {cutoff}"
        
    ax3.text(0, 0.5, text, ha='left')
    ax3.axis('off')
    
    fig.tight_layout()
    
    dir_path = f"./figures/PR_test"
    if not os.path.exists(dir_path):
        os.mkdir(dir_path)
    fp = f"{dir_path}/PR_cutoff_{cutoff:.2f}.jpg"
    fig.savefig(fp, dpi=300)
            
    plt.close()

# train

In [None]:
auc_history = {'train': [], 'test': []}
aucs = {'train': [], 'test': []}
loss_history = {'train': [], 'test': []}
losses = {'train': [], 'test': []}
best_test_auc = 0  # ищем лучший auc по эпохам
best_epoch = None  # ищем лучшую эпоху
best_model = None

early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)


for epoch in tqdm(range(epochs)):
    for i, (x_batch, y_batch) in enumerate(train_loader):
        model.train()
        # forward pass
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        pred = model(x_batch)
        loss = criterion(pred, y_batch)
        losses['train'].append(np.mean(loss.detach().cpu().numpy()))
        aucs['train'].append(np.mean(roc_auc_score(
            y_batch.detach().cpu().numpy(), pred.detach().cpu().numpy())))
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # check performance on test
    model.eval()
    with torch.no_grad():
        test_example = iter(test_loader)
        x_test, y_test = next(test_example)
        x_test = x_test.to(device)
        y_test = y_test.to(device)
        pred_test = model(x_test)
        loss = criterion(pred_test, y_test)
        losses['test'].append(np.mean(loss.detach().cpu().numpy()))
        aucs['test'].append(np.mean(roc_auc_score(
            y_test.detach().cpu().numpy(), pred_test.detach().cpu().numpy())))

    # scheduler step
    scheduler.step()

    # plot statistics
    loss_history['train'].append(sum(losses['train'])/len(losses['train']))
    loss_history['test'].append(sum(losses['test'])/len(losses['test']))
    test_loss = sum(losses['test'])/len(losses['test'])
    losses = {'train': [], 'test': []}
    auc_history['train'].append(sum(aucs['train'])/len(aucs['train']))
    auc_history['test'].append(sum(aucs['test'])/len(aucs['test']))
    
    curr_test_auc = sum(aucs['test'])/len(aucs['test'])
    
    aucs = {'train': [], 'test': []}

    if curr_test_auc > best_test_auc:  # current best model
        best_test_auc = curr_test_auc
        best_epoch = epoch
        best_model = model.state_dict()

    
    #if epoch%10 == 0:
    #    plot_stats(loss_history, auc_history)   
    
    # early stopper
    if early_stopper.early_stop(test_loss):             
        break

# PR on test data

In [None]:
cutoff_values = [round(0.01*i, 2) for i in range(1, 100, 1)]

In [None]:
pr_search = []

model.load_state_dict(best_model)
model.eval()

for i in tqdm(range(len(cutoff_values))):
    cutoff = cutoff_values[i]

    overall_pred_test = []
    overall_pred_test_class = []
    overall_y_test = []
    ovarall_confmatrix = np.zeros((2, 2))
    current_params = {}

    with torch.no_grad():
        for x_test, y_test in test_loader:
            x_test = x_test.to(device)
            y_test = y_test.cpu().numpy()
            pred_test = model(x_test).detach().cpu().numpy()

            pred_test_class = np.copy(pred_test)
           # print(pred_test_class>cutoff)

            pred_test_class[pred_test_class > cutoff] = 1
            pred_test_class[pred_test_class <= cutoff] = 0

            # append predicts
            overall_y_test += list(y_test.flatten())
            overall_pred_test += list(pred_test.flatten())
            overall_pred_test_class += list(pred_test_class.flatten())

    # collect metrics
    overall_y_test = np.array(overall_y_test).reshape(-1, 1)
    overall_pred_test = np.array(overall_pred_test).reshape(-1, 1)
    overall_pred_test_class = (
        np.array(overall_pred_test_class).reshape(-1, 1))

    ovarall_confmatrix = confusion_matrix(
        overall_y_test, overall_pred_test_class)

    current_params['cutoff'] = cutoff
    current_params['test_ROC_AUC'] = roc_auc_score(
        overall_y_test, overall_pred_test)
    current_params['test_recall'] = recall_score(
        overall_y_test, overall_pred_test_class)
    current_params['test_precision'] = precision_score(
        overall_y_test, overall_pred_test_class, zero_division=0)
    conf_matrix = {'TP': int(ovarall_confmatrix[0][0]),
                   'TN': int(ovarall_confmatrix[1][1]),
                   'FP': int(ovarall_confmatrix[0][1]),
                   'FN': int(ovarall_confmatrix[1][0])}
    current_params['confusion_matrix'] = conf_matrix

    current_params['test_accuracy'] = accuracy_score(
        overall_y_test, overall_pred_test_class)

    precision, recall, thresholds = precision_recall_curve(
        overall_y_test, overall_pred_test)
    pr_auc = auc(recall, precision)
    current_params['test_PR_AUC'] = pr_auc

    current_params['F1-score'] = f1_score(overall_y_test,
                                          overall_pred_test_class)

    curves = {'ROC': {}, 'PR': {}}
    curves['ROC']['false_positive_rate'], curves['ROC']['true_positive_rate'], _ = roc_curve(
        overall_y_test, overall_pred_test)
    curves['PR']['precision'], curves['PR']['recall'], _ = precision_recall_curve(
        overall_y_test, overall_pred_test_class)

    # plot curves
    plot_curves(cutoff, curves, current_params)

    pr_search.append(current_params)

In [None]:
results = sorted(pr_search, key=lambda d: d['test_ROC_AUC'], reverse=True)

In [None]:
results[0]

In [None]:
import json

pheno = 'test'
run_id = 'test'

os.makedirs(f"results/{pheno}/{run_id}", exist_ok=True)
with open(f"results/{pheno}/{run_id}/results.json", "a") as f:
    for i in results:
        f.write('\n')
        f.write(json.dumps(i))