<img align="center" style='max-width: 1000px' src="https://github.com/HSG-AIML-Teaching/DL2025-Lab/blob/main/lab_3/images/banner.png?raw=1">

<img align="left" style='max-width: 150px; height: auto' src="https://github.com/HSG-AIML-Teaching/DL2025-Lab/blob/main/lab_3/images/hsg_logo.png?raw=1">

# Lab 03 - "Hypernetworks"


## Objective

After learning the concepts in this lab, you should be able to:

- Understand the basic tools and methods needed for the implementation of Hypernetworks
- Implement basic Hypernetworks
- Apply two different types of slicing techniques to reduce the size of the Hypernetwork


## Outline


1. **A Simple Hypernetwork**: How Hypernetworks can be implemented in PyTorch.
2. **Slicing Technique 1**: A slicing technique that treats all parameters as a single vector.
2. **Slicing Technique 2**: A layer-wise slicing technique.



<img align='center' style='max-width: 700px' src='https://github.com/HSG-AIML-Teaching/DL2025-Lab/blob/main/lab_3/images/hypernet_forward.gif?raw=1'>

*Animation: The forward and backward propagation steps of a Hypernetwork. First, the Hypernetwork (the blue network) generates the weights of the main model (the white network) using some context information $t$. Then, it makes prediction on input $x$ using generated weights $w$ in a stateless manner. Finally, in the backpropagation step, the gradiants of the Hypernetwork are obtained by backpropagating through the main model to the Hypernetwork.*

### Install Required Packages

In [1]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np


## <font color='red'>1. A Simple Hypernetwork</font>



In this section, we implement a simple hypernetwork that generates the weights of an MLP. The weights are generated as a single vector of weights.

First, let's start with the definition of the MLP model:

In [None]:
class MLP(nn.Module):
    def __init__(self, n_inp, n_hidden, n_out):
        super().__init__()
        self.linear_1 = nn.Linear(n_inp, n_hidden)
        self.linear_2 = nn.Linear(n_hidden, n_hidden)
        self.linear_3 = nn.Linear(n_hidden, n_hidden)
        self.classifier = nn.Linear(n_hidden, n_out)

        self.activ = nn.ReLU()

    def forward(self, x):
        x = self.activ(self.linear_1(x))
        x = self.activ(self.linear_2(x))
        x = self.classifier(x)

        return x

Typically, we initialize an instance of the model, and feed it with some input to get the output:

In [None]:
main_model = MLP(10, 50, 5)
x = torch.randn(32, 10)
out = main_model(x)
print(out.shape)

torch.Size([32, 5])


In the example, the weights of the model are stored inside the model. Therefore, when we call `main_model(x)`, it uses the weights stored in the model's `state_dict` to do the forward propagation. What if the weights are provided from outside the model?

Now, let's first generate the weights of the model with another neural network called the Hypernetwork:

### 1.1 Hypernetwork

To generate the weights of another model, we first need to know the number of parameters and the parameter shapes in the main model:

In [None]:
# Shape of each parameter as a dictionary of name: shape
param_shapes = {n: p.shape for (n, p) in main_model.named_parameters()}

# Total number of parameters in the model
num_params = sum([p.numel() for p in main_model.parameters()])
print("Number of parameters: ", num_params)

Number of parameters:  5905


Then, we need to define the architecture of the Hypernetwork. The Hypernetwork is also an MLP that maps the input to the vector space of the main model's weights.

The most basic version of a Hypernetwork treats all weights as a single vector as shown in the following animation:

<img align='center' style='max-width: 700px' src='https://github.com/HSG-AIML-Teaching/DL2025-Lab/blob/main/lab_3/images/no_slice.gif?raw=1'>


In [None]:
class Hypernetwork(nn.Module):
    def __init__(self, n_inp, n_hidden, n_out):
        super().__init__()
        self.linear_1 = nn.Linear(n_inp, n_hidden)
        self.linear_2 = nn.Linear(n_hidden, n_out)

        self.actv = nn.ReLU()

    def forward(self, x):
        x = self.actv(self.linear_1(x))
        x = self.linear_2(x)

        return x

In [None]:
hypernetwork = Hypernetwork(10, 20, num_params)

The input to the hypernetwork is a 10-dimensional tensor, which is then mapped to a to 20-dimenstional hidden state. In the final linear layer, the hidden state is mapped to the vector space of the main model's weights. Let's forward a random tensor to the  Hypernetwork:

In [None]:
hn_out = hypernetwork(torch.randn(1, 10))
print("Hypernetwork output size: ", hn_out.shape)

Hypernetwork output size:  torch.Size([1, 5905])


The output size is equal to the number of parameters in the main model. Now, we need to reshape this tensor back to the original tensor shapes of the main model.

To reshape the output of the hypernetwork, we can start from the index zero of the Hypernetwork's output tensor, and slice it according to the original number of parameters in each layer of the main model. In the end, we need to reshape the tensor to the original size. We can store the reshaped results in a dictionary.

In [None]:
# Dictionary to store the reshaped parameters
reshaped_params = {}

# Start with an offset of 0
offset = 0
for (n, p) in param_shapes.items():
    sliced_parameter = hn_out[0][offset:offset+p.numel()]
    reshaped_params[n] = sliced_parameter.view(p)
    offset += p.numel()

Let's print the shape of reshaped parameters:

In [None]:
for n, p in reshaped_params.items():
    print(n, p.shape)

linear_1.weight torch.Size([50, 10])
linear_1.bias torch.Size([50])
linear_2.weight torch.Size([50, 50])
linear_2.bias torch.Size([50])
linear_3.weight torch.Size([50, 50])
linear_3.bias torch.Size([50])
classifier.weight torch.Size([5, 50])
classifier.bias torch.Size([5])


### 1.2 Forwarding with Parameters

Now, an important question to answer is: how to use these generated weights to make prediction with the main model?

<font color='darkgreen'>[Q] Can we just copy these weights to the `state_dict` dictionary of the model?</font>


In general, we have two ways to forward input with parameters:

1. Defining the function `forward_with_parameters()`
2. Calling the main model ina stateless way

#### Method 1: Defining a new forward function that accepts external parameters

We can add a new forward function that receivs both $x$ and $w$:

In [None]:
# Same model with a different forward function
class ModelV2(nn.Module):
    def __init__(self, n_inp, n_hidden, n_out):
        super().__init__()
        # ! These layers are not used during the forward pass
        self.linear_1 = nn.Linear(n_inp, n_hidden)
        self.linear_2 = nn.Linear(n_hidden, n_hidden)
        self.linear_3 = nn.Linear(n_hidden, n_hidden)
        self.classifier = nn.Linear(n_hidden, n_out)

        self.activ = nn.ReLU()

    def forward_with_params(self, x, params):
        # Params is a dictionary of name: tensor
        x = F.linear(x, params["linear_1.weight"], params["linear_1.bias"])
        x = F.relu(x)
        x = F.linear(x, params["linear_2.weight"], params["linear_2.bias"])
        x = F.relu(x)
        x = F.linear(x, params["linear_3.weight"], params["linear_3.bias"])
        x = F.relu(x)
        x = F.linear(x, params["classifier.weight"], params["classifier.bias"])

        return x

<font color='darkgreen'> [Q] Why are the opeations inside the new forward function performed as functionals instead of using the layers?</font>

Now, we create an instance of the model with the forward-with-parameter pass, and feed it with the same random tensor used to  generate the weights:

In [None]:
model = ModelV2(10, 20, 5)
out = model.forward_with_params(torch.randn(1, 10), reshaped_params)

# Print the output shape
print(out.shape)

torch.Size([1, 5])


#### Method 1: Stateless call

To make stateless calls from a stateful model, we can the use following function from PyTorch (available since version 2.0):

In [None]:
from torch.nn.utils.stateless import functional_call

We can directly use the main model without adding a new forward function. The only thing we need to do is to call it as below:

In [None]:
out = functional_call(main_model, reshaped_params, torch.randn(1, 10))
print(out.shape)

torch.Size([1, 5])


  out = functional_call(main_model, reshaped_params, torch.randn(1, 10))


It's that simple! So far, we have learned to use an external model called the Hypernetwork to generate th weights a main model and make prediction with the generated weights.

#### <font color='darkred'>**BUT**, there is a big problem!</font>

The number of parameters in the hypernetwork can easily "explode" this way. The Hypernetwork employs a linear layer in its final layer to map the hidden state of the Hypernetwork to the vectors space of the main model's weights. This essentially means that, if the size of the hidden state is $S$, and the total number of parameters is $N$, the total number of parameters in the hypernetwork will be $N \times S$:

In [None]:
n_params_main_model = sum([p.numel() for p in main_model.parameters()])
n_params_hypernetwork = sum([p.numel() for p in hypernetwork.parameters()])

print("Number of parameters in main model: ", n_params_main_model)
print("Number of parameters in hypernetwork: ", n_params_hypernetwork)

# Ratio of parameters in hypernetwork to main model
print("Ratio: ", n_params_hypernetwork / n_params_main_model)

Number of parameters in main model:  5905
Number of parameters in hypernetwork:  124225
Ratio:  21.037256562235395


This is super inefficient. The number of parameters in the Hypernetwork is ~21 times more than the number of parameters in the main model. We need to find better ways to generate the weights.

## <font color='red'>2. Slicing Technique 1</font>



In this part, we design a specific slicing technique that splits the entire network parameters with $N$ parameters into $K$ chunks, where $N \mod K = 0$.

The Hypernetwork then generates the weight of each chunk separately, conditioned on the chunk ID:

<img align='center' style='max-width: 700px' src='https://github.com/HSG-AIML-Teaching/DL2025-Lab/blob/main/lab_3/images/slice_1.gif?raw=1'>

In this example, we want to implement an MLP to train an MNIST classifier:

In [None]:
class MLP(nn.Module):
    def __init__(self, n_inp, n_hidden, n_out):
        super().__init__()
        self.linear_1 = nn.Linear(n_inp, n_hidden)
        self.linear_2 = nn.Linear(n_hidden, n_hidden)
        self.linear_3 = nn.Linear(n_hidden, n_hidden)
        self.classifier = nn.Linear(n_hidden, n_out)

        self.activ = nn.ReLU()

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.activ(self.linear_1(x))
        x = self.activ(self.linear_2(x))
        x = self.classifier(x)

        return x

    def forward_with_params(self, x, params):
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["linear_1.weight"], params["linear_1.bias"])
        x = F.relu(x)
        x = F.linear(x, params["linear_2.weight"], params["linear_2.bias"])
        x = F.relu(x)
        x = F.linear(x, params["linear_3.weight"], params["linear_3.bias"])
        x = F.relu(x)
        x = F.linear(x, params["classifier.weight"], params["classifier.bias"])
        return x


Similar to the previous example, we need to implement a Hypernetwork that generates the weights of this MLP. The important point here is to slice the weights in the output as explained above.

In order to avoid "parameter explosion" in the Hypernetwork, we need to use a single linear mapping from the hidden state of the Hypernetwork to each chunk of the main model's weight. Using the same mapping, requires conditioning the mapping on the chunk ID. Therefore, we define an **embedding layer** that maps chunk ID to a vector which is then concatenated to the hidden state of the Hypernetwork:

In [None]:
class Hypernetwork(nn.Module):
    def __init__(self, n_inp, n_hidden, n_out, chunk_size, dim_emb):
        super().__init__()
        # Embedding layer for each state
        self.n_chunks = n_out // chunk_size
        self.emb = nn.Embedding(self.n_chunks, dim_emb)

        # Initialize emb weights with uniform distribution
        nn.init.uniform_(self.emb.weight, -1.0, 1.0)

        # Hypernetwork's layers
        self.linear_1 = nn.Linear(n_inp, n_hidden)
        self.linear_2 = nn.Linear(n_hidden + dim_emb, chunk_size)

        # Activation function
        self.actv = nn.ReLU()

    def forward(self, x):
        # Retrieve embedding for all layers
        emb_inp = torch.arange(self.n_chunks).to(x.device)
        emb = self.emb(emb_inp)

        # Flatten x and apply the first linear layer
        x = x.view(x.shape[0], -1) # [B, C, H, W] -> [B, C * H * W]
        x = self.actv(self.linear_1(x)) # [B, n_hidden]

        # Unsqueeze x in the second dimension and replicate it for n times
        x = x.unsqueeze(1).repeat(1, self.n_chunks, 1) # [B, n_chunks, n_hidden]

        # Unsqueeze emb in the first dimension and replicate it for n times
        emb = emb.unsqueeze(0).repeat(x.shape[0], 1, 1) # [B, n_chunks, dim_embed]

        # Concatenate x and emb along the last dimension
        x = torch.cat([x, emb], dim=-1) # [B, n_chunks, dim_embed + n_hidden]

        # Apply the second linear layer on the conditioned x
        x = self.linear_2(x) # [B, n_chunks, chunk_size]

        # Flatten the output and return it
        x = x.view(x.shape[0], -1) # [B, n_chunks * chunk_size]

        return x

**The next question to answer is: what is a good chunk size?**

Since the number of parameters in the main model can vary, we define a function that takes the number of parameters $N$ , and returns the biggest divisor of $N$ that is smaller than $\sqrt{N}$.

In [None]:
def biggest_divisor(n):
    # Find the biggest divisor of n that is smaller than the square root of n
    for i in range(int(n**0.5), 0, -1):
        if n % i == 0:
            return i

Good, the `biggest_divisor` function finds the chunks size for us.

Now, we need to define an instance of the main model and its corresponding Hypernetwork:

In [None]:
# Initiliaze random seeds n PyTorch and Numpy for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

np.random.seed(0)

In [None]:
# Main model for the MNIST dataset
main_model = MLP(28*28, 50, 10)

# Define parameters shapes and number of parameters
param_shapes = {n: p.shape for (n, p) in main_model.named_parameters()}
num_params = sum([p.numel() for p in main_model.parameters()])

# Chunk size
chunk_size = biggest_divisor(num_params)
print("Chunk size:", chunk_size)

# Hypernetwork with sliced output    n_inp, n_hidden, n_out, chunk_size, dim_emb
hypernetwork = Hypernetwork(n_inp=28*28, n_hidden=5, n_out=num_params, chunk_size=chunk_size, dim_emb=2)

Chunk size: 20


Let's compare the number of parameters:

In [None]:
n_params_main_model = sum([p.numel() for p in main_model.parameters()])
n_params_hypernetwork = sum([p.numel() for p in hypernetwork.parameters()])

print("Number of parameters in main model: ", n_params_main_model)
print("Number of parameters in hypernetwork: ", n_params_hypernetwork)

# Ratio of parameters in hypernetwork to main model
print("Ratio: ", n_params_hypernetwork / n_params_main_model)

Number of parameters in main model:  44860
Number of parameters in hypernetwork:  8571
Ratio:  0.1910610789121712


Great! The number of parameters in the Hypernetwork is now much smaller than the number of parmaeters in the main model.

One last step before training the model is: to define the function that reshapes the generated parameters. The reshape function can be different for each slicing technique.


In [None]:
def reshape_generated_parameters(hn_out, param_shapes):
    reshaped_params = {}
    offset = 0
    for (n, p) in param_shapes.items():
        sliced_parameter = hn_out[0][offset:offset+p.numel()]
        reshaped_params[n] = sliced_parameter.view(p)
        offset += p.numel()

    return reshaped_params

Now, we train the model on the MNIST dataset to see how the final performance will be:

In [None]:
# Load the MNIST dataset
mnist_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))]
)
train_set = datasets.MNIST(root="./data", train=True,
                           download=True, transform=mnist_transform)
test_set = datasets.MNIST(root="./data", train=False,
                          download=True, transform=mnist_transform)
train_loader = DataLoader(train_set, batch_size=64,
                          shuffle=True)
test_loader = DataLoader(test_set, batch_size=64)

100%|██████████| 9.91M/9.91M [00:00<00:00, 17.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 483kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.44MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.3MB/s]


In [None]:
# Define the optimizer and the loss function
optimizer = torch.optim.Adam(hypernetwork.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Set the device
device = torch.device("cpu")

# Move the model and the hypernetwork to the device
main_model.to(device)
hypernetwork.to(device)

# Standard PyTorch training loop
n_epochs = 10
for epoch in range(n_epochs):
    pbar = tqdm(train_loader)
    for batch in pbar:
        x, y = batch
        x, y = x.to(device), y.to(device)

        # Zero-grad optimizer for the hypernetwork
        optimizer.zero_grad()

        # Generate weights with the hypernetwork
        hn_out = hypernetwork(x)

        # Reshape generated weights
        reshaped_params = reshape_generated_parameters(hn_out, param_shapes)

        # Make prediction
        pred = main_model.forward_with_params(x, reshaped_params)

        # Compute loss and backpropagate
        loss = criterion(pred, y)
        loss.backward()

        # Optimizer step update
        optimizer.step()

        # Set progress bar description
        pbar.set_description(f"Loss value: {loss.item():.4f}")

    with torch.no_grad():
        # Evaluate model after each epoch
        batch_accuracies = []
        pbar_test = tqdm(test_loader)
        for batch in test_loader:
            x, y = batch
            x, y = x.to(device), y.to(device)

            # Generate weights with the hypernetwork
            hn_out = hypernetwork(x)
            reshaped_params = reshape_generated_parameters(
                hn_out, param_shapes)

            # Make prediction
            pred = main_model.forward_with_params(x, reshaped_params)
            n_corrects = sum(pred.argmax(dim=1) == y).item()
            acc_batch = n_corrects / len(x)
            batch_accuracies.append(acc_batch)
            pbar_test.update()

    print(f"Average accuracy for epoch {epoch}: {sum(batch_accuracies)/len(batch_accuracies):.4f} \n")


Loss value: 1.7168: 100%|██████████| 938/938 [00:57<00:00, 16.33it/s]
 99%|█████████▉| 156/157 [00:02<00:00, 53.29it/s]

Average accuracy for epoch 0: 0.5809 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 1.3066:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 1.3258:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 1.3258:   0%|          | 2/938 [00:00<00:54, 17.27it/s][A
Loss value: 0.9669:   0%|          | 2/938 [00:00<00:54, 17.27it/s][A
Loss value: 1.5583:   0%|          | 2/938 [00:00<00:54, 17.27it/s][A
Loss value: 1.5583:   0%|          | 4/938 [00:00<00:57, 16.38it/s][A
Loss value: 1.3484:   0%|          | 4/938 [00:00<00:57, 16.38it/s][A
Loss value: 1.3601:   0%|          | 4/938 [00:00<00:57, 16.38it/s][A
Loss value: 1.3601:   1%|          | 6/938 [00:00<00:57, 16.20it/s][A
Loss value: 1.3231:   1%|          | 6/938 [00:00<00:57, 16.20it/s][A
Loss value: 1.4583:   1%|          | 6/938 [00:00<00:57, 16.20it/s][A
Loss value: 1.4583:   1%|          | 8/938 [00:00<01:03, 14.76it/s][A
Loss value: 1.1484:   1%|          | 8/938 [00:00<01:03, 14.76it/s][A
Loss value: 1.3992:   1%|          | 8/938 [00:00

Average accuracy for epoch 1: 0.8274 



Loss value: 0.6146:  30%|███       | 284/938 [00:17<00:38, 17.05it/s]
Loss value: 0.4703: 100%|██████████| 938/938 [00:57<00:00, 16.26it/s]
100%|██████████| 157/157 [01:01<00:00,  2.56it/s]
 97%|█████████▋| 153/157 [00:03<00:00, 56.44it/s]

Average accuracy for epoch 2: 0.8680 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.3659:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2923:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2923:   0%|          | 2/938 [00:00<00:51, 18.07it/s][A
Loss value: 0.5444:   0%|          | 2/938 [00:00<00:51, 18.07it/s][A
Loss value: 0.5575:   0%|          | 2/938 [00:00<00:51, 18.07it/s][A
Loss value: 0.5575:   0%|          | 4/938 [00:00<01:00, 15.56it/s][A
Loss value: 0.4486:   0%|          | 4/938 [00:00<01:00, 15.56it/s][A
Loss value: 0.4976:   0%|          | 4/938 [00:00<01:00, 15.56it/s][A
Loss value: 0.4976:   1%|          | 6/938 [00:00<00:57, 16.35it/s][A
Loss value: 0.5092:   1%|          | 6/938 [00:00<00:57, 16.35it/s][A
Loss value: 0.5917:   1%|          | 6/938 [00:00<00:57, 16.35it/s][A
Loss value: 0.5917:   1%|          | 8/938 [00:00<00:58, 16.01it/s][A
Loss value: 0.5927:   1%|          | 8/938 [00:00<00:58, 16.01it/s][A
Loss value: 0.3815:   1%|          | 8/938 [00:00

Average accuracy for epoch 3: 0.8930 



Loss value: 0.2335:  24%|██▍       | 229/938 [00:14<00:51, 13.88it/s]
Loss value: 0.1861: 100%|██████████| 938/938 [01:00<00:00, 15.59it/s]
100%|██████████| 157/157 [01:03<00:00,  2.47it/s]
 99%|█████████▊| 155/157 [00:02<00:00, 55.10it/s]

Average accuracy for epoch 4: 0.9016 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.1747:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2757:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2757:   0%|          | 2/938 [00:00<00:59, 15.84it/s][A
Loss value: 0.3727:   0%|          | 2/938 [00:00<00:59, 15.84it/s][A
Loss value: 0.3508:   0%|          | 2/938 [00:00<00:59, 15.84it/s][A
Loss value: 0.3508:   0%|          | 4/938 [00:00<00:58, 16.06it/s][A
Loss value: 0.1500:   0%|          | 4/938 [00:00<00:58, 16.06it/s][A
Loss value: 0.3492:   0%|          | 4/938 [00:00<00:58, 16.06it/s][A
Loss value: 0.3492:   1%|          | 6/938 [00:00<01:02, 14.84it/s][A
Loss value: 0.5648:   1%|          | 6/938 [00:00<01:02, 14.84it/s][A
Loss value: 0.3687:   1%|          | 6/938 [00:00<01:02, 14.84it/s][A
Loss value: 0.3687:   1%|          | 8/938 [00:00<01:12, 12.89it/s][A
Loss value: 0.2962:   1%|          | 8/938 [00:00<01:12, 12.89it/s][A
Loss value: 0.4891:   1%|          | 8/938 [00:00

Average accuracy for epoch 5: 0.9130 



Loss value: 0.2243:  29%|██▉       | 272/938 [00:17<00:40, 16.26it/s]
Loss value: 0.0978: 100%|██████████| 938/938 [01:00<00:00, 15.44it/s]
100%|██████████| 157/157 [01:04<00:00,  2.42it/s]
 98%|█████████▊| 154/157 [00:02<00:00, 54.19it/s]

Average accuracy for epoch 6: 0.9178 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2257:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.3065:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.3065:   0%|          | 2/938 [00:00<01:02, 15.02it/s][A
Loss value: 0.1065:   0%|          | 2/938 [00:00<01:02, 15.02it/s][A
Loss value: 0.2282:   0%|          | 2/938 [00:00<01:02, 15.02it/s][A
Loss value: 0.2282:   0%|          | 4/938 [00:00<01:02, 14.86it/s][A
Loss value: 0.3354:   0%|          | 4/938 [00:00<01:02, 14.86it/s][A
Loss value: 0.4008:   0%|          | 4/938 [00:00<01:02, 14.86it/s][A
Loss value: 0.4008:   1%|          | 6/938 [00:00<01:03, 14.75it/s][A
Loss value: 0.3924:   1%|          | 6/938 [00:00<01:03, 14.75it/s][A
Loss value: 0.2493:   1%|          | 6/938 [00:00<01:03, 14.75it/s][A
Loss value: 0.2493:   1%|          | 8/938 [00:00<01:02, 14.95it/s][A
Loss value: 0.2132:   1%|          | 8/938 [00:00<01:02, 14.95it/s][A
Loss value: 0.3308:   1%|          | 8/938 [00:00

Average accuracy for epoch 7: 0.9243 



Loss value: 0.1339:  29%|██▉       | 272/938 [00:17<00:42, 15.76it/s]
Loss value: 0.0323: 100%|██████████| 938/938 [01:01<00:00, 15.29it/s]
100%|██████████| 157/157 [01:04<00:00,  2.42it/s]
 97%|█████████▋| 153/157 [00:03<00:00, 35.27it/s]

Average accuracy for epoch 8: 0.9245 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2277:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.4488:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.4488:   0%|          | 2/938 [00:00<01:04, 14.46it/s][A
Loss value: 0.3666:   0%|          | 2/938 [00:00<01:04, 14.46it/s][A
Loss value: 0.2842:   0%|          | 2/938 [00:00<01:04, 14.46it/s][A
Loss value: 0.2842:   0%|          | 4/938 [00:00<01:06, 14.13it/s][A
Loss value: 0.1565:   0%|          | 4/938 [00:00<01:06, 14.13it/s][A
Loss value: 0.3128:   0%|          | 4/938 [00:00<01:06, 14.13it/s][A
Loss value: 0.3128:   1%|          | 6/938 [00:00<01:02, 14.88it/s][A
Loss value: 0.3225:   1%|          | 6/938 [00:00<01:02, 14.88it/s][A
Loss value: 0.1709:   1%|          | 6/938 [00:00<01:02, 14.88it/s][A
Loss value: 0.1709:   1%|          | 8/938 [00:00<01:06, 13.96it/s][A
Loss value: 0.1658:   1%|          | 8/938 [00:00<01:06, 13.96it/s][A
Loss value: 0.2903:   1%|          | 8/938 [00:00

Average accuracy for epoch 9: 0.9309 



We can see that we got a not so bad performance for 80% less parameters in the model.

## <font color='red'>3. Slicing Technique 2</font>



In the second slicing technique, we have separate heads for each layer of the main model.

In each "HyperHead", the weights of the correponding layers are sliced and then generated conditioned on the chunk ID:

<img align='center' style='max-width: 700px' src='https://github.com/HSG-AIML-Teaching/DL2025-Lab/blob/main/lab_3/images/slice_2.gif?raw=1'>

We want to use the same model as in the first slicing technique:

In [None]:
class MLP(nn.Module):
    def __init__(self, n_inp, n_hidden, n_out):
        super().__init__()
        self.linear_1 = nn.Linear(n_inp, n_hidden)
        self.linear_2 = nn.Linear(n_hidden, n_hidden)
        self.linear_3 = nn.Linear(n_hidden, n_hidden)
        self.classifier = nn.Linear(n_hidden, n_out)

        self.activ = nn.ReLU()

    def forward_with_params(self, x, params):
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["linear_1.weight"], params["linear_1.bias"])
        x = F.relu(x)
        x = F.linear(x, params["linear_2.weight"], params["linear_2.bias"])
        x = F.relu(x)
        x = F.linear(x, params["linear_3.weight"], params["linear_3.bias"])
        x = F.relu(x)
        x = F.linear(x, params["classifier.weight"], params["classifier.bias"])
        return x


The first module of the Hypernetwork that need to implement is called the "HyperHead". Each HyperHead has its own embedding layer, and the number of chuncks in the head can be determined by the original size of the weights vector in the correponding layer:

In [None]:
class HyperHead(nn.Module):
    def __init__(self, n_hidden, n_out, chunk_size, dim_emb):
        super().__init__()
        n_chunks = n_out // chunk_size
        # Embedding layer for each head
        self.emb = nn.Embedding(n_chunks, dim_emb)
        self.n_chunks = n_chunks

        # Initialize emb with uniform distribution
        nn.init.uniform_(self.emb.weight, -1.0, 1.0)

        # Output head linear mapping
        self.linear_1 = nn.Linear(n_hidden + dim_emb, chunk_size)

    def forward(self, x):
        # Retrieve embedding for all layers
        emb_inp = torch.arange(self.n_chunks).to(x.device)
        emb = self.emb(emb_inp)

        # Unsqueeze x in the second dimension and replicate it for n times
        x = x.unsqueeze(1).repeat(1, self.n_chunks, 1)

        # Unsqueeze emb in the first dimension and replicate it for n times
        emb = emb.unsqueeze(0).repeat(x.shape[0], 1, 1)

        # Concatenate x and emb along the last dimension
        x = torch.cat([x, emb], dim=-1)

        x = F.relu(x)
        x = self.linear_1(x)
        x = x.view(x.shape[0], -1)

        return x

As mentioned before, the function that finds the biggest divisor of a number (smaller than its square root) can be modified according to the slicing method. For example, here want the chunk size of the layers whose number of weights is less than 50 to be equal to number of their weights. Therefore, if `n<50`, `chunk_size=n`:

In [None]:
def biggest_divisor(n):
    if n < 50:
        return n
    # Find the biggest divisor of n that is smaller than the square root of n
    for i in range(int((n)**0.5), 0, -1):
        if n % i == 0:
            return i

The implementation of the Hypernetwork's class also need to change accordingly. The Hypernetwork needs to keep a list of HyperHeads for each layer in the model.

<font color='darkgreen'>[Q] Why not use a list to store the HyperHead instead of using `nn.ModuleList`?</font>

In [None]:
class Hypernetwork(nn.Module):
    def __init__(self, n_inp, n_hidden, param_shapes, dim_emb):
        super().__init__()
        self.linear_1 = nn.Linear(n_inp, n_hidden)

        self.heads = nn.ModuleList(
            [
                HyperHead(n_hidden,
                          pshape.numel(),
                          biggest_divisor(pshape.numel()),
                          dim_emb)
                for pshape in param_shapes]
        )
        self.actv = nn.ReLU()

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.actv(self.linear_1(x))

        # Loop over all heads and generate weights
        head_outs = [self.heads[i](x) for i in range(len(self.heads))]
        head_outs = torch.concat(head_outs, dim=1)

        return head_outs

Now, let's initialize the main model and its corresponding Hypernetwork:

In [None]:
# Initiliaze random seeds n PyTorch and Numpy for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

np.random.seed(0)

In [None]:
# Initialize the main model
main_model = MLP(28*28, 50, 10)

# Define parameters shapes and number of parameters
param_shapes = {n: p.shape for (n, p) in main_model.named_parameters()}
num_params = sum([p.numel() for p in main_model.parameters()])
hn_inp_shape = 28 * 28

# Initialize the Hypernetwork
hypernetwork = Hypernetwork(hn_inp_shape, 4, list(param_shapes.values()), 4)

We are also interested in knowing how much compression does the current method make in the end with the set values:

In [None]:
n_params_main_model = sum([p.numel() for p in main_model.parameters()])
n_params_hypernetwork = sum([p.numel() for p in hypernetwork.parameters()])

print("Number of parameters in main model: ", n_params_main_model)
print("Number of parameters in hypernetwork: ", n_params_hypernetwork)

# Ratio of parameters in hypernetwork to main model
print("Ratio: ", n_params_hypernetwork / n_params_main_model)

Number of parameters in main model:  44860
Number of parameters in hypernetwork:  7633
Ratio:  0.17015158270173875


That's good, very similar to the previous method!

It's time to train the Hypernetwork on the MNIST dataset:

In [None]:
# Define the optimizer and the loss function
optimizer = torch.optim.Adam(hypernetwork.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Set the device
device = torch.device("cpu")

# Move the model and the hypernetwork to the device
main_model.to(device)
hypernetwork.to(device)

# Standard PyTorch training loop
n_epochs = 10
for epoch in range(n_epochs):
    pbar = tqdm(train_loader)
    for batch in pbar:
        x, y = batch
        x, y = x.to(device), y.to(device)

        # Zero-grad optimizer for the hypernetwork
        optimizer.zero_grad()

        # Generate weights with the hypernetwork
        hn_out = hypernetwork(x)

        # Reshape generated weights
        reshaped_params = reshape_generated_parameters(hn_out, param_shapes)

        # Make prediction
        pred = main_model.forward_with_params(x, reshaped_params)

        # Compute loss and backpropagate
        loss = criterion(pred, y)
        loss.backward()

        # Optimizer step update
        optimizer.step()

        # Set progress bar description
        pbar.set_description(f"Loss value: {loss.item():.4f}")

    with torch.no_grad():
        # Evaluate model after each epoch
        batch_accuracies = []
        pbar_test = tqdm(test_loader)
        for batch in test_loader:
            x, y = batch
            x, y = x.to(device), y.to(device)

            # Generate weights with the hypernetwork
            hn_out = hypernetwork(x)
            reshaped_params = reshape_generated_parameters(
                hn_out, param_shapes)

            # Make prediction
            pred = main_model.forward_with_params(x, reshaped_params)
            n_corrects = sum(pred.argmax(dim=1) == y).item()
            acc_batch = n_corrects / len(x)
            batch_accuracies.append(acc_batch)
            pbar_test.update()

    print(f"Average accuracy for epoch {epoch}: {sum(batch_accuracies)/len(batch_accuracies):.4f} \n")


Loss value: 2.3127:  27%|██▋       | 254/938 [00:15<00:43, 15.77it/s]
Loss value: 2.2843: 100%|██████████| 938/938 [00:58<00:00, 16.13it/s]
100%|██████████| 157/157 [01:02<00:00,  2.51it/s]
 99%|█████████▊| 155/157 [00:02<00:00, 56.77it/s]

Average accuracy for epoch 0: 0.1059 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 2.3098:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 2.3520:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 2.3520:   0%|          | 2/938 [00:00<00:54, 17.24it/s][A
Loss value: 2.2982:   0%|          | 2/938 [00:00<00:54, 17.24it/s][A
Loss value: 2.2920:   0%|          | 2/938 [00:00<00:54, 17.24it/s][A
Loss value: 2.2920:   0%|          | 4/938 [00:00<00:54, 17.22it/s][A
Loss value: 2.2964:   0%|          | 4/938 [00:00<00:54, 17.22it/s][A
Loss value: 2.3057:   0%|          | 4/938 [00:00<00:54, 17.22it/s][A
Loss value: 2.3057:   1%|          | 6/938 [00:00<00:55, 16.72it/s][A
Loss value: 2.3230:   1%|          | 6/938 [00:00<00:55, 16.72it/s][A
Loss value: 2.2988:   1%|          | 6/938 [00:00<00:55, 16.72it/s][A
Loss value: 2.2988:   1%|          | 8/938 [00:00<00:53, 17.54it/s][A
Loss value: 2.3214:   1%|          | 8/938 [00:00<00:53, 17.54it/s][A
Loss value: 2.3059:   1%|          | 8/938 [00:00

Average accuracy for epoch 1: 0.7174 



Loss value: 0.6226:  31%|███▏      | 295/938 [00:17<00:38, 16.60it/s]
Loss value: 0.7985: 100%|██████████| 938/938 [00:56<00:00, 16.72it/s]
100%|██████████| 157/157 [00:59<00:00,  2.66it/s]
 97%|█████████▋| 153/157 [00:02<00:00, 59.35it/s]

Average accuracy for epoch 2: 0.8246 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.6526:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.5442:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.5442:   0%|          | 2/938 [00:00<01:01, 15.25it/s][A
Loss value: 0.4364:   0%|          | 2/938 [00:00<01:01, 15.25it/s][A
Loss value: 0.9753:   0%|          | 2/938 [00:00<01:01, 15.25it/s][A
Loss value: 0.9753:   0%|          | 4/938 [00:00<01:04, 14.53it/s][A
Loss value: 0.5729:   0%|          | 4/938 [00:00<01:04, 14.53it/s][A
Loss value: 0.4734:   0%|          | 4/938 [00:00<01:04, 14.53it/s][A
Loss value: 0.4734:   1%|          | 6/938 [00:00<01:00, 15.52it/s][A
Loss value: 0.6495:   1%|          | 6/938 [00:00<01:00, 15.52it/s][A
Loss value: 0.5939:   1%|          | 6/938 [00:00<01:00, 15.52it/s][A
Loss value: 0.5939:   1%|          | 8/938 [00:00<01:00, 15.43it/s][A
Loss value: 0.8219:   1%|          | 8/938 [00:00<01:00, 15.43it/s][A
Loss value: 0.4165:   1%|          | 8/938 [00:00

Average accuracy for epoch 3: 0.8336 



Loss value: 0.2597:  33%|███▎      | 314/938 [00:19<00:38, 16.15it/s]
Loss value: 0.4651: 100%|██████████| 938/938 [00:58<00:00, 16.01it/s]
100%|██████████| 157/157 [01:01<00:00,  2.54it/s]
 99%|█████████▊| 155/157 [00:02<00:00, 54.81it/s]

Average accuracy for epoch 4: 0.8506 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.4649:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.5714:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.5714:   0%|          | 2/938 [00:00<00:59, 15.84it/s][A
Loss value: 0.6603:   0%|          | 2/938 [00:00<00:59, 15.84it/s][A
Loss value: 0.6990:   0%|          | 2/938 [00:00<00:59, 15.84it/s][A
Loss value: 0.6990:   0%|          | 4/938 [00:00<01:02, 14.86it/s][A
Loss value: 0.5751:   0%|          | 4/938 [00:00<01:02, 14.86it/s][A
Loss value: 0.5343:   0%|          | 4/938 [00:00<01:02, 14.86it/s][A
Loss value: 0.5343:   1%|          | 6/938 [00:00<01:10, 13.13it/s][A
Loss value: 0.4489:   1%|          | 6/938 [00:00<01:10, 13.13it/s][A
Loss value: 0.6997:   1%|          | 6/938 [00:00<01:10, 13.13it/s][A
Loss value: 0.6997:   1%|          | 8/938 [00:00<01:16, 12.21it/s][A
Loss value: 0.5452:   1%|          | 8/938 [00:00<01:16, 12.21it/s][A
Loss value: 0.3819:   1%|          | 8/938 [00:00

Average accuracy for epoch 5: 0.8727 



Loss value: 0.6171:  24%|██▍       | 224/938 [00:14<00:45, 15.58it/s]
Loss value: 0.8639: 100%|██████████| 938/938 [01:00<00:00, 15.57it/s]
100%|██████████| 157/157 [01:04<00:00,  2.43it/s]
 97%|█████████▋| 152/157 [00:03<00:00, 46.92it/s]

Average accuracy for epoch 6: 0.8809 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2694:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.4947:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.4947:   0%|          | 2/938 [00:00<01:18, 11.86it/s][A
Loss value: 0.4544:   0%|          | 2/938 [00:00<01:18, 11.86it/s][A
Loss value: 0.4781:   0%|          | 2/938 [00:00<01:18, 11.86it/s][A
Loss value: 0.4781:   0%|          | 4/938 [00:00<01:07, 13.91it/s][A
Loss value: 0.3746:   0%|          | 4/938 [00:00<01:07, 13.91it/s][A
Loss value: 0.4157:   0%|          | 4/938 [00:00<01:07, 13.91it/s][A
Loss value: 0.4157:   1%|          | 6/938 [00:00<01:06, 14.00it/s][A
Loss value: 0.4695:   1%|          | 6/938 [00:00<01:06, 14.00it/s][A
Loss value: 0.6361:   1%|          | 6/938 [00:00<01:06, 14.00it/s][A
Loss value: 0.6361:   1%|          | 8/938 [00:00<01:05, 14.29it/s][A
Loss value: 0.4359:   1%|          | 8/938 [00:00<01:05, 14.29it/s][A
Loss value: 0.5968:   1%|          | 8/938 [00:00

Average accuracy for epoch 7: 0.8853 



Loss value: 0.3823:  30%|██▉       | 278/938 [00:18<00:58, 11.38it/s]
Loss value: 0.1335: 100%|██████████| 938/938 [01:02<00:00, 14.99it/s]
100%|██████████| 157/157 [01:06<00:00,  2.37it/s]
 97%|█████████▋| 153/157 [00:03<00:00, 49.43it/s]

Average accuracy for epoch 8: 0.8954 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2891:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.3468:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.3468:   0%|          | 2/938 [00:00<01:08, 13.61it/s][A
Loss value: 0.3055:   0%|          | 2/938 [00:00<01:08, 13.61it/s][A
Loss value: 0.3823:   0%|          | 2/938 [00:00<01:08, 13.61it/s][A
Loss value: 0.3823:   0%|          | 4/938 [00:00<01:03, 14.71it/s][A
Loss value: 0.2114:   0%|          | 4/938 [00:00<01:03, 14.71it/s][A
Loss value: 0.4306:   0%|          | 4/938 [00:00<01:03, 14.71it/s][A
Loss value: 0.4306:   1%|          | 6/938 [00:00<01:07, 13.86it/s][A
Loss value: 0.2186:   1%|          | 6/938 [00:00<01:07, 13.86it/s][A
Loss value: 0.2439:   1%|          | 6/938 [00:00<01:07, 13.86it/s][A
Loss value: 0.2439:   1%|          | 8/938 [00:00<01:13, 12.70it/s][A
Loss value: 0.3526:   1%|          | 8/938 [00:00<01:13, 12.70it/s][A
Loss value: 0.3566:   1%|          | 8/938 [00:00

Average accuracy for epoch 9: 0.9000 



As we can see, the speed of convergence can be different in the two methods!

The choice of slicing technique can depend on the architecure of the main model and complexity of the problem. However, the principles remain the same!

Finally, we need to compare the performance of the Hypernetwork with the same main model trained in a statefull way:

In [None]:
# Initiliaze random seeds n PyTorch and Numpy for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

np.random.seed(0)

In [None]:
# Define model

class MLP(nn.Module):
    def __init__(self, n_inp, n_hidden, n_out):
        super().__init__()
        self.linear_1 = nn.Linear(n_inp, n_hidden)
        self.linear_2 = nn.Linear(n_hidden, n_hidden)
        self.linear_3 = nn.Linear(n_hidden, n_hidden)
        self.classifier = nn.Linear(n_hidden, n_out)

        self.activ = nn.ReLU()

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.activ(self.linear_1(x))
        x = self.activ(self.linear_2(x))
        x = self.classifier(x)

        return x

    def forward_with_params(self, x, params):
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["linear_1.weight"], params["linear_1.bias"])
        x = F.relu(x)
        x = F.linear(x, params["linear_2.weight"], params["linear_2.bias"])
        x = F.relu(x)
        x = F.linear(x, params["linear_3.weight"], params["linear_3.bias"])
        x = F.relu(x)
        x = F.linear(x, params["classifier.weight"], params["classifier.bias"])
        return x

model = MLP(28*28, 50, 10)

# Define the optimizer and the loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Set the device
device = torch.device("cpu")

# Move the model to the device
model.to(device)

# Standard PyTorch training loop
n_epochs = 10
for epoch in range(n_epochs):
    pbar = tqdm(train_loader)
    for batch in pbar:
        x, y = batch
        x, y = x.to(device), y.to(device)

        # Zero-grad optimizer for the hypernetwork
        optimizer.zero_grad()

        # Generate weights with the hypernetwork
        pred = model(x)

        # Compute loss and backpropagate
        loss = criterion(pred, y)
        loss.backward()

        # Optimizer step update
        optimizer.step()

        # Set progress bar description
        pbar.set_description(f"Loss value: {loss.item():.4f}")

    with torch.no_grad():
        # Evaluate model after each epoch
        batch_accuracies = []
        pbar_test = tqdm(test_loader)
        for batch in test_loader:
            x, y = batch
            x, y = x.to(device), y.to(device)

            # Make prediction
            pred = model(x)
            n_corrects = sum(pred.argmax(dim=1) == y).item()
            acc_batch = n_corrects / len(x)
            batch_accuracies.append(acc_batch)
            pbar_test.update()

    print(f"Average accuracy for epoch {epoch}: {sum(batch_accuracies)/len(batch_accuracies):.4f} \n")


Loss value: 0.0334: 100%|██████████| 938/938 [00:15<00:00, 62.05it/s]
100%|██████████| 157/157 [00:19<00:00,  8.05it/s]
 97%|█████████▋| 153/157 [00:01<00:00, 83.22it/s]

Average accuracy for epoch 0: 0.9360 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.3282:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2826:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2428:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.1368:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2269:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.3000:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.3000:   1%|          | 6/938 [00:00<00:16, 57.61it/s][A
Loss value: 0.0844:   1%|          | 6/938 [00:00<00:16, 57.61it/s][A
Loss value: 0.1777:   1%|          | 6/938 [00:00<00:16, 57.61it/s][A
Loss value: 0.2656:   1%|          | 6/938 [00:00<00:16, 57.61it/s][A
Loss value: 0.0970:   1%|          | 6/938 [00:00<00:16, 57.61it/s][A
Loss value: 0.2674:   1%|          | 6/938 [00:00<00:16, 57.61it/s][A
Loss value: 0.2832:   1%|          | 6/938 [00:00<00:16, 57.61it/s][A
Loss value: 0.2832:   1%|▏         | 12/938 [00:00<00:16, 57.76it/s][A
Loss valu

Average accuracy for epoch 1: 0.9525 



Loss value: 0.2412: 100%|██████████| 938/938 [00:15<00:00, 60.63it/s]
100%|██████████| 157/157 [00:18<00:00,  8.56it/s]
 95%|█████████▍| 149/157 [00:01<00:00, 82.76it/s]

Average accuracy for epoch 2: 0.9531 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0578:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.1951:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.1784:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.1115:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0538:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0284:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0284:   1%|          | 6/938 [00:00<00:17, 52.49it/s][A
Loss value: 0.2689:   1%|          | 6/938 [00:00<00:17, 52.49it/s][A
Loss value: 0.2241:   1%|          | 6/938 [00:00<00:17, 52.49it/s][A
Loss value: 0.1672:   1%|          | 6/938 [00:00<00:17, 52.49it/s][A
Loss value: 0.0926:   1%|          | 6/938 [00:00<00:17, 52.49it/s][A
Loss value: 0.0359:   1%|          | 6/938 [00:00<00:17, 52.49it/s][A
Loss value: 0.0925:   1%|          | 6/938 [00:00<00:17, 52.49it/s][A
Loss value: 0.0925:   1%|▏         | 12/938 [00:00<00:17, 52.12it/s][A
Loss valu

Average accuracy for epoch 3: 0.9636 



Loss value: 0.0444:  94%|█████████▍| 885/938 [00:15<00:00, 59.15it/s]
Loss value: 0.0765: 100%|██████████| 938/938 [00:16<00:00, 56.69it/s]
100%|██████████| 157/157 [00:19<00:00,  8.03it/s]
 95%|█████████▍| 149/157 [00:01<00:00, 82.47it/s]

Average accuracy for epoch 4: 0.9677 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0539:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0462:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0294:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0540:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.1660:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.1660:   1%|          | 5/938 [00:00<00:19, 49.08it/s][A
Loss value: 0.0495:   1%|          | 5/938 [00:00<00:19, 49.08it/s][A
Loss value: 0.1000:   1%|          | 5/938 [00:00<00:19, 49.08it/s][A
Loss value: 0.1138:   1%|          | 5/938 [00:00<00:19, 49.08it/s][A
Loss value: 0.1338:   1%|          | 5/938 [00:00<00:19, 49.08it/s][A
Loss value: 0.0527:   1%|          | 5/938 [00:00<00:19, 49.08it/s][A
Loss value: 0.0527:   1%|          | 10/938 [00:00<00:19, 47.80it/s][A
Loss value: 0.0623:   1%|          | 10/938 [00:00<00:19, 47.80it/s][A
Loss value: 0.1963:   1%|          | 10/938 [00:00<00:19, 47.80it/s][A

Average accuracy for epoch 5: 0.9669 



Loss value: 0.0311:  69%|██████▊   | 644/938 [00:12<00:06, 42.78it/s]
Loss value: 0.0370: 100%|██████████| 938/938 [00:18<00:00, 50.84it/s]
100%|██████████| 157/157 [00:21<00:00,  7.29it/s]
 96%|█████████▌| 151/157 [00:01<00:00, 83.15it/s]

Average accuracy for epoch 6: 0.9672 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0366:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0459:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.2075:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0773:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0909:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0909:   1%|          | 5/938 [00:00<00:22, 41.02it/s][A
Loss value: 0.1382:   1%|          | 5/938 [00:00<00:22, 41.02it/s][A
Loss value: 0.0406:   1%|          | 5/938 [00:00<00:22, 41.02it/s][A
Loss value: 0.2749:   1%|          | 5/938 [00:00<00:22, 41.02it/s][A
Loss value: 0.0589:   1%|          | 5/938 [00:00<00:22, 41.02it/s][A
Loss value: 0.0181:   1%|          | 5/938 [00:00<00:22, 41.02it/s][A
Loss value: 0.0181:   1%|          | 10/938 [00:00<00:22, 41.12it/s][A
Loss value: 0.0426:   1%|          | 10/938 [00:00<00:22, 41.12it/s][A
Loss value: 0.0303:   1%|          | 10/938 [00:00<00:22, 41.12it/s][A

Average accuracy for epoch 7: 0.9682 



Loss value: 0.0036: 100%|██████████| 938/938 [00:18<00:00, 50.59it/s]
100%|██████████| 157/157 [00:21<00:00,  7.37it/s]
 97%|█████████▋| 153/157 [00:02<00:00, 60.79it/s]

Average accuracy for epoch 8: 0.9699 




  0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0943:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0171:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0538:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.0233:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.1315:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.1010:   0%|          | 0/938 [00:00<?, ?it/s][A
Loss value: 0.1010:   1%|          | 6/938 [00:00<00:18, 49.62it/s][A
Loss value: 0.0319:   1%|          | 6/938 [00:00<00:18, 49.62it/s][A
Loss value: 0.1371:   1%|          | 6/938 [00:00<00:18, 49.62it/s][A
Loss value: 0.0562:   1%|          | 6/938 [00:00<00:18, 49.62it/s][A
Loss value: 0.0199:   1%|          | 6/938 [00:00<00:18, 49.62it/s][A
Loss value: 0.1239:   1%|          | 6/938 [00:00<00:18, 49.62it/s][A
Loss value: 0.1239:   1%|          | 11/938 [00:00<00:20, 44.80it/s][A
Loss value: 0.0606:   1%|          | 11/938 [00:00<00:20, 44.80it/s][A
Loss val

Average accuracy for epoch 9: 0.9693 



Conclusion: To be discussed during the session

<font color='darkgreen'> [Q] Why bother using Hypernetworks with all additional complexities?</font>

| Technique                          | Ratio | Accuracy |
|-----------------------------------|-------|----------|
| Slicing technique 1                | 0.19  | 0.93     |
| Slicing technique 2                | 0.17  | 0.90     |
| Standard supervised baseline       | 1.00  | 0.97     |
