In [4]:
import torch
import torch.nn as nn
from biopytorch import BioConv2d, BioLinear
from torchsummaryX import summary
import pandas as pd
import numpy as np

In [3]:
from data import CIFAR10DataModule

cifar10 = CIFAR10DataModule(batch_size=64)
cifar10.setup()

Files already downloaded and verified
Files already downloaded and verified


In [5]:
from tqdm.notebook import tqdm

def test_model(model, dataloader) -> float:
    """
    Evaluate the accuracy of `model` on the dataset given by `dataloader`
    """
    
    model.eval()
    
    test_acc = 0.
    with torch.no_grad():
        for x, y in tqdm(dataloader):
            x, y = x.to("cuda"), y.to("cuda")
            
            out = model(x)
            acc = (y == out.argmax(1)).sum().item() / out.size(0)
            
            test_acc += acc
            
    return test_acc / len(dataloader)

In [14]:
def retrieve_stats(model) -> pd.DataFrame:
    """
    Retrieve the main hyperparameters of `model`, and measure its val/test accuracy on CIFAR-10.
    """
    
    bioconv_layers = [layer for layer in model.modules() if isinstance(layer, BioConv2d)]
    biolinear_layers = [layer for layer in model.modules() if isinstance(layer, BioLinear)]

    delta = [layer.delta for layer in bioconv_layers][0]
    p = [layer.lebesgue_p for layer in bioconv_layers][0]
    k = [layer.ranking_param for layer in bioconv_layers][0]

    dropout = [layer.p for layer in model.modules() if isinstance(layer, nn.Dropout)][0]
    
    val_acc = test_model(model, cifar10.val_dataloader())
    test_acc = test_model(model, cifar10.test_dataloader())
    
    #Compute total number of learnable parameters (both by Krotov learning rule or SGD)
    conv_params = np.sum([np.prod(layer.weight.shape) for layer in bioconv_layers])  #They do not support bias, so no need to add it
    lin_params  = np.sum([np.prod(layer.weight.shape) for layer in biolinear_layers])
    
    try:
        lin_params += np.sum([np.prod(layer.bias.shape) for layer in biolinear_layers])
    except AttributeError:
        lin_params += 0 #No bias parameter
            
    model_summary = summary(model, torch.rand((128, 3, 32, 32), device="cuda")) #This counts only the final SGD layer
    sgd_params = model_summary['Params'].dropna().sum()
    
    total_params = conv_params + lin_params + sgd_params

    return pd.Series(data=[val_acc * 100, test_acc * 100, p, k, delta, dropout, total_params], index=['val_acc', 'test_acc', 'p', 'k', 'Delta', 'dropout', 'n_params'])

In [15]:
#Load checkpoints 
checkpoints = [torch.load(f"SavedModels/layers{i+1}.pt") for i in range(5)]
stats       = [retrieve_stats(checkpoint) for checkpoint in checkpoints]

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

                             Kernel Shape       Output Shape   Params  \
Layer                                                                   
0_0.0.BatchNorm2d_batch_norm            -   [128, 3, 32, 32]        -   
1_0.ReLU_1                              -  [128, 96, 28, 28]        -   
2_0.MaxPool2d_2                         -  [128, 96, 14, 14]        -   
3_0.BatchNorm2d_3                       -  [128, 96, 14, 14]        -   
4_1                                     -       [128, 18816]        -   
5_2                                     -       [128, 18816]        -   
6_3                           [18816, 10]          [128, 10]  188.17k   

                             Mult-Adds  
Layer                                   
0_0.0.BatchNorm2d_batch_norm         -  
1_0.ReLU_1                           -  
2_0.MaxPool2d_2                      -  
3_0.BatchNorm2d_3                    -  
4_1                                  -  
5_2                                  -  
6_3           

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

                               Kernel Shape        Output Shape   Params  \
Layer                                                                      
0_0.0.0.BatchNorm2d_batch_norm            -    [128, 3, 32, 32]        -   
1_0.0.ReLU_1                              -   [128, 96, 28, 28]        -   
2_0.0.MaxPool2d_2                         -   [128, 96, 14, 14]        -   
3_0.1.0.BatchNorm2d_batch_norm            -   [128, 96, 14, 14]        -   
4_0.1.ReLU_1                              -  [128, 128, 12, 12]        -   
5_1                                       -  [128, 128, 12, 12]        -   
6_2                                       -        [128, 18432]        -   
7_3                                       -        [128, 18432]        -   
8_4                             [18432, 10]           [128, 10]  184.33k   

                               Mult-Adds  
Layer                                     
0_0.0.0.BatchNorm2d_batch_norm         -  
1_0.0.ReLU_1                      

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

                               Kernel Shape        Output Shape  Params  \
Layer                                                                     
0_0.0.0.BatchNorm2d_batch_norm            -    [128, 3, 32, 32]       -   
1_0.0.ReLU_1                              -   [128, 96, 28, 28]       -   
2_0.0.MaxPool2d_2                         -   [128, 96, 14, 14]       -   
3_0.1.0.BatchNorm2d_batch_norm            -   [128, 96, 14, 14]       -   
4_0.1.ReLU_1                              -  [128, 128, 12, 12]       -   
5_0.2.0.BatchNorm2d_batch_norm            -  [128, 128, 12, 12]       -   
6_0.2.ReLU_1                              -  [128, 192, 10, 10]       -   
7_0.2.MaxPool2d_2                         -    [128, 192, 5, 5]       -   
8_1                                       -    [128, 192, 5, 5]       -   
9_2                                       -         [128, 4800]       -   
10_3                                      -         [128, 4800]       -   
11_4                     

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

                               Kernel Shape        Output Shape  Params  \
Layer                                                                     
0_0.0.0.BatchNorm2d_batch_norm            -    [128, 3, 32, 32]       -   
1_0.0.ReLU_1                              -   [128, 96, 28, 28]       -   
2_0.0.MaxPool2d_2                         -   [128, 96, 14, 14]       -   
3_0.1.0.BatchNorm2d_batch_norm            -   [128, 96, 14, 14]       -   
4_0.1.ReLU_1                              -  [128, 128, 12, 12]       -   
5_0.2.0.BatchNorm2d_batch_norm            -  [128, 128, 12, 12]       -   
6_0.2.ReLU_1                              -  [128, 192, 10, 10]       -   
7_0.2.MaxPool2d_2                         -    [128, 192, 5, 5]       -   
8_0.3.0.BatchNorm2d_batch_norm            -    [128, 192, 5, 5]       -   
9_0.3.ReLU_1                              -    [128, 256, 3, 3]       -   
10_1                                      -    [128, 256, 3, 3]       -   
11_2                     

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

                                Kernel Shape        Output Shape Params  \
Layer                                                                     
0_0.0.0.BatchNorm2d_batch_norm             -    [128, 3, 32, 32]      -   
1_0.0.ReLU_1                               -   [128, 96, 28, 28]      -   
2_0.0.MaxPool2d_2                          -   [128, 96, 14, 14]      -   
3_0.1.0.BatchNorm2d_batch_norm             -   [128, 96, 14, 14]      -   
4_0.1.ReLU_1                               -  [128, 128, 12, 12]      -   
5_0.2.0.BatchNorm2d_batch_norm             -  [128, 128, 12, 12]      -   
6_0.2.ReLU_1                               -  [128, 192, 10, 10]      -   
7_0.2.MaxPool2d_2                          -    [128, 192, 5, 5]      -   
8_0.3.0.BatchNorm2d_batch_norm             -    [128, 192, 5, 5]      -   
9_0.3.ReLU_1                               -    [128, 256, 3, 3]      -   
10_0.4.0.BatchNorm1d_batch_norm            -         [128, 2304]      -   
11_0.4.ReLU_1            

In [20]:
#Retrieve also BioLinear hyperparams
biolinear_layer = [layer for layer in checkpoints[-1].modules() if isinstance(layer, BioLinear)][0]
print(f"p: {biolinear_layer.lebesgue_p}, k: {biolinear_layer.ranking_param}, Delta: {biolinear_layer.delta}")


p: 8, k: 2, Delta: 0.335


In [16]:
#Gather all the stats
df = pd.DataFrame({i+1: stats[i] for i in range(5)})
df

Unnamed: 0,1,2,3,4,5
val_acc,69.19785,67.127787,64.908439,59.832803,46.24801
test_acc,67.058121,65.226911,63.077229,58.857484,45.45183
p,2.0,8.0,8.0,8.0,8.0
k,9.0,3.0,5.0,7.0,2.0
Delta,0.08,0.34,0.25,0.235,0.335
dropout,0.2,0.25,0.05,0.1,0.1
n_params,195370.0,302122.0,386986.0,804394.0,1475554.0


## Comparison

In [17]:
#Benchmark of Hebbian Conv layers on CIFAR-10, taken from Amato et al., "Hebbian Learning Meets Deep Convolutional Neural Networks", 2019
hebbian_paper = pd.DataFrame({
    1: 63.92,
    2: 63.81,
    3: 58.28,
    4: 52.99,
    5: 41.78}, index=['test_ref'])

df = df.append(hebbian_paper)
df

Unnamed: 0,1,2,3,4,5
val_acc,69.19785,67.127787,64.908439,59.832803,46.24801
test_acc,67.058121,65.226911,63.077229,58.857484,45.45183
p,2.0,8.0,8.0,8.0,8.0
k,9.0,3.0,5.0,7.0,2.0
Delta,0.08,0.34,0.25,0.235,0.335
dropout,0.2,0.25,0.05,0.1,0.1
n_params,195370.0,302122.0,386986.0,804394.0,1475554.0
test_ref,63.92,63.81,58.28,52.99,41.78
