In [12]:
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR

from utils_io import read_bitarrays
from utils_attacker_lstm import DatasetAttackerLSTMBeacon, LSTMAttackerDataLoader, ModelAttackerLSTM, LSTMAttackerTrainer, LSTMAttackerTester
from utils_torch import stratified_random_split

In [13]:
num_snps = 40000
genomes_beacon = read_bitarrays('../data/test/In_Pop.pkl')[:, :num_snps]
genomes_reference = read_bitarrays('../data/test/Not_In_Pop.pkl')[:, :num_snps]
genomes = np.concatenate((genomes_beacon, genomes_reference), axis=0)

In [14]:
labels_beacon = np.ones(genomes_beacon.shape[0], dtype=bool)
labels_reference = np.zeros(genomes_reference.shape[0], dtype=bool)
labels = np.concatenate((labels_beacon, labels_reference), axis=0).astype(bool)

In [15]:
presences_beacon = np.any(genomes_beacon, axis=0).astype(bool)
frequencies_reference = np.mean(genomes, axis=0)

In [16]:
dataset = DatasetAttackerLSTMBeacon(
    target_genomes=genomes,
    beacon_presences=presences_beacon,
    reference_frequencies=frequencies_reference,
    labels=labels)
subset_train, subset_eval, subset_test = stratified_random_split(dataset, [0.7, 0.15, 0.15])

In [17]:
genomes_batch_size, snps_batch_size = 32, 10000
loader_train = LSTMAttackerDataLoader(subset_train, genomes_batch_size, snps_batch_size, shuffle=True)
loader_eval = LSTMAttackerDataLoader(subset_eval, genomes_batch_size, snps_batch_size, shuffle=False)
loader_test = LSTMAttackerDataLoader(subset_test, genomes_batch_size, snps_batch_size, shuffle=False)

In [18]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [19]:
model = ModelAttackerLSTM(input_size=3, hidden_size=64, num_layers=1, bidirectional=False, dropout=0.5)
model.to(device)

LSTMAttacker(
  (lstm): LSTM(3, 64, batch_first=True)
  (linear): Linear(in_features=64, out_features=1, bias=True)
)

In [20]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# scheduler = StepLR(optimizer, step_size=1, gamma=0.9) 

In [21]:
trainer = LSTMAttackerTrainer(model, criterion, optimizer, loader_train, loader_eval, device)

In [22]:
losses_train, accuracies_train, losses_eval, accuracies_eval = trainer.train(num_epochs=256, verbose=True)

Epoch 1/256
Train Loss: 0.6947, Train Accuracy: 0.50
Evaluation Loss: 0.6935, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: inf -> 0.6935. Saving Model...
Epoch 2/256
Train Loss: 0.6931, Train Accuracy: 0.50
Evaluation Loss: 0.6930, Evaluation Accuracy: 0.52
Evaluation Loss Decreased: 0.6935 -> 0.6930. Saving Model...
Epoch 3/256
Train Loss: 0.6930, Train Accuracy: 0.51
Evaluation Loss: 0.6929, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6930 -> 0.6929. Saving Model...
Epoch 4/256
Train Loss: 0.6931, Train Accuracy: 0.50
Evaluation Loss: 0.6929, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6929 -> 0.6929. Saving Model...
Epoch 5/256
Train Loss: 0.6931, Train Accuracy: 0.50
Evaluation Loss: 0.6928, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6929 -> 0.6928. Saving Model...
Epoch 6/256
Train Loss: 0.6931, Train Accuracy: 0.50
Evaluation Loss: 0.6928, Evaluation Accuracy: 0.50
Evaluation Loss Decreased: 0.6928 -> 0.6928. Saving Model...
Epoch 7/256
T

KeyboardInterrupt: 

In [None]:
model.save("../models", f"beacon_lstm_attacker_{datetime.now().strftime('%Y%m%d%H%M%S')}")
# model.load("../models", "beacon_lstm_attacker_20210919123456")

In [None]:
tester = LSTMAttackerTester(model, criterion, loader_test, device)

In [None]:
loss, accuracy, precision, recall, f1, auroc, cm = tester.test()
print(f"Loss: {loss:.4f}")

In [None]:
chance = np.mean(subset_test.targets)
