In [1]:
from IPython.display import display, Markdown, Latex
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau

import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import Omniglot
from PIL import Image

from datetime import datetime

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Will use:", device)

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

Will use: cuda


In [2]:
data_dir = '.'

## Loading datasets

In [3]:
from hypnettorch.data import FashionMNISTData, MNISTData
from hypnettorch.data.dataset import Dataset
from hypnettorch.mnets import LeNet
from hypnettorch.mnets.resnet import ResNet
from hypnettorch.mnets.mlp import MLP
from hypnettorch.hnets import HMLP

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import learn2learn as l2l
import copy

np.random.seed(42)
torch.manual_seed(42)

mnist = MNISTData(data_dir, use_one_hot=True, validation_size=0)
fmnist = FashionMNISTData(data_dir, use_one_hot=True, validation_size=0)

omniglot = l2l.vision.datasets.FullOmniglot(root=data_dir,
                                            transform=transforms.Compose([
                                                transforms.Resize(28, interpolation=Image.LANCZOS),
                                                transforms.ToTensor(),
                                                lambda x: 1.0 - x,
                                            ]),
                                            download=True)
omniglot = l2l.data.MetaDataset(omniglot)


Reading MNIST dataset ...
Elapsed time to read dataset: 0.223741 sec
Files already downloaded and verified
Files already downloaded and verified


## Convert the dataset to numpy for easier manipulation

In [4]:
# Create a DataLoader for batching and shuffling the data
batch_size = len(omniglot)  # Set batch size to the total number of examples to load all data at once
data_loader = DataLoader(omniglot, batch_size=batch_size, shuffle=False)

# Iterate through the DataLoader
for batch in data_loader:
    images, labels = batch
    # Convert PyTorch tensors to NumPy arrays
    dataset = images.numpy()
    dataset_lbl = labels.numpy()    
    sizes = dataset.shape
    
print("Dataset dimension:", dataset.shape)
print("Labels dimension:", dataset_lbl.shape)
print(np.min(dataset_lbl))
print(np.max(dataset_lbl))
    

Dataset dimension: (32460, 1, 28, 28)
Labels dimension: (32460,)
0
1622


## Create 2 different datasets for two disjoint set of labels (deterministic for now)

In [5]:
# Get a batch of training samples from each data handler.
# mnist_inps, mnist_trgts = mnist.next_train_batch(4)
# dataset_inps, dataset_trgts = dataset.next_train_batch(4)
# dataset_full, dataset_full_lbl = dataset.next_train_batch(60000)
print(dataset_lbl)

n_classes = len(np.unique(dataset_lbl))
dataset_full = dataset.reshape((dataset.shape[0], dataset.shape[2]*dataset.shape[3]))
dataset_full_lbl = dataset_lbl

print(dataset_full.shape)
print(dataset_full_lbl.shape)

# TODO-yz: you will need to use the same split as is used in evaluation (pull and merge with main to get access to datasets.get_benchmark_tasksets which gives you train/val/test split over classes)
sep = 1100
lbls_0 = [i for i in range(sep)]
lbls_1 = [i for i in range(sep, n_classes)]

mask_0 = np.isin(dataset_full_lbl, np.array(lbls_0))
mask_1 = np.isin(dataset_full_lbl, np.array(lbls_1))
dataset_0, dataset_0_lbl = dataset_full[mask_0], dataset_full_lbl[mask_0]

print("Shape of the dataset_0:",dataset_0.shape)

dataset_1, dataset_1_lbl = dataset_full[mask_1], dataset_full_lbl[mask_1]

print("Shape of the dataset_1:",dataset_1.shape)

print("Some labels in set 1:", dataset_0_lbl[0:10])
print("Some labels in set 2:", dataset_1_lbl[0:10])
assert(np.all(np.isin(dataset_0_lbl, lbls_0)))
assert(np.all(np.isin(dataset_1_lbl, lbls_1)))

# mnist.plot_samples('MNIST Examples', mnist_inps, outputs=mnist_trgts)
# dataset.plot_samples('FashionMNIST Examples with lbl < sep', dataset_0[0:4], outputs=dataset_0_lbl[0:4])
# dataset.plot_samples('FashionMNIST Examples with lbl >= sep', dataset_1[0:4], outputs=dataset_1_lbl[0:4])

torch_dataset = torch.tensor(dataset_full_lbl)
unique_values, counts = torch.unique(torch_dataset, return_counts=True)

print("Minimum and maximum amount of sample per classes in the dataset")
print("Each classes contains at least", torch.min(counts).item(), "samples")
print("Each classes contains at most", torch.max(counts).item(), "samples")

[   0    0    0 ... 1622 1622 1622]
(32460, 784)
(32460,)
Shape of the dataset_0: (22000, 784)
Shape of the dataset_1: (10460, 784)
Some labels in set 1: [0 0 0 0 0 0 0 0 0 0]
Some labels in set 2: [1100 1100 1100 1100 1100 1100 1100 1100 1100 1100]
Minimum and maximum amount of sample per classes in the dataset
Each classes contains at least 20 samples
Each classes contains at most 20 samples


In [6]:
from scipy.ndimage import zoom, rotate
from scipy.interpolate import interp2d

def rotate_dataset(dataset, angle):
    dataset_unflatten = dataset.reshape(-1, 1, 28, 28)
    rotated_data = rotate(dataset_unflatten, angle, axes=(2, 3), reshape=False)
    return rotated_data.reshape(-1, 784)

def zoom_dataset(dataset, zoom_factor):
    dataset_unflatten = dataset.reshape(-1, 1, 28, 28)
    zoomed_dataset = zoom(dataset_unflatten, (1, 1, zoom_factor, zoom_factor), order=1)
    
    original_size = dataset_unflatten.shape
    zoomed_size = zoomed_dataset.shape
    diff = int((zoomed_size[2] - original_size[2])/2)
    interpolated_data = zoomed_dataset[:,:,diff:diff+original_size[2], diff:diff+original_size[2]]
    return interpolated_data.reshape(-1, 28 * 28)

### Compute a pgd attack on test set to assert robustness

In [7]:
class ResNet(nn.Module):
    def __init__(self, z_length):
        super(ResNet, self).__init__()
        self.z_length = z_length
        resnet18 = models.resnet18(pretrained=False)
        resnet18.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3), bias=False)
        resnet18.avgpool = torch.nn.AdaptiveAvgPool2d(1)
        resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, self.z_length)
        self.resnet = resnet18

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        return self.resnet(x)
    
class LeNet(nn.Module):
    def __init__(self, z_length, p):
        super(LeNet, self).__init__()
        self.z_length = z_length
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 16, 120)
        self.dropout1 = nn.Dropout(p=p)
        self.fc2 = nn.Linear(120, 84)
        self.dropout2 = nn.Dropout(p=p)
        self.fc3 = nn.Linear(84, z_length)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
    
class DropResNet(nn.Module):
    def __init__(self, z_length, dropout_prob=0.5):
        super(DropResNet, self).__init__()
        
        # Load the pre-trained ResNet-18 model
        self.z_length = z_length
        resnet18 = models.resnet18(pretrained=True)
        resnet18.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3), bias=False)
        resnet18.avgpool = torch.nn.AdaptiveAvgPool2d(1)
        
        # Remove the last fully connected layer
        self.features = nn.Sequential(*list(resnet18.children())[:-1])
        
        # Add custom fully connected layers with dropout
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),
            nn.Linear(512, self.z_length)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.fc_layers(x)
        return x
    

In [8]:

def compute_kernel(X, y, cnn, K):
    """
    Compute Hypershot kernel for a support set X and label y
    It takes the average of the z's for each label as suggested in the Hypershot paper
    
    Args:
        X (tensor): Support set used to compute the kernel
        y (tensor): corresponding labels
        cnn : CNN used to compute the embeddings
        K: the K of K-shot K-way learning

    Returns:
        type: embeddings, kernel
    """
    # Obtain the indices that would sort y_test
    indices = torch.argsort(y)

    # Use the indices to sort the rows of X_test
    sorted_X = X[indices].to(device)
    sorted_y = y[indices].to(device)
    
    reshaped_X = sorted_X.view(sorted_X.shape[0], 1, 28, 28).to(device)
    nn_X = cnn(reshaped_X)
    
    # TODO-yz: think this can be turned into a one-liner to make things faster. Sorting looks correct to me btw
    mean_X = torch.zeros((int(nn_X.shape[0] / K), nn_X.shape[1])).to(device)
    for i in range(K):
        mean_X[i] = torch.mean(nn_X[i*K:(i+1)*K], dim = 0)
    norm_mean_X = F.normalize(mean_X, p=2, dim=1)
    norm_X = F.normalize(nn_X, p=2, dim=1)
    
    assert(nn_X.shape==(sorted_X.shape[0], cnn.z_length))
    
    # TODO-yz: in the paper they used normalized dot product which is not the same as making the features 0-1-gaussian. See formula (6) in paper. Think it is okay to have f as the identity (its essentially the same as extending the feature extractor)
    return mean_X, torch.matmul(norm_mean_X, torch.t(norm_mean_X))

def get_s_and_q_sets(X, y, trgt_lbls, K, q_size):
    # TODO-yz: this should become much easier if you use the datasets.py
    """
    Computes a support set for data X for classes in y with K sample per classes
    and corresponding query sets of size q_size.
    
    Args:
        X (tensor): Data used to compute the sets (can contain label you do not want for your sets)
        y (tensor): corresponding labels
        trgt_lbls : the labels that end up in the sets
        K: the K of K-shot K-way learning
        q_size: amount of sample per classes in query set

    Returns:
        type: support set, support set labels, query set, query set labels
    """
    s_set = np.zeros((len(trgt_lbls) * K, X.shape[1]))
    s_set_lbl = np.zeros((len(trgt_lbls) * K))
    
    q_set = np.zeros((len(trgt_lbls) * q_size, X.shape[1]))
    q_set_lbl = np.zeros((len(trgt_lbls) * q_size))
    
    for j, l in enumerate(trgt_lbls):
        mask = (y == l)
        masked_data = X[mask]
        masked_lbls = y[mask]
        s_set[j*K:(j+1)*K] = masked_data[0:K]
        s_set_lbl[j*K:(j+1)*K] = masked_lbls[0:K]
        q_set[j*q_size:(j+1)*q_size] = masked_data[K:K+q_size]
        q_set_lbl[j*q_size:(j+1)*q_size] = masked_lbls[K:K+q_size]
    
    s_set = torch.tensor(s_set, requires_grad=True).to(device).float()
    s_set_lbl = torch.tensor(s_set_lbl, requires_grad=True).to(device).float()
    q_set = torch.tensor(q_set, requires_grad=True).to(device).float()
    q_set_lbl = torch.tensor(q_set_lbl, requires_grad=True).to(device).float()
    
    return s_set, s_set_lbl, q_set, q_set_lbl

def get_q_sample_features(X, cnn, kernel, zs):
    """
    Computes the final features used for classification, given a query sample mx
    
    Args:
        X (tensor): query sample 
        cnn: the cnn trained to compute the desired features
        kernel: the kernel corresponding to the corresponding X's support set
        zs: z space of the support set corresponding to the query sample

    Returns:
        type: final flattened features use by the main network
    """
    # TODO-yz: you could pass the entire query set at once instead of iterating over every i in range(q_set.shape[0]). But you'd need to make sure that the first dimension corresponds to the images of the query set and the second dimension to all the ways. Also flattening would have to be done starting at dim 1
    X = X.view(-1, 1, 28, 28)
    # TODO-yz: not sure why you don't need to unsqueeze here (1,28,28) has no batch_dim
    zs_q = cnn(X.view(X.shape[0], 1, 28, -1))
    zs_q = F.normalize(zs_q, p=2, dim=1)
    zs_q_m = torch.matmul(zs, torch.t(zs_q))
    return torch.t(zs_q_m)

def compute_sets_and_features(X, y, trgt_lbls, cnn, K, q_size):
    s_set, s_set_lbl, q_set, q_set_lbl = get_s_and_q_sets(X, y, trgt_lbls, K, q_size)

    # Kernel computation
    z_space, kernel = compute_kernel(s_set, s_set_lbl, cnn, K)
    all_q_features = get_q_sample_features(q_set, cnn, kernel, z_space).to(device)
    all_q_features_lbls = torch.tensor(q_set_lbl).to(device)
        
    return s_set, s_set_lbl, q_set, q_set_lbl, z_space, kernel, all_q_features, all_q_features_lbls

def extend_pred_to_nclasses(pred, n_c, lbls):
    out = torch.zeros((pred.shape[0], n_classes)).to(device)
    for i in range(out.shape[0]):
        # TODO-yz: not sure how this runs, probably have some global c_lbls somewhere bc it's not defined here
        out[i][lbls] = pred[i]
    return out

In [9]:
# TODO-yz: could wrap the entire forward pass of your model into a custom nn.Module and then you can use torchattacks.PGD (you will have to do this anyway for evaluation to work + it's cleaner ;)
def project(x_adv, x_orig):
    epsilon = 8/255.0
    x_adv_eps = torch.minimum(torch.maximum(x_adv, x_orig-epsilon), x_orig+epsilon)
    return torch.clamp(x_adv_eps, 0, 1)

def pgd_attack_data(X, y, t_mnet, t_hnet, K, cnn, kernel, zs):
    criterion = nn.CrossEntropyLoss()    
    x_adv = torch.clone(X).detach()
   
    for i in range(20):
        x_adv = x_adv.requires_grad_(True)
        x_features = torch.zeros((x_adv.shape[0], K)).to(device)
        x_features_lbls = torch.zeros((x_adv.shape[0])).to(device)
        for j in range(x_adv.shape[0]):
            mx = x_adv[j].view(-1, X.shape[1])
            x_sample_features = get_q_sample_features(mx, cnn, kernel, zs)
            x_features[j] = x_sample_features
            x_features_lbls[j] = y[j]
            
        # Apply to test set
        W_mnet = t_hnet(cond_id=0)
        logits = t_mnet.forward(x_features, weights=W_mnet)
        loss_adv = criterion(logits, x_features_lbls.long())
        loss_adv.backward(retain_graph=True)
        
        grad = x_adv.grad.detach()

        with torch.no_grad():
            x_adv = x_adv + 0.1 * torch.sign(grad)  # take a gradient update step to minimize the objective
            x_adv = project(x_adv, X)               # ensure we stay in the allowed range
            
    return x_adv

In [10]:
def calc_accuracy_lbls(X_test, y_test, test_classes, hnet, mnet, Ks, cnn, n_c, q_size):
    """
    Computes the prediction accuracy for the sample with label test_classes in X_test.
    Mainly used as utility for the calc_accuracy function below.
    
    Args:
        X_test (tensor): entire test set
        y_test (tensor): corresponding labels
        test_classes: the classes we want to consider for testing accuracies (should contain Ks classes)
        mnet : main net trained by the hypernetwork
        Ks: the K of K-shot K-way
        s_cnn: the cnn trained to compute the desired features

    Returns:
        type: accuracy
    """
    
    with torch.no_grad():
        s_set_test, s_set_lbl_test, q_set_test, q_set_lbl_test = get_s_and_q_sets(X_test, y_test, \
                                                                                 test_classes, Ks, q_size)
        z_space, K = compute_kernel(s_set_test, s_set_lbl_test, cnn, Ks)
        
        all_q_features = get_q_sample_features(q_set_test, cnn, K, z_space).to(device)
        all_q_features_lbls = torch.tensor(q_set_lbl_test).to(device)

        # TODO-yz: forward pass looks pretty correct to me now :)
        W_dataset_l_acc =  hnet(uncond_input=K.view(1, -1))
        dataset_l_P_acc = mnet.forward(all_q_features, weights=W_dataset_l_acc)
        # TODO-yz: not sure but seems like you could probably use torch.nn.functional.one_hot for this
        prediction_extended_acc = extend_pred_to_nclasses(dataset_l_P_acc, n_c, test_classes)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(prediction_extended_acc, all_q_features_lbls.long())
        accuracy = (torch.argmax(prediction_extended_acc,dim=1) == all_q_features_lbls.long()).float().mean().item()
        # print("Correctly predicted samples had labels:", all_q_features_lbls[torch.argmax(prediction_extended_acc,dim=1) == all_q_features_lbls.long()])
    return accuracy, loss.item()


def calc_accuracy(X_test, y_test, hnet, mnet, Ks, cnn, n_c, q_size):
    """
    Computes the prediction accuracy for the entire X_test test set.
    
    Args:
        X_test (tensor): entire test set
        y_test (tensor): corresponding labels
        mnet : main net trained by the hypernetwork
        Ks: the K of K-shot K-way
        s_cnn: the cnn trained to compute the desired features

    Returns:
        type: average accuracy over all the label batch (of Ks different labels each time)
    """
    if not torch.is_tensor(X_test):
        X_test_t = torch.FloatTensor(X_test).to(device)
    else:  
        X_test_t = torch.clone(X_test)
        
    if not torch.is_tensor(y_test):
        y_test_t = torch.FloatTensor(y_test).to(device)
    else:
        y_test_t = torch.clone(y_test)
        
    diff_classes = torch.unique(y_test_t)
    n_diff_classes = diff_classes.shape[0]
    n_sets = int(n_diff_classes / Ks)
    acc, loss = 0.0, 0.0
    for i in range(n_sets):
        lbls = diff_classes[i*Ks:(i+1)*Ks].tolist()
        d_acc, d_loss = calc_accuracy_lbls(X_test, y_test, lbls, hnet, mnet, Ks, cnn, n_c, q_size)
        acc += d_acc
        loss += d_loss
    acc = acc / n_sets
    loss = loss / n_sets
    return acc, loss

In [11]:
def calc_accuracy_lbls_adv(X_test, y_test, test_classes, mnet, Ks, s_cnn, q_set_test_adv):
    """
    Same as the calc_accuracy_lbls function but replace the query set with an attacked version of itself.
    """
    
    with torch.no_grad():
        s_set_test, s_set_lbl_test, q_set_test, q_set_lbl_test = get_s_and_q_sets(X_test, y_test, \
                                                                                 test_classes, Ks, 5) 
        q_set_test = q_set_test_adv
        z_space, K = compute_kernel(s_set_test, s_set_lbl_test, s_cnn, Ks)
        
        # Accuracy
        all_q_features = torch.zeros((q_set_test.shape[0], Ks)).to(device)
        all_q_features_lbls = torch.zeros((q_set_test.shape[0])).to(device)
        for i in range(q_set_test.shape[0]):
            mx = q_set_test[i].view(-1, q_set_test.shape[1])
            my = torch.argmax(q_set_lbl_test[i])
            q_sample_features = get_q_sample_features(mx, s_cnn, K, z_space)
            all_q_features[i] = q_sample_features
            all_q_features_lbls[i] = my

        # TODO-yz: here the forward pass doesn't look right, really recommend having one model where you call forward
        W_dataset_l = hnet(cond_id=0)
        dataset_l_P = mnet.forward(all_q_features, weights=W_dataset_l)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(dataset_l_P, all_q_features_lbls.long())
        accuracy = (torch.argmax(dataset_l_P,dim=1) == all_q_features_lbls.long()).float().mean().item()
        # print("Correctly predicted labels:", all_q_features_lbls.long()[torch.argmax(dataset_l_P,dim=1) == all_q_features_lbls.long()])
    return accuracy, loss.item()

In [18]:
# General behavior during trainig:
# Using LeNet, it very very slowly (400 epoch) converge to a slightly overfitting case with very bad precision
# Using DropResNet it converges faster but very quickly overfit too. Using this, we can achieve around 30% accuracy on
# full Omniglot dataset but not more and it is not stable at all during training
# Using ResNet, like DropResNet but worse in terms of overfitting of course (especially if we do not do this averaging of the
# z-space)
# Notice that in any of the 3 cases above, we always observe a pleateau around local loss of 1.6
# Test and validation accuracies highly varies depending on the value of q_test

# Configure training.
nepochs=200
# epoch after which adversarial training starts
do_adv_train = 10000
# K-shot k-way
Ks = 5
# Length of the embeddings produced by the CNN
z_len = 50

load_weights = 0
continue_training = 0

# Amount of sample in query sets during training
q_train = 5
# Amount of sample in query sets during validation and testing
q_test = 5

# Array storing statistics (not used for now)
accuracies_dataset_0 = []
accuracies_dataset_0_adv = []
accuracies_dataset_1 = []
accuracies_dataset_1_adv = []

# Loop in case we want to do statistics (not sued for now)
for o in range(1):
    print("Iteration", o+1)
    
    if continue_training == 0:
        # Models definition
        kcnn = DropResNet(z_len, 0.5).to(device)
        # kcnn = LeNet(z_len, p=0.0).to(device)
        mnet = MLP(n_in=Ks, n_out=Ks, hidden_layers=[16, 8]).to(device)
        hnet = HMLP(mnet.param_shapes, uncond_in_size=Ks**2, cond_in_size=0,
                    layers=[32, 32], num_cond_embs=0).to(device)
        params = hnet.conditional_params.copy()
        hnet.apply_hyperfan_init(mnet=mnet)
        criterion = nn.CrossEntropyLoss()

        # If we want to load weights from anywhere
        if load_weights == 1:
            file_path = 'models/hnet_20231229022719_49.pth'
            hnet.load_state_dict(torch.load(file_path))
            file_path = 'models/kcnn_20231229022719_49.pth'
            kcnn.load_state_dict(torch.load(file_path))

        # The amount of sets of Ks labels we can do during training
        n_sets = int(len(lbls_0) / Ks)

        # Compute training and validation sets for each of the n_sets labels sets
        train_test_sets = []
        all_test_sets = np.empty((0, dataset_0.shape[1]))
        all_test_sets_lbl = np.empty((0))
        for l_set_id in range(n_sets):
            c_lbls = lbls_0[l_set_id*Ks:(l_set_id+1)*Ks]
            if (l_set_id+1) % 100 == 0:
                print("Generated train-test split for", l_set_id+1,"/",n_sets)
            mask_b = np.isin(dataset_0_lbl, np.array(c_lbls))
            dataset_0_b, dataset_0_lbl_b = dataset_0[mask_b], dataset_0_lbl[mask_b]
            dataset_0_train, dataset_0_test, dataset_0_lbl_train, dataset_0_lbl_test = \
                            train_test_split(dataset_0_b, dataset_0_lbl_b, random_state=42, test_size=0.5, stratify=dataset_0_lbl_b)
            
            rotate_m10 = rotate_dataset(dataset_0_train, -10)
            rotate_m5 = rotate_dataset(dataset_0_train, -5)
            rotate_10 = rotate_dataset(dataset_0_train, 10)
            rotate_5 = rotate_dataset(dataset_0_train, 5)
            zoom_110 = zoom_dataset(dataset_0_train, 1.10)
            zoom_125 = zoom_dataset(dataset_0_train, 1.25)
            
            assert(rotate_m10.shape == dataset_0_train.shape)
            assert(zoom_125.shape == dataset_0_train.shape)
            
            dataset_0_train = np.concatenate((dataset_0_train, rotate_m10, rotate_m5,
                                                                           rotate_5,
                                                                           rotate_10,
                                                                           zoom_110,
                                                                           zoom_125), axis=0)
            dataset_0_lbl_train = np.repeat(dataset_0_lbl_train, 7, axis=0)
            
            
            all_test_sets = np.concatenate((all_test_sets, dataset_0_test), axis=0)
            all_test_sets_lbl = np.concatenate((all_test_sets_lbl, dataset_0_lbl_test), axis=0)
            train_test_sets.append((dataset_0_train, dataset_0_test, dataset_0_lbl_train, dataset_0_lbl_test, c_lbls))
    
    # Optimizer and scheduler initialization
    # TODO-yz: it's correct you only have 2 optimizers here. if you introduce a kernel function f, you'll need a third one though
    optimizer = optim.Adam(hnet.parameters(), lr=0.00005)
    optimizer_s = optim.Adam(kcnn.parameters(), lr=0.00005)
    scheduler = CosineAnnealingLR(optimizer, T_max=int(nepochs / 1), eta_min=0.000001)
    scheduler_s = CosineAnnealingLR(optimizer_s, T_max=int(nepochs / 1), eta_min=0.000001)
        
    # Main training loop
    for epoch in range(nepochs): # For each epoch.
        print("----------------------- Epoch", epoch, " -----------------------")
        # Stores the loss over all labels sets
        global_loss = 0.0
        global_loss_float = 0.0
        # We loop over all our sets at each epoch
        for l_set_id in range(n_sets):
            loss_dataset_l = 0.0
            (dataset_l_train, dataset_l_test, dataset_l_lbl_train, dataset_l_lbl_test, c_lbls) = train_test_sets[l_set_id]
            
            s_set_train, s_set_lbl_train, q_set_train, q_set_lbl_train, z_space, K, all_q_features, all_q_features_lbls = \
            compute_sets_and_features(dataset_l_train, dataset_l_lbl_train, c_lbls, kcnn, Ks, q_train)
            
            # Formward pass
            W_dataset_l = hnet(uncond_input=K.view(1, -1))
            dataset_l_P = mnet.forward(all_q_features, weights=W_dataset_l)
            prediction_extended = extend_pred_to_nclasses(dataset_l_P, n_classes, c_lbls)
            loss_dataset_l += criterion(prediction_extended, all_q_features_lbls.long())

            # Adversarial training
            # TODO-yz: think it's good to do adversarial querying, this gives a more meaningful baseline result
            if epoch == do_adv_train and l_set_id == 0:
                print("Adversarial training starts.")
            if epoch >= do_adv_train:
                mx_adv = pgd_attack_data(q_set_train, q_set_lbl_train, mnet, hnet, Ks, kcnn, K, z_space)
                
                all_q_features_adv = torch.zeros((q_set_train.shape[0], Ks)).to(device)
                for i in range(mx_adv.shape[0]):
                    mxx = mx_adv[i].view(-1, q_set_train.shape[1])
                    q_sample_features_adv = get_q_sample_features(mxx, kcnn, K, z_space)
                    all_q_features_adv[i] = q_sample_features_adv
                dataset_l_P_adv = mnet.forward(all_q_features_adv, weights=W_dataset_l)
                prediction_extended_adv = extend_pred_to_nclasses(dataset_l_P_adv, n_classes, c_lbls)
                loss_dataset_l_adv = criterion(prediction_extended_adv, all_q_features_lbls.long())
                loss_dataset_l += loss_dataset_l_adv
            
            global_loss += loss_dataset_l
            global_loss_float += loss_dataset_l.item()
            if l_set_id % 100 == 0:
                train_metrics = calc_accuracy(dataset_l_train, dataset_l_lbl_train, hnet, mnet, Ks, kcnn, n_classes, q_test)
                test_metrics = calc_accuracy(dataset_l_test, dataset_l_lbl_test, hnet, mnet, Ks, kcnn, n_classes, q_test)
                print("Local train acc and loss at the end of set:", l_set_id, "-->", train_metrics)
                print("Local valid acc and loss at the end of set:", l_set_id, "-->", test_metrics)
                if do_adv_train < nepochs:
                    s_set_test, s_set_lbl_test, q_set_test, q_set_lbl_test, z_space_tes, K_test, all_q_features_test, all_q_features_lbls_test = \
                    compute_sets_and_features(dataset_l_test, dataset_l_lbl_test, c_lbls, kcnn, Ks, 5)
                    mx_adv_test = pgd_attack_data(q_set_test, q_set_lbl_test, mnet, hnet, Ks, kcnn, K, z_space)
                    print("Local adv test acc and loss at the end of set:", l_set_id, "-->", calc_accuracy_lbls_adv(dataset_l_test, dataset_l_lbl_test, c_lbls, hnet, mnet,\
                                                                                           Ks, kcnn, mx_adv_test))


            loss_dataset_l.backward()
            optimizer.step()
            optimizer_s.step()
            optimizer.zero_grad()
            optimizer_s.zero_grad()
                
        scheduler.step()
        scheduler_s.step()
                
  
        print("Global loss at the end of epoch:", epoch, ":", global_loss_float)
        if (epoch+1) % 1 == 0:
            current_time = datetime.now().strftime("%Y%m%d%H%M%S")
            # Create a file name with the current time
            hnet_file = f'models/hnet_{current_time}_{epoch}.pth'
            torch.save(hnet.state_dict(), hnet_file)
            kcnn_file = f'models/kcnn_{current_time}_{epoch}.pth'
            torch.save(kcnn.state_dict(), kcnn_file)
            print("--> Global valid accuracy after epoch:", epoch, "-->", calc_accuracy(all_test_sets, all_test_sets_lbl,\
                                                                                       hnet, mnet, Ks, kcnn, n_classes, q_test))
            print("--> Global test accuracy after epoch:", epoch, "-->", calc_accuracy(dataset_1, dataset_1_lbl, \
                                                                                       hnet, mnet, Ks, kcnn, n_classes, q_test))
        print()

    print("END OF ITERATION:",o+1)

Iteration 1




Creating an MLP with 277 weights.
Created MLP Hypernet.
Hypernetwork with 11029 weights and 277 outputs (compression ratio: 39.82).
The network consists of 11029 unconditional weights (11029 internally maintained) and 0 conditional weights (0 internally maintained).
Generated train-test split for 100 / 220
Generated train-test split for 200 / 220
----------------------- Epoch 0  -----------------------


  all_q_features_lbls = torch.tensor(q_set_lbl).to(device)
  all_q_features_lbls = torch.tensor(q_set_lbl_test).to(device)
  out[i][lbls] = pred[i]


Local train acc and loss at the end of set: 0 --> (0.19999998807907104, 7.091139316558838)
Local valid acc and loss at the end of set: 0 --> (0.19999998807907104, 7.15891170501709)
Local train acc and loss at the end of set: 100 --> (0.19999998807907104, 3.2907145023345947)
Local valid acc and loss at the end of set: 100 --> (0.19999998807907104, 3.4086031913757324)
Local train acc and loss at the end of set: 200 --> (0.19999998807907104, 2.1442248821258545)
Local valid acc and loss at the end of set: 200 --> (0.19999998807907104, 2.176351547241211)
Global loss at the end of epoch: 0 : 797.3013560771942
--> Global valid accuracy after epoch: 0 --> (0.20145453492348844, 1.8870286638086493)
--> Global test accuracy after epoch: 0 --> (0.20423075843315858, 1.8861972506229694)

----------------------- Epoch 1  -----------------------
Local train acc and loss at the end of set: 0 --> (0.19999998807907104, 1.8302284479141235)
Local valid acc and loss at the end of set: 0 --> (0.1999999880790

--> Global valid accuracy after epoch: 9 --> (0.20981817411428147, 2.603711659799923)
--> Global test accuracy after epoch: 9 --> (0.2099999922256057, 2.485013173176692)

----------------------- Epoch 10  -----------------------
Local train acc and loss at the end of set: 0 --> (0.7999999523162842, 0.6919561624526978)
Local valid acc and loss at the end of set: 0 --> (0.19999998807907104, 3.280329704284668)
Local train acc and loss at the end of set: 100 --> (0.47999998927116394, 0.918348491191864)
Local valid acc and loss at the end of set: 100 --> (0.2800000011920929, 2.254089117050171)
Local train acc and loss at the end of set: 200 --> (0.7199999690055847, 0.6807985901832581)
Local valid acc and loss at the end of set: 200 --> (0.23999999463558197, 2.2492916584014893)
Global loss at the end of epoch: 10 : 179.7565712928772
--> Global valid accuracy after epoch: 10 --> (0.20745453824373808, 2.9557469237934457)
--> Global test accuracy after epoch: 10 --> (0.21153845394460055, 2.8324

Global loss at the end of epoch: 19 : 33.01240942813456
--> Global valid accuracy after epoch: 19 --> (0.20836362947117198, 5.884430684826591)
--> Global test accuracy after epoch: 19 --> (0.2157692237972067, 5.8602374746249275)

----------------------- Epoch 20  -----------------------
Local train acc and loss at the end of set: 0 --> (0.9599999785423279, 0.24018825590610504)
Local valid acc and loss at the end of set: 0 --> (0.19999998807907104, 6.402101993560791)
Local train acc and loss at the end of set: 100 --> (0.9199999570846558, 0.20209187269210815)
Local valid acc and loss at the end of set: 100 --> (0.35999998450279236, 5.139620304107666)
Local train acc and loss at the end of set: 200 --> (0.9599999785423279, 0.05970229580998421)
Local valid acc and loss at the end of set: 200 --> (0.19999998807907104, 8.552838325500488)
Global loss at the end of epoch: 20 : 25.58707440085709
--> Global valid accuracy after epoch: 20 --> (0.21290908418595791, 6.28347000208768)
--> Global te

Global loss at the end of epoch: 29 : 13.009663723409176
--> Global valid accuracy after epoch: 29 --> (0.21036363003606146, 6.4079818671399895)
--> Global test accuracy after epoch: 29 --> (0.20538460778502318, 6.376271639878933)

----------------------- Epoch 30  -----------------------
Local train acc and loss at the end of set: 0 --> (0.9599999785423279, 0.24106544256210327)
Local valid acc and loss at the end of set: 0 --> (0.23999999463558197, 2.734457015991211)
Local train acc and loss at the end of set: 100 --> (0.9599999785423279, 0.03665706515312195)
Local valid acc and loss at the end of set: 100 --> (0.19999998807907104, 5.6084184646606445)
Local train acc and loss at the end of set: 200 --> (1.0, 0.017248988151550293)
Local valid acc and loss at the end of set: 200 --> (0.1599999964237213, 4.162438869476318)
Global loss at the end of epoch: 30 : 12.245634670834988
--> Global valid accuracy after epoch: 30 --> (0.2158181754025546, 6.6395984877239576)
--> Global test accurac

Global loss at the end of epoch: 39 : 12.771385071682744
--> Global valid accuracy after epoch: 39 --> (0.21527272134341977, 6.44249848018993)
--> Global test accuracy after epoch: 39 --> (0.21115384000138596, 6.194011511710974)

----------------------- Epoch 40  -----------------------
Local train acc and loss at the end of set: 0 --> (1.0, 0.00828428752720356)
Local valid acc and loss at the end of set: 0 --> (0.19999998807907104, 5.661026477813721)
Local train acc and loss at the end of set: 100 --> (1.0, 0.029441656544804573)
Local valid acc and loss at the end of set: 100 --> (0.3999999761581421, 5.473513603210449)
Local train acc and loss at the end of set: 200 --> (1.0, 0.006433972157537937)
Local valid acc and loss at the end of set: 200 --> (0.19999998807907104, 5.747174263000488)
Global loss at the end of epoch: 40 : 13.1910816290183
--> Global valid accuracy after epoch: 40 --> (0.2150909027931365, 6.163813605091788)
--> Global test accuracy after epoch: 40 --> (0.2065384549

KeyboardInterrupt: 

In [None]:
current_time = datetime.now().strftime("%Y%m%d%H%M%S")

# Create a file name with the current time
hnet_file = f'models/hnet_{current_time}_best.pth'
torch.save(hnet.state_dict(), hnet_file)
kcnn_file = f'models/kcnn_{current_time}_best.pth'
torch.save(kcnn.state_dict(), kcnn_file)

In [None]:
# x_adv_dataset_1 = pgd_attack_data(dataset_1, dataset_1_lbl, mnet, hnet, z_space_1, K_1, 1)
# x_adv_dataset_1_np = x_adv_dataset_1.detach().cpu().numpy()
# x_adv_dataset_0_test = pgd_attack_data(dataset_0_test, dataset_0_lbl_test, mnet, hnet, z_space, K, 0)
# x_adv_dataset_0_test_np = x_adv_dataset_0_test.detach().cpu().numpy()

print(calc_accuracy(all_test_sets, all_test_sets_lbl, hnet, mnet, Ks, kcnn, n_classes, 5))
print(calc_accuracy(dataset_1, dataset_1_lbl, hnet, mnet, Ks, kcnn, n_classes, 5))
# accuracies_dataset_0_adv.append((calc_accuracy(x_adv_dataset_0_test_np, dataset_0_lbl_test, mnet, W_dataset_0)).detach().cpu())
# accuracies_dataset_1_adv.append((calc_accuracy(x_adv_dataset_1_np, dataset_1_lbl, mnet, W_dataset_1)).detach().cpu())

In [None]:
print("Mean:")
print("dataset 0 accuracy:", np.mean(np.array(accuracies_dataset_0)))
print("dataset 1 accuracy:", np.mean(np.array(accuracies_dataset_1)))
print("dataset 0 adv accuracy:", np.mean(np.array(accuracies_dataset_0_adv)))
print("dataset 1 adv accuracy:", np.mean(np.array(accuracies_dataset_1_adv)))
print()
print("Standard deviation:")
print("dataset 0 accuracy:", np.std(np.array(accuracies_dataset_0)))
print("dataset 1 accuracy:", np.std(np.array(accuracies_dataset_1)))
print("dataset 0 adv accuracy:", np.std(np.array(accuracies_dataset_0_adv)))
print("dataset 1 adv accuracy:", np.std(np.array(accuracies_dataset_1_adv)))