In [None]:
import numpy as np
import torch
from models.rnn import CellType
from models.geo_route_lstm import GeoRouteLSTM
from torch.utils.data import Dataset, DataLoader
from dataset.geo_route import GeoRouteDataset, prepare_tensors

%load_ext autoreload
%autoreload 2

In [None]:
# Find out if a CUDA device (GPU) is available
if torch.cuda.device_count():
    device="cuda"
else:
    device="cpu"
print("Device", device)

In [None]:
# File that contains the data
dataset_file = "dataset_training.pkl.gz"

# The learning rate of the model
lr = 1e-3

# Cell type (LSTM | GRU | RNN)
cell_type=CellType.LSTM

# Number of epochs
n_epochs = 100
# Number of RNN layers
num_layers=3
# Embedding dimension
embedding_dim=32
# Hidden size of the RNN layers
hidden_size=256
# Batch size used for training
batch_size=8192
# Maximum sequence length
max_length=39
# True if bidirectional RNN layers should be used, False otherwise
bidirectional=True

In [None]:
# Create an instance of the dataset and a dataloader
dataset = GeoRouteDataset(dataset_file)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
np.unique(dataset.dest_cc).shape

In [None]:
# Instantiate the network
net = GeoRouteLSTM(device=device).to(device)

# Create an optimizer and a learning rate scheduler for the network
net_optimizer = torch.optim.AdamW(net.parameters(), lr=lr)
net_scheduler = torch.optim.lr_scheduler.StepLR(net_optimizer, 1.0, gamma=0.95)

# Loss function
criterion = torch.nn.NLLLoss()

In [None]:
for epoch in range(n_epochs):
    # Get a batch of training data
    for src_as, dest_as, src_cc, dest_cc, lat, long, asn, ip_source, geo_cc, labels in dataloader:
        src_as, dest_as, src_cc, dest_cc, lat, long, asn, ip_source, geo_cc, labels = prepare_tensors(
            src_as, dest_as, src_cc, dest_cc, lat, long, asn, ip_source, geo_cc, labels, device=device
        )

        # Turn labels into torch.long
        labels = labels.to(torch.long).to(device)
        
        # Create masks for positive and negative labels
        mask_class_0 = labels.squeeze() == 0
        mask_class_1 = labels.squeeze() == 1

        # Set gradients of all model parameters to zero
        net_optimizer.zero_grad()

        # Initialize loss
        loss = 0
        
        # Get logits for each of the two classes
        logits = net(
            lat=lat,
            long=long,
            asn=asn,
            ip_source=ip_source,
            geo_cc=geo_cc,
            src_as=src_as,
            dest_as=dest_as,
            src_cc=src_cc,
            dest_cc=dest_cc,
        )
        
        # Get the most likely class for each input
        topv, topi = logits.topk(1)
        
        # Compute loss for positive and for negative samples
        loss_class_0 = criterion(logits[mask_class_0].squeeze(), labels[mask_class_0].squeeze())
        loss_class_1 = criterion(logits[mask_class_1].squeeze(), labels[mask_class_1].squeeze())
        
        # Compute the loss by putting equal weight on positive and negative samples (similar to focal loss)
        loss = 0.5 * loss_class_0 + 0.5 * loss_class_1
        
        # Get number of positive and neagtive samples
        n_class_0 = mask_class_0.sum().item()
        n_class_1 = mask_class_1.sum().item()
        
        # Compute total accuracy and accuracies for both positive and negative samples
        matchings = labels.squeeze() == topi.squeeze()
        accuracy_total = matchings.sum().item() / batch_size
        accuracy_class_0 = matchings[mask_class_0].sum().item() / n_class_0 if n_class_0 > 0 else 0.0
        accuracy_class_1 = matchings[mask_class_1].sum().item() / n_class_1 if n_class_1 > 0 else 0.0
        
        print(f"LOSS after epoch {epoch}", loss.item() / (labels.size(1)), "AccAll", round(accuracy_total, 3), "Acc0", round(accuracy_class_0, 3), "Acc1", round(accuracy_class_1, 3))

        # Compute gradient
        loss.backward()

        # Update weights of network
        net_optimizer.step()

    # Adjust the learning rate
    net_scheduler.step()

In [None]:
import pickle
from datetime import datetime

filename = f"{cell_type}_{num_layers}layers_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

torch.save({
    'epoch': epoch,
    'net_state_dict': net.state_dict(),
    'net_optimizer_state_dict': net_optimizer.state_dict(),
    'loss': loss,
    "lr": lr,
    "cell_type": cell_type,
    "embedding_dim": embedding_dim,
    "hidden_size": hidden_size,
    "batch_size": batch_size,
    "max_length": max_length,
    "num_layers": num_layers,
    "bidirectional": bidirectional,
}, filename + ".pt")
    
print(str(datetime.now()), "Saved model: " + filename)