<a href="https://colab.research.google.com/github/ThomasMiconi/HebbianCNNPyTorch/blob/main/HebbGrad_CheckHebb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This code verifies that the PyTorch-computed gradient updates are equal to hand-computed Hebbian updates, for the various rules being studied.

We run the code for one single batch (and for a single layer), perform the backward pass, then (in the next cell) compare the computed gradient with the appropriate Hebbian update.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import pdb
import matplotlib.pyplot as plt

import scipy
from scipy import ndimage
from scipy import linalg

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import time

import numpy as np
from numpy import fft 

from scipy import io as spio



NL=3; STRIDES=(12,1,1); POOLSTRIDES = (2, 2, 2); POOLDIAMS = (2, 2, 2) ; SIZES = (5, 3, 3) ; N = [100, 196, 400, 100]; K = [10, 1, 1]

RULE = "OJA"

# Should we use pruning / masking of weights? Probability for each layer.
#PROBAMASK =  .75 # .97 
#PROBAMASKS = (0, 0.5, 0.5, 0.5)
#PROBAMASKS = (0, 0.9, 0.9, 0.9)
#PROBAMASKS = (0, 0.95, 0.95, 0.95)
#PROBAMASKS = (0, 0.98, 0.98, 0.98)
#PROBAMASKS = (0, 0.999, 0.999, 0.999)
PROBAMASKS = (0, 0., 0., 0.)
# PROBAMASKS = (0, 0.99, 0.99)

# Do we apply the ZCA to the whole image (IMAGE), individually to each patch through the unfolding/flattening convolution ("method 2" in the preprint) (FLAT), as a spatial filter (highly experimental, FILTER), or not at all? 
ZCA = 'NO' # FLAT, IMAGE, FILTER, None 

BINARIZEDPLAST = True   # Binarized WTA for plasticity?
LEARNONLYL1 = False     # Learning only in Layer 1 (for tests)?
USEFLAT = False         # Use the flattening/unfolding convolution ("method 2" in the preprint)?
NORMW = True            # Constrain weights to norm 1?
USETHRES = False         # Use adaptive thresholds?

# Use Coates' "triangle" method? If TRAINTEST: Don't use it during Hebbian learning, but do use it during data collection for training/testing the linear classifier.
USETRIANGLE = 'YES' # 'NO', 'YES', 'TRAINTEST'
L1PEN =  [0, 0, 0] # ; L1PEN = [-0.0 * xx for xx in L1PEN] ; print(L1PEN[2])

#TARGETRATE = [1.0/36, 2.0/400, 0.01, 0.01] 
#TARGETRATE = [0.05, 0.003, 0.01, 0.01] 
#TARGETRATE = [2/N[0], 2/N[1], 5/N[2]] # [3/xx for xx in N] 
TARGETRATE = [float(K[ii] / N[ii]) for ii in range(NL)]     # When using adaptive threshold, the target "firing" (k-WTA winning) rate must be K/N, otherwise trouble happens.

BATCHSIZE=27 # 00

RGB = True  # RGB or Grayscale?
NBINPUTCHANNELS = 3 if RGB else 1


CSIZE =    32       #  Size of input images
FSIZE = 32          # Size of whitening filter. Should be CSIZE for ZCA=IMAGE, SIZES[0] for ZCA=FLAT.
MIXDOG = 1.0        # For the Difference of Gaussians filter below. No longer used
NBLEARNINGEPOCHS = 20   # Number of learning epochs. Add 2 epochs with frozen weights for collecting training/testing data for the linear classifier.
LR = [ 0.001 / BATCHSIZE,  0.001 / BATCHSIZE,  10 * .001 / BATCHSIZE]   # Learning rates for each layer. We learn faster in the last layer, which has little data (only 2x2 outputs, only 1 winner for 400 filters each time)
LR = [xx * 10.0 for xx in LR]
MUTHRES =  3.0      #  Adaptation rate for adaptive thresholds


# This computation of receptive field sizes may not be correct anymore.
rfL1size = SIZES[0]  # Assumes no dilation in L1 
rfL2size = STRIDES[0] * POOLSTRIDES[0]*(SIZES[1]-1)+SIZES[0]
rfL3size = STRIDES[1] * POOLSTRIDES[1]*STRIDES[0] * POOLSTRIDES[0]* (SIZES[2]-1)+rfL2size
print("RF sizes: L1:", rfL1size, "L2:", rfL2size, "L3:", rfL3size )


tic = time.time()

device = 'cuda' if torch.cuda.is_available() else 'cpu'



# PyTorch data loading boilerplate. Note the multiple conditions. "UNLAB" = unlabelled (Hebbian learning)

if RGB:    
    transform_zca = transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
    #torchvision.transforms.RandomRotation(45),
            transforms.RandomCrop(FSIZE),  
        transforms.ToTensor(),
    ])
    transform_unlab = transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
    #torchvision.transforms.RandomRotation(45),
            transforms.RandomCrop(CSIZE),  # Larger image
        transforms.ToTensor(),
    ])
    transform_train = transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(CSIZE),  # Larger image
        transforms.ToTensor(),
    ])
else:
    transform_zca = transforms.Compose([
    torchvision.transforms.Grayscale(),
    torchvision.transforms.RandomHorizontalFlip(),
    #torchvision.transforms.RandomRotation(45),
    transforms.RandomCrop(FSIZE),  # Larger image
    transforms.ToTensor(),
    ])
    transform_unlab = transforms.Compose([
    torchvision.transforms.Grayscale(),
    torchvision.transforms.RandomHorizontalFlip(),
    #torchvision.transforms.RandomRotation(45),
    transforms.RandomCrop(CSIZE),  # Larger image
    transforms.ToTensor(),
    ])
    transform_train = transforms.Compose([
    torchvision.transforms.Grayscale(),
    torchvision.transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(CSIZE),  # Larger image
    transforms.ToTensor(),
    ])
transform_test = transform_train
# trainset = torchvision.datasets.CIFAR10(
#    root='./data', train=True, download=True, transform=transform_train)
# trainloader = torch.utils.data.DataLoader(
#    trainset, batch_size=BATCHSIZE, shuffle=False)#, num_workers=2) # Also check out pin_memory if using GPU


# STL10 data. Potentially useful for unsupervised learning !  Requires adjusting the size parameters.
# unlabloader = trainloader = testloader = None
# trainset = torchvision.datasets.STL10(
#    root='./data', split='train', download=True, transform=transform_train)
# testset = torchvision.datasets.STL10(
#     root='./data', split='test', download=True, transform=transform_test)
# unlabset = torchvision.datasets.STL10(
#     root='./data', split='unlabeled', download=True, transform=transform_unlab)
# zcaset = torchvision.datasets.STL10(
#     root='./data', split='unlabeled', download=True, transform=transform_zca)

trainset = torchvision.datasets.CIFAR10(
   root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
unlabset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_unlab)
zcaset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_zca)


zcaloader = torch.utils.data.DataLoader(
    zcaset, batch_size=9000, shuffle=True, num_workers=2)
trainloader = torch.utils.data.DataLoader(
   trainset, batch_size=BATCHSIZE, shuffle=True, num_workers=2) # Also check out pin_memory if using GPU
testloader = torch.utils.data.DataLoader(
    testset, batch_size=BATCHSIZE, shuffle=False, num_workers=2)
unlabloader = torch.utils.data.DataLoader(
    unlabset, batch_size=BATCHSIZE, shuffle=True, num_workers=2)


# The Olshausen & Field images. Must be downloaded from Bruno Olshausen's website.
#OP = spio.loadmat('IMAGES.mat')['IMAGES'] # Olshausen pictures, 512 * 512 * 10
#OP = spio.loadmat('IMAGES_RAW.mat')['IMAGESr'].astype(np.float32) # Olshausen pictures, 512 * 512 * 10
#OP = OP - np.min(OP); OP = OP / np.max(OP)




# Initializations. Some of that stuff is not used anymore.

# w has shape OutChannels, InChannels, H, W
w=[]; masks = []
wflat=[]
wlat=[]; wlatdecay=[]
meanfiringrates=[];meanfiringrates_p=[]; meanrealy=[]; meanabsprelimy=[]
allthres=[]; allfirings = []; allfirings_p = []
meancorry = [] ; meancorrp = []
optimizers=[]
DeccFilts = []
thres=[]; C = []
for numl in range(NL):

    # For the flattening convolution, wflat are the fixed weights:
    if numl == 0:
        NIC = NBINPUTCHANNELS
    else:
        NIC = N[numl-1]
    idx = 0
    wf =  torch.zeros((NIC * SIZES[numl] * SIZES[numl], NIC, SIZES[numl], SIZES[numl]), requires_grad=False, device=device) 
    for nin in range(NIC):
        for xin in range(SIZES[numl]):
            for yin in range(SIZES[numl]):
                wf.data[idx, nin, xin, yin] = 1
                idx += 1
    wflat.append(wf)

    # The actual (learned) weights:
    if  USEFLAT:
        # when using the flattening / unfolding convolution :
        wi = torch.randn((N[numl], NIC * SIZES[numl] * SIZES[numl], 1, 1), requires_grad=True, device=device) 
    else:
        # When *not* using the flattening convolution
        wi = torch.randn((N[numl], NIC, SIZES[numl], SIZES[numl]), requires_grad=True, device=device) 

    wi.data = wi.data * .01
    
    if True: # PROBAMASK > 0:
        mi = (torch.rand_like(wi, requires_grad=False) > PROBAMASKS[numl]).float().to(device)
        if numl == 0:
            mi.fill_(1.0)
        masks.append(mi)
        wi.data = wi.data * mi.data
    w.append(wi)
    if NORMW:
        # w has shape OutChannels, InChannels, H, W
        w[numl].data =    w[numl].data  / (1e-10 + torch.sqrt(torch.sum(w[numl].data ** 2, dim=[1,2,3], keepdim=True)))
    wlat.append( torch.zeros((N[numl], N[numl]), requires_grad=False, device=device) ) 
    wlatdecay.append( torch.zeros((N[numl], N[numl]), requires_grad=False, device=device) ) 

    C.append(0)
    meancorry.append(torch.zeros((N[numl], N[numl]), requires_grad=False, device=device))
    meancorrp.append(torch.zeros((N[numl], N[numl]), requires_grad=False, device=device))
    allthres.append([])
    allfirings.append([])
    allfirings_p.append([])
    meanfiringrates.append(torch.zeros((1,N[numl],1,1), requires_grad=False).to(device))
    meanfiringrates_p.append(torch.zeros((1,N[numl],1,1), requires_grad=False).to(device))
    meanrealy.append(torch.zeros((1,N[numl],1,1), requires_grad=False).to(device))
    meanabsprelimy.append(torch.zeros((1,N[numl],1,1), requires_grad=False).to(device))
    thres.append(torch.zeros_like(meanfiringrates[numl], requires_grad=False).to(device))
    optimizers.append(optim.SGD((w[numl],), lr=LR[numl], momentum=0.0))
    #optimizers.append(optim.Adam((w[numl],), lr=3e-4))

realys = [0] * NL; prelimys = [0] * NL; xs = [0] * NL




print("Init time:", time.time()-tic, "Device:", device)
tic = time.time()
firstpass=True
nbbatches = -1
wsav = []

testaccs=[]; trainaccs=[]
trainouts = []; trainouts_lin= []; trainouts_nomp = []; trainoutsq_l1 = []; trainoutsq_l2=[]; trainoutsq_l3=[]; traintargets = []; 
testouts = []; testouts_lin=[]; testouts_nomp=[]; testoutsq_l1 = []; testoutsq_l2=[]; testoutsq_l3=[]; testtargets = [] 



# Start the experiment !
epoch = 0

TESTING=False; TRAINING = False; UNLAB = True; 
zeloader = unlabloader

numbatch=0

x, targets = iter(zeloader).next()

with torch.no_grad():

    
    nbbatches += 1

    # First, load and prepare the image data

    x = x.to(device)
    if TRAINING or TESTING:
        targets = targets.to(device)

    xorig = x.detach().clone()

    x = x - torch.mean(x, dim=(1, 2,3),keepdim=True);  x = x / (1e-10 + torch.std(x, dim=(1, 2,3), keepdim=True)) 


    # Multiply whole images by full ZCA matrix, after whole-image normalization
    if ZCA == 'IMAGE':
        assert FSIZE == CSIZE
        xshape = x.shape
        t = x.reshape([xshape[0], -1])
        t = t - torch.mean(t, axis=1, keepdims=True)
        t = t / (1e-10 + torch.std(t, axis=1, keepdims=True))
        
        tzca = torch.matmul(zcamat, t.T).T
        x = tzca.reshape(xshape)
        # x=x.detach()
    
    xorig2 = x + 1e-10 # Just for debugging purposes


numl = 0  # 1st layer only

optimizers[numl].zero_grad()

#Prepare the input to the layer
with torch.no_grad():
    # Normalize each element in the batch
    x = x - torch.mean(x, dim=(1, 2,3),keepdim=True);  x = x / (1e-10 + torch.std(x, dim=(1, 2,3), keepdim=True)) 

    # If required, apply the "flattening"/unfolding convolution, separating each individual patch into a vector, and normalizing each of these.
    if USEFLAT:
        t = F.conv2d(x, wflat[numl]).detach()  # Should have size BS x nbinputchan*SIZES[0]*SIZES[0] x H_input x  W_input

        t = t - torch.mean(t, axis=1, keepdims=True)
        t = t / (1e-10 +  torch.std(t, axis=1, keepdims=True))
        # Also if specified, also ZCA-whiten each individual input patch (note: this *does* improve performance a little bit, it seems; individual patch-normalization alone doesn't.)
        if ZCA == 'FLAT' and numl == 0:
            t = t.moveaxis(1,3)

            t = torch.matmul(t[:,:, :, None, :], zcamat)
            t = t.squeeze(3).moveaxis(3, 1)

            t = t -  torch.mean(t, dim=1, keepdims=True)
            t = t / (1e-10 + torch.std(t, dim=1, keepdims=True))

        x = t

    xs[numl] = x.clone()


# Compute the FF input to the cells, which is also the first part of the computational graph ("w*x") and common to the "real" and "surrogate" outputs.
prelimy = F.conv2d(x, w[numl], stride=STRIDES[numl]) 


# Now compute the "real" y output
with torch.no_grad():
    
    prelimysav = prelimy.detach().clone()

    if not USETHRES:
        thres[numl].fill_(0)

    realy = (prelimy - thres[numl])

    # k-WTA
    # y output has shape BatchSize x NbOutChannels x H x 
    tk = torch.topk(realy.data, K[numl], dim=1, largest=True)[0]
    realy.data[realy.data < tk.data[:,-1,:,:][:, None, :, :]] = 0       

    if BINARIZEDPLAST:
        realy.data = (realy.data > 0).float()
    
    torch.clamp_(realy.data, max=50.0) # realy.data[realy.data>50.0] = 50.0 # Sould not happen.




# Then we compute an auxiliary output yforgrad, which will be used solely to make gradient computations produce the desired Hebbian output:
# w has shape OutChannels, InChannels, H, W
# y output has shape BatchSize x NbOutChannels x H x 

# Note: you must not include thresholds here, because this is only to build the appropriate computational graph. 
# The actual values will come from realy, which does include thresholding.

if RULE == "INSTAR":
    yforgrad = prelimy - 1/2 * torch.sum(w[numl] * w[numl], dim=(1,2,3))[None,:, None, None]
elif RULE == "OJA":
    yforgrad = prelimy - 1/2 * torch.sum(w[numl] * w[numl], dim=(1,2,3))[None,:, None, None] * realy.data
elif RULE == "PLAINHEBB":
    yforgrad = prelimy
else:
    raise ValueError("Which Rule?")



yforgrad.data = realy.data # We force the value of yforgrad to be the "correct" y. Note: if using plain Hebb (yforgrad = prelimy), that means we're modifying the value of prelimy and thus can't reuse it in the future !


# Compute the loss and perform the backward pass (if we are in the Hebbian / unlabeled phase)
loss = ( torch.sum( -1/2 * yforgrad * yforgrad) )
    # loss = ( torch.sum( yforgrad) )
    # loss = ( torch.sum( -1/2 * w[0] * w[0]) )


loss.backward()

print("Done!")
        
        




RF sizes: L1: 5 L2: 53 L3: 149
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Init time: 18.93275499343872 Device: cuda
Done!


In [2]:
w0 = w[0].data.cpu().numpy()
g0 = w[0].grad.cpu().numpy()
py = prelimysav.cpu().numpy()
ry = realy.cpu().numpy()
xx = x.cpu().numpy()
xshape = xx.shape

# First, other tests:

print("w.shape:", w0.shape, "grad.shape:", g0.shape, "x.shape:", xx.shape, "py.shape:", py.shape)
print(g0[0,0,0,0])
z = 0

print("Actual py (channel 0, pos 0,0, batch position 0):", py[0,0,0, 0])

for col in range(xshape[2]):
    for row in range(xshape[3]):
        if col < 5  and row < 5:
            for chan in range(1):
                z += w0[0, chan, row,col] * xx[0, chan, row, col]
print("Computed py:", z)
            
posmaxy = np.argmax(py[0, :, 0, 0])


print("Position max prelimy (batch pos 0, pos 0,0):", posmaxy, " - py at this pos:", py[0, posmaxy, 0, 0])
print("======")


# Now for the actual Hebbian update verification:

# We look at the PyTorch-computed gradient for the weight at kernel position 1,1, input channel 0, output channel numfilt (looping from 0 to 9).
# It should be equal to the proper Hebbian update for the rule considered, summed over all positions at which the filter is applied and all batch elements.

z=0

print("The two columns in the following should be roughly identical:")
for numfilt in range(10):
    z = 0
    for bpos in range(BATCHSIZE):
        for col in range(xshape[2]):
            for row in range(xshape[3]):
                if col % STRIDES[0] == 0 and row % STRIDES[0] == 0:
                    #print("adding:")
                    if RULE == "PLAINHEBB":
                        z += ry[bpos, numfilt, col // STRIDES[0], row // STRIDES[0]] * xx[bpos, 0, col+1, row+1] 
                    elif RULE == "INSTAR":
                        z += ry[bpos, numfilt, col // STRIDES[0], row // STRIDES[0]] * (xx[bpos, 0, col+1, row+1] - w0[numfilt,0,1,1])
                    elif RULE == "OJA":
                        z += ry[bpos, numfilt, col // STRIDES[0], row // STRIDES[0]] * (xx[bpos, 0, col+1, row+1] - w0[numfilt,0,1,1] * ry[bpos, numfilt, col // STRIDES[0], row // STRIDES[0]])
                    else:
                        raise ValueError("Which Rule?")


    print(g0[numfilt, 0, 1, 1], -z) #, w0[numfilt, 0, 0, 0], xx[0,0,0,0])
#for posmaxy in range(10):
#    z = -xx[0, 0, 0, 0] * py[0, posmaxy, 0, 0]
#    print(g0[posmaxy, 0, 0, 0], z, w0[posmaxy, 0, 0, 0], xx[0,0,0,0])

w.shape: (100, 3, 5, 5) grad.shape: (100, 3, 5, 5) x.shape: (27, 3, 32, 32) py.shape: (27, 100, 3, 3)
-49.826347
Actual py (channel 0, pos 0,0, batch position 0): 1.0887709
Computed py: 0.5310140330693685
Position max prelimy (batch pos 0, pos 0,0): 72  - py at this pos: 1.7010709
The two columns in the following should be roughly identical:
-47.784554 -47.78455396741629
-1.3501656 -1.3501656651496887
-1.2514424 -1.2514425963163376
-4.815675 -4.815674617886543
-17.923779 -17.923781782388687
9.121071 9.121071457862854
3.9969635 3.996963679790497
0.70757526 0.707575336098671
15.537903 15.537902865558863
0.273444 0.2734439969062805
