In [None]:
from IPython.display import display, Markdown, Latex
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torchattacks import PGD
import random

import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import Omniglot
from PIL import Image

from datetime import datetime

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Will use:", device)

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
data_dir = '.'

## Loading datasets

In [None]:
from hypnettorch.data import FashionMNISTData, MNISTData
from hypnettorch.data.dataset import Dataset
from hypnettorch.mnets import LeNet
from hypnettorch.mnets.resnet import ResNet
from hypnettorch.mnets.mlp import MLP
from hypnettorch.hnets import HMLP

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import learn2learn as l2l
import copy

np.random.seed(42)
torch.manual_seed(42)

# Load Min-image net
mini_train = l2l.vision.datasets.MiniImagenet(root=data_dir, mode='train', download=True)
mini_train = l2l.data.MetaDataset(mini_train)

mini_valid = l2l.vision.datasets.MiniImagenet(root=data_dir, mode='validation', download=True)
mini_valid = l2l.data.MetaDataset(mini_valid)

mini_test = l2l.vision.datasets.MiniImagenet(root=data_dir, mode='test', download=True)
mini_test = l2l.data.MetaDataset(mini_test)

In [None]:
# Create a DataLoader for batching and shuffling the data
batch_size = len(mini_train)  # Set batch size to the total number of examples to load all data at once
# print(batch_size)
data_loader_train = DataLoader(mini_train, batch_size=batch_size, shuffle=False)

# Iterate through the DataLoader
for batch in data_loader_train:
    images, labels = batch
    # Convert PyTorch tensors to NumPy arrays
    dataset_train = images.numpy()
    dataset_lbl_train = labels.numpy()
    
# print("Dataset train dimension:", dataset_train.shape)
# print("Labels train dimension:", dataset_lbl_train.shape)
# print(np.min(dataset_lbl_train))
# print(np.max(dataset_lbl_train))
lbls_0 = np.unique(dataset_lbl_train).tolist()
# print(lbls_0)

batch_size = len(mini_valid)  # Set batch size to the total number of examples to load all data at once
data_loader_valid = DataLoader(mini_valid, batch_size=batch_size, shuffle=False)

# Iterate through the DataLoader
for batch in data_loader_valid:
    images, labels = batch
    # Convert PyTorch tensors to NumPy arrays
    dataset_valid = images.numpy()
    dataset_lbl_valid = labels.numpy()
# Such that the mapping when we extend the labels works
dataset_lbl_valid = dataset_lbl_valid + len(lbls_0)

# print(np.unique(dataset_lbl_train).tolist())
# print(np.unique(dataset_lbl_valid).tolist())
    
# print("Dataset valid dimension:", dataset_valid.shape)
# print("Labels valid dimension:", dataset_lbl_valid.shape)
# print(np.min(dataset_lbl_valid))
# print(np.max(dataset_lbl_valid))

batch_size = len(mini_test)  # Set batch size to the total number of examples to load all data at once
data_loader_test = DataLoader(mini_test, batch_size=batch_size, shuffle=False)

# Iterate through the DataLoader
for batch in data_loader_test:
    images, labels = batch
    # Convert PyTorch tensors to NumPy arrays
    dataset_test = images.numpy()
    dataset_lbl_test = labels.numpy()
    
# print("Dataset valid dimension:", dataset_test.shape)
# print("Labels valid dimension:", dataset_lbl_test.shape)
# print(np.min(dataset_lbl_test))
# print(np.max(dataset_lbl_test))

## Create 2 different datasets for two disjoint set of labels (deterministic for now)

In [None]:
n_classes = len(np.unique(dataset_lbl_train)) + len(np.unique(dataset_lbl_valid))

dataset_0_, dataset_0_lbl = dataset_train, dataset_lbl_train

dataset_0 = dataset_0_.reshape(dataset_0_.shape[0], -1)
dataset_0_lbl_t = torch.tensor(dataset_0_lbl)
unique_values, counts = torch.unique(dataset_0_lbl_t, return_counts=True)
# print("In train minimum and maximum amount of sample per classes in the dataset")
# print("Each classes contains at least", torch.min(counts).item(), "samples")
# print("Each classes contains at most", torch.max(counts).item(), "samples")

# print("Number of classes", n_classes)
# print("Shape of the dataset_0:",dataset_0.shape)
# print("Shape of the lbls", dataset_0_lbl.shape)
# print("Min label", np.min(dataset_0_lbl))
# print("Max label", np.max(dataset_0_lbl))

# dataset_1_, dataset_1_lbl = dataset_test, dataset_lbl_test
# dataset_1 = dataset_1_.reshape(dataset_1_.shape[0], -1)
# dataset_1_lbl_t = torch.tensor(dataset_1_lbl)
# unique_values, counts = torch.unique(dataset_1_lbl_t, return_counts=True)
# print("In train minimum and maximum amount of sample per classes in the dataset")
# print("Each classes contains at least", torch.min(counts).item(), "samples")
# print("Each classes contains at most", torch.max(counts).item(), "samples")

# print("Shape of the dataset_1:",dataset_1.shape)
# print("Min label", np.min(dataset_1_lbl))
# print("Max label", np.max(dataset_1_lbl))

# print("Some labels in set 1:", dataset_0_lbl[0:10])
# print("Some labels in set 2:", dataset_1_lbl[0:10])
# assert(np.all(np.isin(dataset_0_lbl, lbls_0)))
# assert(np.all(np.isin(dataset_1_lbl, lbls_1)))

# torch_dataset = torch.tensor(dataset_full_lbl)
# unique_values, counts = torch.unique(torch_dataset, return_counts=True)

# print("Minimum and maximum amount of sample per classes in the dataset")
# print("Each classes contains at least", torch.min(counts).item(), "samples")
# print("Each classes contains at most", torch.max(counts).item(), "samples")

dataset_valid = dataset_valid.reshape(dataset_valid.shape[0], -1)

In [None]:
from scipy.ndimage import zoom, rotate
from scipy.interpolate import interp2d

def rotate_dataset(dataset, angle):
    dataset_unflatten = dataset.reshape(-1, 3, 84, 84)
    rotated_data = rotate(dataset_unflatten, angle, axes=(2, 3), reshape=False)
    return rotated_data.reshape(-1, 3 * 84 * 84)

def zoom_dataset(dataset, zoom_factor):
    dataset_unflatten = dataset.reshape(-1, 3, 84, 84)
    zoomed_dataset = zoom(dataset_unflatten, (1, 1, zoom_factor, zoom_factor), order=1)
    
    original_size = dataset_unflatten.shape
    zoomed_size = zoomed_dataset.shape
    diff = int((zoomed_size[2] - original_size[2])/2)
    interpolated_data = zoomed_dataset[:,:,diff:diff+original_size[2], diff:diff+original_size[2]]
    return interpolated_data.reshape(-1, 3 * 84 * 84)

### Models definition

In [None]:
# ResNet18
class ResNet(nn.Module):
    def __init__(self, z_length):
        super(ResNet, self).__init__()
        self.z_length = z_length
        resnet18 = models.resnet18(pretrained=False)
        resnet18.conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        resnet18.avgpool = torch.nn.AdaptiveAvgPool2d(1)
        resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, self.z_length)
        self.resnet = resnet18

    def forward(self, x):
        x = x.view(-1, 3, 84, 84)
        return self.resnet(x)

In [None]:
from models import ResNet12

class Hypershot(nn.Module):
    def __init__(self, kcnn_input_channels, z_length, kcnn_weights,
                       hnet_hidden_layers, hnet_hidden_size, hnet_weights,
                       mnet_hidden_layers, mnet_hidden_size,
                       K, W, i_dim=28, i_cha=1, load_w = False):
        super(Hypershot, self).__init__()
        
        self.kcnn_input_channels = kcnn_input_channels
        self.z_length = z_length
        self.kcnn_weights = kcnn_weights
        self.hnet_hidden_layers = hnet_hidden_layers
        self.hnet_hidden_size = hnet_hidden_size
        self.hnet_weights = hnet_weights
        self.mnet_hidden_layers = mnet_hidden_layers
        self.mnet_hidden_size = mnet_hidden_size
        
        # Images dimension
        self.i_dim = i_dim
        self.i_cha = i_cha
        
        self.K = K
        self.W = W
        self.kernel = None
        self.z_space = None
        
        # self.kcnn = ResNet12(output_size=z_length, hidden_size=64, channels=self.kcnn_input_channels, dropblock_dropout=0, avg_pool=False)
        self.kcnn = ResNet(z_length=self.z_length)
        self.mnet = MLP(n_in=W, n_out=W, hidden_layers=self.mnet_hidden_layers * [self.mnet_hidden_size])
        # K**2 is the size of the kernel
        self.hnet = HMLP(self.mnet.param_shapes, uncond_in_size=W**2, cond_in_size=0, \
                         layers = self.hnet_hidden_layers * [self.hnet_hidden_size],\
                         num_cond_embs=0)
        self.hnet.apply_hyperfan_init(mnet=self.mnet)
        
        if load_w:
            hnet.load_state_dict(torch.load(self.hnet_weights))
            kcnn.load_state_dict(torch.load(self.kcnn_weights))
            
    
    def compute_kernel(self, X, y):
        """
        Compute Hypershot kernel for a support set X and label y
        It takes the average of the z's for each label as suggested in the Hypershot paper

        Args:
            X (tensor): Support set used to compute the kernel
            y (tensor): corresponding labels

        Returns:
            type: embeddings, kernel
        """
        # Obtain the indices that would sort y_test
        indices = torch.argsort(y)

        # Use the indices to sort the rows of X_test
        sorted_X = X[indices].to(device)
        sorted_y = y[indices].to(device)

        reshaped_X = sorted_X.view(sorted_X.shape[0], self.i_cha, self.i_dim, self.i_dim).to(device)
        nn_X = self.kcnn(reshaped_X)
        
        mean_X = torch.zeros((int(nn_X.shape[0] / self.K), nn_X.shape[1])).to(device)
        for i in range(self.W):
            mean_X[i] = torch.mean(nn_X[i*self.K:(i+1)*self.K], dim = 0)
        norm_mean_X = F.normalize(mean_X, p=2, dim=1)
        
        assert(nn_X.shape==(sorted_X.shape[0], self.z_length))
        
        return norm_mean_X, torch.matmul(norm_mean_X, torch.t(norm_mean_X)) 

    def get_s_and_q_sets(self, X, y, trgt_lbls, q_size):
        """
        Computes a support set for data X for classes in y with K sample per classes
        and corresponding query sets of size q_size.

        Args:
            X (tensor): Data used to compute the sets (can contain label you do not want for your sets)
            y (tensor): corresponding labels
            trgt_lbls : the labels that end up in the sets
            q_size: amount of sample per classes in query set

        Returns:
            type: support set, support set labels, query set, query set labels
        """
        s_set = np.zeros((len(trgt_lbls) * self.K, X.shape[1]))
        s_set_lbl = np.zeros((len(trgt_lbls) * self.K))

        q_set = np.zeros((len(trgt_lbls) * q_size, X.shape[1]))
        q_set_lbl = np.zeros((len(trgt_lbls) * q_size))
        r_s = [random.randint(0, 500) for _ in range(self.K)]
        r_q = [random.randint(0, 500) for _ in range(q_size)]
        for j, l in enumerate(trgt_lbls):
            mask = (y == l)
            masked_data = X[mask]
            masked_lbls = y[mask]
            s_set[j*self.K:(j+1)*self.K] = masked_data[r_s]
            s_set_lbl[j*self.K:(j+1)*self.K] = masked_lbls[r_s]
            q_set[j*q_size:(j+1)*q_size] = masked_data[r_q]
            q_set_lbl[j*q_size:(j+1)*q_size] = masked_lbls[r_q]

        s_set = torch.tensor(s_set, requires_grad=True).to(device).float()
        s_set_lbl = torch.tensor(s_set_lbl, requires_grad=True).to(device).float()
        q_set = torch.tensor(q_set, requires_grad=True).to(device).float()
        q_set_lbl = torch.tensor(q_set_lbl, requires_grad=True).to(device).float()
        
        return s_set, s_set_lbl, q_set, q_set_lbl

    def get_q_sample_features(self, X):
        """
        Computes the final features used for classification, given a query sample X

        Args:
            X (tensor): query sample 

        Returns:
            type: final features use by the main network
        """
        X = X.view(-1, self.i_cha, self.i_dim, self.i_dim)
        zs_q = self.kcnn(X)
        zs_q = F.normalize(zs_q, p=2, dim=1)
        zs_q_m = torch.matmul(self.z_space, torch.t(zs_q))
        return torch.t(zs_q_m)

    def compute_sets_and_features(self, X, y, trgt_lbls, q_size, update_kernel):
        """
        Computes support and query sets given a list of labels. It can update the kernel directly if desired.

        Args:
            X (tensor): data
            y (tensor): data labels
            trgt_lbls : the labels that will be contained inside the returned sets
            q_size : query set's amount of sample per label
            update_kernel (bool) : if we want to update the hypershot kernel and z space directly

        Returns:
            type: final features use by the main network
        """
        s_set, s_set_lbl, q_set, q_set_lbl = self.get_s_and_q_sets(X, y, trgt_lbls, q_size)
        
        z_space, kernel = self.compute_kernel(s_set, s_set_lbl)
        if update_kernel:
            self.z_space = z_space
            self.kernel = kernel
            
        return s_set, s_set_lbl, q_set, q_set_lbl

    def extend_pred_to_nclasses(self, pred, n_c, lbls):
        out = torch.zeros((pred.shape[0], n_c)).to(device)
        int_lbls = [int(x) for x in lbls]
        for i in range(out.shape[0]):
            out[i][int_lbls] = pred[i]
        return out
    
    def update_kernel(s_set, s_set_lbl):
        z_space, kernel = self.compute_kernel(s_set, s_set_lbl)
        self.z_space = z_space
        self.kernel = kernel
    
    def forward(self, x):
        q_features = self.get_q_sample_features(x)
        W = self.hnet(uncond_input=self.kernel.view(1, -1))
        P = self.mnet.forward(q_features, weights=W)
        return P

### Accuracy computation methods

In [None]:
def calc_accuracy_lbls(X_test, y_test, test_classes, hs, n_c, q_size):
    """
    Computes the prediction accuracy for the sample with label test_classes in X_test.
    Mainly used as utility for the calc_accuracy function below.
    
    Args:
        X_test (tensor): entire test set
        y_test (tensor): corresponding labels
        test_classes: the classes we want to consider for testing accuracies (should contain Ks classes)
        hs : hypershot model
        n_c: number of classes in the entire dataset
        q_size : query set's amount of sample per label

    Returns:
        type: accuracy
    """
    
    with torch.no_grad():
        s_set_test, s_set_lbl_test, q_set_test, q_set_lbl_test = \
        hs.compute_sets_and_features(X_test, y_test, test_classes, q_size, True)

        p = hs.forward(q_set_test)
        prediction_extended_acc = hs.extend_pred_to_nclasses(p, n_c, test_classes)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(prediction_extended_acc, q_set_lbl_test.long())
        accuracy = (torch.argmax(prediction_extended_acc,dim=1) == q_set_lbl_test.long()).float().mean().item()
        # print("Correctly predicted samples had labels:", all_q_features_lbls[torch.argmax(prediction_extended_acc,dim=1) == all_q_features_lbls.long()])
    return round(accuracy, 2), round(loss.item(), 2)


def calc_accuracy(X_test, y_test, hs, n_c, q_size):
    """
    Computes the prediction accuracy for the entire X_test test set.
    
    Args:
        X_test (tensor): entire test set
        y_test (tensor): corresponding labels
        hs : hypershot model
        n_c: number of classes in the entire dataset
        q_size : query set's amount of sample per label

    Returns:
        average accuracy and loss
    """
    if not torch.is_tensor(X_test):
        X_test_t = torch.FloatTensor(X_test).to(device)
    else:  
        X_test_t = torch.clone(X_test)
        
    if not torch.is_tensor(y_test):
        y_test_t = torch.FloatTensor(y_test).to(device)
    else:
        y_test_t = torch.clone(y_test)
        
    diff_classes = torch.unique(y_test_t)
    n_diff_classes = diff_classes.shape[0]
    n_sets = int(n_diff_classes / hs.W)
    acc, loss = 0.0, 0.0
    for i in range(n_sets):
        lbls = diff_classes[i*hs.W:(i+1)*hs.W].tolist()
        # print("Looking at classes in acc cal", lbls)
        # print("Using y_test_t", y_test_t)
        d_acc, d_loss = calc_accuracy_lbls(X_test, y_test, lbls, hs, n_c, q_size)
        acc += d_acc
        loss += d_loss
    acc = acc / n_sets
    loss = loss / n_sets
    return round(acc,2), round(loss, 2)

### Adversarial accuracy computation methods

In [None]:
def calc_accuracy_lbls_adv(q_set_test, q_set_lbl_test, test_classes, hs, n_c):
    """
    Compute accuracy for a given query test set. We assume the kernel and the z_space of the given hypershot
    model hs to be set accordingly. 
    
    """
    hs.eval()
    p = hs.forward(q_set_test)
    prediction_extended_acc = hs.extend_pred_to_nclasses(p, n_c, test_classes)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(prediction_extended_acc, q_set_lbl_test.long())
    accuracy = (torch.argmax(prediction_extended_acc,dim=1) == q_set_lbl_test.long()).float().mean().item()
    # print("Correctly predicted samples had labels:", all_q_features_lbls[torch.argmax(prediction_extended_acc,dim=1) == all_q_features_lbls.long()])
    return round(accuracy, 2), round(loss.item(), 2)

def calc_accuracy_adv_helper(X_test, y_test, test_classes, hs, n_c, q_size, pgd):
    """
    Compute a support set and a query set for some given test classes. Only used within calc_accuracy_adv
    
    """
    
    # Updating the kernel here
    s_set_test, s_set_lbl_test, q_set_test, q_set_lbl_test = \
    hs.compute_sets_and_features(X_test, y_test, test_classes, q_size, True)

    unique_values = torch.unique(q_set_lbl_test.long())
    value_to_index = {value.item(): index for index, value in enumerate(unique_values)}
    q_set_lbl_test_m = torch.FloatTensor([value_to_index[value.item()] for value in q_set_lbl_test.long()])

    adv_inputs = 255.0 * pgd_attack(q_set_test / 255.0, q_set_lbl_test_m.long())

    p = hs.forward(adv_inputs)
    prediction_extended_acc = hs.extend_pred_to_nclasses(p, n_c, test_classes)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(prediction_extended_acc, q_set_lbl_test.long())
    accuracy = (torch.argmax(prediction_extended_acc,dim=1) == q_set_lbl_test.long()).float().mean().item()
    # print("Correctly predicted samples had labels:", all_q_features_lbls[torch.argmax(prediction_extended_acc,dim=1) == all_q_features_lbls.long()])
    return accuracy, loss.item()

def calc_accuracy_adv(X_test, y_test, hs, n_c, q_size, pgd):
    """
    Computes the prediction accuracy for the entire X_test test set applying PGD.
    
    Args:
        X_test (tensor): test set
        y_test (tensor): corresponding labels
        hs : hypershot model
        n_c: number of classes in the entire dataset
        q_size : query set's amount of sample per label
        pgd : a pgd attack from torchattacks

    Returns:
        type: average accuracy and loss
    """
    hs.eval()
    if not torch.is_tensor(X_test):
        X_test_t = torch.FloatTensor(X_test).requires_grad_(True).to(device)
    else:  
        X_test_t = torch.clone(X_test).requires_grad_(True)
        
    if not torch.is_tensor(y_test):
        y_test_t = torch.FloatTensor(y_test).requires_grad_(True).to(device)
    else:
        y_test_t = torch.clone(y_test).requires_grad_(True)
        
    diff_classes = torch.unique(y_test_t)
    n_diff_classes = diff_classes.shape[0]
    n_sets = int(n_diff_classes / hs.W)
    acc, loss = 0.0, 0.0
    for i in range(n_sets):
        lbls = diff_classes[i*hs.W:(i+1)*hs.W].tolist()
        d_acc, d_loss = calc_accuracy_adv_helper(X_test, y_test, lbls, hs, n_c, q_size, pgd)
        acc += d_acc
        loss += d_loss
    acc = acc / n_sets
    loss = loss / n_sets
    return round(acc,2), round(loss, 2)

In [None]:
# Configure training.
# We used 100 for training with adversarial querying and K = 5
# We used 50 for training with adversarial querying and K = 1
# Increasing this would be very likely to improve the scores we obtained substantially
nepochs=50

# Epoch after which adversarial training starts
# To disable adversarial training, we can put a number higher than nepochs
# We used 50 for training with adversarial and K = 5
# We used 20 for training with adversarial querying and K = 1
# Increasing this would be very likely to improve the scores we obtained substantially
do_adv_train = 20

# K-shot W-way
# We can change the value for K as we want (tested for 1 and 5)
K = 1
W = 5

# Epsilon parameter of the PGD attack
# Here this will be the maximum epislon, we gradually increase it during training
eps = 8.0 / 255

# Length of the embeddings produced by the CNN (Hypershot parameter)
z_len = 256

load_weights = 0
continue_training = 0

# Amount of sample in query sets during training (for one corresponding support set)
q_train = 15
# Amount of sample in query sets during validation and testing (for one corresponding support set)
q_test = 15

# Factor to sample more or less training support sets per epoch. The total will be 12 * m
m = 5

# Loop in case we want to do statistics (not sued for now)
for o in range(1):
    print("Iteration", o+1)
    
    # If we want to train the model for some more epochs
    if continue_training == 0:
        hypershot = Hypershot(kcnn_input_channels=1, z_length=z_len, kcnn_weights=None,
                              hnet_hidden_layers=1, hnet_hidden_size=256, hnet_weights=None,
                              mnet_hidden_layers=1, mnet_hidden_size=128,
                              K=K, W=W, i_dim=84, i_cha=3, load_w = False).to(device)
        criterion = nn.CrossEntropyLoss()
        n_sets = int(len(lbls_0) / W)
    
    # Optimizer and scheduler initialization
    optimizer = optim.Adam(hypershot.parameters(), lr=0.0001)
    scheduler = CosineAnnealingLR(optimizer, T_max=int(nepochs / 1), eta_min=0.00001)
        
    # Main training loop
    for epoch in range(nepochs): # For each epoch.
        c_eps = max(0, min(((epoch-do_adv_train) / nepochs) * eps * 2, eps))
        print("current eps", c_eps)
        pgd_attack = PGD(hypershot, eps=c_eps, alpha=2.0 / 255, steps=5)
        
        
        print("----------------------- Epoch", epoch, " -----------------------")
        
        # We start by generating training and test set split
        train_test_sets = []
        # all_test_sets = np.empty((0, dataset_0.shape[1]))
        # all_test_sets_lbl = np.empty((0))
        print(n_sets)
        for l_set_id in range(m * n_sets):
            r_lbls = random.sample(lbls_0, W)
            r_lbls.sort()
            if (l_set_id+1) % 10 == 0:
                print("Generated train-test split for", l_set_id+1,"/",m * n_sets)
            mask_b = np.isin(dataset_0_lbl, np.array(r_lbls))
            dataset_0_b, dataset_0_lbl_b = dataset_0[mask_b], dataset_0_lbl[mask_b]
            dataset_0_train, dataset_0_test, dataset_0_lbl_train, dataset_0_lbl_test = \
                            train_test_split(dataset_0_b, dataset_0_lbl_b, random_state=42, test_size=0.05, stratify=dataset_0_lbl_b)
            train_test_sets.append((dataset_0_train, dataset_0_test, dataset_0_lbl_train, dataset_0_lbl_test, r_lbls))
            
            # all_test_sets = np.concatenate((all_test_sets, dataset_0_test), axis=0)
            # all_test_sets_lbl = np.concatenate((all_test_sets_lbl, dataset_0_lbl_test), axis=0)
            train_test_sets.append((dataset_0_train, dataset_0_test, dataset_0_lbl_train, dataset_0_lbl_test, r_lbls))
        
        # Stores the loss over all labels sets
        global_loss = 0.0
        global_loss_float = 0.0
        
        # We loop over all our sets at each epoch
        for l_set_id in range(m * n_sets):
            if (l_set_id + 1) % 10 == 0:
                print("Went over", l_set_id+1, "over", n_sets * m, "sets.")
            hypershot.train()
            loss_dataset_l = 0.0
            (dataset_l_train, dataset_l_test, dataset_l_lbl_train, dataset_l_lbl_test, c_lbls) = train_test_sets[l_set_id]
            
            # This also modifies the kernel and z_space of hypershot
            s_set_train, s_set_lbl_train, q_set_train, q_set_lbl_train = \
            hypershot.compute_sets_and_features(dataset_l_train, dataset_l_lbl_train, c_lbls, q_train, True)
            
            # Formward pass
            dataset_l_P = hypershot.forward(q_set_train)
            prediction_extended = hypershot.extend_pred_to_nclasses(dataset_l_P, n_classes, c_lbls)
            loss_dataset_l += criterion(prediction_extended, q_set_lbl_train.long())

            # Adversarial training
            if epoch == do_adv_train and l_set_id == 0:
                print("Adversarial training starts.")
            if epoch >= do_adv_train:
                unique_values = torch.unique(q_set_lbl_train.long())
                value_to_index = {value.item(): index for index, value in enumerate(unique_values)}
                q_set_lbl_train_m = torch.tensor([value_to_index[value.item()] for value in q_set_lbl_train.long()])

                adv_inputs = 255 * pgd_attack(q_set_train / 255.0, q_set_lbl_train_m.long())
                if l_set_id == 0:
                    print("Norm of pgd", torch.norm(adv_inputs-q_set_train))
                adv_outputs = hypershot.forward(adv_inputs)
                prediction_extended_adv = hypershot.extend_pred_to_nclasses(adv_outputs, n_classes, c_lbls)
                adv_loss = criterion(prediction_extended_adv, q_set_lbl_train.long())
                loss_dataset_l += adv_loss
            
            global_loss += loss_dataset_l
            global_loss_float += loss_dataset_l.item()
            
            # Notice that the train_metrics and the train_metrics_adv are not computed on the same data, hence it is expected
            # to have no relationship between them during training.
            if l_set_id % 10 == 0:
                train_metrics = calc_accuracy(dataset_l_train, dataset_l_lbl_train, hypershot, n_classes, q_train)
                # test_metrics = calc_accuracy(dataset_l_test, dataset_l_lbl_test, hypershot, n_classes, q_test)
                print("Local train acc and loss at the end of set:", l_set_id, "-->", train_metrics)
                # print("Local valid acc and loss at the end of set:", l_set_id, "-->", test_metrics)
                
                # Training loss when attacking query image only
                unique_values = torch.unique(q_set_lbl_train.long())
                value_to_index = {value.item(): index for index, value in enumerate(unique_values)}
                q_set_lbl_train_m = torch.tensor([value_to_index[value.item()] for value in q_set_lbl_train.long()])

                adv_inputs_train = 255 * pgd_attack(q_set_train / 255.0, q_set_lbl_train_m.long())
                train_metrics_adv = calc_accuracy_lbls_adv(adv_inputs_train, q_set_lbl_train.long(), c_lbls, hypershot, n_classes)
                print("Local adv train acc and loss (train kernel, q attacked only):", l_set_id, "-->", train_metrics_adv)

                # Validation loss when attacking query image only
                # We update the kernel here (but we did not attack the support set)
                # s_set_test, s_set_lbl_test, q_set_test, q_set_lbl_test = \
                # hypershot.compute_sets_and_features(dataset_l_test, dataset_l_lbl_test, c_lbls, q_test, True)

                # unique_values = torch.unique(q_set_lbl_test.long())
                # value_to_index = {value.item(): index for index, value in enumerate(unique_values)}
                # q_set_lbl_test_m = torch.tensor([value_to_index[value.item()] for value in q_set_lbl_test.long()])

                # adv_inputs_test = pgd_attack(q_set_test, q_set_lbl_test_m.long())
                # test_metrics_adv = calc_accuracy_lbls_adv(adv_inputs_test, q_set_lbl_test.long(), c_lbls, hypershot, n_classes)
                # print("Local adv valid acc and loss (valid kernel, q attacked only):", l_set_id, "-->", test_metrics_adv)
                print("----")

            loss_dataset_l.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        scheduler.step()                
  
        print("Global loss at the end of epoch:", epoch, ":", global_loss_float)
        if (epoch+1) % 1 == 0:
            # (gva, gvl) = calc_accuracy(all_test_sets, all_test_sets_lbl, hypershot, n_classes, q_test)
            (gta, gtl) = calc_accuracy(dataset_valid, dataset_lbl_valid, hypershot, n_classes, q_test)
            # print("--> Global valid accuracy after epoch:", epoch, "-->", gva)
            print("--> Global test accuracy after epoch:", epoch, "-->", gta)
            # (gva_adv, gvl_adv) = calc_accuracy_adv(all_test_sets, all_test_sets_lbl, hypershot, n_classes, q_test, pgd_attack)
            (gta_adv, gtl_adv) = calc_accuracy_adv(dataset_valid, dataset_lbl_valid, hypershot, n_classes, q_test, pgd_attack)
            # print("--> Global adv valid accuracy after epoch:", epoch, "-->", gva_adv)
            print("--> Global adv test accuracy after epoch:", epoch, "-->", gta_adv)
            current_time = datetime.now().strftime("%Y%m%d%H%M%S")
            gta_r = round(gta, 2)
            gta_str = str(gta_r).replace('.', '_')
            gta_r_adv = round(gta_adv, 2)
            gta_str_adv = str(gta_r_adv).replace('.', '_')
            hs_file = f'models/HS_{K}Shot_{W}Way_{int(eps*255)}eps_{gta_str}Acc_{gta_str_adv}Adv_{current_time}_{epoch}.pth'
            torch.save(hypershot.state_dict(), hs_file)
        print()

    print("END OF ITERATION:",o+1)

In [None]:
current_time = datetime.now().strftime("%Y%m%d%H%M%S")

# Create a file name with the current time
hnet_file = f'models/hnet_{current_time}_best.pth'
torch.save(hnet.state_dict(), hnet_file)
kcnn_file = f'models/kcnn_{current_time}_best.pth'
torch.save(kcnn.state_dict(), kcnn_file)

In [None]:
# x_adv_dataset_1 = pgd_attack_data(dataset_1, dataset_1_lbl, mnet, hnet, z_space_1, K_1, 1)
# x_adv_dataset_1_np = x_adv_dataset_1.detach().cpu().numpy()
# x_adv_dataset_0_test = pgd_attack_data(dataset_0_test, dataset_0_lbl_test, mnet, hnet, z_space, K, 0)
# x_adv_dataset_0_test_np = x_adv_dataset_0_test.detach().cpu().numpy()

print(calc_accuracy(all_test_sets, all_test_sets_lbl, hnet, mnet, Ks, kcnn, n_classes, 5))
print(calc_accuracy(dataset_1, dataset_1_lbl, hnet, mnet, Ks, kcnn, n_classes, 5))
# accuracies_dataset_0_adv.append((calc_accuracy(x_adv_dataset_0_test_np, dataset_0_lbl_test, mnet, W_dataset_0)).detach().cpu())
# accuracies_dataset_1_adv.append((calc_accuracy(x_adv_dataset_1_np, dataset_1_lbl, mnet, W_dataset_1)).detach().cpu())