<H1 align="center">Long Short-Term Memory for Membership Inference Attack on Beacon Data</H1>

<br>
<strong>This notebook performs a long short-term memory membership inference attack on the beacon data with set hyperparameters.</strong>

## Import Libraries

### Python Libraries 

In [None]:
from datetime import datetime

### External Libraries

In [None]:
import torch
import torch.nn as nn

### Custom Libraries

In [None]:
from src.utils_random import set_random_seed
from src.utils_torch.data import stratified_random_split
from src.utils_attacker_lstm.data import DatasetAttackerLSTMBeacon, DataLoaderAttackerLSTM
from src.utils_attacker_lstm.models import ModelAttackerLSTM, TesterAttackerLSTM, TrainerAttackerLSTM, \
    ManagerAttackerLSTM
from src.utils_plot import plot_train_eval_loss_accuracy, plot_receiver_operating_characteristics_curve, \
    plot_confusion_matrix, plot_long_short_term_memory

## Set Parameters

### Model Id

In [None]:
model_id = "999999"

In [None]:
random_seed = 42

### Data Params

In [None]:
num_snps = 40000
train_eval_test_split = [0.7, 0.15, 0.15]

### Loader Params

In [None]:
genome_batch_size = 32
snp_batch_size = 80000

### Model Params

#### Conv1d Params

In [None]:
conv_num_layers = 3
conv_channel_size = [3, 16, 32, 16]
conv_kernel_size = [20, 10, 10]
conv_stride = [2, 2, 2]
conv_dilation = [1, 1, 1]
conv_groups = [1, 1, 1]

conv_activation = [nn.ReLU, nn.ReLU]
conv_activation_kwargs = [{}, {}]

conv_dropout_p = [0.5, 0.5]
conv_dropout_first = [True, True]

conv_batch_norm = [True, True]
conv_batch_norm_momentum = [0.1, 0.1]

#### Conv1d to LSTM Params

In [None]:
conv_lstm_activation = nn.ReLU
conv_lstm_activation_kwargs = {}
conv_lstm_dropout_p = 0.5
conv_lstm_dropout_first = True
conv_lstm_layer_norm = True

#### LSTM Params

In [None]:
lstm_num_layers = 1
lstm_input_size = 16
lstm_hidden_size = [32]
lstm_proj_size = [0]
lstm_bidirectional = [True]

lstm_dropout_p = []
lstm_dropout_first = []

lstm_layer_norm = []

#### LSTM to Linear Params

In [None]:
lstm_linear_dropout_p = 0.5
lstm_linear_dropout_first = True

lstm_linear_batch_norm = True
lstm_linear_batch_norm_momentum = 0.1

#### Linear Params

In [None]:
linear_num_layers = 1
linear_num_features = [64, 1]

linear_activation = []
linear_activation_kwargs = []

linear_dropout_p = []
linear_dropout_first = []

linear_batch_norm = []
linear_batch_norm_momentum = []

### Trainer Params

In [None]:
num_epochs = 256
learning_rate = 0.001

### IO Params

In [None]:
models_dir = "../models"
models_file = "models.csv"
plots_dir = "../plots"

## Set Torch Device

In [None]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

## Set Random Seed

In [None]:
set_random_seed(random_seed)

## Create Dataset

In [None]:
dataset = DatasetAttackerLSTMBeacon(genomes_beacon_path='../data/test/In_Pop.pkl',
                                    genomes_reference_path='../data/test/Not_In_Pop.pkl',
                                    num_snps=num_snps)
subset_train, subset_eval, subset_test = stratified_random_split(dataset, train_eval_test_split)

## Create Data Loaders

In [None]:
dataloader_train = DataLoaderAttackerLSTM(dataset=subset_train,
                                          genome_batch_size=genome_batch_size,
                                          snp_batch_size=snp_batch_size,
                                          shuffle=True)
dataloader_eval = DataLoaderAttackerLSTM(dataset=subset_eval,
                                         genome_batch_size=genome_batch_size,
                                         snp_batch_size=snp_batch_size,
                                         shuffle=False)
dataloader_test = DataLoaderAttackerLSTM(dataset=subset_test,
                                         genome_batch_size=genome_batch_size,
                                         snp_batch_size=snp_batch_size,
                                         shuffle=False)

## Create Model

In [None]:
model = ModelAttackerLSTM(conv_num_layers=conv_num_layers,
                          conv_channel_size=conv_channel_size,
                          conv_kernel_size=conv_kernel_size,
                          conv_stride=conv_stride,
                          conv_dilation=conv_dilation,
                          conv_groups=conv_groups,
                          conv_activation=conv_activation,
                          conv_activation_kwargs=conv_activation_kwargs,
                          conv_dropout_p=conv_dropout_p,
                          conv_dropout_first=conv_dropout_first,
                          conv_batch_norm=conv_batch_norm,
                          conv_batch_norm_momentum=conv_batch_norm_momentum,
                          conv_lstm_activation=conv_lstm_activation,
                          conv_lstm_activation_kwargs=conv_lstm_activation_kwargs,
                          conv_lstm_dropout_p=conv_lstm_dropout_p,
                          conv_lstm_dropout_first=conv_lstm_dropout_first,
                          conv_lstm_layer_norm=conv_lstm_layer_norm,
                          lstm_num_layers=lstm_num_layers,
                          lstm_input_size=lstm_input_size,
                          lstm_hidden_size=lstm_hidden_size,
                          lstm_proj_size=lstm_proj_size,
                          lstm_bidirectional=lstm_bidirectional,
                          lstm_dropout_p=lstm_dropout_p,
                          lstm_dropout_first=lstm_dropout_first,
                          lstm_layer_norm=lstm_layer_norm,
                          lstm_linear_dropout_p=lstm_linear_dropout_p,
                          lstm_linear_dropout_first=lstm_linear_dropout_first,
                          lstm_linear_batch_norm=lstm_linear_batch_norm,
                          lstm_linear_batch_norm_momentum=lstm_linear_batch_norm_momentum,
                          linear_num_layers=linear_num_layers,
                          linear_num_features=linear_num_features,
                          linear_activation=linear_activation,
                          linear_activation_kwargs=linear_activation_kwargs,
                          linear_dropout_p=linear_dropout_p,
                          linear_dropout_first=linear_dropout_first,
                          linear_batch_norm=linear_batch_norm,
                          linear_batch_norm_momentum=linear_batch_norm_momentum)
model.to(device)

## Create Trainer

### Create Criterion and Optimizer

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=learning_rate)
scheduler = None

### Create Trainer

In [None]:
trainer = TrainerAttackerLSTM(model=model,
                              criterion=criterion,
                              optimizer=optimizer,
                              scheduler=scheduler,
                              train_loader=dataloader_train,
                              eval_loader=dataloader_eval,
                              device=device,
                              max_grad_norm=1.0,
                              norm_type=2)

## Create Tester

In [None]:
tester = TesterAttackerLSTM(model=model,
                            criterion=criterion,
                            test_loader=dataloader_test,
                            device=device)

## Create Manager

In [None]:
manager = ManagerAttackerLSTM(models_dir=models_dir,
                              models_file=models_file)

## Train Model

### Train Model

In [None]:
trainer.train(num_epochs=256, verbose=True)

### Print Metrics

In [None]:
finish_time = datetime.now()
best_eval_loss_epoch = trainer.best_eval_loss_epoch
best_eval_loss = trainer.best_eval_loss
best_eval_accuracy = trainer.eval_accuracies[best_eval_loss_epoch]

print(f'Finished training at {finish_time}')
print(f'Best evaluation loss epoch found at: {best_eval_loss_epoch}')
print(f'Best evaluation loss found: {best_eval_loss:.4f}')
print(f'Best evaluation accuracy found: {best_eval_accuracy:.4f}')

### Plot Metrics

In [None]:
plot_train_eval_loss_accuracy(train_loss=trainer.train_losses,
                              train_accuracy=trainer.train_accuracies,
                              eval_loss=trainer.eval_losses,
                              eval_accuracy=trainer.eval_accuracies,
                              saved_epoch=best_eval_loss_epoch,
                              output_path=plots_dir,
                              output_file=f"model_attacker_beacon_{model_id}_train_eval_loss_acc.png")

## Test Model

### Test Model

In [None]:
tester.test()

### Print Metrics

In [None]:
print(f'Test loss: {tester.loss:.4f}')
print(f'Test accuracy: {tester.accuracy_score:.2f}')
print(f'Test precision: {tester.precision_score:.2f}')
print(f'Test recall: {tester.recall_score:.2f}')
print(f'Test f1: {tester.f1_score:.2f}')
print(f'Test AUC: {tester.auroc_score:.2f}')

### Plot ROC Curve

In [None]:
fpr, tpr, _ = tester.roc_curve
plot_receiver_operating_characteristics_curve(false_positive_rates=fpr,
                                              true_positive_rates=tpr,
                                              auc=tester.auroc_score,
                                              output_path=plots_dir,
                                              output_file=f"model_attacker_beacon_{model_id}_roc_curve.png")

### Plot Confusion Matrix

In [None]:
plot_confusion_matrix(confusion_matrix=tester.confusion_matrix_scores,
                      task="binary",
                      output_path=plots_dir,
                      output_file=f"model_attacker_beacon_{model_id}_confusion_matrix.png")

## Save Model

In [None]:
manager.add_model(model_id=model_id,
                  random_seed=random_seed,
                  data=dataset,
                  loader=dataloader_train,
                  model=model,
                  trainer=trainer,
                  tester=tester)

## Plot Memory

In [None]:
model.set_hidden_cell_mode(True)
model.eval()
with torch.no_grad():
    x, y = dataloader_test.dataset[0]
    x = x.unsqueeze(0).to(device)
    hx = None
    logits, out = model.forward(x, hx)
out_last = out[-1]
(h, c), (h_last, c_last) = out_last
h, c = h.squeeze(0), c.squeeze(0)

In [None]:
plot_long_short_term_memory(long_term_memory=c.cpu(),
                            short_term_memory=h.cpu(),
                            bidirectional=lstm_bidirectional[-1],
                            output_path=plots_dir,
                            output_file=f"model_attacker_beacon_{model_id}_lstm.png",
                            show=False)