In [106]:
import copy
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split

from utils_io import read_bitarrays
from utils_plot import plot_receiver_operating_characteristics_curve
from utils_classification import PoolAttackerDataset, LSTMAttacker, LSTMAttackerTrainer

In [None]:
num_snps = 40000

In [107]:
genomes_pool = read_bitarrays("../data/test/In_Pop.pkl")[:, :num_snps]
genomes_reference = read_bitarrays("../data/test/Not_In_Pop.pkl")[:, :num_snps]
genomes = np.concatenate((genomes_pool, genomes_reference), axis=0)

In [109]:
labels_pool = np.ones(genomes_pool.shape[0], dtype=bool)
labels_reference = np.zeros(genomes_reference.shape[0], dtype=bool)
labels = np.concatenate((labels_pool, labels_reference), axis=0).astype(bool)

In [None]:
frequencies_pool = np.mean(genomes_pool, axis=0)
frequencies_reference = np.mean(genomes_reference, axis=0)

In [114]:
dataset = PoolAttackerDataset(
    target_genomes=genomes,
    pool_frequencies=frequencies_pool,
    reference_frequencies=frequencies_reference,
    labels=labels)

In [117]:
num_genomes = len(dataset)
num_genomes_train = int(0.7 * num_genomes)
num_genomes_eval = int(0.15 * num_genomes)
num_genomes_test = num_genomes - num_genomes_train - num_genomes_eval

In [None]:
dataset_train, dataset_eval, dataset_test = random_split(
    dataset,
    [num_genomes_train, num_genomes_eval, num_genomes_test])

In [118]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [119]:
model = LSTMAttacker(input_size=2, hidden_size=64, num_layers=1, bidirectional=False, dropout=0.5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Attacker(
  (lstm): LSTM(2, 64, batch_first=True)
  (linear): Linear(in_features=64, out_features=1, bias=True)
)

In [120]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [121]:
num_epochs = 256
best_epoch = -1
best_val_loss = np.inf
best_state_dict = None

for epoch in range(num_epochs):
    model.train()
    train_losses = []
    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        output = model(x_batch)
        loss = criterion(output, y_batch)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
    model.eval()
    val_losses = []
    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            output = model(x_batch)
            loss = criterion(output, y_batch)
            val_losses.append(loss.item())
    train_loss = np.mean(train_losses)
    val_loss = np.mean(val_losses)
    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_epoch = epoch
        best_val_loss = val_loss
        best_state_dict = copy.deepcopy(model.state_dict())
model.load_state_dict(best_state_dict)
print(f"Best Model found at Epoch {best_epoch + 1}")

Epoch 1/256, Train Loss: 0.6939, Val Loss: 0.6932
Epoch 2/256, Train Loss: 0.6933, Val Loss: 0.6932
Epoch 3/256, Train Loss: 0.6933, Val Loss: 0.6931
Epoch 4/256, Train Loss: 0.6934, Val Loss: 0.6932
Epoch 5/256, Train Loss: 0.6939, Val Loss: 0.6931
Epoch 6/256, Train Loss: 0.6932, Val Loss: 0.6931
Epoch 7/256, Train Loss: 0.6937, Val Loss: 0.6931
Epoch 8/256, Train Loss: 0.6935, Val Loss: 0.6931
Epoch 9/256, Train Loss: 0.6938, Val Loss: 0.6931
Epoch 10/256, Train Loss: 0.6933, Val Loss: 0.6930
Epoch 11/256, Train Loss: 0.6936, Val Loss: 0.6931
Epoch 12/256, Train Loss: 0.6932, Val Loss: 0.6930
Epoch 13/256, Train Loss: 0.6932, Val Loss: 0.6930
Epoch 14/256, Train Loss: 0.6932, Val Loss: 0.6929
Epoch 15/256, Train Loss: 0.6932, Val Loss: 0.6929
Epoch 16/256, Train Loss: 0.6933, Val Loss: 0.6929
Epoch 17/256, Train Loss: 0.6933, Val Loss: 0.6929
Epoch 18/256, Train Loss: 0.6932, Val Loss: 0.6928
Epoch 19/256, Train Loss: 0.6934, Val Loss: 0.6928
Epoch 20/256, Train Loss: 0.6931, Val Lo

In [122]:
model.save("../models", f"pool_lstm_attacker_{datetime.now().strftime('%Y%m%d%H%M%S')}")
# model.load("../models", "attacker_20240916085039")

In [123]:
model.eval()
test_losses = []
correct = 0
total = 0
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        output = model(x_batch)
        loss = criterion(output, y_batch)
        test_losses.append(loss.item())
        predictions = (torch.sigmoid(output) >= 0.5).float()
        correct += (predictions == y_batch).sum().item()
        total += len(y_batch)
test_loss = np.mean(test_losses)
test_accuracy = correct / total

In [124]:
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

Test Loss: 0.6833, Test Accuracy: 0.5813


In [125]:
chance = y_test.mean()
print(f"Chance: {chance}")

Chance: 0.5
