In [1]:
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR

from utils_io import read_bitarrays
from utils_classification import BeaconAttackerDataset, LSTMAttackerDataLoader, LSTMAttacker, LSTMAttackerTrainer, LSTMAttackerTester, stratified_random_split

ImportError: cannot import name 'BeaconAttacker' from 'utils_classification' (D:\Education\NEU\CS7990 - Thesis\Projects\Genomic Privacy\utils_classification\__init__.py)

In [None]:
num_snps = 40000
genomes_beacon = read_bitarrays('../data/test/In_Pop.pkl')[:, :num_snps]
genomes_reference = read_bitarrays('../data/test/Not_In_Pop.pkl')[:, :num_snps]
genomes = np.concatenate((genomes_beacon, genomes_reference), axis=0)

In [None]:
labels_beacon = np.ones(genomes_beacon.shape[0], dtype=bool)
labels_reference = np.zeros(genomes_reference.shape[0], dtype=bool)
labels = np.concatenate((labels_beacon, labels_reference), axis=0).astype(bool)

In [None]:
presences_beacon = np.any(genomes_beacon, axis=0).astype(bool)
frequencies_reference = np.mean(genomes, axis=0)

In [None]:
dataset = BeaconAttackerDataset(
    target_genomes=genomes,
    beacon_presences=presences_beacon,
    reference_frequencies=frequencies_reference,
    labels=labels)
subset_train, subset_eval, subset_test = stratified_random_split(dataset, [0.7, 0.15, 0.15])

In [None]:
genomes_batch_size, snps_batch_size = 32, 10000
loader_train = LSTMAttackerDataLoader(subset_train, genomes_batch_size, snps_batch_size, shuffle=True)
loader_eval = LSTMAttackerDataLoader(subset_eval, genomes_batch_size, snps_batch_size, shuffle=False)
loader_test = LSTMAttackerDataLoader(subset_test, genomes_batch_size, snps_batch_size, shuffle=False)

In [None]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
model = LSTMAttacker(input_size=3, hidden_size=64, num_layers=1, bidirectional=False, dropout=0.5)
model.to(device)

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=1, gamma=0.9) 

In [None]:
trainer = LSTMAttackerTrainer(model, criterion, optimizer, scheduler, loader_train, loader_eval, device)

In [None]:
losses_train, accuracies_train, losses_eval, accuracies_eval = trainer.train(num_epochs=256, verbose=True)

In [None]:
model.save("../models", f"beacon_lstm_attacker_{datetime.now().strftime('%Y%m%d%H%M%S')}")
# model.load("../models", "beacon_lstm_attacker_20210919123456")

In [None]:
tester = LSTMAttackerTester(model, criterion, loader_test, device)

In [None]:
loss, accuracy, precision, recall, f1, auroc, cm = tester.test()
print(f"Loss: {loss:.4f}")

In [None]:
chance = np.mean(subset_test.targets)
