# Imports:

In [None]:
from connect import *

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler

In [None]:
import NAR.utils as utils
import NAR.data_management as dm
import NAR.plotting_utils as pu
import NAR.masking as masking
import NAR.models as models

# Reload modules (DEV):

In [None]:
reload_modules([utils, dm, pu, masking, models])

# Checkpoint 1:

In [None]:
df = pd.read_csv(
    'data/pre_processed_unsw.csv',
    low_memory=False)

df = df[df['Micro Attack'] != -1]
df['Micro Attack'] = df['Attack'] + '_' +df['Micro Attack'].astype(str)

In [None]:
categorical_columns_dict = {
       'PROTOCOL': True, 'L7_PROTO': True, 'IN_BYTES': False,
       'IN_PKTS': False, 'OUT_BYTES': False, 'OUT_PKTS': False, 'TCP_FLAGS': True, 'CLIENT_TCP_FLAGS': True,
       'SERVER_TCP_FLAGS': True, 'FLOW_DURATION_MILLISECONDS': False, 'DURATION_IN': False,
       'DURATION_OUT': False, 'MIN_TTL': False, 'MAX_TTL': False, 'LONGEST_FLOW_PKT': False,
       'SHORTEST_FLOW_PKT': False, 'MIN_IP_PKT_LEN': False, 'MAX_IP_PKT_LEN': False,
       'SRC_TO_DST_SECOND_BYTES': False, 'DST_TO_SRC_SECOND_BYTES': False,
       'RETRANSMITTED_IN_BYTES': False, 'RETRANSMITTED_IN_PKTS': False,
       'RETRANSMITTED_OUT_BYTES': False, 'RETRANSMITTED_OUT_PKTS': False,
       'SRC_TO_DST_AVG_THROUGHPUT': False, 'DST_TO_SRC_AVG_THROUGHPUT': False,
       'NUM_PKTS_UP_TO_128_BYTES': False, 'NUM_PKTS_128_TO_256_BYTES': False,
       'NUM_PKTS_256_TO_512_BYTES': False, 'NUM_PKTS_512_TO_1024_BYTES': False,
       'NUM_PKTS_1024_TO_1514_BYTES': False, 'TCP_WIN_MAX_IN': False, 'TCP_WIN_MAX_OUT': False,
       'ICMP_TYPE': True, 'ICMP_IPV4_TYPE': True, 'DNS_QUERY_ID': False, 'DNS_QUERY_TYPE': True,
       'DNS_TTL_ANSWER': True, 'FTP_COMMAND_RET_CODE': True}

categorical_columns = pd.Series(categorical_columns_dict)
categorical_columns = df.columns[list(np.where(np.array(categorical_columns)==True)[0])].tolist()
cont_columns = list(set(df.columns.tolist()) - set(categorical_columns) - set(['Attack', 'Micro Attack']))

In [None]:
X = df.drop(['Attack', 'Micro Attack'], axis=1)
y = df[['Attack', 'Micro Attack']]

categorical_indicator = X.dtypes == 'O'

categorical_columns = X.columns[list(np.where(np.array(categorical_indicator)==True)[0])].tolist()
cont_columns = list(set(X.columns.tolist()) - set(categorical_columns))

cat_idxs = list(np.where(np.array(categorical_indicator)==True)[0])
con_idxs = list(set(range(len(X.columns))) - set(cat_idxs))

# Categories and target classes to natural numbers:
cat_dims = []
for col in categorical_columns:
    l_enc = LabelEncoder()
    X[col] = l_enc.fit_transform(X[col].values)
    cat_dims.append(len(l_enc.classes_))

X = X.values

scaler = MinMaxScaler()
X = scaler.fit_transform(X)

y = y.values
macro_l_enc = LabelEncoder()
micro_l_enc = LabelEncoder()
macro_y = macro_l_enc.fit_transform(y[:, 0])
micro_y = micro_l_enc.fit_transform(y[:, 1])

# num of classes:
num_macro_classes = len(np.unique(macro_y))
num_micro_classes = len(np.unique(micro_y))

In [None]:
# Converting to suitable format:
data = pd.DataFrame(X)
data['Macro Label'] = macro_l_enc.inverse_transform(macro_y)
data['Micro Label'] = micro_l_enc.inverse_transform(micro_y)

micro_zdas = [
        'Shellcode_0',                 # Type A
        'Shellcode_1',                 # Type A
        'Shellcode_2',                 # Type A
        'Worms_0',                     # Type A
        'Worms_1',                     # Type A
        'Worms_2',                     # Type A
        'Fuzzers_1',                   # Type B
        'Reconnaissance_1',            # Type B
        'Backdoor_0',                  # Type B
        'Generic_0',                   # Type B
        'DoS_0',                       # Type B
        'Exploits_0',                  # Type B
        'Analysis_0',                  # Type B
        'Generic_1',                   # Type B
        ]

micro_type_A_ZdAs = [
        'Shellcode_0',           # Type A
        'Shellcode_1',           # Type A
        'Shellcode_2',           # Type A
        'Worms_0',               # Type A
        'Worms_1',               # Type A
        'Worms_2',               # Type A
        ]

micro_type_B_ZdAs = [
        'Fuzzers_1',                   # Type B
        'Reconnaissance_1',            # Type B
        'Backdoor_0',                  # Type B
        'Generic_0',                   # Type B
        'DoS_0',                       # Type B
        'Exploits_0',                  # Type B
        'Analysis_0',                  # Type B
        'Generic_1',                   # Type B
        ]

train_type_B_micro_classes = [
        'Fuzzers_1',                   # Type B
        'Reconnaissance_1',            # Type B
        'Backdoor_0',                  # Type B
        'Generic_0',                   # Type B
        ]

test_type_B_micro_classes = [
        'DoS_0',                       # Type B
        'Exploits_0',                  # Type B
        'Analysis_0',                  # Type B
        'Generic_1',                   # Type B
        ]


test_type_A_macro_classes = [
        'Shellcode'            # Type A
        ]

train_type_A_macro_classes = [
        'Worms'                # Type A
        ]


data = masking.mask_real_data_lowdim(
    data=data,
    micro_zdas=micro_zdas,
    micro_type_A_ZdAs=micro_type_A_ZdAs,
    micro_type_B_ZdAs=micro_type_B_ZdAs
    )

train_data, test_data = masking.split_real_data(
    data,
    train_type_B_micro_classes,
    test_type_B_micro_classes,
    test_type_A_macro_classes,
    train_type_A_macro_classes
    )

micro_classes = data['Micro Label'].unique()
macro_classes = data['Macro Label'].unique()

In [None]:
# Train test split is psuedo-randomic. for this reason we do it just once.
train_data.to_csv('data/unsw_train.csv', index=0)
test_data.to_csv('data/unsw_test.csv', index=0)

# Checkpoint 2:

In [None]:
train_data = pd.read_csv('data/unsw_train.csv', low_memory=False)
test_data = pd.read_csv('data/unsw_test.csv', low_memory=False)

# Training:

## helper code:

In [None]:
def save_stuff(prefix):
    torch.save(
        encoder.state_dict(),
        prefix+'enc.pt')
    torch.save(
        decoder_1_b.state_dict(),
        prefix+'_dec_b.pt')

In [None]:
def first_phase_simple(
        sample_batch):

    global cs_cm_1
    global os_cm_1
    global metrics_dict

    # get masks: THESE ARE NOT COMPLEMETARY!
    zda_mask, \
        known_classes_mask, \
        unknown_1_mask, \
        active_query_mask = utils.get_masks_1(
            sample_batch[1],
            N_QUERY,
            device=device)

    # get one_hot_labels:
    oh_labels = utils.get_oh_labels(
        decimal_labels=sample_batch[1][:, 1].long(),
        total_classes=max_prototype_buffer_micro,
        device=device)

    # mask labels:
    oh_masked_labels = utils.get_one_hot_masked_labels(
        oh_labels,
        unknown_1_mask,
        device=device)

    # encoding input space:
    encoded_inputs = encoder(
        sample_batch[0].float())

    # processing
    decoded_1, hiddens_1, predicted_kernel = processor_1(
        encoded_inputs,
        oh_masked_labels)

    # semantic kernel:
    semantic_kernel = oh_labels @ oh_labels.T
    # Processor regularization:
    proc_1_reg_loss = utils.get_kernel_kernel_loss(
        semantic_kernel,
        predicted_kernel,
        a_w=attr_w,
        r_w=rep_w)

    # Transform lables for Few_shot Closed-set classif.
    # compatible with the design of models.get_centroids functions,
    # wich is called by our GAT processors.
    unique_labels, transformed_labels = sample_batch[1][:, 1][active_query_mask].unique(
        return_inverse=True)

    # closed set classification
    dec_1_loss_a = decoder_1a_criterion(
        decoded_1[active_query_mask],
        transformed_labels)

    # Detach closed from open set gradients
    input_for_os_dec = decoded_1.detach()
    input_for_os_dec.requires_grad = True

    # Unknown cluster prediction:
    predicted_unknown_1s = decoder_1_b(
        scores=input_for_os_dec[unknown_1_mask]
        )

    # open-set loss:
    dec_1_loss_b = decoder_1b_criterion(
        predicted_unknown_1s,
        zda_mask[unknown_1_mask].float().unsqueeze(-1))

    # inverse transform cs preds
    it_preds = utils.inverse_transform_preds(
        transormed_preds=decoded_1[active_query_mask],
        real_labels=unique_labels,
        real_class_num=max_prototype_buffer_micro)

    #
    # REPORTING:
    #

    # Closed set confusion matrix
    cs_cm_1 += utils.efficient_cm(
        preds=it_preds.detach(),
        targets=sample_batch[1][:, 1][active_query_mask].long())

    # Open set confusion matrix
    os_cm_1 += utils.efficient_os_cm(
        preds=(predicted_unknown_1s.detach() > 0.5).long(),
        targets=zda_mask[unknown_1_mask].long()
        )

    # accuracies:
    CS_acc = utils.get_acc(
        logits_preds=it_preds,
        oh_labels=oh_labels[active_query_mask])

    OS_acc = utils.get_binary_acc(
        logits=predicted_unknown_1s.detach(),
        labels=zda_mask[unknown_1_mask].float().unsqueeze(-1))

    OS_b_acc = utils.get_balanced_accuracy(
                os_cm=os_cm_1,
                n_w=balanced_acc_n_w
                )

    # for reporting:
    metrics_dict['losses_1a'].append(dec_1_loss_a.item())
    metrics_dict['proc_reg_loss1'].append(proc_1_reg_loss.item())
    metrics_dict['CS_accuracies'].append(CS_acc.item())
    metrics_dict['losses_1b'].append(dec_1_loss_b.item())
    metrics_dict['OS_accuracies'].append(OS_acc.item())
    metrics_dict['OS_B_accuracies'].append(OS_b_acc.item())

    # Processor loss:
    proc_1_loss = dec_1_loss_a + proc_1_reg_loss

    return proc_1_loss, \
        dec_1_loss_b, \
        hiddens_1, \
        decoded_1

In [None]:
def second_phase_simple(
        sample_batch,
        hiddens_1):

    global cs_cm_2
    global os_cm_2
    global metrics_dict

    # get masks: THESE ARE NOT COMPLEMETARY!
    type_A_mask, known_macro_classes_mask, \
        unknown_2_mask, active_query_mask_2 = utils.get_masks_2(
            sample_batch[1],
            N_QUERY,
            device=device)

    # get one_hot_labels:
    oh_labels = utils.get_oh_labels(
        decimal_labels=sample_batch[1][:, 0].long(),
        total_classes=max_prototype_buffer_macro,
        device=device)

    # mask labels:
    oh_masked_labels = utils.get_one_hot_masked_labels(
        oh_labels,
        unknown_2_mask,
        device=device)

    decoded_2, hiddens_2, predicted_kernel_2 = processor_2(
        hiddens_1,
        oh_masked_labels)

    # semantic kernel:
    semantic_kernel_2 = oh_labels @ oh_labels.T
    # Processor regularization:
    proc_2_reg_loss = utils.get_kernel_kernel_loss(
        semantic_kernel_2,
        predicted_kernel_2,
        a_w=attr_w,
        r_w=rep_w)

    unique_macro_labels, transformed_labels_2 = sample_batch[1][:, 0][active_query_mask_2].unique(
        return_inverse=True)

    # Closed set: should learn to associate type B's to corr. macro cluster.
    # geometrical "break" in the real-data case. (GRadients 2A)
    dec_2_loss_a = decoder_2a_criterion(
        decoded_2[active_query_mask_2],
        transformed_labels_2)

    input_for_os_dec_2 = decoded_2.detach()
    input_for_os_dec_2.requires_grad = True

    # Unknown cluster prediction:
    predicted_unknown_2s = decoder_2_b(
        scores=input_for_os_dec_2[unknown_2_mask]
        )

    # open-set loss:
    dec_2_loss_b = decoder_2b_criterion(
        predicted_unknown_2s,
        type_A_mask[unknown_2_mask].float().unsqueeze(-1))

    # inverse transform cs preds
    it_preds = utils.inverse_transform_preds(
        transormed_preds=decoded_2[active_query_mask_2],
        real_labels=unique_macro_labels,
        real_class_num=max_prototype_buffer_macro)

    # Closed set confusion matrix
    cs_cm_2 += utils.efficient_cm(
        preds=it_preds.detach(),
        targets=sample_batch[1][:, 0][active_query_mask_2].long(),
        )

    # Open set confusion matrix
    os_cm_2 += utils.efficient_os_cm(
        preds=(predicted_unknown_2s.detach() > 0.5).long(),
        targets=type_A_mask[unknown_2_mask].long()
        )

    # accuracies:
    CS_acc_2 = utils.get_acc(
        logits_preds=it_preds,
        oh_labels=oh_labels[active_query_mask_2])

    OS_acc_2 = utils.get_binary_acc(
        logits=predicted_unknown_2s.detach(),
        labels=type_A_mask[unknown_2_mask].float().unsqueeze(-1))

    OS_2_B_acc = utils.get_balanced_accuracy(
                os_cm=os_cm_2,
                n_w=balanced_acc_n_w
                )

    proc_2_loss = dec_2_loss_a + proc_2_reg_loss

    # for reporting:
    metrics_dict['losses_2a'].append(dec_2_loss_a.item())
    metrics_dict['proc_reg_loss2'].append(proc_2_reg_loss.item())
    metrics_dict['losses_2b'].append(dec_2_loss_b.item())
    metrics_dict['CS_2_accuracies'].append(CS_acc_2.item())
    metrics_dict['OS_2_accuracies'].append(OS_acc_2.item())
    metrics_dict['OS_2_B_accuracies'].append(OS_2_B_acc.item())

    return proc_2_loss, \
        dec_2_loss_b, \
        hiddens_2, \
        decoded_2

## init data:

In [44]:
natural_inputs_dim = 39
save = True
wb = True

In [45]:
# Generate Data
torch_seed = 1234
torch.manual_seed(torch_seed)

#
#
# Initialize
#
#
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataset and Dataloader:
train_dataset = dm.RealFewShotDataset_LowDim(
    features=train_data.drop(columns=[
        'Micro Label',
        'Macro Label',
        'ZdA',
        'Type_A_ZdA',
        'Type_B_ZdA']).values,
    df=train_data)

test_dataset = dm.RealFewShotDataset_LowDim(
    features=test_data.drop(columns=[
        'Micro Label',
        'Macro Label',
        'ZdA',
        'Type_A_ZdA',
        'Type_B_ZdA']).values,
    df=test_data)

# Number of classes per task :
# two of them are ZdAs, one is a type B and the other a type A
N_WAY = 5
N_SHOT = 5   # Number of samples per class in the support set
N_QUERY = 15  # Number of samples per class in the query set

n_train_tasks = 500    # For speedy tests, reduce here...
n_eval_tasks = 50     # For speedy tests, reduce here...


num_of_test_classes = len(test_dataset.micro_classes)
num_of_train_classes = len(train_dataset.micro_classes)

train_loader = DataLoader(
    dataset=train_dataset,
    sampler=dm.FewShotSampler(
                dataset=train_dataset,
                n_tasks=n_train_tasks,
                classes_per_it=N_WAY,
                k_shot=N_SHOT,
                q_shot=N_QUERY),
    num_workers=4,
    drop_last=True,
    collate_fn=dm.convenient_cf)

test_loader = DataLoader(
    dataset=test_dataset,
    sampler=dm.FewShotSampler(
                dataset=test_dataset,
                n_tasks=n_eval_tasks,
                classes_per_it=N_WAY,
                k_shot=N_SHOT,
                q_shot=N_QUERY),
    num_workers=4,
    drop_last=True,
    collate_fn=dm.convenient_cf)

# reproducibility
train_loader.sampler.reset()
test_loader.sampler.reset()

## init architectures:

In [48]:
###
# Hyper params:
###
max_prototype_buffer_micro = 40
max_prototype_buffer_macro = 10
lr = 0.001
# training parameters
n_epochs = 60
norm = "batch"
dropout = 0.1
patience = 20
lambda_os = 1

processor_attention_heads = 8
h_dim = 1024
report_step_frequency = 100

pos_weight_1 = 2.5
pos_weight_2 = 5

architectures = 'GATV5 Confidence Dec'
balanced_acc_n_w = 0.5
attr_w = 1
rep_w = 1
run_name = f'UNSW from scratch'

if wb:
    wandb.init(project='Nero_1.1',
               name=run_name,
               config={"N_SHOT": N_SHOT,
                       "N_QUERY": N_QUERY,
                       "N_WAY": N_WAY,
                       "num_of_test_classes": num_of_test_classes,
                       "num_of_train_classes": num_of_train_classes,
                       "train_batch_size": N_WAY * (N_SHOT + N_QUERY),
                       "len(train_loader)": train_loader.sampler.n_tasks,
                       "len(test_dataset)": test_loader.sampler.n_tasks,
                       "max_prototype_buffer_micro": max_prototype_buffer_micro,
                       "max_prototype_buffer_macro": max_prototype_buffer_macro,
                       "device": device,
                       "natural_inputs_dim": natural_inputs_dim,
                       "h_dim": h_dim,
                       "lr": lr,
                       "n_epochs": n_epochs,
                       "norm": norm,
                       "dropout": dropout,
                       "patience": patience,
                       'zdas': data[data.ZdA == True]['Micro Label'].unique(),
                       "lambda_os": lambda_os,
                       "positive_weight_1": pos_weight_1,
                       "positive_weight_2": pos_weight_2,
                       "architectures": architectures,
                       "balanced_acc_n_w": balanced_acc_n_w,
                       "attr_w": attr_w,
                       "rep_w": rep_w
                       })
else:
    print(run_name)

# Encoder
encoder = models.Encoder(
    in_features=natural_inputs_dim,
    out_features=h_dim,
    norm=norm,
    dropout=dropout,
    ).to(device)

# First phase:
processor_1 = models.GAT_V5_Processor(
                h_dim=h_dim,
                processor_attention_heads=processor_attention_heads,
                dropout=dropout,
                device=device
                ).to(device)

decoder_1a_criterion = nn.CrossEntropyLoss()

decoder_1_b = models.Confidence_Decoder(
                in_dim=N_WAY-2,  # Subtract 1 ZdA
                dropout=dropout,
                device=device
                ).to(device)

decoder_1b_criterion = nn.BCEWithLogitsLoss(
    pos_weight=torch.Tensor([pos_weight_1])).to(device)


# Second phase:
processor_2 = models.GAT_V5_Processor(
                h_dim=h_dim,
                processor_attention_heads=processor_attention_heads,
                dropout=dropout,
                device=device
                ).to(device)

decoder_2a_criterion = nn.CrossEntropyLoss().to(device)

decoder_2_b = models.Confidence_Decoder(
                in_dim=N_WAY-1,
                dropout=dropout,
                device=device
                ).to(device)

decoder_2b_criterion = nn.BCEWithLogitsLoss(
    pos_weight=torch.Tensor([pos_weight_2])).to(device)


params_for_processor_optimizer = \
        list(encoder.parameters()) + \
        list(processor_1.parameters()) + \
        list(processor_2.parameters())


processor_optimizer = optim.Adam(
    params_for_processor_optimizer,
    lr=lr)

params_for_os_optimizer = \
        list(decoder_1_b.parameters()) + \
        list(decoder_2_b.parameters())

os_optimizer = optim.Adam(
    params_for_os_optimizer,
    lr=lr)


# TRAINING
max_eval_TNR = torch.zeros(1, device=device)
epochs_without_improvement = 0

In [49]:
wandb.watch(processor_1)
wandb.watch(encoder)
wandb.watch(decoder_1_b)

[]

## Train (from scratch):

In [50]:
for epoch in tqdm(range(n_epochs)):

    # TRAIN
    encoder.train()
    processor_1.train()
    decoder_1_b.train()
    processor_2.train()
    decoder_2_b.train()

    # reset conf Mats
    cs_cm_1 = torch.zeros(
        [max_prototype_buffer_micro, max_prototype_buffer_micro],
        device=device)
    os_cm_1 = torch.zeros([2, 2], device=device)
    cs_cm_2 = torch.zeros(
        [max_prototype_buffer_macro, max_prototype_buffer_macro],
        device=device)
    os_cm_2 = torch.zeros([2, 2], device=device)

    # reset metrics dict
    metrics_dict = utils.reset_metrics_dict()

    # go!
    for batch_idx, sample_batch in enumerate(train_loader):
        # go to cuda:
        sample_batch = sample_batch[0].to(device), sample_batch[1].to(device)

        # PHASE 1
        proc_1_loss, \
            os_1_loss, \
            hiddens_1, \
            decoded_1 = first_phase_simple(
                                sample_batch)

        # PHASE 2
        proc_2_loss, \
            os_2_loss, \
            hiddens_2, \
            decoded_2 = second_phase_simple(
                sample_batch,
                hiddens_1)

        # Learning
        proc_loss = proc_1_loss + proc_2_loss
        processor_optimizer.zero_grad()
        proc_loss.backward()
        processor_optimizer.step()

        os_loss = os_1_loss + os_2_loss
        os_optimizer.zero_grad()
        os_loss.backward()
        os_optimizer.step()

        # Reporting
        step = batch_idx + (epoch * train_loader.sampler.n_tasks)

        if step % report_step_frequency == 0:
            utils.reporting_simple(
                'train',
                epoch,
                metrics_dict,
                step,
                wb,
                wandb)

    pu.super_plotting_function(
                phase='Training',
                labels=sample_batch[1].cpu(),
                hiddens_1=hiddens_1.detach().cpu(),
                hiddens_2=hiddens_2.detach().cpu(),
                scores_1=decoded_1.detach().cpu(),
                scores_2=decoded_2.detach().cpu(),
                cs_cm_1=cs_cm_1.cpu(),
                cs_cm_2=cs_cm_2.cpu(),
                os_cm_1=os_cm_1.cpu(),
                os_cm_2=os_cm_2.cpu(),
                wb=wb,
                wandb=wandb,
                complete_micro_classes=micro_classes,
                complete_macro_classes=macro_classes
                )

    with torch.inference_mode():

        # Evaluation
        encoder.eval()
        processor_1.eval()
        decoder_1_b.eval()
        processor_2.eval()
        decoder_2_b.eval()

        # reset conf Mats
        cs_cm_1 = torch.zeros(
            [max_prototype_buffer_micro, max_prototype_buffer_micro],
            device=device)
        os_cm_1 = torch.zeros([2, 2], device=device)
        cs_cm_2 = torch.zeros(
            [max_prototype_buffer_macro, max_prototype_buffer_macro],
            device=device)
        os_cm_2 = torch.zeros([2, 2], device=device)

        # reset metrics dict
        metrics_dict = utils.reset_metrics_dict()

        # go!
        for batch_idx, sample_batch in enumerate(test_loader):
            # go to cuda:
            sample_batch = sample_batch[0].to(device), sample_batch[1].to(device)

            # PHASE 1
            proc_1_loss, \
                os_1_loss, \
                hiddens_1, \
                decoded_1 = first_phase_simple(
                                    sample_batch)

            # PHASE 2
            proc_2_loss, \
                os_2_loss, \
                hiddens_2, \
                decoded_2 = second_phase_simple(
                    sample_batch,
                    hiddens_1)

            # Reporting
            step = batch_idx + (epoch * test_loader.sampler.n_tasks)

            if step % report_step_frequency == 0:
                utils.reporting_simple(
                    'eval',
                    epoch,
                    metrics_dict,
                    step,
                    wb,
                    wandb)

        pu.super_plotting_function(
                phase='Evaluation',
                labels=sample_batch[1].cpu(),
                hiddens_1=hiddens_1.detach().cpu(),
                hiddens_2=hiddens_2.detach().cpu(),
                scores_1=decoded_1.detach().cpu(),
                scores_2=decoded_2.detach().cpu(),
                cs_cm_1=cs_cm_1.cpu(),
                cs_cm_2=cs_cm_2.cpu(),
                os_cm_1=os_cm_1.cpu(),
                os_cm_2=os_cm_2.cpu(),
                wb=wb,
                wandb=wandb,
                complete_micro_classes=micro_classes,
                complete_macro_classes=macro_classes
            )

        # Checking for improvement
        curr_TNR = np.array(metrics_dict['OS_2_B_accuracies']).mean()

        if curr_TNR > max_eval_TNR:
            max_eval_TNR = curr_TNR
            epochs_without_improvement = 0
            print(f'saving models at epoch {epoch}')
            save_stuff(run_name)
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f'Early stopping at episode {step}')
            if wb:
                wandb.log({'Early stopping at episode': step})
            break

print(f'max_eval_TNR: {max_eval_TNR}')

if wb:
    wandb.finish()

  0%|          | 0/60 [00:00<?, ?it/s]

saving models at epoch 0
saving models at epoch 1
saving models at epoch 2
saving models at epoch 4
saving models at epoch 9
Early stopping at episode 1499
max_eval_TNR: 0.773594981431961




VBox(children=(Label(value='24.762 MB of 24.811 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.9980…

0,1
CS1 accuracy_eval:,▇▃█▆▇██▅▅██▁██▃
CS1 accuracy_train:,▂▄▆▆▆▆▆▆▆▇▆▆█▆▇▇▇█▇▇▁▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇█▇
CS2 accuracy_eval:,▂▆▃██▅▆▇▂▁▆▅█▃▁
CS2 accuracy_train:,▁▄▆▆▇▆▇▆▇▇▇▇▇▇▇▇▇█▇▇▅▇▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇█▇
Early stopping at episode,▁
OS1 Bal. accuracy_eval:,▁▂▁▂▄▂▂▅▄▁▆▄█▅▅
OS1 Bal. accuracy_train:,▁▁▁▁▁▁▁▁▁▃▁▂▂▂▂▂▃▄▃▄▃▄▄▅▅▆▅▅▅▅▅▆▅▄▆▆▆▆█▆
OS1 accuracy_eval:,▁▂▁▂▄▂▂▅▅▁▆▅█▆▅
OS1 accuracy_train:,▁▁▁▁▁▁▁▁▁▃▁▂▂▂▂▂▃▄▃▄▃▄▄▄▅▆▅▅▅▅▅▆▅▄▆▆▆▆█▆
OS2 Bal. accuracy_eval:,▁▆█▂▂▇▁▂▂▄▂▂▇█▂

0,1
CS1 accuracy_eval:,0.77778
CS1 accuracy_train:,0.9642
CS2 accuracy_eval:,0.63333
CS2 accuracy_train:,0.95033
Early stopping at episode,1499.0
OS1 Bal. accuracy_eval:,0.725
OS1 Bal. accuracy_train:,0.84675
OS1 accuracy_eval:,0.74118
OS1 accuracy_train:,0.84002
OS2 Bal. accuracy_eval:,0.5


In [None]:
wandb.finish()

## Train (fine tune a pre-trained Neural algorithmic processor):

In [None]:
processor_1.load_state_dict(torch.load('GENNARO-processor.pt'))
processor_1.eval()

# Freeze the pre-trained processor
for param in processor_1.parameters():
    param.requires_grad = False

In [None]:
for epoch in tqdm(range(n_epochs)):

    # TRAIN
    encoder.train()
    decoder_1_b.train()

    # reset conf Mats
    cs_cm_1 = torch.zeros(
        [max_prototype_buffer, max_prototype_buffer],
        device=device)
    os_cm_1 = torch.zeros([2, 2], device=device)

    # reset metrics dict
    metrics_dict = utils.reset_metrics_dict()

    # go!
    for batch_idx, sample_batch in enumerate(train_loader):
        # go to cuda:
        sample_batch = sample_batch[0].to(device), sample_batch[1].to(device)

        # PHASE 1
        proc_loss, \
            os_loss, \
            hiddens_1, \
            decoded_1 = first_phase_simple(
                                sample_batch)

        # Learning
        processor_optimizer.zero_grad()
        proc_loss.backward()
        processor_optimizer.step()

        os_loss = os_loss
        os_optimizer.zero_grad()
        os_loss.backward()
        os_optimizer.step()

        # Reporting
        step = batch_idx + (epoch * len(train_loader))

        if step % report_step_frequency == 0:
            utils.reporting_gennaro(
                'train',
                epoch,
                metrics_dict,
                step,
                wb,
                wandb)

    pu.super_plotting_function_gennaro(
                phase='Training',
                labels=sample_batch[1].cpu(),
                hiddens_1=hiddens_1.detach().cpu(),
                scores_1=decoded_1.detach().cpu(),
                cs_cm_1=cs_cm_1.cpu(),
                os_cm_1=os_cm_1.cpu(),
                wb=wb,
                wandb=wandb,
                complete_classes=classes,
                )

    with torch.inference_mode():

        # Evaluation
        encoder.eval()
        decoder_1_b.eval()

        # reset conf Mats
        cs_cm_1 = torch.zeros(
            [max_prototype_buffer, max_prototype_buffer],
            device=device)
        os_cm_1 = torch.zeros([2, 2], device=device)

        # reset metrics dict
        metrics_dict = utils.reset_metrics_dict()

        # go!
        for batch_idx, sample_batch in enumerate(test_loader):
            # go to cuda:
            sample_batch = sample_batch[0].to(device), sample_batch[1].to(device)

            # PHASE 1
            proc_1_loss, \
                os_1_loss, \
                hiddens_1, \
                decoded_1 = first_phase_simple(
                                    sample_batch)

            # Reporting
            step = batch_idx + (epoch * len(test_loader))

            if step % report_step_frequency == 0:
                utils.reporting_gennaro(
                    'eval',
                    epoch,
                    metrics_dict,
                    step,
                    wb,
                    wandb)

        pu.super_plotting_function_gennaro(
                phase='Evaluation',
                labels=sample_batch[1].cpu(),
                hiddens_1=hiddens_1.detach().cpu(),
                scores_1=decoded_1.detach().cpu(),
                cs_cm_1=cs_cm_1.cpu(),
                os_cm_1=os_cm_1.cpu(),
                wb=wb,
                wandb=wandb,
                complete_classes=classes,
            )

        # Checking for improvement
        curr_TNR = utils.get_balanced_accuracy(
                pos_labels=sample_batch[1][:, 1].long(),
                n_tasks=n_eval_tasks,
                os_cm=os_cm_1,
                n_w=balanced_acc_n_w
                )

        if curr_TNR > max_eval_TNR:
            max_eval_TNR = curr_TNR
            epochs_without_improvement = 0
            save_stuff(run_name)
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f'Early stopping at episode {step}')
            if wb:
                wandb.log({'Early stopping at episode': step})
            break

if wb:
    wandb.finish()

In [None]:
wandb.finish()