# Setup

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import re
import time
import pickle

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from collections import Counter

from sklearn import metrics
from sklearn.metrics import classification_report, f1_score

In [None]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
dataset = {'train': [], 'validation': [], 'test': []}
dataset_path = '/data/graphner_embeddings/ae_emb_npy_2000_15epochs/'

for split in dataset:
    files_list = os.listdir(dataset_path+split)
    for i, filename in tqdm(enumerate(sorted(files_list)), total=len(files_list)):
        dataset[split].append(pickle.load(open(dataset_path+split+'/'+str(i)+'.pickle', 'rb')))

In [None]:
labels = pickle.load(open('labels.pickle', 'rb'))
label2id = {l: i for i, l in enumerate(labels)}
print(label2id)

# Dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, split, label2id=label2id):
        X, Y = zip(*dataset[split])

        self.X = [torch.tensor(x) for x in X]
        self.Y = [torch.tensor(y) for y in Y]
        self.X_len = len(X)
        self.labels = sorted(label2id.keys())
        self.label2id = label2id

    def __len__(self):
        return self.X_len

    def __getitem__(self, index):
        x = self.X[index]
        y = self.Y[index]
        x.requires_grad = False
        
        return x, y

    def labels(self):
        return self.labels
    
    def Y(self):
        return self.Y

In [None]:
train_set = Dataset(dataset, 'train')
dev_set = Dataset(dataset, 'validation')
test_set = Dataset(dataset, 'test')

In [None]:
batch_size = 64
num_workers = 4

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, shuffle=True)
dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [None]:
input_dim = 0
for local_features, local_labels in train_loader:
    input_dim = local_features.shape[1]
    print(local_features.shape)
    print(local_labels.shape)
    break

In [None]:
input_dim

In [None]:
training_counter = Counter([y.item() for y in train_set.Y])
print(training_counter)

In [None]:
labels

# The Model

In [None]:
def backprop(batch_X, batch_Y, model, optimizer, loss_fn):
    Y_hat = model(batch_X)
    loss = loss_fn(Y_hat, batch_Y)
    loss.backward()
    optimizer.step()
    
    return loss.item()

class FeedForwardNetwork(nn.Module):
    def __init__(self, input_dim=input_dim, hidden_dim=512, output_dim=5, dropout_rate=0.2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fch = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

        # extra layers layers
        self.dropout = nn.Dropout(p=dropout_rate)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fch(x)
        x = self.relu(x)
        x = self.dropout(x)
        logits = self.fc2(x)

        return logits

In [None]:
ffnet = FeedForwardNetwork(dropout_rate=0.2, hidden_dim=1024).to(device)

In [None]:
logs = {'loss/train': {}, 'dev': {}}
writer = SummaryWriter(comment='xp5-autoreg-wei2-lr1e3-mom0.9-wd5e4-hd1024-dr0.2-bs64-dim2000-15', log_dir=None,)

In [None]:
label_counter   = Counter([y.item() for y in train_set.Y])
labels_freqs    = [label_counter[label] / sum(label_counter.values()) for label in range(len(labels))]
labels_weights1 = [min(label_counter.values()) / label_counter[label] for label in range(len(labels))]
labels_weights2 = [np.sqrt(min(label_counter.values())) / np.sqrt(label_counter[label]) for label in range(len(labels))]

weights = torch.Tensor(labels_weights2).to(device)
print(weights)

In [None]:
optimizer_params = {'lr': 5e-3, 
                    'momentum': 0.9, 
                    'weight_decay': 5e-4,
                   }

log_interval = int(len(train_loader) / 2)

loss_fn = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(ffnet.parameters(), **optimizer_params)

In [None]:
%%time
max_epochs = 1

for epoch in range(len(logs['loss/train']), len(logs['loss/train']) + max_epochs):
    
    # Training
    ffnet.train()
    print('Epoch', epoch)
    logs['loss/train'][epoch] = []
    writer.add_scalar("Learning_rate", optimizer_params['lr'], epoch)

    for batch, (batch_X, batch_Y) in enumerate(tqdm(train_loader)):
        # tranfer to GPU
        batch_X, batch_Y = batch_X.float().to(device), batch_Y.to(device)
        optimizer.zero_grad()
        l = backprop(batch_X, batch_Y, ffnet, optimizer, loss_fn)
        logs['loss/train'][epoch].append(l)
        
        if batch % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch * len(batch_X), len(train_loader.dataset),
                100. * batch / len(train_loader), l))
    
    logs['loss/train'][epoch] = np.mean(logs['loss/train'][epoch])
    writer.add_scalar("Loss/train", logs['loss/train'][epoch], epoch)
    print(f'Average loss on epoch {epoch}: {logs["loss/train"][epoch]}')
    
    # Validation
    ffnet.eval()
    with torch.no_grad():
        preds = []
        gt = []
        for batch, (batch_X, batch_Y) in enumerate(tqdm(dev_loader)):
            # Transfer to GPU
            batch_X = batch_X.float().to(device)
            output = nn.Softmax(dim=1)(ffnet(batch_X))
            preds.append(output.cpu())
            gt.append(batch_Y)

        all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
        all_gt  = [l for batch in gt for l in batch.numpy()]

        print(classification_report(all_out, all_gt, digits=4))

        micro_F1 = metrics.f1_score(all_gt, all_out, average='micro')
        macro_F1 = metrics.f1_score(all_gt, all_out, average='macro')
        weighted_F1 = metrics.f1_score(all_gt, all_out, average='weighted')
        writer.add_scalar("micro_F1/dev", micro_F1, epoch)
        writer.add_scalar("macro_F1/dev", macro_F1, epoch)
        writer.add_scalar("weighted_F1/dev", weighted_F1, epoch)
        logs['dev'][epoch] = (micro_F1, weighted_F1, macro_F1, (all_gt, all_out))

In [None]:
ffnet.eval()
with torch.no_grad():
    preds = []
    gt = []
    for batch, (batch_X, batch_Y) in enumerate(tqdm(test_loader)):
        # Transfer to GPU
        batch_X = batch_X.float().to(device)
        output = nn.Softmax(dim=1)(ffnet(batch_X))
        preds.append(output.cpu())
        gt.append(batch_Y)

    all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
    all_gt  = [l for batch in gt for l in batch.numpy()]

    print(classification_report(all_out, all_gt, digits=4))