# Libraries

In [None]:
import os
import sys
import copy
import time
import math
import pickle
import random
import pathlib
import sqlite3
import tempfile
import importlib
import subprocess

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm
from sru import SRU, SRUCell
from opacus import PrivacyEngine

from configparser import ConfigParser
from distutils.spawn import find_executable
from sklearn.metrics import confusion_matrix
from collections import OrderedDict, namedtuple


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
import torch.nn.functional as F
from torch.autograd import Variable
from tensorboardX import SummaryWriter

In [None]:
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)


torch.cuda.manual_seed(0)
torch.cuda.set_device(2)
torch.cuda.empty_cache() 

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark=True

# Models

In [None]:
class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(3, 50)

        self.selu = nn.SELU()
        self.drop = nn.AlphaDropout(p=0.5)
        #self.drop = nn.Dropout(p=0.5)

        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(50, 50)
        
        self.linear3 = nn.Linear(50, 50)

        self.linear4 = nn.Linear(50, 50)

        self.linear5 = nn.Linear(50,2)


    def forward(self, input):
        output = self.linear1(input)
        output = self.selu(output)
        #output = self.relu(output)

        output = self.linear2(output)
        output = self.selu(output)
        #output = self.relu(output)

        output = self.linear3(output)
        output = self.selu(output)  
        #output = self.relu(output)

        output = self.linear4(output)
        output = self.selu(output)
        #output = self.relu(output)

        output = self.drop(output)
        output = self.linear5(output)
        return output

In [5]:
class CBSDNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, padding=1)

        self.conv2 = nn.Conv1d(in_channels=16, out_channels=16, kernel_size=3, padding=1)

        self.linear1 = nn.Linear(16*3, 100)
        
        self.sru = SRU(input_size=100, hidden_size=100,num_layers=2,bidirectional='True')
        self.linear2 = nn.Linear(100*2, 100*2)

        self.linear3 = nn.Linear(100*2, 2)

        self.relu = nn.ReLU()

        self.dropout1 = nn.Dropout(0.4)
        self.dropout2 = nn.Dropout(0.4)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, input):
        
        output = self.conv1(input)
        output = self.relu(output)

        output = self.conv2(output)
        output = self.relu(output)
        output = self.dropout1(output)
                
        output = output.view(-1, 16*3)
        output1 = self.linear1(output)

        output = output1.view(len(output1),1 ,-1 )
        output, _ = self.sru(output)
        
        output = output.view(len(output1), -1)

        output = self.linear2(output)
        output = self.relu(output)

        output = self.dropout2(output)

        output = self.linear3(output)

        return output

# Select Scrimmage

In [None]:
scr = 4 # Scrimmage 4 or 5

# Load Data

In [None]:
if scr == 4:
    with open('scrimmage4_edge_node_dataset.pickle', 'rb') as file:
        link_data = pickle.load(file)
        
if scr == 5:
    with open('scrimmage5_edge_node_dataset.pickle', 'rb') as file:
        link_data = pickle.load(file)

In [None]:
cols = torch.LongTensor([0,1,3]) # Select Columns

link_data = [(link_datas[0][:,cols], link_datas[1]) for link_datas in link_data]

In [None]:
# Stats to convert noise variance to SNR
scr4_mean = 31.827541
scr4_std = 7.5468507

scr5_mean = 33.17964
scr5_std = 6.6672482

In [None]:
datadist_edge = [0.4, 0.1, 0.50] # distribution among train, validation, 

test_x_npn = []
test_y_npn = []
edge_train_x_npn= []
edge_train_y_npn = []
edge_val_x_npn = []
edge_val_y_npn = []

for i in range(len(link_data)):
    datalen = len(link_data[i][1])
    edge_trainlen = int(datalen*sum(datadist_edge[:1]))
    edge_vallen = int(datalen*sum(datadist_edge[:2]))

    edge_train_x_npn.append(link_data[i][0][0:edge_trainlen].numpy())    
    edge_train_y_npn.append(link_data[i][1][0:edge_trainlen].numpy())
    
    edge_val_x_npn.append(link_data[i][0][edge_trainlen:edge_vallen].numpy())    
    edge_val_y_npn.append(link_data[i][1][edge_trainlen:edge_vallen].numpy())
    
    test_x_npn.append(link_data[i][0][edge_vallen:datalen].numpy())
    test_y_npn.append(link_data[i][1][edge_vallen:datalen].numpy())

In [None]:
# Prepare Tensors for Batch Processing
# Edge Training and Testing for individual links
edge_train_x = [torch.from_numpy(link_x).type(torch.float) for link_x in edge_train_x_npn]
edge_val_x = [torch.from_numpy(link_x).type(torch.float) for link_x in edge_val_x_npn]
test_x = [torch.from_numpy(link_x).type(torch.float) for link_x in test_x_npn]

edge_train_y = [torch.from_numpy(link_y).type(torch.long) for link_y in edge_train_y_npn]
edge_val_y = [torch.from_numpy(link_y).type(torch.long) for link_y in edge_val_y_npn]
test_y = [torch.from_numpy(link_y).type(torch.long) for link_y in test_y_npn]                        

# 1. Local Training Only

In [None]:
NUM_EPOCHS = 100
TRAIN_BATCH_SIZE = 128
VAL_BATCH_SIZE = 64

best_val_accuracy = 0
best_val_loss = 100
number_epoch_until_best = 1
training_time = 0
training_time_until_best = 0
average_time_per_epoch = 0

model = MLP().cuda()
loss_function = nn.CrossEntropyLoss(weight=torch.tensor([1.0,1.0])).cuda()

optimizer = optim.SGD(model.parameters(),lr=0.001, momentum=0.9)

In [None]:
# Number of epochs in edge training
NUM_EPOCHS_EDGE = 100
prd = []
trg = []
snr = []

for i in range(len(link_data)):
    
    # Train
    edge_train_dataloader=data.DataLoader(data.TensorDataset(edge_train_x[i],edge_train_y[i]),
                                     batch_size=TRAIN_BATCH_SIZE, shuffle=True, 
                                     num_workers=16, pin_memory=True)

    edge_val_dataloader=data.DataLoader(data.TensorDataset(edge_val_x[i],edge_val_y[i]),
                                   batch_size=VAL_BATCH_SIZE, shuffle=False, 
                                   num_workers=16, pin_memory=True)
    start_time = time.time()
    best_val_accuracy = 0
    best_val_loss = 100
    number_epoch_until_best = 1
    training_time = 0
    training_time_until_best = 0
    average_time_per_epoch = 0
    #teacher_outputs = fetch_teacher_outputs(model1, train_dataloader)
    for epoch_idx in range(NUM_EPOCHS_EDGE): 

        progress_edge_training_epoch = tqdm(
            edge_train_dataloader, 
            desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Training',
            miniters=1, ncols=88, position=0,
            leave=True, total=len(edge_train_dataloader), smoothing=.9)

        progress_edge_validation_epoch = tqdm(
            edge_val_dataloader, 
            desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Validation',
            miniters=1, ncols=88, position=0, 
            leave=True, total=len(edge_val_dataloader), smoothing=.9)
        
        train_loss = 0
        train_size = 0
        model.train()           
        for idx, (sentence, tags) in enumerate(progress_edge_training_epoch):
            sentence = sentence.cuda()
            tags = tags.cuda()
            optimizer.zero_grad()
            tag_scores = model(sentence)
            loss = loss_function(tag_scores, tags)
            loss.backward()
            optimizer.step()
            train_loss += loss * tags.size()[0]
            train_size += tags.size()[0]
                
        test_loss = 0
        test_size = 0    
        test_total_num_correct = 0
        predict = []
        target = []
        model.eval()
        with torch.no_grad():
            for idx, (sentence, tags) in enumerate(progress_edge_validation_epoch):
                sentence = sentence.cuda()
                tags = tags.cuda()
                tag_scores = model(sentence)
                loss = loss_function(tag_scores, tags)
                predict.append(tag_scores.argmax(dim=1).cpu().numpy())
                target.append(tags.cpu().numpy())        
                test_loss += loss * tags.size()[0]
                test_size += tags.size()[0]
                test_total_num_correct += torch.eq(tag_scores.argmax(dim=1), tags).sum()  
        val_accuracy = test_total_num_correct.item()/test_size
        val_loss = test_loss.item()/test_size
        predict = np.concatenate(predict, axis=0)
        target = np.concatenate(target, axis=0)
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            number_epoch_until_best_accuracy = epoch_idx
            training_time_until_best = training_time
            torch.save(model.state_dict(), 'mlp_local.pt')

        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            number_epoch_until_best_loss = epoch_idx
        
        print(f'epoch:{epoch_idx}, '
              f'training loss:{train_loss.item()/train_size: .5f}, '
              f'validation loss:{val_loss: .5f}, '
              f'accuracy: {val_accuracy: .4f}, '
              f'best accuracy: {best_val_accuracy: .4f}')

        if epoch_idx > number_epoch_until_best_accuracy+4 and epoch_idx > number_epoch_until_best_loss+4:
            break

    print(f'number of epochs: {number_epoch_until_best_accuracy}')
    model.load_state_dict(torch.load('mlp_local.pt'))
    model.cuda()
    
    # Test
    test_dataloader=data.DataLoader(data.TensorDataset(test_x[i],test_y[i]), 
                                       batch_size=VAL_BATCH_SIZE, shuffle=False,
                                       num_workers=16, pin_memory=True)

    progress_test_epoch = tqdm(
        test_dataloader, 
        desc=f'Link {i}, Test',
        miniters=1, ncols=88, position=0, 
        leave=True, total=len(test_dataloader), smoothing=.9)

    predict = []
    target = []
    snrss = []
    model.eval()
    with torch.no_grad():
        for idx, (sentence, tags) in enumerate(progress_test_epoch):
            sentence = sentence.cuda()
            tags = tags.cuda()
            tag_scores = model(sentence)
            predict.append(tag_scores.argmax(dim=1).cpu().numpy())
            target.append(tags.cpu().numpy())
            snrss.append(sentence.cpu().numpy()[:,0])

    prd.append(np.concatenate(predict, axis=0))
    trg.append(np.concatenate(target, axis=0))
    snr.append(np.concatenate(snrss, axis=0))

In [None]:
if Scr == 4:
    mlp_scr4_edge_res = {}

    mlp_scr4_edge_res['mlp_scr4_prd'] = prd
    mlp_scr4_edge_res['mlp_scr4_trg'] = trg
    mlp_scr4_edge_res['mlp_scr4_snr'] = [snrs * scr4_std + scr4_mean for snrs in snr]

    #Name of pickle file to be saved
    outname = 'mlp_scr4_nodes_local_res.pickle'                    

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr4_edge_res, file)

In [None]:
if Scr == 5:
    mlp_scr5_edge_res = {}

    mlp_scr5_edge_res['mlp_scr5_prd'] = prd
    mlp_scr5_edge_res['mlp_scr5_trg'] = trg
    mlp_scr5_edge_res['mlp_scr5_snr'] = [snrs * scr5_std + scr5_mean for snrs in snr]

    #Name of pickle file to be saved
    outname = 'mlp_scr5_nodes_local_res.pickle'

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr5_edge_res, file)

# 2. FedAvg

In [None]:
def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

In [None]:
num_total_links = len(link_data)
n_round = 10 # Number of Rounds
n_repeat = 10 # Maximum how many times a link can be repeated
n_links = int(num_total_links/(n_repeat+1)) # No of links per round

# Initialize matrix with out of bound value as the count is used to limit no of occurances

link_occurance_count = np.zeros(num_total_links)
randlist = num_total_links * np.ones((n_round,n_links))
link_list = np.arange(0,num_total_links)


for i in range(0,n_round):  
    
    mask = link_occurance_count < n_repeat
    # only select links that has not yet reached maximum occurance
    newlinks = np.random.choice(link_list[mask], size=n_links, replace=False) 
    
    randlist[i] = newlinks
    
    for j in newlinks:
        link_occurance_count[j] += 1                

In [None]:
# Initialize global model
global_model = MLP().cuda()

# Number of epochs in edge training
NUM_EPOCHS = 10
NUM_EPOCHS_EDGE = 10
TRAIN_BATCH_SIZE = 128
VAL_BATCH_SIZE = 64

state_dicts = []
global_weights = []
for comm_round in range(n_round):
    
    selected_edges = randlist[comm_round].astype(int)

    local_weights = []
    states = []
    
    for count,i in enumerate(selected_edges): # Train on each edge
        print('Round :', comm_round+1, ' of ', n_round,', Link :', count+1, ' of ', n_links)
        
        model = global_model
        model.cuda()

        loss_function = nn.CrossEntropyLoss(weight=torch.tensor([1.0,1.0])).cuda()

        optimizer = optim.SGD(model.parameters(),lr=0.001, momentum=0.9)
        
        edge_train_dataloader=data.DataLoader(data.TensorDataset(edge_train_x[i],edge_train_y[i]),
                                         batch_size=TRAIN_BATCH_SIZE, shuffle=False, 
                                         num_workers=16, pin_memory=True)

        edge_val_dataloader=data.DataLoader(data.TensorDataset(edge_val_x[i],edge_val_y[i]),
                                       batch_size=VAL_BATCH_SIZE, shuffle=False, 
                                       num_workers=16, pin_memory=True)


        start_time = time.time()
        best_val_accuracy = 0
        best_val_loss = 100
        number_epoch_until_best = 1
        training_time = 0
        training_time_until_best = 0
        average_time_per_epoch = 0        
        for epoch_idx in range(NUM_EPOCHS_EDGE): 

            progress_edge_training_epoch = tqdm(
                edge_train_dataloader, 
                desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Training',
                miniters=1, ncols=88, position=0,
                leave=True, total=len(edge_train_dataloader), smoothing=.9, disable = True)

            progress_edge_validation_epoch = tqdm(
                edge_val_dataloader, 
                desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Validation',
                miniters=1, ncols=88, position=0, 
                leave=True, total=len(edge_val_dataloader), smoothing=.9, disable = True)

            train_loss = 0
            train_size = 0
            model.train() 
            for idx, (sentence, tags) in enumerate(progress_edge_training_epoch):
                sentence = sentence.cuda()
                tags = tags.cuda()
                optimizer.zero_grad()
                tag_scores = model(sentence)
                loss = loss_function(tag_scores, tags)
                loss.backward()
                optimizer.step()
                train_loss += loss * tags.size()[0]
                train_size += tags.size()[0]

            val_loss = 0
            val_size = 0
            val_total_num_correct = 0
            predict = []
            target = []
            model.eval()
            with torch.no_grad():
                for idx, (sentence, tags) in enumerate(progress_edge_validation_epoch):
                    sentence = sentence.cuda()
                    tags = tags.cuda()
                    tag_scores = model(sentence)
                    loss = loss_function(tag_scores, tags)
                    predict.append(tag_scores.argmax(dim=1).cpu().numpy())
                    target.append(tags.cpu().numpy())        
                    val_loss += loss * tags.size()[0]
                    val_size += tags.size()[0]
                    val_total_num_correct += torch.eq(tag_scores.argmax(dim=1), tags).sum()  

            val_accuracy = val_total_num_correct.item()/val_size
            val_loss = val_loss.item()/val_size
            predict = np.concatenate(predict, axis=0)
            target = np.concatenate(target, axis=0)
            if val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
                number_epoch_until_best_accuracy = epoch_idx
                training_time_until_best = training_time
                # Save Model
                state = {'state_dict': model.state_dict()}
                if Scr == 4:
                    torch.save(model.state_dict(),'mlp_fd_scr4.pt')
                    torch.save(state,'mlp_fd_scr4.pth')
                if Scr == 5:
                    torch.save(model.state_dict(),'mlp_fd_scr5.pt')
                    torch.save(state,'mlp_fd_scr5.pth')
                    
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                number_epoch_until_best_loss = epoch_idx
        
            print(f'epoch:{epoch_idx}, '
                  f'training loss:{train_loss.item()/train_size: .5f}, '
                  f'validation loss:{val_loss: .5f}, '
                  f'accuracy: {val_accuracy: .4f}, '
                  f'best accuracy: {best_val_accuracy: .4f}')

            if epoch_idx > number_epoch_until_best_accuracy+4 and epoch_idx > number_epoch_until_best_loss+4:
                break

        if Scr == 4:     
            local_weights.append(copy.deepcopy(torch.load('mlp_fd_scr4.pt')))            
            states.append(torch.load('mlp_fd_scr4.pth'))
        if Scr == 5:     
            local_weights.append(copy.deepcopy(torch.load('mlp_fd_scr5.pt')))            
            states.append(torch.load('mlp_fd_scr5.pth'))
            
    global_weight = average_weights(local_weights)           
    global_weights.append(average_weights(local_weights))
    
    # Edge based weighting
    num_edges = n_links
    weights_1 = np.ones(num_edges)/num_edges

    # weight the state dicts

    for i in range(len(states)):
        for key in states[i]['state_dict'].keys():
            states[i]['state_dict'][key] *= weights_1[i]

    state_dict_fed = states[0]['state_dict'].copy() # initialize the dict with weighted first edge, then add others

    for i in range(1,len(states)):
        for key in states[i]['state_dict'].keys():
            state_dict_fed[key] = state_dict_fed[key] + states[i]['state_dict'][key]
    
    # Save comm round state dicts
    state_dicts.append(state_dict_fed)
    global_model.load_state_dict(global_weight)


In [None]:
if Scr == 4: 
    outname = 'fed_multiround_scr4_states.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(state_dicts, file)

    outname = 'fed_multiround_scr4_weights.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(global_weights, file)

In [None]:
if Scr == 5: 
    outname = 'fed_multiround_scr5_states.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(state_dicts, file)

    outname = 'fed_multiround_scr5_weights.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(global_weights, file)

In [None]:
if Scr == 4: 
    with open('fed_multiround_scr4_states.pickle', 'rb') as file:
        state_dicts = pickle.load(file)
        
if Scr == 5: 
    with open('fed_multiround_scr5_states.pickle', 'rb') as file:
        state_dicts = pickle.load(file)

In [None]:
TEST_BATCH_SIZE = 64

prds = []
trgs = []
snrs = []


for j in range(n_round):

    model = global_model
    model.load_state_dict(state_dicts[j]) 
    model.cuda()  

    prd = []
    trg = []
    snr = []

    for i in range(num_total_links): # Train on each edge   
        print('Round :', j+1, ' of ', n_round,', Link :', i+1, ' of ', num_total_links)
        
        test_dataloader=data.DataLoader(data.TensorDataset(test_x[i],test_y[i]), 
                                           batch_size=TEST_BATCH_SIZE, shuffle=False,
                                           num_workers=16, pin_memory=True)

        progress_test_epoch = tqdm(
            test_dataloader, 
            desc=f'Link {i}, test',
            miniters=1, ncols=88, position=0, 
            leave=True, total=len(test_dataloader), smoothing=.9,disable=True)

        predict = []
        target = []
        snrss = []
        model.eval()
        with torch.no_grad():
            for idx, (sentence, tags) in enumerate(progress_test_epoch):
                sentence = sentence.cuda()
                tags = tags.cuda()
                tag_scores = model(sentence)
                predict.append(tag_scores.argmax(dim=1).cpu().numpy())
                target.append(tags.cpu().numpy())
                snrss.append(sentence.cpu().numpy()[:,0])

        prd.append(np.concatenate(predict, axis=0))
        trg.append(np.concatenate(target, axis=0))
        snr.append(np.concatenate(snrss, axis=0))
        
    prds.append([prd])
    trgs.append([trg])
    if Scr == 4:
        snrs.append([[snrs * scr4_std + scr4_mean for snrs in snr]])
    if Scr == 5:
        snrs.append([[snrs * scr5_std + scr5_mean for snrs in snr]])

In [None]:
if Scr == 4: 
    mlp_scr4_edge_fed_res_multiround = {}

    mlp_scr4_edge_fed_res_multiround['mlp_scr4_prd'] = prds
    mlp_scr4_edge_fed_res_multiround['mlp_scr4_trg'] = trgs
    mlp_scr4_edge_fed_res_multiround['mlp_scr4_snr'] = snrs

    outname = 'mlp_scr4_edge_fed_res_multiround.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr4_edge_fed_res_multiround, file)

In [None]:
if Scr == 5: 
    mlp_scr5_edge_fed_res_multiround = {}

    mlp_scr5_edge_fed_res_multiround['mlp_scr5_prd'] = prds
    mlp_scr5_edge_fed_res_multiround['mlp_scr5_trg'] = trgs
    mlp_scr5_edge_fed_res_multiround['mlp_scr5_snr'] = snrs

    outname = 'mlp_scr5_edge_fed_res_multiround.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr5_edge_fed_res_multiround, file)

# 3. DP-Fed

In [10]:
def copy_params(self, state_dict, coefficient_transfer=100):
    own_state = self.state_dict()
    for name, param in state_dict.items():
        if name in own_state:
            own_state[name].copy_(param)

In [None]:
num_total_links = len(link_data)
n_round = 10 # Number of Rounds
n_repeat = 10 # Maximum how many times a link can be repeated
n_links = int(num_total_links/(n_repeat+1)) # No of links per round

# Initialize matrix with out of bound value as the count is used to limit no of occurances

link_occurance_count = np.zeros(num_total_links)
randlist = num_total_links * np.ones((n_round,n_links))
link_list = np.arange(0,num_total_links)


for i in range(0,n_round):  
    
    mask = link_occurance_count < n_repeat
    # only select links that has not yet reached maximum occurance
    newlinks = np.random.choice(link_list[mask], size=n_links, replace=False) 
    
    randlist[i] = newlinks
    
    for j in newlinks:
        link_occurance_count[j] += 1                

In [None]:
# Initialize global model
global_model = MLP().cuda()


# Number of epochs in edge training
NUM_EPOCHS = 10
NUM_EPOCHS_EDGE = 10
TRAIN_BATCH_SIZE = 128
VAL_BATCH_SIZE = 64
MAX_GRAD_NORM = 1.2
NOISE_MULTIPLIER = 1
sigma = 0.01
state_dicts = []
global_weights = []
for comm_round in range(n_round):
    
    selected_edges = randlist[comm_round].astype(int)

    local_weights = []
    states = []

    for count,i in enumerate(selected_edges): # Train on each edge
        print('Round :', comm_round+1, ' of ', n_round,', Link :', count+1, ' of ', n_links)
        
        model = MLP().cuda()
        model.load_state_dict(copy.deepcopy(global_model.state_dict()))
        optimizer = optim.SGD(model.parameters(),lr=0.001, momentum=0.9)
        privacy_engine = PrivacyEngine(
        model,
        batch_size=TRAIN_BATCH_SIZE,
        sample_size=len(edge_train_x[i]),
        alphas=range(10,100),
        noise_multiplier=NOISE_MULTIPLIER,
        max_grad_norm=MAX_GRAD_NORM,
        )
        privacy_engine.attach(optimizer)
        loss_function = nn.CrossEntropyLoss(weight=torch.tensor([1.0,1.0])).cuda()


        edge_train_dataloader=data.DataLoader(data.TensorDataset(edge_train_x[i],edge_train_y[i]),
                                         batch_size=TRAIN_BATCH_SIZE, shuffle=False, 
                                         num_workers=16, pin_memory=True)

        edge_val_dataloader=data.DataLoader(data.TensorDataset(edge_val_x[i],edge_val_y[i]),
                                       batch_size=VAL_BATCH_SIZE, shuffle=False, 
                                       num_workers=16, pin_memory=True)


        start_time = time.time()
        best_val_accuracy = 0
        best_val_loss = 100
        number_epoch_until_best = 1
        training_time = 0
        training_time_until_best = 0
        average_time_per_epoch = 0        
        for epoch_idx in range(NUM_EPOCHS_EDGE): 

            progress_edge_training_epoch = tqdm(
                edge_train_dataloader, 
                desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Training',
                miniters=1, ncols=88, position=0,
                leave=True, total=len(edge_train_dataloader), smoothing=.9, disable = True)

            progress_edge_validation_epoch = tqdm(
                edge_val_dataloader, 
                desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Validation',
                miniters=1, ncols=88, position=0, 
                leave=True, total=len(edge_val_dataloader), smoothing=.9, disable = True)

            train_loss = 0
            train_size = 0
            model.train() 
            for idx, (sentence, tags) in enumerate(progress_edge_training_epoch):
                sentence = sentence.cuda()
                tags = tags.cuda()
                optimizer.zero_grad()
                tag_scores = model(sentence)
                loss =loss_function(tag_scores,tags)
                loss.backward()
                optimizer.step()
                model.zero_grad()
                train_loss += loss * tags.size()[0]
                train_size += tags.size()[0]

            val_loss = 0
            val_size = 0
            val_total_num_correct = 0
            predict = []
            target = []
            model.eval()
            with torch.no_grad():
                for idx, (sentence, tags) in enumerate(progress_edge_validation_epoch):
                    sentence = sentence.cuda()
                    tags = tags.cuda()
                    tag_scores = model(sentence)
                    loss =loss_function(tag_scores,tags)
                    predict.append(tag_scores.argmax(dim=1).cpu().numpy())
                    target.append(tags.cpu().numpy())        
                    val_loss += loss * tags.size()[0]
                    val_size += tags.size()[0]
                    val_total_num_correct += torch.eq(tag_scores.argmax(dim=1), tags).sum()  

            val_accuracy = val_total_num_correct.item()/val_size
            val_loss = val_loss.item()/val_size
            predict = np.concatenate(predict, axis=0)
            target = np.concatenate(target, axis=0)
            if val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
                number_epoch_until_best_accuracy = epoch_idx
                training_time_until_best = training_time
                #torch.save(model.state_dict(), 'mlp_fd.pt')
                # Save Model
                state = {'state_dict': model.state_dict()}
                if Scr == 4:
                    torch.save(model.state_dict(),'mlp_fdd_scr4.pt')
                    torch.save(state,'mlp_fdd_scr4.pth')
                if Scr == 5:
                    torch.save(model.state_dict(),'mlp_fdd_scr5.pt')
                    torch.save(state,'mlp_fdd_scr5.pth')

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                number_epoch_until_best_loss = epoch_idx
        
            print(f'epoch:{epoch_idx}, '
                  f'training loss:{train_loss.item()/train_size: .5f}, '
                  f'validation loss:{val_loss: .5f}, '
                  f'accuracy: {val_accuracy: .4f}, '
                  f'best accuracy: {best_val_accuracy: .4f}')

            if epoch_idx > number_epoch_until_best_accuracy+4 and epoch_idx > number_epoch_until_best_loss+4:
                break

             
        if Scr == 4:     
            local_weights.append(copy.deepcopy(torch.load('mlp_fdd_scr4.pt')))            
            states.append(torch.load('mlp_fdd_scr4.pth'))
        if Scr == 5:     
            local_weights.append(copy.deepcopy(torch.load('mlp_fdd_scr5.pt')))            
            states.append(torch.load('mlp_fdd_scr5.pth'))
    
    global_weight = average_weights(local_weights)           
    global_weights.append(average_weights(local_weights))
    
    # Edge based weighting
    num_edges = n_links
    weights_1 = np.ones(num_edges)/num_edges

    # weight the state dicts

    for i in range(len(states)):
        for key in states[i]['state_dict'].keys():
            states[i]['state_dict'][key] *= weights_1[i]

    state_dict_fed = states[0]['state_dict'].copy() # initialize the dict with weighted first edge, then add others

    for i in range(1,len(states)):
        for key in states[i]['state_dict'].keys():
            state_dict_fed[key] = state_dict_fed[key] + states[i]['state_dict'][key]
    
    # Save comm round state dicts
    state_dicts.append(state_dict_fed)
    global_model.load_state_dict(global_weight)


In [None]:
if Scr == 4: 
    outname = 'fedd_multiround_scr4_states.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(state_dicts, file)

    outname = 'fedd_multiround_scr4_weights.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(global_weights, file)

In [None]:
if Scr == 5: 
    outname = 'fedd_multiround_scr5_states.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(state_dicts, file)

    outname = 'fedd_multiround_scr5_weights.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(global_weights, file)

In [None]:
if Scr == 4:
    with open('fedd_multiround_scr4_states.pickle', 'rb') as file:
        state_dicts = pickle.load(file)

if Scr == 5:
    with open('fedd_multiround_scr5_states.pickle', 'rb') as file:
        state_dicts = pickle.load(file)

In [None]:
TEST_BATCH_SIZE = 64

prds = []
trgs = []
snrs = []


for j in range(n_round):

    model = global_model
    model.load_state_dict(state_dicts[j]) 
    model.cuda()  

    prd = []
    trg = []
    snr = []

    for i in range(num_total_links): # Train on each edge   
        print('Round :', j+1, ' of ', n_round,', Link :', i+1, ' of ', num_total_links)
        
        test_dataloader=data.DataLoader(data.TensorDataset(test_x[i],test_y[i]), 
                                           batch_size=TEST_BATCH_SIZE, shuffle=False,
                                           num_workers=16, pin_memory=True)

        progress_test_epoch = tqdm(
            test_dataloader, 
            desc=f'Link {i}, test',
            miniters=1, ncols=88, position=0, 
            leave=True, total=len(test_dataloader), smoothing=.9,disable=True)

        predict = []
        target = []
        snrss = []
        model.eval()
        with torch.no_grad():
            for idx, (sentence, tags) in enumerate(progress_test_epoch):
                sentence = sentence.cuda()
                tags = tags.cuda()
                tag_scores = model(sentence)
                predict.append(tag_scores.argmax(dim=1).cpu().numpy())
                target.append(tags.cpu().numpy())
                snrss.append(sentence.cpu().numpy()[:,0])

        prd.append(np.concatenate(predict, axis=0))
        trg.append(np.concatenate(target, axis=0))
        snr.append(np.concatenate(snrss, axis=0))
        
    prds.append([prd])
    trgs.append([trg])
    if Scr == 4:
        snrs.append([[snrs * scr4_std + scr4_mean for snrs in snr]])
    if Scr == 5:
        snrs.append([[snrs * scr5_std + scr5_mean for snrs in snr]])

In [None]:
if Scr == 4:

    mlp_scr4_edge_fed_res_multiround = {}

    mlp_scr4_edge_fed_res_multiround['mlp_scr4_prd'] = prds
    mlp_scr4_edge_fed_res_multiround['mlp_scr4_trg'] = trgs
    mlp_scr4_edge_fed_res_multiround['mlp_scr4_snr'] = snrs

    outname = 'mlp_scr4_edge_dpfed_res_multiround.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr4_edge_fed_res_multiround, file)

In [None]:
if Scr == 5:

    mlp_scr5_edge_fed_res_multiround = {}

    mlp_scr5_edge_fed_res_multiround['mlp_scr5_prd'] = prds
    mlp_scr5_edge_fed_res_multiround['mlp_scr5_trg'] = trgs
    mlp_scr5_edge_fed_res_multiround['mlp_scr5_snr'] = snrs

    outname = 'mlp_scr5_edge_dpfed_res_multiround.pickle'         

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr5_edge_fed_res_multiround, file)

# 4. KD-Scr

In [None]:
# Define kd loss func

def loss_kd(outputs, labels, teacher_outputs, alpha, T):
    """
    loss function for Knowledge Distillation (KD)
    """
    loss_CE = F.cross_entropy(outputs, labels)
    D_KL = F.kl_div(F.log_softmax(outputs/T, dim=1),F.softmax(teacher_outputs/T, dim=1),reduction='batchmean') * (T * T)
    KD_loss =  (1. - alpha)*loss_CE + alpha*D_KL
    return KD_loss

In [9]:
# Central Training for all links - convert scr edge train val to central
train_x = torch.cat(tuple(torch.from_numpy(link_x).type(torch.float) 
                          for link_x in edge_train_x_npn),dim=0)
val_x = torch.cat(tuple(torch.from_numpy(link_x).type(torch.float) 
                          for link_x in edge_val_x_npn),dim=0)


train_y = torch.cat(tuple(torch.from_numpy(link_y).type(torch.long) 
                          for link_y in edge_train_y_npn),dim=0)
val_y = torch.cat(tuple(torch.from_numpy(link_y).type(torch.long) 
                        for link_y in edge_val_y_npn),dim=0)
train_x = train_x.view(-1,1,3)
val_x = val_x.view(-1,1,3)

In [None]:
NUM_EPOCHS = 100
TRAIN_BATCH_SIZE = 1024
VAL_BATCH_SIZE = 128

best_val_accuracy = 0
best_val_loss = 100
number_epoch_until_best = 1
training_time = 0
training_time_until_best = 0
average_time_per_epoch = 0

model = CBSDNN().cuda()
loss_function = nn.CrossEntropyLoss(weight=torch.tensor([1.0,1.0])).cuda()

optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
train_dataloader=data.DataLoader(data.TensorDataset(train_x,train_y),batch_size=TRAIN_BATCH_SIZE, 
                                   shuffle=True, num_workers=16, pin_memory=True)

val_dataloader=data.DataLoader(data.TensorDataset(val_x,val_y),batch_size=VAL_BATCH_SIZE, 
                                 shuffle=False, num_workers=16, pin_memory=True)

start_time = time.time()
model.train()

for epoch_idx in range(NUM_EPOCHS):  

    progress_training_epoch = tqdm(
        train_dataloader, 
        desc=f'Epoch {epoch_idx}/{NUM_EPOCHS}, Training',
        miniters=1, ncols=88, position=0,
        leave=True, total=len(train_dataloader), smoothing=.9)

    progress_validation_epoch = tqdm(
        val_dataloader, 
        desc=f'Epoch {epoch_idx}/{NUM_EPOCHS}, Validation',
        miniters=1, ncols=88, position=0, 
        leave=True, total=len(val_dataloader), smoothing=.9)

    train_loss = 0
    train_size = 0

    for idx, (input, target) in enumerate(progress_training_epoch):
        input = input.cuda()
        target = target.cuda()
        model.zero_grad()
        predict = model(input)
        loss = loss_function(predict, target)
        loss.backward()
        optimizer.step()
        train_loss += loss * target.size()[0]
        train_size += target.size()[0]
    
    training_time += time.time() - start_time

    test_loss = 0
    test_size = 0
    test_total_num_correct = 0

    predict = []
    target = []
    model.eval()
    with torch.no_grad():
        for idx, (batch_input, batch_target) in enumerate(progress_validation_epoch):
            batch_input = batch_input.cuda()
            batch_target = batch_target.cuda()
            batch_predict = model(batch_input)
            loss = loss_function(batch_predict, batch_target)
            predict.append(batch_predict.argmax(dim=1).cpu().numpy())
            target.append(batch_target.cpu().numpy())        
            test_loss += loss * batch_target.size()[0]
            test_size += batch_target.size()[0]
            test_total_num_correct += torch.eq(batch_predict.argmax(dim=1), batch_target).sum()  
    val_accuracy = test_total_num_correct.item()/test_size
    val_loss = test_loss.item()/test_size
    predict = np.concatenate(predict, axis=0)
    target = np.concatenate(target, axis=0)
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        number_epoch_until_best_accuracy = epoch_idx
        training_time_until_best = training_time
        # Save Model
        if Scr == 4:
            torch.save(model.state_dict(), 'cbsdnn_scr4_nodes.pt')        
            state = {
            'epoch': NUM_EPOCHS,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            }
            savepath='cbsdnn_scr4_nodes_checkpoint.pth'
            torch.save(state,savepath)            
        if Scr == 5:
            torch.save(model.state_dict(), 'cbsdnn_scr5_nodes.pt')        
            state = {
            'epoch': NUM_EPOCHS,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            }
            savepath='cbsdnn_scr5_nodes_checkpoint.pth'
            torch.save(state,savepath)
            
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        number_epoch_until_best_loss = epoch_idx
        
    print(f'epoch:{epoch_idx}, '
          f'training loss:{train_loss.item()/train_size: .5f}, '
          f'validation loss:{val_loss: .5f}, '
          f'accuracy: {val_accuracy: .4f}, '
          f'best accuracy: {best_val_accuracy: .4f}')

    if epoch_idx > number_epoch_until_best_accuracy+4 and epoch_idx > number_epoch_until_best_loss+4:
        break


print(f'total training time: {training_time_until_best}')
print(f'number of epochs: {number_epoch_until_best_accuracy}')
print(f'time per epoch: {(training_time_until_best/number_epoch_until_best_accuracy): .2f}')    
    

print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

In [None]:
if Scr == 4:
    checkpoint = torch.load('cbsdnn_scr4_nodes_checkpoint.pth')
if Scr == 5:
    checkpoint = torch.load('cbsdnn_scr5_nodes_checkpoint.pth')

model1.load_state_dict(checkpoint['state_dict'])

In [None]:
# Number of epochs in edge training
NUM_EPOCHS_EDGE = 100.
prd = []
trg = []
snr = []

# Parameters for kd loss func
if Scr == 4:
    alpha = 0.4
    T = 10.0
if Scr == 5:
    alpha = 0.5
    T = 10.0
    
for i in range(len(link_data)):
    
    # Train
    edge_train_dataloader=data.DataLoader(data.TensorDataset(edge_train_x[i],edge_train_y[i]),
                                     batch_size=TRAIN_BATCH_SIZE, shuffle=True, 
                                     num_workers=16, pin_memory=True)

    edge_val_dataloader=data.DataLoader(data.TensorDataset(edge_val_x[i],edge_val_y[i]),
                                   batch_size=VAL_BATCH_SIZE, shuffle=False, 
                                   num_workers=16, pin_memory=True)
    start_time = time.time()
    best_val_accuracy = 0
    best_val_loss = 100
    number_epoch_until_best = 1
    training_time = 0
    training_time_until_best = 0
    average_time_per_epoch = 0
    for epoch_idx in range(NUM_EPOCHS_EDGE): 

        progress_edge_training_epoch = tqdm(
            edge_train_dataloader, 
            desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Training',
            miniters=1, ncols=88, position=0,
            leave=True, total=len(edge_train_dataloader), smoothing=.9,disable=True)

        progress_edge_validation_epoch = tqdm(
            edge_val_dataloader, 
            desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Validation',
            miniters=1, ncols=88, position=0, 
            leave=True, total=len(edge_val_dataloader), smoothing=.9,disable=True)
        
        train_loss = 0
        train_size = 0
        model1.eval()
        model.train()           
        for idx, (sentence, tags) in enumerate(progress_edge_training_epoch):
            sentence = sentence.cuda()
            tags = tags.cuda()
            optimizer.zero_grad()
            tag_scores = model(sentence)
            with torch.no_grad():
                teacher_outputs = model1(sentence.view(-1,1,3))
            loss =loss_kd(tag_scores,tags,teacher_outputs,alpha,T)
            loss.backward()
            optimizer.step()
            train_loss += loss * tags.size()[0]
            train_size += tags.size()[0]
                
        test_loss = 0
        test_size = 0    
        test_total_num_correct = 0
        predict = []
        target = []
        model.eval()
        model1.eval()
        with torch.no_grad():
            for idx, (sentence, tags) in enumerate(progress_edge_validation_epoch):
                sentence = sentence.cuda()
                tags = tags.cuda()
                tag_scores = model(sentence)
                teacher_outputs = model1(sentence.view(-1,1,3))
                loss =loss_kd(tag_scores,tags,teacher_outputs,alpha,T)
                predict.append(tag_scores.argmax(dim=1).cpu().numpy())
                target.append(tags.cpu().numpy())        
                test_loss += loss * tags.size()[0]
                test_size += tags.size()[0]
                test_total_num_correct += torch.eq(tag_scores.argmax(dim=1), tags).sum()  
        val_accuracy = test_total_num_correct.item()/test_size
        val_loss = test_loss.item()/test_size
        predict = np.concatenate(predict, axis=0)
        target = np.concatenate(target, axis=0)
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            number_epoch_until_best_accuracy = epoch_idx
            training_time_until_best = training_time
            if Scr == 4:
                torch.save(model.state_dict(), 'mlp_scr4_kd_nodes.pt')
            if Scr == 5:
                torch.save(model.state_dict(), 'mlp_scr5_kd_nodes.pt')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            number_epoch_until_best_loss = epoch_idx
        
        print(f'epoch:{epoch_idx}, '
              f'training loss:{train_loss.item()/train_size: .5f}, '
              f'validation loss:{val_loss: .5f}, '
              f'accuracy: {val_accuracy: .4f}, '
              f'best accuracy: {best_val_accuracy: .4f}')

        if epoch_idx > number_epoch_until_best_accuracy+4 and epoch_idx > number_epoch_until_best_loss+4:
            break

    print(f'number of epochs: {number_epoch_until_best_accuracy}')
    if Scr == 4:
        model.load_state_dict(torch.load('mlp_scr4_kd_nodes.pt'))
    if Scr == 5:
        model.load_state_dict(torch.load('mlp_scr5_kd_nodes.pt'))
    model.cuda()
    
    # Test

    test_dataloader=data.DataLoader(data.TensorDataset(test_x[i],test_y[i]), 
                                       batch_size=VAL_BATCH_SIZE, shuffle=False,
                                       num_workers=16, pin_memory=True)

    progress_test_epoch = tqdm(
        test_dataloader, 
        desc=f'Link {i}, Test',
        miniters=1, ncols=88, position=0, 
        leave=True, total=len(test_dataloader), smoothing=.9)

    predict = []
    target = []
    snrss = []
    model.eval()
    with torch.no_grad():
        for idx, (sentence, tags) in enumerate(progress_test_epoch):
            sentence = sentence.cuda()
            tags = tags.cuda()
            tag_scores = model(sentence)
            predict.append(tag_scores.argmax(dim=1).cpu().numpy())
            target.append(tags.cpu().numpy())
            snrss.append(sentence.cpu().numpy()[:,0])

    prd.append(np.concatenate(predict, axis=0))
    trg.append(np.concatenate(target, axis=0))
    snr.append(np.concatenate(snrss, axis=0))

In [None]:
if Scr ==4:
    mlp_scr4_edge_res = {}

    mlp_scr4_edge_res['mlp_scr4_prd'] = prd
    mlp_scr4_edge_res['mlp_scr4_trg'] = trg
    mlp_scr4_edge_res['mlp_scr4_snr'] = [snrs * scr4_std + scr4_mean for snrs in snr]

    outname = 'mlp_scr4_nodes_kd_scr_res.pickle'                 

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr4_edge_res, file)

In [None]:
if Scr ==5:
    mlp_scr5_edge_res = {}

    mlp_scr5_edge_res['mlp_scr5_prd'] = prd
    mlp_scr5_edge_res['mlp_scr5_trg'] = trg
    mlp_scr5_edge_res['mlp_scr5_snr'] = [snrs * scr5_std + scr5_mean for snrs in snr]

    outname = 'mlp_scr5_nodes_kd_scr_res.pickle'                 

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr5_edge_res, file)

# 5. KD-Smote

In [None]:
# Load Data
if Scr == 4:
    with open('scr4_edge_smote_nodes_train.pickle', 'rb') as file:
        link_data_smote = pickle.load(file)
if Scr == 4:
    with open('scr5_edge_smote_nodes_train.pickle', 'rb') as file:
        link_data_smote = pickle.load(file)

In [None]:
cols = torch.LongTensor([0,1,3]) # Select Columns

link_data_smote = [(link_datas[0][:,cols], link_datas[1]) for link_datas in link_data]

In [None]:
# Divide into Train, Edge Train and Test Sets (With Validation)

datadist = [0.8,0.2]                        

train_x_npn = []
train_y_npn = []
val_x_npn = []
val_y_npn = []

for i in range(len(link_data_smote)):
    datalen = len(link_data_smote[i][1])
    trainlen = int(datalen*sum(datadist[:1]))
    vallen = int(datalen*sum(datadist[:2]))

    train_x_npn.append(link_data_smote[i][0][0:trainlen].numpy())    
    train_y_npn.append(link_data_smote[i][1][0:trainlen].numpy())
    
    val_x_npn.append(link_data_smote[i][0][trainlen:vallen].numpy())    
    val_y_npn.append(link_data_smote[i][1][trainlen:vallen].numpy())
   

In [None]:
# Central Training for all links
train_x = torch.cat(tuple(torch.from_numpy(link_x).type(torch.float) 
                          for link_x in train_x_npn),dim=0)
val_x = torch.cat(tuple(torch.from_numpy(link_x).type(torch.float) 
                          for link_x in val_x_npn),dim=0)


train_y = torch.cat(tuple(torch.from_numpy(link_y).type(torch.long) 
                          for link_y in train_y_npn),dim=0)
val_y = torch.cat(tuple(torch.from_numpy(link_y).type(torch.long) 
                        for link_y in val_y_npn),dim=0)
train_x = train_x.view(-1,1,3)
val_x = val_x.view(-1,1,3)

In [None]:
NUM_EPOCHS = 100
TRAIN_BATCH_SIZE = 1024
VAL_BATCH_SIZE = 128

best_val_accuracy = 0
best_val_loss = 100
number_epoch_until_best = 1
training_time = 0
training_time_until_best = 0
average_time_per_epoch = 0

model = CBSDNN().cuda()
loss_function = nn.CrossEntropyLoss(weight=torch.tensor([1.0,1.0])).cuda()

optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
train_dataloader=data.DataLoader(data.TensorDataset(train_x,train_y),batch_size=TRAIN_BATCH_SIZE, 
                                   shuffle=True, num_workers=16, pin_memory=True)

val_dataloader=data.DataLoader(data.TensorDataset(val_x,val_y),batch_size=VAL_BATCH_SIZE, 
                                 shuffle=False, num_workers=16, pin_memory=True)

start_time = time.time()
model.train()

for epoch_idx in range(NUM_EPOCHS):  

    progress_training_epoch = tqdm(
        train_dataloader, 
        desc=f'Epoch {epoch_idx}/{NUM_EPOCHS}, Training',
        miniters=1, ncols=88, position=0,
        leave=True, total=len(train_dataloader), smoothing=.9)

    progress_validation_epoch = tqdm(
        val_dataloader, 
        desc=f'Epoch {epoch_idx}/{NUM_EPOCHS}, Validation',
        miniters=1, ncols=88, position=0, 
        leave=True, total=len(val_dataloader), smoothing=.9)

    train_loss = 0
    train_size = 0

    for idx, (input, target) in enumerate(progress_training_epoch):
        input = input.cuda()
        target = target.cuda()
        model.zero_grad()
        predict = model(input)
        loss = loss_function(predict, target)
        loss.backward()
        optimizer.step()
        train_loss += loss * target.size()[0]
        train_size += target.size()[0]
    
    training_time += time.time() - start_time

    test_loss = 0
    test_size = 0
    test_total_num_correct = 0

    predict = []
    target = []
    model.eval()
    with torch.no_grad():
        for idx, (batch_input, batch_target) in enumerate(progress_validation_epoch):
            batch_input = batch_input.cuda()
            batch_target = batch_target.cuda()
            batch_predict = model(batch_input)
            loss = loss_function(batch_predict, batch_target)
            predict.append(batch_predict.argmax(dim=1).cpu().numpy())
            target.append(batch_target.cpu().numpy())        
            test_loss += loss * batch_target.size()[0]
            test_size += batch_target.size()[0]
            test_total_num_correct += torch.eq(batch_predict.argmax(dim=1), batch_target).sum()  
    val_accuracy = test_total_num_correct.item()/test_size
    val_loss = test_loss.item()/test_size
    predict = np.concatenate(predict, axis=0)
    target = np.concatenate(target, axis=0)
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        number_epoch_until_best_accuracy = epoch_idx
        training_time_until_best = training_time
        # Save Model
        if Scr == 4:
            torch.save(model.state_dict(), 'cbsdnn_smote4_nodes.pt')        
            state = {
            'epoch': NUM_EPOCHS,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            }
            savepath='cbsdnn_smote4_nodes_checkpoint.pth'
            torch.save(state,savepath)            
        if Scr == 5:
            torch.save(model.state_dict(), 'cbsdnn_smote5_nodes.pt')        
            state = {
            'epoch': NUM_EPOCHS,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            }
            savepath='cbsdnn_smote5_nodes_checkpoint.pth'
            torch.save(state,savepath)
            
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        number_epoch_until_best_loss = epoch_idx
        
    print(f'epoch:{epoch_idx}, '
          f'training loss:{train_loss.item()/train_size: .5f}, '
          f'validation loss:{val_loss: .5f}, '
          f'accuracy: {val_accuracy: .4f}, '
          f'best accuracy: {best_val_accuracy: .4f}')

    if epoch_idx > number_epoch_until_best_accuracy+4 and epoch_idx > number_epoch_until_best_loss+4:
        break


print(f'total training time: {training_time_until_best}')
print(f'number of epochs: {number_epoch_until_best_accuracy}')
print(f'time per epoch: {(training_time_until_best/number_epoch_until_best_accuracy): .2f}')    
    

print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

In [None]:
if Scr == 4:
    checkpoint = torch.load('cbsdnn_smote4_nodes_checkpoint.pth')
if Scr == 5:
    checkpoint = torch.load('cbsdnn_smote5_nodes_checkpoint.pth')

model1.load_state_dict(checkpoint['state_dict'])

In [None]:
# Number of epochs in edge training
NUM_EPOCHS_EDGE = 100.
prd = []
trg = []
snr = []

# Parameters for kd loss func
if Scr == 4:
    alpha = 0.4
    T = 10.0
if Scr == 5:
    alpha = 0.5
    T = 10.0
    
for i in range(len(link_data)):
    
    # Train
    edge_train_dataloader=data.DataLoader(data.TensorDataset(edge_train_x[i],edge_train_y[i]),
                                     batch_size=TRAIN_BATCH_SIZE, shuffle=True, 
                                     num_workers=16, pin_memory=True)

    edge_val_dataloader=data.DataLoader(data.TensorDataset(edge_val_x[i],edge_val_y[i]),
                                   batch_size=VAL_BATCH_SIZE, shuffle=False, 
                                   num_workers=16, pin_memory=True)
    start_time = time.time()
    best_val_accuracy = 0
    best_val_loss = 100
    number_epoch_until_best = 1
    training_time = 0
    training_time_until_best = 0
    average_time_per_epoch = 0
    for epoch_idx in range(NUM_EPOCHS_EDGE): 

        progress_edge_training_epoch = tqdm(
            edge_train_dataloader, 
            desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Training',
            miniters=1, ncols=88, position=0,
            leave=True, total=len(edge_train_dataloader), smoothing=.9,disable=True)

        progress_edge_validation_epoch = tqdm(
            edge_val_dataloader, 
            desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Validation',
            miniters=1, ncols=88, position=0, 
            leave=True, total=len(edge_val_dataloader), smoothing=.9,disable=True)
        
        train_loss = 0
        train_size = 0
        model1.eval()
        model.train()           
        for idx, (sentence, tags) in enumerate(progress_edge_training_epoch):
            sentence = sentence.cuda()
            tags = tags.cuda()
            optimizer.zero_grad()
            tag_scores = model(sentence)
            with torch.no_grad():
                teacher_outputs = model1(sentence.view(-1,1,3))
            loss =loss_kd(tag_scores,tags,teacher_outputs,alpha,T)
            loss.backward()
            optimizer.step()
            train_loss += loss * tags.size()[0]
            train_size += tags.size()[0]
                
        test_loss = 0
        test_size = 0    
        test_total_num_correct = 0
        predict = []
        target = []
        model.eval()
        model1.eval()
        with torch.no_grad():
            for idx, (sentence, tags) in enumerate(progress_edge_validation_epoch):
                sentence = sentence.cuda()
                tags = tags.cuda()
                tag_scores = model(sentence)
                teacher_outputs = model1(sentence.view(-1,1,3))
                loss =loss_kd(tag_scores,tags,teacher_outputs,alpha,T)
                predict.append(tag_scores.argmax(dim=1).cpu().numpy())
                target.append(tags.cpu().numpy())        
                test_loss += loss * tags.size()[0]
                test_size += tags.size()[0]
                test_total_num_correct += torch.eq(tag_scores.argmax(dim=1), tags).sum()  
        val_accuracy = test_total_num_correct.item()/test_size
        val_loss = test_loss.item()/test_size
        predict = np.concatenate(predict, axis=0)
        target = np.concatenate(target, axis=0)
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            number_epoch_until_best_accuracy = epoch_idx
            training_time_until_best = training_time
            if Scr == 4:
                torch.save(model.state_dict(), 'mlp_scr4_kd_nodes_2.pt')
            if Scr == 5:
                torch.save(model.state_dict(), 'mlp_scr5_kd_nodes_2.pt')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            number_epoch_until_best_loss = epoch_idx
        
        print(f'epoch:{epoch_idx}, '
              f'training loss:{train_loss.item()/train_size: .5f}, '
              f'validation loss:{val_loss: .5f}, '
              f'accuracy: {val_accuracy: .4f}, '
              f'best accuracy: {best_val_accuracy: .4f}')

        if epoch_idx > number_epoch_until_best_accuracy+4 and epoch_idx > number_epoch_until_best_loss+4:
            break

    print(f'number of epochs: {number_epoch_until_best_accuracy}')
    if Scr == 4:
        model.load_state_dict(torch.load('mlp_scr4_kd_nodes_2.pt'))
    if Scr == 5:
        model.load_state_dict(torch.load('mlp_scr5_kd_nodes_2.pt'))
    model.cuda()
    
    # Test

    test_dataloader=data.DataLoader(data.TensorDataset(test_x[i],test_y[i]), 
                                       batch_size=VAL_BATCH_SIZE, shuffle=False,
                                       num_workers=16, pin_memory=True)

    progress_test_epoch = tqdm(
        test_dataloader, 
        desc=f'Link {i}, Test',
        miniters=1, ncols=88, position=0, 
        leave=True, total=len(test_dataloader), smoothing=.9)

    predict = []
    target = []
    snrss = []
    model.eval()
    with torch.no_grad():
        for idx, (sentence, tags) in enumerate(progress_test_epoch):
            sentence = sentence.cuda()
            tags = tags.cuda()
            tag_scores = model(sentence)
            predict.append(tag_scores.argmax(dim=1).cpu().numpy())
            target.append(tags.cpu().numpy())
            snrss.append(sentence.cpu().numpy()[:,0])

    prd.append(np.concatenate(predict, axis=0))
    trg.append(np.concatenate(target, axis=0))
    snr.append(np.concatenate(snrss, axis=0))

In [None]:
if Scr ==4:
    mlp_scr4_edge_res = {}

    mlp_scr4_edge_res['mlp_scr4_prd'] = prd
    mlp_scr4_edge_res['mlp_scr4_trg'] = trg
    mlp_scr4_edge_res['mlp_scr4_snr'] = [snrs * scr4_std + scr4_mean for snrs in snr]

    outname = 'mlp_scr4_nodes_kd_smote_res.pickle'                 

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr4_edge_res, file)

In [None]:
if Scr ==5:
    mlp_scr5_edge_res = {}

    mlp_scr5_edge_res['mlp_scr5_prd'] = prd
    mlp_scr5_edge_res['mlp_scr5_trg'] = trg
    mlp_scr5_edge_res['mlp_scr5_snr'] = [snrs * scr5_std + scr5_mean for snrs in snr]

    outname = 'mlp_scr5_nodes_kd_smote_res.pickle'                 

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr5_edge_res, file)

# 6. TF-KD

In [12]:
# Smote data for edge training

datadist_smote = [0.8, 0.2] # distribution among train, validation, 

edge_train_smote_x_npn= []
edge_train_smote_y_npn = []
edge_val_smote_x_npn = []
edge_val_smote_y_npn = []

for i in range(len(link_data_smote)):
    datalen = len(link_data_smote[i][1])
    edge_trainlen_smote = int(datalen*sum(datadist_smote[:1]))
    edge_vallen_smote = int(datalen*sum(datadist_smote[:2]))

    edge_train_smote_x_npn.append(link_data_smote[i][0][0:edge_trainlen_smote].numpy())    
    edge_train_smote_y_npn.append(link_data_smote[i][1][0:edge_trainlen_smote].numpy())
    
    edge_val_smote_x_npn.append(link_data_smote[i][0][edge_trainlen_smote:edge_vallen_smote].numpy())    
    edge_val_smote_y_npn.append(link_data_smote[i][1][edge_trainlen_smote:edge_vallen_smote].numpy())

In [14]:
# Prepare Tensors for Batch Processing
# Edge Training and Testing for individual links
edge_train_smote_x = [torch.from_numpy(link_x).type(torch.float) for link_x in edge_train_smote_x_npn]
edge_val_smote_x = [torch.from_numpy(link_x).type(torch.float) for link_x in edge_val_smote_x_npn]

edge_train_smote_y = [torch.from_numpy(link_y).type(torch.long) for link_y in edge_train_smote_y_npn]
edge_val_smote_y = [torch.from_numpy(link_y).type(torch.long) for link_y in edge_val_smote_y_npn]

In [None]:
NUM_EPOCHS = 100
TRAIN_BATCH_SIZE = 128
VAL_BATCH_SIZE = 64

best_val_accuracy = 0
best_val_loss = 100
number_epoch_until_best = 1
training_time = 0
training_time_until_best = 0
average_time_per_epoch = 0

model = MLP().cuda()
model1 = CBSDNN().cuda()
loss_function = nn.CrossEntropyLoss(weight=torch.tensor([1.0,1.0])).cuda()

optimizer = optim.SGD(model.parameters(),lr=0.001, momentum=0.9)

In [None]:
if Scr == 4:
    checkpoint = torch.load('cbsdnn_smote4_nodes_checkpoint.pth')
if Scr == 5:
    checkpoint = torch.load('cbsdnn_smote5_nodes_checkpoint.pth')

model1.load_state_dict(checkpoint['state_dict'])

In [None]:
# Number of epochs in edge training
NUM_EPOCHS_EDGE = 100
prd = []
trg = []
snr = []

# Parameters for kd loss func
if Scr == 4:
    alpha = 0.4
    T = 10.0
if Scr == 5:
    alpha = 0.5
    T = 10.0

for i in range(len(link_data_smote)):
    
    # Train
    edge_train_smote_dataloader=data.DataLoader(data.TensorDataset(edge_train_smote_x[i],edge_train_smote_y[i]),
                                     batch_size=TRAIN_BATCH_SIZE, shuffle=True, 
                                     num_workers=16, pin_memory=True)

    edge_val_smote_dataloader=data.DataLoader(data.TensorDataset(edge_val_smote_x[i],edge_val_smote_y[i]),
                                   batch_size=VAL_BATCH_SIZE, shuffle=False, 
                                   num_workers=16, pin_memory=True)

    start_time = time.time()
    best_val_accuracy = 0
    best_val_loss = 100
    number_epoch_until_best = 1
    training_time = 0
    training_time_until_best = 0
    average_time_per_epoch = 0
    for epoch_idx in range(NUM_EPOCHS_EDGE): 

        progress_edge_training_smote_epoch = tqdm(
            edge_train_smote_dataloader, 
            desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Training',
            miniters=1, ncols=88, position=0,
            leave=True, total=len(edge_train_smote_dataloader), smoothing=.9, disable = True)

        progress_edge_validation_smote_epoch = tqdm(
            edge_val_smote_dataloader, 
            desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Validation',
            miniters=1, ncols=88, position=0, 
            leave=True, total=len(edge_val_smote_dataloader), smoothing=.9, disable = True)
        
        train_loss = 0
        train_size = 0
        model1.eval()
        model.train()           
        for idx, (sentence, tags) in enumerate(progress_edge_training_smote_epoch):
            sentence = sentence.cuda()
            tags = tags.cuda()
            optimizer.zero_grad()
            tag_scores = model(sentence)
            with torch.no_grad():
                teacher_outputs = model1(sentence.view(-1,1,3))
            loss =loss_kd(tag_scores,tags,teacher_outputs,alpha,T)
            loss.backward()
            optimizer.step()
            train_loss += loss * tags.size()[0]
            train_size += tags.size()[0]
                
        test_loss = 0
        test_size = 0    
        test_total_num_correct = 0
        predict = []
        target = []
        model.eval()
        model1.eval()
        with torch.no_grad():
            for idx, (sentence, tags) in enumerate(progress_edge_validation_smote_epoch):
                sentence = sentence.cuda()
                tags = tags.cuda()
                tag_scores = model(sentence)
                teacher_outputs = model1(sentence.view(-1,1,3))
                loss =loss_kd(tag_scores,tags,teacher_outputs,alpha,T)
                #loss = loss_function(tag_scores, tags)
                predict.append(tag_scores.argmax(dim=1).cpu().numpy())
                target.append(tags.cpu().numpy())        
                test_loss += loss * tags.size()[0]
                test_size += tags.size()[0]
                test_total_num_correct += torch.eq(tag_scores.argmax(dim=1), tags).sum()  
        val_accuracy = test_total_num_correct.item()/test_size
        val_loss = test_loss.item()/test_size
        predict = np.concatenate(predict, axis=0)
        target = np.concatenate(target, axis=0)
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            number_epoch_until_best_accuracy = epoch_idx
            training_time_until_best = training_time
            if Scr == 4:
                torch.save(model.state_dict(), 'mlp_smote4_nodes_tf_kd.pt')
                # Save Model
                state = {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),}
                savepath='mlp_smote4_nodes_tf_kd.pth'
                torch.save(state,savepath)
            if Scr == 5:
                torch.save(model.state_dict(), 'mlp_smote5_nodes_tf_kd.pt')
                # Save Model
                state = {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),}
                savepath='mlp_smote5_nodes_tf_kd.pth'
                torch.save(state,savepath)
                
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            number_epoch_until_best_loss = epoch_idx
        
        print(f'epoch:{epoch_idx}, '
              f'training loss:{train_loss.item()/train_size: .5f}, '
              f'validation loss:{val_loss: .5f}, '
              f'accuracy: {val_accuracy: .4f}, '
              f'best accuracy: {best_val_accuracy: .4f}')

        if epoch_idx > number_epoch_until_best_accuracy+4 and epoch_idx > number_epoch_until_best_loss+4:
            break

    print(f'number of epochs: {number_epoch_until_best_accuracy}')
    
    # Transfer Learning
    if Scr == 4:
        checkpoint = torch.load('mlp_smote4_nodes_tf_kd.pth')
    if Scr == 5:
        checkpoint = torch.load('mlp_smote5_nodes_tf_kd.pth')
    
    model.load_state_dict(checkpoint['state_dict'])
    model.cuda()    
    
    edge_train_dataloader=data.DataLoader(data.TensorDataset(edge_train_x[i],edge_train_y[i]),
                                     batch_size=TRAIN_BATCH_SIZE, shuffle=True, 
                                     num_workers=16, pin_memory=True)

    edge_val_dataloader=data.DataLoader(data.TensorDataset(edge_val_x[i],edge_val_y[i]),
                                   batch_size=VAL_BATCH_SIZE, shuffle=False, 
                                   num_workers=16, pin_memory=True)

    
    start_time_1 = time.time()
    best_val_accuracy_1 = 0
    best_val_loss_1 = 100
    number_epoch_until_best_1 = 1
    training_time_1 = 0
    training_time_until_best_1 = 0
    average_time_per_epoch_1 = 0
    for epoch_idx in range(NUM_EPOCHS_EDGE): 
        
        progress_edge_training_epoch = tqdm(
            edge_train_dataloader, 
            desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Training',
            miniters=1, ncols=88, position=0,
            leave=True, total=len(edge_train_dataloader), smoothing=.9, disable = True)

        progress_edge_validation_epoch = tqdm(
            edge_val_dataloader, 
            desc=f'Link {i}, Epoch {epoch_idx+1}/{NUM_EPOCHS}, Validation',
            miniters=1, ncols=88, position=0, 
            leave=True, total=len(edge_val_dataloader), smoothing=.9, disable = True)
        
        train_loss_1 = 0
        train_size_1 = 0
        model.train() 
        for idx, (sentence, tags) in enumerate(progress_edge_training_epoch):
            sentence = sentence.cuda()
            tags = tags.cuda()
            optimizer.zero_grad()
            tag_scores = model(sentence)
            loss = loss_function(tag_scores, tags)
            loss.backward()
            optimizer.step()
            train_loss_1 += loss * tags.size()[0]
            train_size_1 += tags.size()[0]
                
        test_loss_1 = 0
        test_size_1 = 0    
        test_total_num_correct_1 = 0
        predict_1 = []
        target_1 = []
        model.eval()
        with torch.no_grad():
            for idx, (sentence, tags) in enumerate(progress_edge_validation_epoch):
                sentence = sentence.cuda()
                tags = tags.cuda()
                tag_scores = model(sentence)
                loss = loss_function(tag_scores, tags)
                predict_1.append(tag_scores.argmax(dim=1).cpu().numpy())
                target_1.append(tags.cpu().numpy())        
                test_loss_1 += loss * tags.size()[0]
                test_size_1 += tags.size()[0]
                test_total_num_correct_1 += torch.eq(tag_scores.argmax(dim=1), tags).sum()  
        val_accuracy_1 = test_total_num_correct_1.item()/test_size_1
        val_loss_1 = test_loss_1.item()/test_size_1
        predict_1 = np.concatenate(predict_1, axis=0)
        target_1 = np.concatenate(target_1, axis=0)
        if val_accuracy_1 > best_val_accuracy_1:
            best_val_accuracy_1 = val_accuracy_1
            number_epoch_until_best_accuracy_1 = epoch_idx
            training_time_until_best_1 = training_time_1
            if Scr == 4:
                torch.save(model.state_dict(), 'mlp_scr4_nodes_tf_kd_2.pt')
            if Scr == 5:
                torch.save(model.state_dict(), 'mlp_scr5_nodes_tf_kd_2.pt')
                
        if val_loss_1 < best_val_loss_1:
            best_val_loss_1 = val_loss_1
            number_epoch_until_best_loss_1 = epoch_idx
        
        print(f'epoch:{epoch_idx}, '
              f'Main training loss:{train_loss_1.item()/train_size_1: .5f}, '
              f'Main validation loss:{val_loss_1: .5f}, '
              f'Main accuracy: {val_accuracy_1: .4f}, '
              f'Main best accuracy: {best_val_accuracy_1: .4f}')

        if epoch_idx > number_epoch_until_best_accuracy_1+4 and epoch_idx > number_epoch_until_best_loss_1+4:
            break

    print(f'number of epochs: {number_epoch_until_best_accuracy_1}')
    if Scr == 4:
        model.load_state_dict(torch.load('mlp_scr4_nodes_tf_kd_2.pt'))
    if Scr == 5:
        model.load_state_dict(torch.load('mlp_scr5_nodes_tf_kd_2.pt'))
        
    model.cuda()
    
    # Test

    test_dataloader=data.DataLoader(data.TensorDataset(test_x[i],test_y[i]), 
                                       batch_size=VAL_BATCH_SIZE, shuffle=False,
                                       num_workers=16, pin_memory=True)

    progress_test_epoch = tqdm(
        test_dataloader, 
        desc=f'Link {i}, Test',
        miniters=1, ncols=88, position=0, 
        leave=True, total=len(test_dataloader), smoothing=.9)

    predict = []
    target = []
    snrss = []
    model.eval()
    with torch.no_grad():
        for idx, (sentence, tags) in enumerate(progress_test_epoch):
            sentence = sentence.cuda()
            tags = tags.cuda()
            tag_scores = model(sentence)
            predict.append(tag_scores.argmax(dim=1).cpu().numpy())
            target.append(tags.cpu().numpy())
            snrss.append(sentence.cpu().numpy()[:,0])

    prd.append(np.concatenate(predict, axis=0))
    trg.append(np.concatenate(target, axis=0))
    snr.append(np.concatenate(snrss, axis=0))

In [None]:
if Scr == 4:    
    mlp_scr4_edge_res = {}

    mlp_scr4_edge_res['mlp_scr4_prd'] = prd
    mlp_scr4_edge_res['mlp_scr4_trg'] = trg
    mlp_scr4_edge_res['mlp_scr4_snr'] = [snrs * scr4_std + scr4_mean for snrs in snr]


    outname = 'mlp_scr4_nodes_tf_kd_res.pickle'        

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr4_edge_res, file)

In [None]:
if Scr == 5:    
    mlp_scr5_edge_res = {}

    mlp_scr5_edge_res['mlp_scr5_prd'] = prd
    mlp_scr5_edge_res['mlp_scr5_trg'] = trg
    mlp_scr5_edge_res['mlp_scr5_snr'] = [snrs * scr5_std + scr5_mean for snrs in snr]


    outname = 'mlp_scr5_nodes_tf_kd_res.pickle'        

    outfile = os.path.join(os.getcwd(), outname)
    if os.path.exists(outfile):
            os.replace(outfile, outfile + ".old")

    with open(outfile, 'wb') as file:
        pickle.dump(mlp_scr5_edge_res, file)