In [52]:
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn

from utils_attacker_lstm.ManagerAttackerLSTM import ManagerAttackerLSTM
from utils_io import read_bitarrays
from utils_attacker_lstm import DatasetAttackerLSTMPool, DataLoaderAttackerLSTM, ModelAttackerConvLSTMLinear, LSTMAttackerTrainer, LSTMAttackerTester
from utils_torch import stratified_random_split

In [53]:
num_snps = 40000
genomes_pool = read_bitarrays('../data/test/In_Pop.pkl')[:, :num_snps]
genomes_reference = read_bitarrays('../data/test/Not_In_Pop.pkl')[:, :num_snps]
genomes = np.concatenate((genomes_pool, genomes_reference), axis=0)

In [54]:
labels_beacon = np.ones(genomes_pool.shape[0], dtype=bool)
labels_reference = np.zeros(genomes_reference.shape[0], dtype=bool)
labels = np.concatenate((labels_beacon, labels_reference), axis=0).astype(bool)

In [55]:
frequencies_pool = np.mean(genomes_pool, axis=0)
frequencies_reference = np.mean(genomes_reference, axis=0)

In [56]:
dataset = DatasetAttackerLSTMPool(
    target_genomes=genomes,
    pool_frequencies=frequencies_pool,
    reference_frequencies=frequencies_reference,
    targets=labels)
subset_train, subset_eval, subset_test = stratified_random_split(dataset, [0.7, 0.15, 0.15])

In [57]:
genomes_batch_size, snps_batch_size = 32, 20000
loader_train = DataLoaderAttackerLSTM(subset_train, genomes_batch_size, snps_batch_size, shuffle=True)
loader_eval = DataLoaderAttackerLSTM(subset_eval, genomes_batch_size, snps_batch_size, shuffle=False)
loader_test = DataLoaderAttackerLSTM(subset_test, genomes_batch_size, snps_batch_size, shuffle=False)

In [58]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [59]:
model = ModelAttackerConvLSTMLinear(
    conv_in_channels=3,
    conv_out_channels=6,
    conv_kernel_size=20,
    conv_stride=1,
    lstm_hidden_size=9,
    lstm_num_layers=1,
    lstm_bidirectional=False,
    lstm_dropout=0.66)
model.to(device)

ModelAttackerConvLSTMLinear(
  (lstm): LSTM(6, 9, batch_first=True)
  (linear): Linear(in_features=9, out_features=1, bias=True)
  (conv): Conv1d(3, 6, kernel_size=(20,), stride=(1,))
)

In [60]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# scheduler = StepLR(optimizer, step_size=1, gamma=0.9) 

In [61]:
trainer = LSTMAttackerTrainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_loader=loader_train,
    eval_loader=loader_eval,
    device=device,
    max_grad_norm=1.0,
    norm_type=2)

In [62]:
trainer.train(num_epochs=256, verbose=True)

Epoch 1/256
Train Loss: 0.7016, Train Accuracy: 0.50
Evaluation Loss: 0.6981, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: inf -> 0.6981. Saving Model...
Epoch 2/256
Train Loss: 0.6970, Train Accuracy: 0.50
Evaluation Loss: 0.6950, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6981 -> 0.6950. Saving Model...
Epoch 3/256
Train Loss: 0.6943, Train Accuracy: 0.50
Evaluation Loss: 0.6937, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6950 -> 0.6937. Saving Model...
Epoch 4/256
Train Loss: 0.6934, Train Accuracy: 0.50
Evaluation Loss: 0.6935, Evaluation Accuracy: 0.49
Evaluation Loss Decreased: 0.6937 -> 0.6935. Saving Model...
Epoch 5/256
Train Loss: 0.6934, Train Accuracy: 0.48
Evaluation Loss: 0.6935, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6935 -> 0.6935. Saving Model...
Epoch 6/256
Train Loss: 0.6934, Train Accuracy: 0.46
Evaluation Loss: 0.6934, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6935 -> 0.6934. Saving Model...
Epoch 7/256
T

In [63]:
training_datetime = datetime.now()
print(f'Finished training at {datetime.now()}')
print(f'Best evaluation epoch found at: {trainer.best_eval_loss_epoch}')
print(f'Best evaluation loss: {trainer.best_eval_loss:.4f}')
print(f'Best evaluation accuracy: {trainer.best_eval_accuracy:.2f}')

Finished training at 2024-10-28 11:22:00.843163
Best evaluation epoch found at: 40
Best evaluation loss: 0.6899
Best evaluation accuracy: 0.56


In [64]:
tester = LSTMAttackerTester(
    model=model,
    criterion=criterion,
    test_loader=loader_test,
    device=device)

In [65]:
tester.test()

In [66]:
print(f'Test loss: {tester.loss:.4f}')
print(f'Test accuracy: {tester.accuracy_score:.2f}')
print(f'Test precision: {tester.precision_score:.2f}')
print(f'Test recall: {tester.recall_score:.2f}')
print(f'Test F1: {tester.f1_score:.2f}')
print(f'Test AUC: {tester.auroc_score:.2f}')

Test loss: 0.6908
Test accuracy: 0.54
Test precision: 0.56
Test recall: 0.42
Test F1: 0.48
Test AUC: 0.54


In [67]:
manager = ManagerAttackerLSTM(
    models_dir='../models',
    models_file="models.csv"
)

In [68]:
manager.add_model(
    model_id=training_datetime.strftime('%m%d%H%M'),
    data=dataset,
    loader=loader_train,
    model=model,
    trainer=trainer,
    tester=tester)