In [None]:
import sys
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from sklearn.metrics import roc_auc_score, recall_score, precision_score, accuracy_score, f1_score
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve
from sklearn.metrics import auc

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV

import numpy as np
import pandas as pd

from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import neptune
from seaborn import heatmap

# Settings

In [None]:
main_path = os.path.abspath('')

pheno = "diab"

# paths
## input paths

vcf_path = f"./data/ext_prs.90k.{pheno}.vcf"
ordered_target_path = f"./data/phenotype.{pheno}.ordered"
ordered_covariates_path = f"./data/cov.{pheno}.ordered"

## output paths
target_output_path = os.path.join(main_path, "data", f"target_{pheno}.csv")
transposed_feature_matrix_path = os.path.join(
    main_path, "data", f"feature_matrix_{pheno}.csv")
feature_cov_path = os.path.join(
    main_path, "data", f"feature_cov_matrix_{pheno}.csv")
feature_cov_hla_path = os.path.join(
    main_path, "data", f"feature_cov_hla_matrix_{pheno}.csv")

imbalance = "SMOTE"

# Model (MLP)

In [None]:
min_delta = 0.08
patience = 10

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

## Dataset and DataLoader

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'

In [None]:
from sklearn.model_selection import train_test_split

DO_UNDERSAMPLE = True
UNDERSAMPLE_N = 10000

# training (60%), validation (20%) and testing (20%)
df = pd.read_csv(target_output_path, header=None)
print("All dfs shape", df.shape)

train, test = train_test_split(df, test_size=0.2, stratify=df, random_state=5)
train, val = train_test_split(train, test_size=0.25, stratify=train, random_state=5)  # 0.25 x 0.8 = 0.2

if DO_UNDERSAMPLE:
    train_healthy = train[train[0] == 0]
    train_ill = train[train[0] == 1]
    train_healthy = train_healthy.sample(UNDERSAMPLE_N)

    train = pd.concat([train_ill, train_healthy])
    print("New train shape", train.shape)

train_index = train.index
val_index = val.index
test_index = test.index

In [None]:
from prs_dataset_standard import PRS_Dataset

In [None]:
batch_size = 4096
imbalance_type = "SMOTE"
mode = "gen_cov_hla"

if mode == "gen":
    feature_path = transposed_feature_matrix_path
elif mode == "cov":
    feature_path = ordered_covariates_path
elif mode == "gen_cov":
    feature_path = feature_cov_path
elif mode == "gen_cov_hla":
    feature_path = feature_cov_hla_path


if imbalance_type == "ROS":
    train_dataset = PRS_Dataset(feature_path, target_output_path,
                            'train', train_index, test_index, val_index, imbalance='ROS')
elif imbalance_type == "SMOTE":
    train_dataset = PRS_Dataset(feature_path, target_output_path,
                            'train', train_index, test_index, val_index, imbalance='SMOTE')
else:
    train_dataset = PRS_Dataset(feature_path, target_output_path,
                            'train', train_index, test_index, val_index)
    
val_dataset = PRS_Dataset(feature_path, target_output_path,
                          'val', train_index, test_index, val_index)
test_dataset = PRS_Dataset(feature_path, target_output_path,
                           'test', train_index, test_index, val_index)


train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=2)

val_loader = DataLoader(dataset=val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=2)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=2)

In [None]:
DEBUG_CASE_CONTROL_AMOUNT = True
if DEBUG_CASE_CONTROL_AMOUNT:
    print(imbalance_type)
    print("Train")
    print(train_dataset.y_data.shape)
    c = 0
    for y in train_dataset.y_data:
        if y[0] == 1:
            c += 1
    print("Ill", c)
    print("Val")
    print(val_dataset.y_data.shape)
    c = 0
    for y in val_dataset.y_data:
        if y[0] == 1:
            c += 1
    print("Ill", c)
    print("Test")
    print(test_dataset.y_data.shape)
    c = 0
    for y in test_dataset.y_data:
        if y[0] == 1:
            c += 1
    print("Ill", c)

## Model

In [None]:
example = iter(test_loader)
x_batch, y_batch = next(example)
input_size = x_batch.shape[1]
print(x_batch.shape)

In [None]:
# Dense
from models.model_dense import Model

In [None]:
# CNN
from models.model_cnn import Model

In [None]:
# RNN
from models.model_rnn import Model

In [None]:
# RNN CNN
from models.model_rnn_cnn import Model

## Training loop

In [None]:
def plot_curves(run, curves, current_params):
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 10))

    # PR curve
    ax1.plot(curves['PR']['recall'], curves['PR']['precision'])
    ax1.title.set_text('Precision-Recall Curve')
    ax1.set_ylabel('Precision')
    ax1.set_xlabel('Recall')

    # ROC curve
    ax2.plot(curves['ROC']['false_positive_rate'], curves['ROC']
             ['true_positive_rate'], label='AUC = %0.2f' % current_params['test_ROC_AUC'])
    ax2.title.set_text('ROC Curve')
    ax2.set_ylabel('True Positive Rate')
    ax2.set_xlabel('False Positive Rate')
    ax2.legend(loc = 'lower right')

    # confusion matrix
    conf_matrix = np.array(([current_params['confusion_matrix']['TP'],
                                     current_params['confusion_matrix']['FP']],
                                   [current_params['confusion_matrix']['FN'],
                                   current_params['confusion_matrix']['TN']]))
    
    

    ax4 = heatmap(conf_matrix, annot=True, fmt=".1f")
    ax4.set(xlabel="Predicted Label", ylabel="True Label")
    ax4.title.set_text('Confusion matrix')

    
    # model info
    text = f"Run number: {current_params['run']}\nTraining with the following parameters:\n"
    for k, v in current_params.items():
        text += f"{k}: {v}\n"
        
    ax3.text(0, 0.5, text, ha='left')
    ax3.axis('off')
    
    fig.tight_layout()
    
    dir_path = f"./figures/{current_params['run_id']}"
    if not os.path.exists(dir_path):
        os.mkdir(dir_path)
    fp = f"{dir_path}/run{current_params['run']}_training.jpg"
    fig.savefig(fp, dpi=300)
        
    run["curves"].upload(fp)
    
    plt.close()

In [None]:
def plot_stats(run, loss_history, auc_history, current_params):
    """Plot loss and ROC AUC in jupyter notebook"""
    for training_loss_item in loss_history['train']:
        run["train/loss"].append(training_loss_item)
    for val_loss_item in loss_history['val']:
        run["val/loss"].append(val_loss_item)
    
    for training_auc_item in auc_history['train']:
        run["train/auc"].append(training_auc_item)
    for val_auc_item in auc_history['val']:
        run["val/auc"].append(val_auc_item)

In [None]:
def training_loop(n, epochs, run_id, learning_rate, **kwargs):
    """
    Trains a single net on the supplied params.
    Returns average ROC AUC on the whole test dataset after learning is complete.    
    """
    early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)
    
    model = Model(input_size, **kwargs).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()
    scheduler = CosineAnnealingWarmRestarts(optimizer,
                                            T_0=10,  # Number of iterations for the first restart
                                            T_mult=1,  # A factor increases TiTi after a restart
                                            eta_min=1e-4)  # Minimum learning rate

    # summary of the current model
    current_params = {
        'run': n,
        'overall_epochs': epochs,
        'lr': learning_rate,
        "run_id": run_id,
        **kwargs
    }

    auc_history = {'train': [], 'val': []}
    aucs = {'train': [], 'val': []}
    loss_history = {'train': [], 'val': []}
    losses = {'train': [], 'val': []}
    best_val_auc = 0  # ищем лучший auc по эпохам
    best_epoch = None  # ищем лучшую эпоху
    best_model = None

    for epoch in tqdm(range(epochs)):
        for x_batch, y_batch in train_loader:
            model.train()
            # forward pass
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            pred = model(x_batch)
            loss = criterion(pred, y_batch)
            losses['train'].append(np.mean(loss.detach().cpu().numpy()))
            aucs['train'].append(np.mean(roc_auc_score(
                y_batch.detach().cpu().numpy(), pred.detach().cpu().numpy())))
            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # check current performance on val
        model.eval()
        with torch.no_grad():
            for x_val, y_val in val_loader:
                x_val = x_val.to(device)
                y_val = y_val.to(device)
                pred_val = model(x_val)
                loss = criterion(pred_val, y_val)

                losses['val'].append(np.mean(loss.detach().cpu().numpy()))
                aucs['val'].append(np.mean(roc_auc_score(
                    y_val.detach().cpu().numpy(), pred_val.detach().cpu().numpy())))

        # scheduler step
        scheduler.step()

        # plot statistics
        loss_history['train'].append(sum(losses['train'])/len(losses['train']))
        loss_history['val'].append(sum(losses['val'])/len(losses['val']))
        validation_loss = sum(losses['val'])/len(losses['val'])
        losses = {'train': [], 'val': []}

        curr_val_auc = sum(aucs['val'])/len(aucs['val'])  # current val auc
        auc_history['train'].append(sum(aucs['train'])/len(aucs['train']))
        auc_history['val'].append(curr_val_auc)
        aucs = {'train': [], 'val': []}

        if curr_val_auc > best_val_auc:  # current best model
            best_val_auc = curr_val_auc
            best_epoch = epoch
            best_model = model.state_dict()
        
        # early stopper
        if early_stopper.early_stop(validation_loss):
            break
            
    # load best model params
    model.load_state_dict(best_model)
    model.eval()
    
    overall_pred_test = []
    overall_pred_test_class = []
    overall_y_test = []
    ovarall_confmatrix = np.zeros((2, 2))
    
    with torch.no_grad():
        for x_test, y_test in test_loader:
            x_test = x_test.to(device)
            y_test = y_test.cpu().numpy()
            pred_test = model(x_test).detach().cpu().numpy()
            pred_test_class = np.rint(pred_test)
            # append predicts
            overall_y_test += list(y_test.flatten())
            overall_pred_test += list(pred_test.flatten())
            overall_pred_test_class += list(pred_test_class.flatten())
            ovarall_confmatrix += confusion_matrix(y_test, pred_test_class)
    
    # collect metrics
    overall_y_test = np.array(overall_y_test).reshape(-1, 1)
    overall_pred_test = np.array(overall_pred_test).reshape(-1, 1)
    overall_pred_test_class = (
        np.array(overall_pred_test_class).reshape(-1, 1))

    current_params['test_ROC_AUC'] = roc_auc_score(
        overall_y_test, overall_pred_test)
    current_params['test_recall'] = recall_score(
        overall_y_test, overall_pred_test_class)
    current_params['test_precision'] = precision_score(
        overall_y_test, overall_pred_test_class)
    conf_matrix = {'TP': ovarall_confmatrix[0][0],
                   'TN': ovarall_confmatrix[1][1],
                   'FP': ovarall_confmatrix[0][1],
                   'FN': ovarall_confmatrix[1][0]}
    current_params['confusion_matrix'] = conf_matrix
    
    current_params['test_accuracy'] = accuracy_score(overall_y_test, overall_pred_test_class)
    
    precision, recall, thresholds = precision_recall_curve(overall_y_test, overall_pred_test)
    pr_auc = auc(recall, precision)
    current_params['test_PR_AUC'] = pr_auc
    
    current_params['F1-score'] = f1_score(overall_y_test, overall_pred_test_class)
    
    current_params['best_epochs'] = best_epoch
    
        
    run = neptune.init_run(
            project="NA",       # your neptune credentials
            api_token="NA",     # your neptune credentials
    ) 
        
    run["parameters"] = {**current_params, "pheno": pheno}
    
    # plot stats
    plot_stats(run, loss_history, auc_history, current_params)
    
    # calculate curves
    curves = {'ROC':{}, 'PR':{}}
    curves['ROC']['false_positive_rate'], curves['ROC']['true_positive_rate'], _ = roc_curve(
        overall_y_test, overall_pred_test)
    curves['PR']['precision'], curves['PR']['recall'], _ = precision_recall_curve(
        overall_y_test, overall_pred_test_class)
    
    # plot curves
    plot_curves(run, curves, current_params)
    
    run.stop()
    
    return current_params

In [None]:
# Dense model info
params = {
    'epochs': [100],
    'lr': [0.001],
    'bn_momentum': [0.7, 0.9],  # batch norm momentum
    'first_dropout': [0.7, 0.9],
    'other_dropouts': [0.7, 0.9],
    'lin1_output': [100, 250, 500],  # edit to change shapes of the linear layers
    'lin2_output': [50],
    'lin3_output': [10, 20, 40]
}

In [None]:
# Conv model info
params = {
    'epochs': [200],
    'input_dim' : [input_size],
    'kernel_size':[2,3],
    'stride': [2,3,4],
    'kernel_size2':[2,3],
    'stride2':[2,3,4],
    'dropout' : [0.7, 0.8, 0.9],
    'out_channels_first':[1000, 500, 250],
    'out_channels_second':[1000, 500, 250],
    'lr' : [0.0001],
    'linear_first':[250, 100]
}


In [None]:
# RNN model info
params = {
    'epochs': [200],
    'input_dim' : [input_size],
    'hidden_dim' : hidden_dim,
    'dropout' : [0.7, 0.8, 0.9],
    'bi_value' : [False, True],
    'lr' : [0.0001]
}

In [None]:
# RNN CNN model info
params = {
    'epochs': [200],
    'input_dim' : [input_dim],
    'hidden_dim': hidden_dim,
    'bi_value': [True, False],
    'kernel_size':[2,3],
    'stride': [2,3,4],
    'kernel_size2':[2,3],
    'stride2':[2,3,4],
    'dropout' : [0.7, 0.8, 0.9],
    'out_channels_first':[1000, 500, 250],
    'out_channels_second':[1000, 500, 250],
    'lr' : [0.0001],
    'linear_first':[250, 100]
}


In [None]:
import time
run_id = int(time.time())
print("Run id is", run_id)

In [None]:
import itertools

# hyperparam grid search
grid_search = []
total_searches = np.prod(np.array([len(v) for v in params.values()]))
i = 1

for params_combination in list(itertools.product(*params.values())):
    params_dict = dict(zip(params.keys(), list(params_combination)))
    
    epochs_number = params_dict["epochs"]
    params_dict.pop("epochs")
    
    learning_rate = params_dict["lr"]
    params_dict.pop("lr")
    
    print(f"Grid Serach step {i} of total {total_searches}")
    try:
        grid_search.append(training_loop(i, epochs_number, run_id, learning_rate, **params_dict))
    except Exception as e:
        print(f"!!! Error: {e}")
    
    i += 1

In [None]:
results = sorted(grid_search, key=lambda d: d['test_ROC_AUC'], reverse=True)

import json

os.makedirs(f"results/{pheno}/{run_id}", exist_ok=True)
with open(f"results/{pheno}/{run_id}/results.json", "w") as f:
    json.dump(results, f)

In [None]:
results[0] # best parameters

In [None]:
for i, g in enumerate(grid_search):
    if g == results[0]:
        print(i)

# Linear model 
- Run a single run or grid search

## Lasso

In [None]:
logreg = LogisticRegression(max_iter = 1000, penalty = 'l2', solver='saga')
logreg.fit(train_dataset.x_data.numpy(), train_dataset.y_data.numpy().ravel())

In [None]:
overall_y_test = test_dataset.y_data.numpy().ravel()
overall_pred_test = logreg.predict_proba(test_dataset.x_data.numpy())[:, 1]
overall_pred_test_class = np.rint(y_pred)

In [None]:
current_params = {}
ovarall_confmatrix = np.zeros((2, 2))

current_params['test_ROC_AUC'] = roc_auc_score(
        overall_y_test, overall_pred_test)
current_params['test_recall'] = recall_score(
    overall_y_test, overall_pred_test_class)
current_params['test_precision'] = precision_score(
    overall_y_test, overall_pred_test_class)
ovarall_confmatrix = confusion_matrix(overall_y_test, overall_pred_test_class)
conf_matrix = {'TP': ovarall_confmatrix[0][0],
               'TN': ovarall_confmatrix[1][1],
               'FP': ovarall_confmatrix[0][1],
               'FN': ovarall_confmatrix[1][0]}
current_params['confusion_matrix'] = conf_matrix

current_params['test_accuracy'] = accuracy_score(overall_y_test, overall_pred_test_class)

precision, recall, thresholds = precision_recall_curve(overall_y_test, overall_pred_test)
pr_auc = auc(recall, precision)
current_params['test_PR_AUC'] = pr_auc
current_params['F1-score'] = f1_score(overall_y_test, overall_pred_test_class)

current_params["run"] = -1
current_params["run_id"] = run_id

run = neptune.init_run(
        project="NA",       # your neptune credentials
        api_token="NA",     # your neptune credentials
) 
        
run["parameters"] = {**current_params, "pheno": pheno}

curves = {'ROC':{}, 'PR':{}}
curves['ROC']['false_positive_rate'], curves['ROC']['true_positive_rate'], _ = roc_curve(
    overall_y_test, overall_pred_test)
curves['PR']['precision'], curves['PR']['recall'], _ = precision_recall_curve(
    overall_y_test, overall_pred_test_class)

# plot curves
plot_curves(run, curves, current_params)

run.stop()

In [None]:
current_params