In [12]:
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn

from utils_io import read_bitarrays
from utils_attacker_lstm import DatasetAttackerLSTMPool, LSTMAttackerDataLoader, ModelAttackerLSTMLinear, LSTMAttackerTrainer, LSTMAttackerTester
from utils_torch import stratified_random_split

In [13]:
num_snps = 40000
genomes_pool = read_bitarrays('../data/test/In_Pop.pkl')[:, :num_snps]
genomes_reference = read_bitarrays('../data/test/Not_In_Pop.pkl')[:, :num_snps]
genomes = np.concatenate((genomes_pool, genomes_reference), axis=0)

In [14]:
labels_beacon = np.ones(genomes_pool.shape[0], dtype=bool)
labels_reference = np.zeros(genomes_reference.shape[0], dtype=bool)
labels = np.concatenate((labels_beacon, labels_reference), axis=0).astype(bool)

In [15]:
frequencies_pool = np.mean(genomes_pool, axis=0)
frequencies_reference = np.mean(genomes_reference, axis=0)

In [16]:
dataset = DatasetAttackerLSTMPool(
    target_genomes=genomes,
    pool_frequencies=frequencies_pool,
    reference_frequencies=frequencies_reference,
    labels=labels)
subset_train, subset_eval, subset_test = stratified_random_split(dataset, [0.7, 0.15, 0.15])

In [17]:
genomes_batch_size, snps_batch_size = 32, 20000
loader_train = LSTMAttackerDataLoader(subset_train, genomes_batch_size, snps_batch_size, shuffle=True)
loader_eval = LSTMAttackerDataLoader(subset_eval, genomes_batch_size, snps_batch_size, shuffle=False)
loader_test = LSTMAttackerDataLoader(subset_test, genomes_batch_size, snps_batch_size, shuffle=False)

In [18]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [19]:
model = ModelAttackerLSTMLinear(lstm_input_size=3, lstm_hidden_size=16, lstm_num_layers=1, lstm_bidirectional=False, lstm_dropout=0.5)
model.to(device)

LSTMAttacker(
  (lstm): LSTM(3, 16, batch_first=True)
  (linear): Linear(in_features=16, out_features=1, bias=True)
)

In [20]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# scheduler = StepLR(optimizer, step_size=1, gamma=0.9) 

In [21]:
trainer = LSTMAttackerTrainer(model, criterion, optimizer, loader_train, loader_eval, device)

In [22]:
losses_train, accuracies_train, losses_eval, accuracies_eval = trainer.train(num_epochs=256, verbose=True)
min_loss = min(losses_eval)

Epoch 1/256
Train Loss: 0.6940, Train Accuracy: 0.50
Evaluation Loss: 0.6934, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: inf -> 0.6934. Saving Model...
Epoch 2/256
Train Loss: 0.6934, Train Accuracy: 0.50
Evaluation Loss: 0.6931, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6934 -> 0.6931. Saving Model...
Epoch 3/256
Train Loss: 0.6932, Train Accuracy: 0.50
Evaluation Loss: 0.6930, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6931 -> 0.6930. Saving Model...
Epoch 4/256
Train Loss: 0.6931, Train Accuracy: 0.50
Evaluation Loss: 0.6929, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6930 -> 0.6929. Saving Model...
Epoch 5/256
Train Loss: 0.6931, Train Accuracy: 0.47
Evaluation Loss: 0.6928, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6929 -> 0.6928. Saving Model...
Epoch 6/256
Train Loss: 0.6931, Train Accuracy: 0.51
Evaluation Loss: 0.6928, Evaluation Accuracy: 0.52
Evaluation Loss Decreased: 0.6928 -> 0.6928. Saving Model...
Epoch 7/256
T

In [23]:
model.save("../models", f"Attacker_LSTM_Pool_SNP{num_snps}_LSS{int(min_loss * 10000)}_DTT{datetime.now().strftime('%m%d%H%M')}")
# model.load("../models", "pool_lstm_attacker_20210919123456")

In [24]:
tester = LSTMAttackerTester(model, criterion, loader_test, device)

In [25]:
loss, accuracy, precision, recall, f1, auroc, cm = tester.test()
print(f"Loss: {loss:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1: {f1:.4f}")
print(f"AUROC: {auroc:.4f}")

Loss: 0.7240
Accuracy: 0.5500
Precision: 0.5789
Recall: 0.3667
F1: 0.4490
AUROC: 0.5500


In [26]:
# chance = dataset.get_chance()
# print(f"Chance: {chance:.4f}")