# Imports:

In [None]:
from connect import *

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
import pickle

In [None]:
import NAR.utils as utils
import NAR.data_management as dm
import NAR.plotting_utils as pu
import NAR.masking as masking
import NAR.models as models

# Reload modules (DEV):

In [None]:
reload_modules([utils, dm, pu, masking, models])

# Checkpoint:

### Edge to IIoT set:

In [None]:
df = pd.read_csv(
    'data/pre_processed_iiotset.csv',
    low_memory=False)
df = df[df['Attack_type'] != 'MITM_-1']

In [None]:
X = df.drop(['Attack_label', 'attack_macro_cat', 'Attack_type'], axis=1)
y = df[['attack_macro_cat', 'Attack_type']]

categorical_indicator = X.dtypes == 'O'

categorical_columns = X.columns[list(np.where(np.array(categorical_indicator)==True)[0])].tolist()
cont_columns = list(set(X.columns.tolist()) - set(categorical_columns))

cat_idxs = list(np.where(np.array(categorical_indicator)==True)[0])
con_idxs = list(set(range(len(X.columns))) - set(cat_idxs))

# Categories and target classes to natural numbers:
cat_dims = []
for col in categorical_columns:
    l_enc = LabelEncoder()
    X[col] = l_enc.fit_transform(X[col].values)
    cat_dims.append(len(l_enc.classes_))

X = X.values

scaler = MinMaxScaler()
X = scaler.fit_transform(X)

y = y.values
macro_l_enc = LabelEncoder()
micro_l_enc = LabelEncoder()
macro_y = macro_l_enc.fit_transform(y[:, 0])
micro_y = micro_l_enc.fit_transform(y[:, 1])

# num of classes:
num_macro_classes = len(np.unique(macro_y))
num_micro_classes = len(np.unique(micro_y))

In [None]:
df.groupby(['attack_macro_cat', 'Attack_type']).count()

In [None]:
# Converting to suitable format:
data = pd.DataFrame(X)
data['Macro Label'] = macro_l_enc.inverse_transform(macro_y)
data['Micro Label'] = micro_l_enc.inverse_transform(micro_y)


micro_zdas = [
        'MITM_0',                       # Type A
        'MITM_1',                       # Type A
        'MITM_2',                       # Type A
        'MITM_3',                       # Type A
        'Fingerprinting',              # Type A
        'Port_Scanning',               # Type A
        'Vulnerability_scanner',       # Type A
        'Backdoor',                    # Type B
        'DDoS_ICMP',                   # Type B
        'DDoS_HTTP',                   # Type B
        'SQL_injection',               # Type B
        'Uploading',                   # Type B
        'Password'                     # Type B
        ]

micro_type_A_ZdAs = [
        'MITM_0',                       # Type A
        'MITM_1',                       # Type A
        'MITM_2',                       # Type A
        'MITM_3',                       # Type A
        'Fingerprinting',              # Type A
        'Port_Scanning',               # Type A
        'Vulnerability_scanner',       # Type A
        ]

micro_type_B_ZdAs = [
        'Backdoor',                    # Type B
        'DDoS_ICMP',                   # Type B
        'DDoS_HTTP',                   # Type B
        'SQL_injection',               # Type B
        'Uploading',                   # Type B
        'Password'                     # Type B
        ]

train_type_B_micro_classes = [
        'Backdoor',                    # Type B
        'DDoS_ICMP',                   # Type B
        'Uploading'                    # Type B
        ]

test_type_B_micro_classes = [
        'DDoS_HTTP',                   # Type B
        'SQL_injection',               # Type B
        'Password'                     # Type B
        ]


test_type_A_macro_classes = [
        'MITM'                         # Type A
        ]

train_type_A_macro_classes = [
        'Information_Gathering'        # Type A
        ]


data = masking.mask_real_data_lowdim(
    data=data,
    micro_zdas=micro_zdas,
    micro_type_A_ZdAs=micro_type_A_ZdAs,
    micro_type_B_ZdAs=micro_type_B_ZdAs
    )

train_data, test_data = masking.split_real_data(
    data,
    train_type_B_micro_classes,
    test_type_B_micro_classes,
    test_type_A_macro_classes,
    train_type_A_macro_classes
    )

micro_classes = data['Micro Label'].unique()
macro_classes = data['Macro Label'].unique()

In [None]:
# Train test split is psuedo-randomic. for this reason we do it just once.
train_data.to_csv('data/iiotset_train.csv', index=0)
test_data.to_csv('data/iiotset_test.csv', index=0)

# Checkpoint 2:

In [None]:
train_data = pd.read_csv('data/iiotset_train.csv', low_memory=False)
test_data = pd.read_csv('data/iiotset_test.csv', low_memory=False)

data = pd.concat([train_data, test_data])

micro_classes = data['Micro Label'].unique()
macro_classes = data['Macro Label'].unique()

with open('data/micro_label_encoder.pkl', 'rb') as file:
    micro_label_encoder = pickle.load(file)

with open('data/macro_label_encoder.pkl', 'rb') as file:
    macro_label_encoder = pickle.load(file)

# Training:

## helper code:

In [None]:
def save_stuff(prefix):
    torch.save(
        micro_classifier.state_dict(),
        prefix+'_micro_classifier.pt')
    torch.save(
        macro_classifier.state_dict(),
        prefix+'_macro_classifier.pt')
    torch.save(
        micro_anomaly_detector.state_dict(),
        prefix+'_micro_anomaly_detector.pt')
    torch.save(
        macro_anomaly_detector.state_dict(),
        prefix+'_macro_anomaly_detector.pt')

In [None]:
def micro_classif(
        sample_batch,
        batch_idx,
        closed_set=True,
        open_set=True):

    global cs_cm_1
    global os_cm_1
    global metrics_dict

    # get masks: THESE ARE NOT COMPLEMETARY!
    zda_mask, \
        known_classes_mask, \
        unknown_1_mask, \
        active_query_mask = utils.get_masks_1(
            sample_batch[1],
            N_QUERY,
            device=device)

    dec_1_loss_b = None
    micro_loss = None
    
    # forward passes
    micro_logits = micro_classifier(sample_batch[0])
    
    # get one_hot_labels:
    oh_labels = utils.get_oh_labels(
        decimal_labels=sample_batch[1][:, 1].long(),
        total_classes=max_prototype_buffer_micro,
        device=device)
    
    # Closed set confusion matrix
    cs_cm_1 = cs_cm_1 + utils.efficient_cm(
        preds=micro_logits[active_query_mask].detach(),
        targets=sample_batch[1][:, 1][active_query_mask].long())

    # accuracies:
    CS_acc = utils.get_acc(
        logits_preds=micro_logits[active_query_mask].detach(),
        oh_labels=oh_labels[active_query_mask])

    metrics_dict['CS_accuracies'][batch_idx] = CS_acc.detach()

    if closed_set:

        # loss computation
        micro_loss = micro_multiclass_error(
            micro_logits[active_query_mask],
            sample_batch[1][:, 1][active_query_mask].long())        

        # for reporting:
        metrics_dict['losses_1a'][batch_idx] = micro_loss.detach()

    if open_set:
        
        # Our decoding:
        inputs_for_os = micro_logits[unknown_1_mask]
        predicted_unknown_1s = micro_anomaly_detector(
            inputs_for_os[:,known_classes_mask]).squeeze(-1)
        
        # open-set loss:
        dec_1_loss_b = decoder_1b_criterion(
            predicted_unknown_1s,
            zda_mask[unknown_1_mask].float())

        # Open set confusion matrix
        os_cm_1 += utils.efficient_os_cm(
            preds=(predicted_unknown_1s.detach() > 0.5).long(),
            targets=zda_mask[unknown_1_mask].long()
            )
        
        """
        
        # Bovenzi OS decoding:
        predicted_unknown_1s = micro_anomaly_detector(
            micro_logits[unknown_1_mask].max(1)[0])
        
        # open-set loss:
        dec_1_loss_b = decoder_1b_criterion(
            1 - predicted_unknown_1s,
            zda_mask[unknown_1_mask].float())

        # Open set confusion matrix
        os_cm_1 += utils.efficient_os_cm(
            preds=(predicted_unknown_1s.detach() == 0).long(),
            targets=zda_mask[unknown_1_mask].long()
            )
        """
        
        OS_b_acc = utils.get_balanced_accuracy(
                    os_cm=os_cm_1,
                    n_w=balanced_acc_n_w
                    )

        metrics_dict['losses_1b'][batch_idx] = dec_1_loss_b.detach()
        metrics_dict['OS_B_accuracies'][batch_idx] = OS_b_acc

    return micro_loss, \
        dec_1_loss_b, \
        micro_logits

In [None]:
def macro_classif(
        sample_batch,
        batch_idx,
        closed_set=True,
        open_set=True):

    global cs_cm_2
    global os_cm_2
    global metrics_dict

    # get masks: THESE ARE NOT COMPLEMETARY!
    type_A_mask, known_macro_classes_mask, \
        unknown_2_mask, active_query_mask_2 = utils.get_masks_2(
            sample_batch[1],
            N_QUERY,
            device=device)

    dec_2_loss_b = None
    macro_loss = None
    
    # forward passes
    macro_logits = macro_classifier(sample_batch[0])
    
    # get one_hot_labels:
    oh_labels = utils.get_oh_labels(
        decimal_labels=sample_batch[1][:, 0].long(),
        total_classes=max_prototype_buffer_macro,
        device=device)
    
    # accuracies:
    CS_acc_2 = utils.get_acc(
        logits_preds=macro_logits[active_query_mask_2].detach(),
        oh_labels=oh_labels[active_query_mask_2])

    # Closed set confusion matrix
    cs_cm_2 = cs_cm_2 + utils.efficient_cm(
        preds=macro_logits[active_query_mask_2].detach(),
        targets=sample_batch[1][:, 0][active_query_mask_2].long())
    
    metrics_dict['CS_2_accuracies'][batch_idx] = CS_acc_2.detach()

    if closed_set:
        
        # loss computation
        macro_loss = macro_multiclass_error(
            macro_logits[active_query_mask_2],
            sample_batch[1][:, 0][active_query_mask_2].long())

        # for reporting:
        metrics_dict['losses_2a'][batch_idx] = macro_loss.detach()

    if open_set:
        
        # our decoding:
        inputs_for_decoding = macro_logits[unknown_2_mask]
        predicted_unknown_2s = macro_anomaly_detector(
            inputs_for_decoding[:,known_macro_classes_mask]).squeeze(-1)
            
        # open-set loss:
        dec_2_loss_b = decoder_2b_criterion(
            predicted_unknown_2s,
            type_A_mask[unknown_2_mask].float())

        # Open set confusion matrix
        os_cm_2 += utils.efficient_os_cm(
            preds=(predicted_unknown_2s.detach() > 0.5).long(),
            targets=type_A_mask[unknown_2_mask].long()
            )
        """
        
        # Bovenzi decoding:
        predicted_unknown_2s = macro_anomaly_detector(
            macro_logits[unknown_2_mask].max(1)[0])
        
        # open-set loss:
        dec_2_loss_b = decoder_2b_criterion(
            1 - predicted_unknown_2s,
            type_A_mask[unknown_2_mask].float())

        # Open set confusion matrix
        os_cm_2 += utils.efficient_os_cm(
            preds=(predicted_unknown_2s.detach() == 0).long(),
            targets=type_A_mask[unknown_2_mask].long()
            )
        """
        
        OS_2_B_acc = utils.get_balanced_accuracy(
                    os_cm=os_cm_2,
                    n_w=balanced_acc_n_w
                    )

        metrics_dict['losses_2b'][batch_idx] = dec_2_loss_b.detach()
        metrics_dict['OS_2_B_accuracies'][batch_idx] = OS_2_B_acc.detach()

    return macro_loss, \
        dec_2_loss_b, \
        macro_logits

## init data:

In [42]:
natural_inputs_dim = 46
save = False
wb = True

In [43]:
# Generate Data
torch_seed = 777
torch.manual_seed(torch_seed)

#
#
# Initialize
#
#
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataset and Dataloader:
train_dataset = dm.RealFewShotDataset_LowDim(
    features=train_data.drop(columns=[
        'Micro Label',
        'Macro Label',
        'ZdA',
        'Type_A_ZdA',
        'Type_B_ZdA']).values,
    df=train_data,
    micro_label_enc=micro_label_encoder,
    macro_label_enc=macro_label_encoder)

test_dataset = dm.RealFewShotDataset_LowDim(
    features=test_data.drop(columns=[
        'Micro Label',
        'Macro Label',
        'ZdA',
        'Type_A_ZdA',
        'Type_B_ZdA']).values,
    df=test_data,
    micro_label_enc=micro_label_encoder,
    macro_label_enc=macro_label_encoder)


# Number of classes per task :
# two of them are ZdAs, one is a type B and the other a type A
N_WAY = 4
N_SHOT = 5   # Number of samples per class in the support set
N_QUERY = 15  # Number of samples per class in the query set

n_train_tasks = 500    # For speedy tests, reduce here...
n_eval_tasks = 50     # For speedy tests, reduce here...


num_of_test_micro_classes = len(test_dataset.micro_classes)
train_micro_classes = train_dataset.micro_classes
num_of_train_micro_classes = len(train_micro_classes)
train_macro_classes = np.unique(train_dataset.macro_labels)
num_of_train_macro_classes = len(train_macro_classes)

train_loader = DataLoader(
    dataset=train_dataset,
    sampler=dm.FewShotSampler(
                dataset=train_dataset,
                n_tasks=n_train_tasks,
                classes_per_it=N_WAY,
                k_shot=N_SHOT,
                q_shot=N_QUERY),
    num_workers=4,
    drop_last=True,
    collate_fn=dm.convenient_cf)


test_loader = DataLoader(
    dataset=test_dataset,
    sampler=dm.FewShotSampler(
                dataset=test_dataset,
                n_tasks=n_eval_tasks,
                classes_per_it=N_WAY,
                k_shot=N_SHOT,
                q_shot=N_QUERY),
    num_workers=4,
    drop_last=True,
    collate_fn=dm.convenient_cf)


# reproducibility
train_loader.sampler.reset()
test_loader.sampler.reset()

## init architectures:

In [45]:
###
# Hyper params:
###
max_prototype_buffer_micro = 40
max_prototype_buffer_macro = 10
lr = 0.001
# training parameters
n_epochs = 100
norm = "batch"
dropout = 0.1
patience = 60
lambda_os = "NAN"

processor_attention_heads = "NAN"
h_dim = 1024
report_step_frequency = 100

pos_weight_1 = 1
pos_weight_2 = 3

architectures = 'MLP + relu (learnable)'
balanced_acc_n_w = 0.5
attr_w = "NAN"
rep_w = "NAN"
tau_u = 0.85
run_name = f'Bvnzi confec online'

if wb:
    wandb.init(project='Nero_1.1',
               name=run_name,
               config={"N_SHOT": N_SHOT,
                       "N_QUERY": N_QUERY,
                       "N_WAY": N_WAY,
                       "num_of_test_classes": num_of_test_micro_classes,
                       "num_of_train_classes": num_of_train_micro_classes,
                       "train_batch_size": N_WAY * (N_SHOT + N_QUERY),
                       "len(train_loader)": train_loader.sampler.n_tasks,
                       "len(test_dataset)": test_loader.sampler.n_tasks,
                       "max_prototype_buffer_micro": max_prototype_buffer_micro,
                       "max_prototype_buffer_macro": max_prototype_buffer_macro,
                       "device": device,
                       "natural_inputs_dim": natural_inputs_dim,
                       "h_dim": h_dim,
                       "lr": lr,
                       "n_epochs": n_epochs,
                       "norm": norm,
                       "dropout": dropout,
                       "patience": patience,
                       'zdas': data[data.ZdA == True]['Micro Label'].unique(),
                       "lambda_os": lambda_os,
                       "positive_weight_1": pos_weight_1,
                       "positive_weight_2": pos_weight_2,
                       "architectures": architectures,
                       "balanced_acc_n_w": balanced_acc_n_w,
                       "attr_w": attr_w,
                       "rep_w": rep_w
                       })
else:
    print(run_name)


# Initialize the autoencoder model, loss function, and optimizer
micro_classifier = models.Simple_MLP_Classifier(
    input_size=natural_inputs_dim,
    hidden_size=h_dim,
    output_size=max_prototype_buffer_micro,
    dropout=dropout).to(torch.float64).to(device)

"""
micro_anomaly_detector = models.LearnableAnomalyDetectionModule(tau_u)\
    .to(torch.float64).to(device)
"""
"""
micro_anomaly_detector = models.FixedAnomalyDetectionModule(tau_u)\
    .to(torch.float64).to(device)

"""
micro_anomaly_detector = models.Confidence_Decoder(
                in_dim=N_WAY-2,  # Subtract 1 ZdA
                dropout=dropout,
                device=device
                ).to(torch.float64).to(device)

decoder_1b_criterion = nn.BCEWithLogitsLoss(
    pos_weight=torch.Tensor([pos_weight_1])).to(device)

micro_multiclass_error = nn.CrossEntropyLoss().to(device)

macro_classifier = models.Simple_MLP_Classifier(
    input_size=natural_inputs_dim,
    hidden_size=h_dim,
    output_size=max_prototype_buffer_macro,
    dropout=dropout).to(torch.float64).to(device)

"""
macro_anomaly_detector = models.LearnableAnomalyDetectionModule(tau_u)\
    .to(torch.float64).to(device)
"""
"""
macro_anomaly_detector = models.FixedAnomalyDetectionModule(tau_u)\
    .to(torch.float64).to(device)
"""
macro_anomaly_detector = models.Confidence_Decoder(
                in_dim=N_WAY-1,
                dropout=dropout,
                device=device
                ).to(torch.float64).to(device)

decoder_2b_criterion = nn.BCEWithLogitsLoss(
    pos_weight=torch.Tensor([pos_weight_2])).to(device)

macro_multiclass_error = nn.CrossEntropyLoss().to(device)


params_for_optimizer = \
        list(macro_classifier.parameters()) + \
        list(micro_classifier.parameters()) + \
        list(micro_anomaly_detector.parameters()) + \
        list(macro_anomaly_detector.parameters())

optimizer = optim.Adam(
    params_for_optimizer,
    lr=lr)


# TRAINING
max_acc_classif_micro = torch.zeros(1, device=device)
epochs_without_improvement = 0

In [46]:
wandb.watch(macro_classifier)
wandb.watch(micro_classifier)
wandb.watch(macro_anomaly_detector)
wandb.watch(micro_anomaly_detector)

[]

## Train Online:

In [47]:
for epoch in tqdm(range(n_epochs)):

    # TRAIN
    macro_classifier.train()
    micro_classifier.train()
    macro_anomaly_detector.train()
    micro_anomaly_detector.train()

    # reset conf Mats
    cs_cm_1 = torch.zeros(
        [max_prototype_buffer_micro, max_prototype_buffer_micro],
        device=device)
    os_cm_1 = torch.zeros([2, 2], device=device)
    cs_cm_2 = torch.zeros(
        [max_prototype_buffer_macro, max_prototype_buffer_macro],
        device=device)
    os_cm_2 = torch.zeros([2, 2], device=device)

    # reset metrics dict
    metrics_dict = utils.reset_metrics_dict_optimized(
        train_loader.sampler.n_tasks,
        device)

    for batch_idx, sample_batch in enumerate(train_loader):
        # go to cuda:
        sample_batch = sample_batch[0].to(device), sample_batch[1].to(device)

        micro_loss, os_1_loss, micro_logits = micro_classif(
            sample_batch,
            batch_idx)

        macro_loss, os_2_loss, macro_logits = macro_classif(
            sample_batch,
            batch_idx)

        # Backward pass and optimization
        loss = micro_loss + macro_loss + os_1_loss + os_2_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Reporting
        step = batch_idx + (epoch * train_loader.sampler.n_tasks)

        if step % report_step_frequency == 0:
            utils.reporting_simple_optimized(
                'train',
                epoch,
                metrics_dict,
                batch_idx,
                report_step_frequency,
                wb,
                wandb)

    with torch.inference_mode():

        # Evaluation
        macro_classifier.eval()
        micro_classifier.eval()
        macro_anomaly_detector.eval()
        micro_anomaly_detector.eval()

        # reset conf Mats
        cs_cm_1 = torch.zeros(
            [max_prototype_buffer_micro, max_prototype_buffer_micro],
            device=device)
        os_cm_1 = torch.zeros([2, 2], device=device)
        cs_cm_2 = torch.zeros(
            [max_prototype_buffer_macro, max_prototype_buffer_macro],
            device=device)
        os_cm_2 = torch.zeros([2, 2], device=device)

        # reset metrics dict
        metrics_dict = utils.reset_metrics_dict_optimized(
            test_loader.sampler.n_tasks,
            device)

        # go!
        for batch_idx, sample_batch in enumerate(test_loader):
            # go to cuda:
            sample_batch = sample_batch[0].to(device), sample_batch[1].to(device)

            micro_loss, os_1_loss, micro_logits = micro_classif(
                sample_batch,
                batch_idx)

            macro_loss, os_2_loss, macro_logits = macro_classif(
                sample_batch,
                batch_idx)

            # Reporting
            step = batch_idx + (epoch * test_loader.sampler.n_tasks)

            if step % report_step_frequency == 0:
                utils.reporting_simple_optimized(
                    'eval',
                    epoch,
                    metrics_dict,
                    batch_idx,
                    report_step_frequency,
                    wb,
                    wandb)

        pu.super_plotting_function(
                phase='Evaluation',
                labels=sample_batch[1].cpu(),
                hiddens_1=micro_logits.detach().cpu(),
                hiddens_2=macro_logits.detach().cpu(),
                scores_1=micro_logits.detach().cpu(),
                scores_2=macro_logits.detach().cpu(),
                cs_cm_1=cs_cm_1.cpu(),
                cs_cm_2=cs_cm_2.cpu(),
                os_cm_1=os_cm_1.cpu(),
                os_cm_2=os_cm_2.cpu(),
                wb=wb,
                wandb=wandb,
                complete_micro_classes=micro_classes,
                complete_macro_classes=macro_classes
            )
        
        # Checking for improvement
        curr_acc_classif_micro = metrics_dict['CS_accuracies'].mean().item()

        if curr_acc_classif_micro > max_acc_classif_micro:
            max_acc_classif_micro = curr_acc_classif_micro
            epochs_without_improvement = 0
            if save:
                print(f'saving models at epoch {epoch}')
                save_stuff(run_name)
        else:
            epochs_without_improvement += 1

        '''
        if epochs_without_improvement >= patience:
            print(f'Early stopping at episode {step}')
            if wb:
                wandb.log({'Early stopping at episode': step})
            break
        '''
print(f'max_acc_classif_micro: {max_acc_classif_micro}')

if wb:
    wandb.finish()

  0%|          | 0/100 [00:00<?, ?it/s]

max_acc_classif_micro: 0.9520000219345093




VBox(children=(Label(value='32.129 MB of 32.129 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
CS1 accuracy_eval:,▃█▇▇█▇▁██▂▆▆██▆▇█▃▃▄▄▁▇█▃▃█▃▇████▇▂█▆█▆▂
CS1 accuracy_train:,▆█▆▆▆▆▆▆▆▇▅▆▁▆▆▆▆▅▆▆▅▇▆█▆▆▅▆▅▆▆█▆▆▅▆▆▆▆█
CS2 accuracy_eval:,▁█▃▃▃▄▅█▃▁▃▂▇▇█▃▃▂▆▂▂▅▃▄▆██▃▂▃▃▃▄▃▃▃▃█▂▃
CS2 accuracy_train:,▄█▅▆▆▅▅▆▆▇▅▅▁▅▅▆▅▅▆▅▅▆▅█▅▆▇▇▇▆▆█▇▆▅▇▆▆▆▇
OS1 Bal. accuracy_eval:,▂▅▅▅▅▅▆█▅▂▃▅███▅█▇▇▄▂▃▅█▄▄█▁▅▅▅██▅▁█▅█▅▁
OS1 Bal. accuracy_train:,▁▄▆▆▆▆▆▇▆▆▆▆▆▆▆▆▆▆▆▆▅▆▆█▆▆▆▆▆▆▆▅▆▆▅▆▆▆▆▆
OS1 accuracy_eval:,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
OS1 accuracy_train:,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
OS2 Bal. accuracy_eval:,▂▄▁██▇▇▇█▇▇█▆▆▇▅▇▅▇▅▇▆█▅▆▆▆▇███▆▅█▇▆█▆█▇
OS2 Bal. accuracy_train:,▂▁▂▂▄▄▄▄▃▃▄▄▆▄▄▄▄▄▄▄▅▅▅█▅▅▅▅▅▅▅█▆▅▇▅▅▆▆█

0,1
CS1 accuracy_eval:,0.73333
CS1 accuracy_train:,0.91782
CS2 accuracy_eval:,0.66667
CS2 accuracy_train:,0.95776
OS1 Bal. accuracy_eval:,0.36667
OS1 Bal. accuracy_train:,0.87676
OS1 accuracy_eval:,0.0
OS1 accuracy_train:,0.0
OS2 Bal. accuracy_eval:,0.83333
OS2 Bal. accuracy_train:,0.69266


In [None]:
wandb.finish()

# Train (Offline)

## first phase: closed set classification:

In [None]:
for epoch in tqdm(range(n_epochs)):

    # TRAIN
    macro_classifier.train()
    micro_classifier.train()

    # reset conf Mats
    cs_cm_1 = torch.zeros(
        [max_prototype_buffer_micro, max_prototype_buffer_micro],
        device=device)
    cs_cm_2 = torch.zeros(
        [max_prototype_buffer_macro, max_prototype_buffer_macro],
        device=device)

    # reset metrics dict
    metrics_dict = utils.reset_metrics_dict_optimized(
        train_loader.sampler.n_tasks,
        device)

    for batch_idx, sample_batch in enumerate(train_loader):
        # go to cuda:
        sample_batch = sample_batch[0].to(device), sample_batch[1].to(device)

        micro_loss, _, micro_logits = micro_classif(
            sample_batch,
            batch_idx,
            True,
            False)

        macro_loss, _, macro_logits = macro_classif(
            sample_batch,
            batch_idx,
            True,
            False)

        # Backward pass and optimization
        loss = micro_loss + macro_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Reporting
        step = batch_idx + (epoch * train_loader.sampler.n_tasks)

        if step % report_step_frequency == 0:
            utils.reporting_simple_optimized(
                'train',
                epoch,
                metrics_dict,
                batch_idx,
                report_step_frequency,
                wb,
                wandb)

    with torch.inference_mode():

        # Evaluation
        macro_classifier.eval()
        micro_classifier.eval()

        # reset conf Mats
        cs_cm_1 = torch.zeros(
            [max_prototype_buffer_micro, max_prototype_buffer_micro],
            device=device)
        cs_cm_2 = torch.zeros(
            [max_prototype_buffer_macro, max_prototype_buffer_macro],
            device=device)

        # reset metrics dict
        metrics_dict = utils.reset_metrics_dict_optimized(
            test_loader.sampler.n_tasks,
            device)

        # go!
        for batch_idx, sample_batch in enumerate(test_loader):
            # go to cuda:
            sample_batch = sample_batch[0].to(device), sample_batch[1].to(device)

            micro_loss, _, micro_logits = micro_classif(
                sample_batch,
                batch_idx,
                True,
                False)

            macro_loss, _, macro_logits = macro_classif(
                sample_batch,
                batch_idx,
                True,
                False)

            # Reporting
            step = batch_idx + (epoch * test_loader.sampler.n_tasks)

            if step % report_step_frequency == 0:
                utils.reporting_simple_optimized(
                    'eval',
                    epoch,
                    metrics_dict,
                    batch_idx,
                    report_step_frequency,
                    wb,
                    wandb)

        # Checking for improvement
        curr_acc_classif_micro = metrics_dict['CS_accuracies'].mean().item()

        if curr_acc_classif_micro > max_acc_classif_micro:
            max_acc_classif_micro = curr_acc_classif_micro
            epochs_without_improvement = 0
            if save:
                print(f'saving models at epoch {epoch}')
                save_stuff(run_name)
        else:
            epochs_without_improvement += 1


print(f'max_acc_classif_micro: {max_acc_classif_micro}')

if wb:
    wandb.finish()

## second phase: open set classification:

In [None]:
micro_classifier.load_state_dict(torch.load('Bovnz_micro_classifier.pt'))
macro_classifier.load_state_dict(torch.load('Bovnz_macro_classifier.pt'))
micro_classifier.eval()
macro_classifier.eval()

# Freeze the pre-trained processor
for param in micro_classifier.parameters():
    param.requires_grad = False
# Freeze the pre-trained processor
for param in macro_classifier.parameters():
    param.requires_grad = False

In [None]:
for epoch in tqdm(range(n_epochs)):

    # TRAIN
    macro_classifier.train()
    micro_classifier.train()

    # reset conf Mats
    os_cm_1 = torch.zeros([2, 2], device=device)
    os_cm_2 = torch.zeros([2, 2], device=device)

    # reset metrics dict
    metrics_dict = utils.reset_metrics_dict_optimized(
        train_loader.sampler.n_tasks,
        device)

    for batch_idx, sample_batch in enumerate(train_loader):
        # go to cuda:
        sample_batch = sample_batch[0].to(device), sample_batch[1].to(device)

        _, os_micro_loss, micro_logits = micro_classif(
            sample_batch,
            batch_idx,
            False,
            True)

        _, os_macro_loss, macro_logits = macro_classif(
            sample_batch,
            batch_idx,
            False,
            True)

        # Backward pass and optimization
        loss = os_micro_loss + os_macro_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Reporting
        step = batch_idx + (epoch * train_loader.sampler.n_tasks)

        if step % report_step_frequency == 0:
            utils.reporting_simple_optimized(
                'train',
                epoch,
                metrics_dict,
                batch_idx,
                report_step_frequency,
                wb,
                wandb)

    with torch.inference_mode():

        # Evaluation
        macro_classifier.eval()
        micro_classifier.eval()

        # reset conf Mats
        os_cm_1 = torch.zeros([2, 2], device=device)
        os_cm_2 = torch.zeros([2, 2], device=device)

        # reset metrics dict
        metrics_dict = utils.reset_metrics_dict_optimized(
            test_loader.sampler.n_tasks,
            device)

        # go!
        for batch_idx, sample_batch in enumerate(test_loader):
            # go to cuda:
            sample_batch = sample_batch[0].to(device), sample_batch[1].to(device)

            micro_loss, _, micro_logits = micro_classif(
                sample_batch,
                batch_idx,
                False,
                True)

            macro_loss, _, macro_logits = macro_classif(
                sample_batch,
                batch_idx,
                False,
                True)

            # Reporting
            step = batch_idx + (epoch * test_loader.sampler.n_tasks)

            if step % report_step_frequency == 0:
                utils.reporting_simple_optimized(
                    'eval',
                    epoch,
                    metrics_dict,
                    batch_idx,
                    report_step_frequency,
                    wb,
                    wandb)

        # Checking for improvement
        curr_acc_classif_micro = metrics_dict['CS_accuracies'].mean().item()

        if curr_acc_classif_micro > max_acc_classif_micro:
            max_acc_classif_micro = curr_acc_classif_micro
            epochs_without_improvement = 0
            if save:
                print(f'saving models at epoch {epoch}')
                save_stuff(run_name)
        else:
            epochs_without_improvement += 1


print(f'max_acc_classif_micro: {max_acc_classif_micro}')

if wb:
    wandb.finish()

In [None]:
wandb.finish()