In [None]:
import random
from datetime import datetime

In [None]:
import torch
from torch import nn

In [None]:
from utils_random import set_random_seed
from utils_torch.data import stratified_random_split
from utils_attacker_lstm.data import DatasetAttackerLSTMPool, DataLoaderAttackerLSTM
from utils_attacker_lstm.models import ModelAttackerLSTMNew, TrainerAttackerLSTM, TesterAttackerLSTM, ManagerAttackerLSTM
from utils_plot import plot_train_eval_loss_accuracy, plot_receiver_operating_characteristics_curve, plot_confusion_matrix

In [None]:
model_id = datetime.now().strftime("%m%d%H%M")

In [None]:
random_seed = random.randint(0, 2 ** 32 - 1)

In [None]:
set_random_seed(random_seed)

In [None]:
num_snps = random.choice([10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000])
train_eval_test_split = [0.7, 0.15, 0.15]

In [None]:
genome_batch_size = random.randint(16, 64)
snp_batch_size = num_snps

In [None]:
conv_num_layers = random.randint(1, 8)
conv_channel_size = [random.randint(16, 64) for _ in range(conv_num_layers)]
conv_kernel_size = [random.randint(4, 32) for _ in range(conv_num_layers)]
conv_stride = [random.randint(1, conv_kernel_size) for _ in range(conv_num_layers)]
conv_dilation = [random.randint(1, 4) for _ in range(conv_num_layers)]
conv_groups = [1 for _ in range(conv_num_layers)]

conv_activation = random.choice(['relu', 'tanh', 'sigmoid', 'leaky_relu'])
conv_activation_kwargs = [{} for _ in range(conv_num_layers)]

conv_dropout_p = [random.uniform(0, 0.66) for _ in range(conv_num_layers - 1)]
conv_dropout_first = [random.choice([True, False]) for _ in range(conv_num_layers - 1)]

conv_batch_norm = [random.choice([True, False]) for _ in range(conv_num_layers)]
conv_batch_norm_momentum = [random.uniform(0, 1) for _ in range(conv_num_layers)]

In [None]:
conv_lstm_activation = random.choice(['relu', 'tanh', 'sigmoid', 'leaky_relu'])
conv_lstm_activation_kwargs = {}

conv_lstm_dropout_p = random.uniform(0, 0.66)
conv_lstm_dropout_first = random.choice([True, False])

conv_lstm_layer_norm = random.choice([True, False])

In [None]:
lstm_num_layers = random.randint(1, 4)
lstm_input_size = conv_channel_size[-1]
lstm_hidden_size = [random.randint(4, 64) for _ in range(lstm_num_layers)]
lstm_proj_size = [0 for _ in range(lstm_num_layers)]
lstm_bidirectional = [random.choice([True, False]) for _ in range(lstm_num_layers)]

lstm_dropout_p = [random.uniform(0, 0.66) for _ in range(lstm_num_layers - 1)]
lstm_dropout_first = [random.choice([True, False]) for _ in range(lstm_num_layers - 1)]

lstm_layer_norm = [random.choice([True, False]) for _ in range(lstm_num_layers)]

In [None]:
lstm_linear_dropout_p = random.uniform(0, 0.66)
lstm_linear_dropout_first = random.choice([True, False])

lstm_linear_batch_norm = random.choice([True, False])
lstm_linear_batch_norm_momentum = random.uniform(0, 1)

In [None]:
linear_num_layers = random.randint(1, 4)
linear_num_features = [lstm_hidden_size[-1] * (2 if any(lstm_bidirectional) else 1)] + [random.randint(4, 64) for _ in
                                                                                        range(
                                                                                            linear_num_layers - 2)] + [
                          1]

linear_activation = random.choice(['relu', 'tanh', 'sigmoid', 'leaky_relu'])
linear_activation_kwargs = [{} for _ in range(linear_num_layers)]

linear_dropout_p = [random.uniform(0, 0.66) for _ in range(linear_num_layers - 1)]
linear_dropout_first = [random.choice([True, False]) for _ in range(linear_num_layers - 1)]

linear_batch_norm = [random.choice([True, False]) for _ in range(linear_num_layers)]
linear_batch_norm_momentum = [random.uniform(0, 1) for _ in range(linear_num_layers)]

In [None]:
num_epochs = 256
learning_rate = 0.001

In [None]:
models_dir = "../models"
models_file = "models.csv"
plots_dir = "../plots"

In [None]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
dataset = DatasetAttackerLSTMPool(
    genomes_pool_path="../data/test/In_Pop.pkl",
    genomes_reference_path="../data/test/Not_In_Pop.pkl.pkl",
    num_snps=num_snps)
subset_train, subset_eval, subset_test = stratified_random_split(dataset=dataset,
                                                                 ratios=train_eval_test_split)

In [None]:
dataloader_train = DataLoaderAttackerLSTM(
    dataset=subset_train,
    genome_batch_size=genome_batch_size,
    snp_batch_size=snp_batch_size,
    shuffle=True)
dataloader_eval = DataLoaderAttackerLSTM(
    dataset=subset_eval,
    genome_batch_size=genome_batch_size,
    snp_batch_size=snp_batch_size,
    shuffle=False)
dataloader_test = DataLoaderAttackerLSTM(
    dataset=subset_test,
    genome_batch_size=genome_batch_size,
    snp_batch_size=snp_batch_size,
    shuffle=False)

In [None]:
model = ModelAttackerLSTMNew(
    conv_num_layers=conv_num_layers,
    conv_channel_size=conv_channel_size,
    conv_kernel_size=conv_kernel_size,
    conv_stride=conv_stride,
    conv_dilation=conv_dilation,
    conv_groups=conv_groups,
    conv_activation=conv_activation,
    conv_activation_kwargs=conv_activation_kwargs,
    conv_dropout_p=conv_dropout_p,
    conv_dropout_first=conv_dropout_first,
    conv_batch_norm=conv_batch_norm,
    conv_batch_norm_momentum=conv_batch_norm_momentum,
    conv_lstm_activation=conv_lstm_activation,
    conv_lstm_activation_kwargs=conv_lstm_activation_kwargs,
    conv_lstm_dropout_p=conv_lstm_dropout_p,
    conv_lstm_dropout_first=conv_lstm_dropout_first,
    conv_lstm_layer_norm=conv_lstm_layer_norm,
    lstm_num_layers=lstm_num_layers,
    lstm_input_size=lstm_input_size,
    lstm_hidden_size=lstm_hidden_size,
    lstm_proj_size=lstm_proj_size,
    lstm_bidirectional=lstm_bidirectional,
    lstm_dropout_p=lstm_dropout_p,
    lstm_dropout_first=lstm_dropout_first,
    lstm_layer_norm=lstm_layer_norm,
    lstm_linear_dropout_p=lstm_linear_dropout_p,
    lstm_linear_dropout_first=lstm_linear_dropout_first,
    lstm_linear_batch_norm=lstm_linear_batch_norm,
    lstm_linear_batch_norm_momentum=lstm_linear_batch_norm_momentum,
    linear_num_layers=linear_num_layers,
    linear_num_features=linear_num_features,
    linear_activation=linear_activation,
    linear_activation_kwargs=linear_activation_kwargs,
    linear_dropout_p=linear_dropout_p,
    linear_dropout_first=linear_dropout_first,
    linear_batch_norm=linear_batch_norm,
    linear_batch_norm_momentum=linear_batch_norm_momentum)
model.to(device)

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
scheduler = None

In [None]:
trainer = TrainerAttackerLSTM(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=dataloader_train,
    eval_loader=dataloader_eval,
    device=device,
    max_grad_norm=1.0,
    norm_type=2
)

In [None]:
tester = TesterAttackerLSTM(
    model=model,
    criterion=criterion,
    test_loader=dataloader_test,
    device=device
)

In [None]:
manager = ManagerAttackerLSTM(
    models_dir=models_dir,
    models_file=models_file
)

In [None]:
trainer.train(num_epochs=num_epochs, verbose=True)

In [None]:
finish_time = datetime.now()
best_eval_loss_epoch = trainer.best_eval_loss_epoch
best_eval_loss = trainer.best_eval_loss
best_eval_accuracy = trainer.eval_accuracies[best_eval_loss_epoch]

print(f'Finished training at {finish_time}')
print(f'Best evaluation loss epoch found at: {best_eval_loss_epoch}')
print(f'Best evaluation loss found: {best_eval_loss:.4f}')
print(f'Best evaluation accuracy found: {best_eval_accuracy:.4f}')

In [None]:
plot_train_eval_loss_accuracy(train_loss=trainer.train_losses,
                              train_accuracy=trainer.train_accuracies,
                              eval_loss=trainer.eval_losses,
                              eval_accuracy=trainer.eval_accuracies,
                              saved_epoch=best_eval_loss_epoch,
                              output_path=plots_dir,
                              output_file=f"model_attacker_pool_{model_id}_train_eval_loss_acc.png")

In [None]:
tester.test()

In [None]:
print(f'Test loss: {tester.loss:.4f}')
print(f'Test accuracy: {tester.accuracy_score:.2f}')
print(f'Test precision: {tester.precision_score:.2f}')
print(f'Test recall: {tester.recall_score:.2f}')
print(f'Test f1: {tester.f1_score:.2f}')
print(f'Test AUC: {tester.auroc_score:.2f}')

In [None]:
fpr, tpr, _ = tester.roc_curve
plot_receiver_operating_characteristics_curve(false_positive_rates=fpr,
                                              true_positive_rates=tpr,
                                              auc=tester.auroc_score,
                                              output_path=plots_dir,
                                              output_file=f"model_attacker_pool_{model_id}_roc_curve.png")

In [None]:
plot_confusion_matrix(confusion_matrix=tester.confusion_matrix_scores,
                      task="binary",
                      output_path=plots_dir,
                      output_file=f"model_attacker_pool_{model_id}_confusion_matrix.png")

In [None]:
manager.add_model(
    model_id=model_id,
    random_seed=random_seed,
    data=dataset,
    loader=dataloader_train,
    model=model,
    trainer=trainer,
    tester=tester
)