In [None]:
import itertools
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pylab as pl
from IPython import display
from seaborn import heatmap
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    accuracy_score,
    auc,
    confusion_matrix,
    f1_score,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader
from tqdm import tqdm

%matplotlib inline

In [None]:
from prs_dataset_standard import PRS_Dataset
import compare_auc_delong_xu

# Settings

In [None]:
main_path = os.path.abspath("")

pheno = "bmi"

# paths
## input paths
vcf_path = f"./data/ext_prs.90k.{pheno}.vcf"
ordered_target_path = f"./data/phenotype.{pheno}.ordered"
ordered_covariates_path = f"./data/cov.{pheno}.ordered"

## output paths
target_output_path = os.path.join(main_path, "data", f"target_{pheno}.csv")
transposed_feature_matrix_path = os.path.join(
    main_path, "data", f"feature_matrix_{pheno}.csv"
)
feature_cov_path = os.path.join(main_path, "data", f"feature_cov_matrix_{pheno}.csv")
feature_cov_hla_path = os.path.join(
    main_path, "data", f"feature_cov_hla_matrix_{pheno}.csv"
)

# json with results
json_output_raw = os.path.join(main_path, f"delong_cycle_cv_raw_{pheno}.json")
json_output_summary = os.path.join(main_path, f"delong_cycle_cv_summary_{pheno}.json")

In [None]:
# Early stopping settings
min_delta = 0.08
patience = 10

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Data and dataset settings

In [None]:
# How many CV cycles
num_of_cycles = 5

In [None]:
# Data settings
DO_UNDERSAMPLE = True
UNDERSAMPLE_N = 10000

batch_size = 4096
imbalance_type = "SMOTE"
mode = "gen_cov"

In [None]:
def write_json(current_results, file_name=f"./delong_cycle_cv_raw.json"):
    with open(file_name, "a") as file:
        file.write("\n")
        file.write(json.dumps(current_results, indent=1, cls=NpEncoder))


def write_json_2dicts(dict1, dict2, file_name=f"./delong_cycle_cv_summary.json"):
    output = []
    output.append(dict1)
    output.append(dict2)

    with open(file_name, "a") as file:
        file.write("\n")
        file.write(json.dumps(output, indent=1, cls=NpEncoder))


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

## Model

In [None]:
# Dense
from model_dense import Model as dense_model

In [None]:
# CNN
from model_cnn import Model as cnn_model

In [None]:
# RNN
from model_rnn import Model as rnn_model

In [None]:
# RNN CNN
from model_rnn_cnn import Model as rnn_cnn_model

## Training loop

In [None]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [None]:
def plot_curves(curves, current_params):
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 10))

    # PR curve
    ax1.plot(curves["PR"]["recall"], curves["PR"]["precision"])
    ax1.title.set_text("Precision-Recall Curve")
    ax1.set_ylabel("Precision")
    ax1.set_xlabel("Recall")

    # ROC curve
    ax2.plot(
        curves["ROC"]["false_positive_rate"],
        curves["ROC"]["true_positive_rate"],
        label="AUC = %0.2f" % current_params["test_ROC_AUC"],
    )
    ax2.title.set_text("ROC Curve")
    ax2.set_ylabel("True Positive Rate")
    ax2.set_xlabel("False Positive Rate")
    ax2.legend(loc="lower right")

    # confusion matrix
    conf_matrix = np.array(
        (
            [
                current_params["confusion_matrix"]["TP"],
                current_params["confusion_matrix"]["FP"],
            ],
            [
                current_params["confusion_matrix"]["FN"],
                current_params["confusion_matrix"]["TN"],
            ],
        )
    )

    ax4 = heatmap(conf_matrix, annot=True, fmt=".1f")
    ax4.set(xlabel="Predicted Label", ylabel="True Label")
    ax4.title.set_text("Confusion matrix")

    # model info
    text = "Training with the following parameters:\n"
    for k, v in current_params.items():
        text += f"{k}: {v}\n"

    ax3.text(0, 0.5, text, ha="left")
    ax3.axis("off")

    fig.tight_layout()

    plt.close()

In [None]:
def plot_stats(loss_history, auc_history):
    """Plot loss and ROC AUC in jupyter notebook"""

    fig, (ax1, ax2) = pl.subplots(1, 2, figsize=(10, 5))

    # loss
    training_loss = loss_history["train"]
    val_loss = loss_history["val"]

    epoch_count = range(1, len(training_loss) + 1)

    ax1.plot(epoch_count, training_loss, "-r")
    ax1.plot(epoch_count, val_loss, "-b")
    ax1.legend(["Training loss", "Val loss"])
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")

    # auc
    training_auc = auc_history["train"]
    val_auc = auc_history["val"]

    ax2.plot(epoch_count, training_auc, "-r")
    ax2.plot(epoch_count, val_auc, "-b")
    ax2.legend(["Training ROC AUC", "Val ROC AUC"])
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("ROC AUC")

    fig.tight_layout()

    display.clear_output(wait=True)
    display.display(pl.gcf())
    # time.sleep(1.0)

In [None]:
def training_loop(n, model, epochs, run_id, learning_rate, **kwargs):
    """
    Trains a single net on the supplied params.
    Returns average ROC AUC on the whole test dataset after learning is complete.
    """
    early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)

    model = model(input_size, **kwargs).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=10,  # Number of iterations for the first restart
        T_mult=1,  # A factor increases TiTi after a restart
        eta_min=1e-4,  # Minimum learning rate
    )
    # summary of the current model
    current_params_temp = {
        "run": n,
        "overall_epochs": epochs,
        "lr": learning_rate,
        "run_id": run_id,
        **kwargs,
    }

    auc_history = {"train": [], "val": []}
    aucs = {"train": [], "val": []}
    loss_history = {"train": [], "val": []}
    losses = {"train": [], "val": []}
    best_val_auc = 0   
    best_epoch = None   
    best_model = None

    print("Amount of epochs")
    for epoch in tqdm(range(epochs)):
        for x_batch, y_batch in train_loader:
            model.train()
            # forward pass
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            pred = model(x_batch)
            loss = criterion(pred, y_batch)
            losses["train"].append(np.mean(loss.detach().cpu().numpy()))
            aucs["train"].append(
                np.mean(
                    roc_auc_score(
                        y_batch.detach().cpu().numpy(), pred.detach().cpu().numpy()
                    )
                )
            )
            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # check current performance on val
        model.eval()
        with torch.no_grad():
            for x_val, y_val in val_loader:
                x_val = x_val.to(device)
                y_val = y_val.to(device)
                pred_val = model(x_val)
                loss = criterion(pred_val, y_val)

                losses["val"].append(np.mean(loss.detach().cpu().numpy()))
                aucs["val"].append(
                    np.mean(
                        roc_auc_score(
                            y_val.detach().cpu().numpy(),
                            pred_val.detach().cpu().numpy(),
                        )
                    )
                )

        # scheduler step
        scheduler.step()

        # plot statistics
        loss_history["train"].append(sum(losses["train"]) / len(losses["train"]))
        loss_history["val"].append(sum(losses["val"]) / len(losses["val"]))
        validation_loss = sum(losses["val"]) / len(losses["val"])
        losses = {"train": [], "val": []}

        curr_val_auc = sum(aucs["val"]) / len(aucs["val"])  # current val auc
        auc_history["train"].append(sum(aucs["train"]) / len(aucs["train"]))
        auc_history["val"].append(curr_val_auc)
        aucs = {"train": [], "val": []}

        if curr_val_auc > best_val_auc:  # current best model
            best_val_auc = curr_val_auc
            best_epoch = epoch
            best_model = model.state_dict()

        # early stopper
        if early_stopper.early_stop(validation_loss):
            break

        # plot stats
        if epoch % 10 == 0:
            plot_stats(loss_history, auc_history)

    ################################################
    # load best model params
    model.load_state_dict(best_model)
    model.eval()

    overall_pred_test = []
    overall_pred_test_class = []
    overall_y_test = []
    ovarall_confmatrix = np.zeros((2, 2))
    # current_params_temp = final_current_params.copy()

    with torch.no_grad():
        for x_test, y_test in test_loader:
            x_test = x_test.to(device)
            y_test = y_test.cpu().numpy()
            pred_test = model(x_test).detach().cpu().numpy()

            pred_test_class = np.rint(pred_test)

            # append predicts
            overall_y_test += list(y_test.flatten())
            overall_pred_test += list(pred_test.flatten())
            overall_pred_test_class += list(pred_test_class.flatten())
            ovarall_confmatrix += confusion_matrix(y_test, pred_test_class)

    # collect metrics
    overall_y_test = np.array(overall_y_test).reshape(-1, 1)
    overall_pred_test = np.array(overall_pred_test).reshape(-1, 1)
    overall_pred_test_class = np.array(overall_pred_test_class).reshape(-1, 1)

    current_params_temp["test_ROC_AUC"] = roc_auc_score(
        overall_y_test, overall_pred_test
    )

    current_params_temp["auc_delong"], current_params_temp["variances_delong"] = (
        compare_auc_delong_xu.delong_roc_variance(overall_y_test, overall_pred_test)
    )

    current_params_temp["test_recall"] = recall_score(
        overall_y_test, overall_pred_test_class
    )
    current_params_temp["test_precision"] = precision_score(
        overall_y_test, overall_pred_test_class
    )
    conf_matrix = {
        "TP": ovarall_confmatrix[0][0],
        "TN": ovarall_confmatrix[1][1],
        "FP": ovarall_confmatrix[0][1],
        "FN": ovarall_confmatrix[1][0],
    }
    current_params_temp["confusion_matrix"] = conf_matrix

    current_params_temp["test_accuracy"] = accuracy_score(
        overall_y_test, overall_pred_test_class
    )

    precision, recall, thresholds = precision_recall_curve(
        overall_y_test, overall_pred_test
    )
    pr_auc = auc(recall, precision)
    current_params_temp["test_PR_AUC"] = pr_auc

    current_params_temp["F1-score"] = f1_score(overall_y_test, overall_pred_test_class)

    # calculate curves
    curves = {"ROC": {}, "PR": {}}
    curves["ROC"]["false_positive_rate"], curves["ROC"]["true_positive_rate"], _ = (
        roc_curve(overall_y_test, overall_pred_test)
    )
    curves["PR"]["precision"], curves["PR"]["recall"], _ = precision_recall_curve(
        overall_y_test, overall_pred_test_class
    )

    # plot curves
    plot_curves(curves, current_params_temp)

    return current_params_temp, overall_pred_test

In [None]:
def set_nn_params(pheno):
    if pheno == "bmi":
        dense_params = {
            "model": dense_model,
            "model_name": "mlp",
            "epochs": 200,
            "lr": 0.001,
            "bn_momentum": [0.9],
            "first_dropout": [0.9],
            "other_dropouts": [0.9],
            "lin1_output": [300],
            "lin2_output": [100, 150],
            "lin3_output": [10, 25, 35],
        }
        ## cnn
        cnn_params = {
            "model": cnn_model,
            "model_name": "cnn",
            "epochs": 150,
            "lr": 0.001,
            "out_channels_first": [500, 250],
            "out_channels_second": [150],
            "linear_first": [100, 50],
            "kernel_size": [1],
            "stride": [1, 2],
            "drop1": [0.9],
        }

        rnn_params = {
            "model": rnn_model,
            "model_name": "rnn",
            "epochs": 150,
            "lr": 0.001,
            "hidden_dim": [300, 200, 100, 50],
            "target_size": [1],
        }

        rnn_cnn_params = {
            "model": rnn_cnn_model,
            "model_name": "rnn_cnn",
            "epochs": 150,
            "lr": 0.001,
            "bn_momentum": [0.8],
            "drop": [0.9],
            "hidden1": [500],
            "conv1": [1000],
            "conv2": [2000, 1000],
            "lin1": [500, 250],
        }

    elif pheno == "diab":
        dense_params = {
            "model": dense_model,
            "model_name": "mlp",
            "epochs": 200,
            "lr": 0.001,
            "bn_momentum": [0.9],
            "first_dropout": [0.9],
            "other_dropouts": [0.9],
            "lin1_output": [100],
            "lin2_output": [50, 75],
            "lin3_output": [
                10,
                20,
                40,
            ],
        }

        cnn_params = {
            "model": cnn_model,
            "model_name": "cnn",
            "epochs": 150,
            "lr": 0.01,
            "out_channels_first": [500],
            "out_channels_second": [150],
            "linear_first": [100],
            "kernel_size": [1],
            "stride": [1, 2],
            "drop1": [0.9, 0.8],
        }

        rnn_params = {
            "model": rnn_model,
            "model_name": "rnn",
            "epochs": 150,
            "lr": 0.001,
            "hidden_dim": [300, 200, 100, 50],
            "target_size": [1],
        }

        rnn_cnn_params = {
            "model": rnn_cnn_model,
            "model_name": "rnn_cnn",
            "epochs": 150,
            "lr": 0.001,
            "bn_momentum": [0.8],
            "drop": [0.9],
            "hidden1": [500],
            "conv1": [1000],
            "conv2": [2000, 1000],
            "lin1": [500, 250],
        }

    elif pheno == "psor":
        dense_params = {
            "model": dense_model,
            "model_name": "mlp",
            "epochs": 200,
            "lr": 0.001,
            "bn_momentum": [0.9],
            "first_dropout": [0.9],
            "other_dropouts": [0.9],
            "lin1_output": [100],
            "lin2_output": [50, 75],
            "lin3_output": [
                10,
                20,
                40,
            ],
        }

        cnn_params = {
            "model": cnn_model,
            "model_name": "cnn",
            "epochs": 150,
            "lr": 0.001,
            "out_channels_first": [500, 250],
            "out_channels_second": [150],
            "linear_first": [100, 50],
            "kernel_size": [1],
            "stride": [1, 2],
            "drop1": [0.9],
        }

        rnn_params = {
            "model": rnn_model,
            "model_name": "rnn",
            "epochs": 150,
            "lr": 0.001,
            "hidden_dim": [300, 200, 100, 50],
            "target_size": [1],
        }

        rnn_cnn_params = {
            "model": rnn_cnn_model,
            "model_name": "rnn_cnn",
            "epochs": 150,
            "lr": 0.001,
            "bn_momentum": [0.8],
            "drop": [0.9, 0.8],
            "hidden1": [500],
            "conv1": [1000],
            "conv2": [2000, 1000],
            "lin1": [500, 250],
        }

    else:
        raise ValueError("Incorrect phenotype")

    return [dense_params, cnn_params, rnn_params, rnn_cnn_params]

In [None]:
import time

run_id = int(time.time())
print("Run id is", run_id)

## Lasso

In [None]:
def lasso_train(train_dataset):
    logreg = LogisticRegression(max_iter=1000, penalty="l2", solver="saga")
    logreg.fit(train_dataset.x_data.numpy(), train_dataset.y_data.numpy().ravel())

    overall_y_test = test_dataset.y_data.numpy().ravel()
    overall_pred_test = logreg.predict_proba(test_dataset.x_data.numpy())[:, 1]

    overall_pred_test_class = np.rint(overall_pred_test)

    current_params = {}
    ovarall_confmatrix = np.zeros((2, 2))

    current_params["test_ROC_AUC"] = roc_auc_score(overall_y_test, overall_pred_test)

    current_params["auc_delong"], current_params["variances_delong"] = (
        compare_auc_delong_xu.delong_roc_variance(overall_y_test, overall_pred_test)
    )

    current_params["test_recall"] = recall_score(
        overall_y_test, overall_pred_test_class
    )
    current_params["test_precision"] = precision_score(
        overall_y_test, overall_pred_test_class
    )
    ovarall_confmatrix = confusion_matrix(overall_y_test, overall_pred_test_class)
    conf_matrix = {
        "TP": ovarall_confmatrix[0][0],
        "TN": ovarall_confmatrix[1][1],
        "FP": ovarall_confmatrix[0][1],
        "FN": ovarall_confmatrix[1][0],
    }
    current_params["confusion_matrix"] = conf_matrix

    current_params["test_accuracy"] = accuracy_score(
        overall_y_test, overall_pred_test_class
    )

    precision, recall, thresholds = precision_recall_curve(
        overall_y_test, overall_pred_test
    )
    pr_auc = auc(recall, precision)
    current_params["test_PR_AUC"] = pr_auc
    current_params["F1-score"] = f1_score(overall_y_test, overall_pred_test_class)

    curves = {"ROC": {}, "PR": {}}
    curves["ROC"]["false_positive_rate"], curves["ROC"]["true_positive_rate"], _ = (
        roc_curve(overall_y_test, overall_pred_test)
    )
    curves["PR"]["precision"], curves["PR"]["recall"], _ = precision_recall_curve(
        overall_y_test, overall_pred_test_class
    )

    # plot curves
    plot_curves(curves, current_params)

    return current_params, overall_pred_test

In [None]:
def gridsearch(i, model, run_id, learning_rate, params_dict):

    gs_results = []
    epochs = params_dict.pop("epochs")

    print(
        f"GS will run {np.prod(np.array([len(v) for v in params_dict.values()]))} cycles"
    )

    for params_combination in list(itertools.product(*params.values())):
        params_dict = dict(zip(params.keys(), list(params_combination)))

        print(f"params_dict: {params_dict}")

        perfomances, preds = training_loop(
            i, model, epochs, run_id, learning_rate, **params_dict
        )

        perfomances["predictions"] = preds

        gs_results.append(perfomances)

    results = sorted(gs_results, key=lambda d: d["test_ROC_AUC"], reverse=True)

    best_pred = results[0].pop("predictions")

    return results[0], best_pred

In [None]:
def metircs_average_prep(cv_average_metics, current_metrics):

    model_list = [
        "lasso",
        "mlp",
        "cnn",
        "rnn",
        "rnn_cnn",
    ]
    metric_list = [
        "test_ROC_AUC",
        "test_recall",
        "test_precision",
        "test_accuracy",
        "test_PR_AUC",
        "F1-score",
    ]

    for model in model_list:
        for metric in metric_list:
            if metric in cv_average_metics[model]:
                cv_average_metics[model][metric].append(current_metrics[model][metric])
            else:
                cv_average_metics[model][metric] = []


def metircs_average(cv_average_metics):

    average_metrics = {}

    model_list = [
        "lasso",
        "mlp",
        "cnn",
        "rnn",
        "rnn_cnn",
    ]
    metric_list = [
        "test_ROC_AUC",
        "test_recall",
        "test_precision",
        "test_accuracy",
        "test_PR_AUC",
        "F1-score",
    ]

    for model in model_list:
        average_metrics[model] = []

        for metric in metric_list:
            data = np.array(cv_average_metics[model][metric])
            m = f"{metric} = {round(np.mean(data), 3)} +- {round(np.std(data, ddof=1) / np.sqrt(np.size(data)), 3)}"
            average_metrics[model].append(m)

    return average_metrics

# 5-fold CV

In [None]:
cv_p_value_results = {}
cv_average_metics = {}

cv_current_metics = {
    "lasso": {},
    "mlp": {},
    "cnn": {},
    "rnn": {},
    "rnn_cnn": {},
}


for i in range(num_of_cycles):
    print(f"Run {i}/{num_of_cycles}")

    df = pd.read_csv(target_output_path, header=None)
    print("All dfs shape", df.shape)

    train_classic, test = train_test_split(
        df, test_size=0.2, stratify=df, random_state=i
    )
    train, val = train_test_split(
        train_classic, test_size=0.25, stratify=train_classic, random_state=i
    )  # 0.25 x 0.8 = 0.2

    ######################## dataset prep ########################
    ### for neural nets ###
    if DO_UNDERSAMPLE:
        train_healthy = train[train[0] == 0]
        train_ill = train[train[0] == 1]
        train_healthy = train_healthy.sample(UNDERSAMPLE_N)

        train = pd.concat([train_ill, train_healthy])
        print("New train shape", train.shape)

    ### for classic models ###
    if DO_UNDERSAMPLE:
        train_healthy_classic = train_classic[train_classic[0] == 0]
        train_ill_classic = train_classic[train_classic[0] == 1]
        train_healthy_classic = train_healthy_classic.sample(UNDERSAMPLE_N)

        train_classic = pd.concat([train_ill_classic, train_healthy_classic])
        print("New train shape", train.shape)

    train_index = train.index
    train_index_classic = train_classic.index
    val_index = val.index
    test_index = test.index

    # Dataset prep
    if mode == "gen":
        feature_path = transposed_feature_matrix_path
    elif mode == "cov":
        feature_path = ordered_covariates_path
    elif mode == "gen_cov":
        feature_path = feature_cov_path
    elif mode == "gen_cov_hla":
        feature_path = feature_cov_hla_path

    #### for neural nets ####
    if imbalance_type == "ROS":
        train_dataset = PRS_Dataset(
            feature_path,
            target_output_path,
            "train",
            train_index,
            test_index,
            val_index,
            imbalance="ROS",
        )
    elif imbalance_type == "SMOTE":
        train_dataset = PRS_Dataset(
            feature_path,
            target_output_path,
            "train",
            train_index,
            test_index,
            val_index,
            imbalance="SMOTE",
        )
    else:
        train_dataset = PRS_Dataset(
            feature_path,
            target_output_path,
            "train",
            train_index,
            test_index,
            val_index,
        )

    val_dataset = PRS_Dataset(
        feature_path, target_output_path, "val", train_index, test_index, val_index
    )

    train_loader = DataLoader(
        dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2
    )

    val_loader = DataLoader(
        dataset=val_dataset, batch_size=batch_size, shuffle=False, num_workers=2
    )

    #### for classic models ####
    if imbalance_type == "ROS":
        train_dataset_classic = PRS_Dataset(
            feature_path,
            target_output_path,
            "train",
            train_index_classic,
            test_index,
            val_index,
            imbalance="ROS",
        )
    elif imbalance_type == "SMOTE":
        train_dataset_classic = PRS_Dataset(
            feature_path,
            target_output_path,
            "train",
            train_index_classic,
            test_index,
            val_index,
            imbalance="SMOTE",
        )
    else:
        train_dataset_classic = PRS_Dataset(
            feature_path,
            target_output_path,
            "train",
            train_index_classic,
            test_index,
            val_index,
        )

    test_dataset = PRS_Dataset(
        feature_path, target_output_path, "test", train_index, test_index, val_index
    )

    test_loader = DataLoader(
        dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=2
    )

    # tests
    DEBUG_CASE_CONTROL_AMOUNT = True
    if DEBUG_CASE_CONTROL_AMOUNT:
        print(imbalance_type)
        print("Train")
        print(train_dataset.y_data.shape)
        c = 0
        for y in train_dataset.y_data:
            if y[0] == 1:
                c += 1
        print("Ill", c)
        print("Val")
        print(val_dataset.y_data.shape)
        c = 0
        for y in val_dataset.y_data:
            if y[0] == 1:
                c += 1
        print("Ill", c)
        print("Test")
        print(test_dataset.y_data.shape)
        c = 0
        for y in test_dataset.y_data:
            if y[0] == 1:
                c += 1
        print("Ill", c)

    example = iter(train_loader)
    x_batch, y_batch = next(example)
    input_size = x_batch.shape[1]
    print(f"input size is {x_batch.shape[1]}")

    ######################## train and test ########################
    # Neural nets
    models_performances = {}
    models_pred = {}

    model_list = set_nn_params(pheno)

    for params in model_list:
        params_dict = params

        learning_rate = params_dict["lr"]
        params_dict.pop("lr")

        model_name = params_dict["model_name"]
        params_dict.pop("model_name")

        model = params_dict["model"]
        params_dict.pop("model")

        # Grid search
        models_performances[model_name], models_pred[model_name] = gridsearch(
            i, model, run_id, learning_rate, params_dict
        )

    # Lasso
    models_performances["lasso"], models_pred["lasso"] = lasso_train(
        train_dataset_classic
    )

    # save everything to json
    write_json(models_performances, json_output_raw)

    ######################## ROC AUC compare ########################
    overall_y_test = test_dataset.y_data.numpy().ravel()

    print(models_pred)
    # function returns log10(p-val) thus we take 10**x
    p_values = {}
    p_values["mlp"] = 10 ** float(
        compare_auc_delong_xu.delong_roc_test(
            overall_y_test, models_pred["lasso"], models_pred["mlp"].reshape(-1)
        )
    )
    p_values["cnn"] = 10 ** float(
        compare_auc_delong_xu.delong_roc_test(
            overall_y_test, models_pred["lasso"], models_pred["cnn"].reshape(-1)
        )
    )
    p_values["rnn"] = 10 ** float(
        compare_auc_delong_xu.delong_roc_test(
            overall_y_test, models_pred["lasso"], models_pred["rnn"].reshape(-1)
        )
    )
    p_values["rnn_cnn"] = 10 ** float(
        compare_auc_delong_xu.delong_roc_test(
            overall_y_test, models_pred["lasso"], models_pred["rnn_cnn"].reshape(-1)
        )
    )

    cv_p_value_results[i] = p_values
    print(cv_current_metics)
    ######################## metrics average ########################
    metircs_average_prep(cv_current_metics, models_performances)

cv_average_metics = metircs_average(cv_current_metics)

write_json_2dicts(cv_p_value_results, cv_average_metics, json_output_summary)

In [None]:
cv_p_value_results

In [None]:
cv_average_metics