In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F
import argparse
import matplotlib
matplotlib.use('Agg')

import glob
from PIL import Image
import os
from datetime import datetime
import time
import math
import sys
from model_utils import *
from data_utils import *

In [5]:
device = torch.device('cuda:1')
device

device(type='cuda', index=1)

# Train a basic energy model CNN on MNIST or CIFAR with EP

In [104]:
class ModelArgs():
    def __init__(self, preset):
        if preset == 'cifar_cnn':
            ## using the model architecture defined in check/train_cifar10.sh :
            # python main.py --model 'CNN' --task 'CIFAR10' --channels 128 256 512 --kernels 3 3 3 --pools 'mmm'
            # --strides 1 1 1 --fc 10 --optim 'adam' --lrs 5e-5 5e-5 1e-5 7e-6 --epochs 1 --act 'hard_sigmoid
            # --todo 'train' --betas 0.0 0.5 --T1 200 --T2 20 --mbs 128 --check-thm --random-sign --loss 'mse' --save
            # --save-nrn  --device 0
            self.pools ='mmm'
            self.channels = [128, 256, 512]
            self.kernels = [3,3,3]
            self.strides = [1,1,1]
            self.paddings = [0,0,0]
            self.fc = [10]
            self.lrs = [5e-5, 5e-5, 1e-5, 7e-6] # per-layer learning rates
            self.wds = None # weight decay
            self.softmax = False
            self.act = 'hard_sigmoid'
            self.betas = [0.0, 0.5] # free phase, nudged phase
            self.T1 = 200 # free phase
            self.T2 = 20 # nudged phase
            self.loss = 'mse' # or 'cel'
            self.alg = 'EP' # options: 'CEP', 'EP', 'BPTT'
            self.mbs = 128  # minibatch size
        elif preset == 'mnist_cnn':
            # python main.py --model 'CNN' --task 'MNIST' --channels 32 64 --kernels 5 5 --pools 'mm'
            # --strides 1 1 --fc 10 --optim 'adam' --lrs 5e-5 1e-5 8e-6 --epochs 1 --act 'hard_sigmoid'
            # --todo 'train' --betas 0.0 0.4 --T1 200 --T2 10 --mbs 100 --device 0
            self.pools ='mm'
            self.channels = [32, 64]
            self.kernels = [5,5]
            self.strides = [1,1]
            self.paddings = [0,0]
            self.fc = [10]
            self.lrs = [5e-5, 1e-5, 8e-6] # per-layer learning rates
            self.wds = None # weight decay
            self.softmax = False
            self.act = 'hard_sigmoid'
            self.betas = [0.0, 0.5] # free phase, nudged phase
            self.T1 = 200 # free phase
            self.T2 = 10 # nudged phase
            self.loss = 'mse' # or 'cel'
            self.alg = 'EP' # options: 'CEP', 'EP', 'BPTT'
            self.mbs = 100  # minibatch size
        else:
            print('unrecognized preset {}!'.format(preset,))
        
                    
        self.device = device
        self.save = False
        self.load_path = ''
args = ModelArgs('mnist_cnn')
print('\nargs\tmbs\tT1\tT2\tepochs\tactivation\tbetas')
print('\t',args.mbs,'\t',args.T1,'\t',args.T2,'\t','_','\t',args.act, '\t', args.betas)


args	mbs	T1	T2	epochs	activation	betas
	 100 	 200 	 10 	 _ 	 hard_sigmoid 	 [0.0, 0.5]


In [100]:
if args.act=='mysig':
    activation = my_sigmoid
elif args.act=='sigmoid':
    activation = torch.sigmoid
elif args.act=='tanh':
    activation = torch.tanh
elif args.act=='hard_sigmoid':
    activation = hard_sigmoid
elif args.act=='my_hard_sig':
    activation = my_hard_sig
elif args.act=='ctrd_hard_sig':
    activation = ctrd_hard_sig
    

if args.loss=='mse':
    criterion = torch.nn.MSELoss(reduction='none').to(device)
elif args.loss=='cel':
    criterion = torch.nn.CrossEntropyLoss(reduction='none').to(device)

In [106]:
args.save = True

In [107]:
if args.save:
    date = datetime.now().strftime('%Y-%m-%d')
    time = datetime.now().strftime('%H-%M-%S')
    if args.load_path=='':
        path = 'results/'+args.alg+'/'+args.loss+'/'+date+'/'+time+'_gpu'+str(args.device)
    else:
        path = args.load_path
    if not(os.path.exists(path)):
        os.makedirs(path)
else:
    path = ''
path

'results/EP/mse/2023-06-28/09-44-52_gpucuda:1'

In [74]:
args.strides, args.channels, args.paddings

([1, 1], [32, 64], [0, 0])

In [76]:
if args.load_path == '':
    pools = make_pools(args.pools)
    channels = [3]+args.channels
    model = P_CNN(32, channels, args.kernels, args.strides, args.fc, pools, args.paddings,
                  activation=activation, softmax=args.softmax)
else:
    model = torch.load(load_path + '/model.pt', map_location=device)

In [77]:
model

P_CNN(
  (synapses): ModuleList(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (2): Linear(in_features=1600, out_features=10, bias=True)
  )
)

## set up CIFAR dataset

In [78]:
transform_train = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                                          torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 
                                                                                           std=(3*0.2023, 3*0.1994, 3*0.2010)) ])   

transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                                 torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 
                                                                                  std=(3*0.2023, 3*0.1994, 3*0.2010)) ]) 

cifar10_train_dset = torchvision.datasets.CIFAR10('./cifar10_pytorch', train=True, transform=transform_train, download=True)
cifar10_test_dset = torchvision.datasets.CIFAR10('./cifar10_pytorch', train=False, transform=transform_test, download=True)

# For Validation set
val_index = np.random.randint(10)
val_samples = list(range( 5000 * val_index, 5000 * (val_index + 1) ))

#train_loader = torch.utils.data.DataLoader(cifar10_train_dset, batch_size=mbs, sampler = torch.utils.data.SubsetRandomSampler(val_samples), shuffle=False, num_workers=1)
train_loader = torch.utils.data.DataLoader(cifar10_train_dset, batch_size=mbs, shuffle=True, num_workers=1)
test_loader = torch.utils.data.DataLoader(cifar10_test_dset, batch_size=200, shuffle=False, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


## training

### set up

In [79]:
optim_params = []
for idx in range(len(model.synapses)):
    if args.wds is None:
        optim_params.append(  {'params': model.synapses[idx].parameters(), 'lr': args.lrs[idx]}  )
    else:
        optim_params.append(  {'params': model.synapses[idx].parameters(), 'lr': args.lrs[idx], 'weight_decay': args.wds[idx]}  )
    if hasattr(model, 'B_syn'):
        for idx in range(len(model.B_syn)):
            if args.wds is None:
                optim_params.append( {'params': model.B_syn[idx].parameters(), 'lr': args.lrs[idx+1]} )
            else:
                optim_params.append( {'params': model.B_syn[idx].parameters(), 'lr': args.lrs[idx+1], 'weight_decay': args.wds[idx+1]} )
#         if hasattr(model, 'lat_syn'):
#             for idx in range(len(model.lat_syn)):
#                 if args.wds is None:
#                     optim_params.append( {'params': model.lat_syn[idx].parameters(), 'lr': args.lrs[idx]} )
#                 else:
#                     optim_params.append( {'params': model.lat_syn[idx].parameters(), 'lr': args.lrs[idx], 'weight_decay': args.wds[idx+1]} )

In [80]:
optimizer = torch.optim.Adam( optim_params ) # we will still use EP to calculate the gradient of the weights, but torch will update the values for us

In [81]:
if save and load_path=='':
        createHyperparametersFile(path, args, model, command_line)

In [82]:
### do the training!

In [84]:
model = model.to(device)
train(model, optimizer, train_loader, test_loader, args.T1, args.T2, args.betas, device, 10, criterion, alg=args.alg, 
                 random_sign=True, check_thm=False, save=False, path=path, checkpoint=None, 
                 thirdphase=True, scheduler=None, cep_debug=False)

Epoch : 0.0 	Run train acc : 0.25 	(5/20)	 elapsed time : 0m 0s 	 (will finish in 406m 58s)
Epoch : 0.1 	Run train acc : 0.235 	(1182/5020)	 elapsed time : 3m 23s 	 (will finish in 334m 15s)
Epoch : 0.2 	Run train acc : 0.236 	(2360/10020)	 elapsed time : 6m 42s 	 (will finish in 328m 22s)
Epoch : 0.3 	Run train acc : 0.233 	(3507/15020)	 elapsed time : 10m 28s 	 (will finish in 338m 21s)
Epoch : 0.4 	Run train acc : 0.234 	(4676/20020)	 elapsed time : 13m 55s 	 (will finish in 333m 55s)
Epoch : 0.5 	Run train acc : 0.236 	(5901/25020)	 elapsed time : 17m 3s 	 (will finish in 323m 51s)
Epoch : 0.6 	Run train acc : 0.237 	(7106/30020)	 elapsed time : 20m 26s 	 (will finish in 319m 54s)
Epoch : 0.7 	Run train acc : 0.239 	(8362/35020)	 elapsed time : 23m 12s 	 (will finish in 308m 11s)
Epoch : 0.8 	Run train acc : 0.241 	(9635/40020)	 elapsed time : 27m 0s 	 (will finish in 310m 30s)
Epoch : 0.9 	Run train acc : 0.243 	(10945/45020)	 elapsed time : 30m 41s 	 (will finish in 310m 14s)
Epo

Epoch : 7.3 	Run train acc : 0.391 	(5878/15020)	 elapsed time : 269m 32s 	 (will finish in 99m 40s)
Epoch : 7.4 	Run train acc : 0.389 	(7794/20020)	 elapsed time : 272m 54s 	 (will finish in 95m 51s)
Epoch : 7.5 	Run train acc : 0.392 	(9801/25020)	 elapsed time : 276m 43s 	 (will finish in 92m 13s)
Epoch : 7.6 	Run train acc : 0.391 	(11738/30020)	 elapsed time : 280m 30s 	 (will finish in 88m 33s)
Epoch : 7.7 	Run train acc : 0.392 	(13732/35020)	 elapsed time : 284m 21s 	 (will finish in 84m 55s)
Epoch : 7.8 	Run train acc : 0.392 	(15692/40020)	 elapsed time : 288m 12s 	 (will finish in 81m 16s)
Epoch : 7.9 	Run train acc : 0.394 	(17720/45020)	 elapsed time : 291m 23s 	 (will finish in 77m 26s)
Epoch : 8.0 	Run train acc : 0.394 	(19721/50000)	 elapsed time : 294m 57s 	 (will finish in 73m 44s)
Test accuracy :	 0.3989
Epoch : 8.0 	Run train acc : 0.45 	(9/20)	 elapsed time : 295m 46s 	 (will finish in 73m 55s)
Epoch : 8.1 	Run train acc : 0.419 	(2104/5020)	 elapsed time : 299m 

In [85]:
model = model.to(device)
train(model, optimizer, train_loader, test_loader, args.T1, args.T2, args.betas, device, 20, criterion, alg=args.alg, 
                 random_sign=True, check_thm=False, save=False, path=path, checkpoint=None, 
                 thirdphase=True, scheduler=None, cep_debug=False)

Epoch : 0.0 	Run train acc : 0.45 	(9/20)	 elapsed time : 0m 0s 	 (will finish in 705m 12s)
Epoch : 0.1 	Run train acc : 0.409 	(2055/5020)	 elapsed time : 3m 52s 	 (will finish in 766m 33s)
Epoch : 0.2 	Run train acc : 0.409 	(4097/10020)	 elapsed time : 7m 44s 	 (will finish in 765m 31s)
Epoch : 0.3 	Run train acc : 0.41 	(6152/15020)	 elapsed time : 11m 33s 	 (will finish in 758m 0s)
Epoch : 0.4 	Run train acc : 0.41 	(8208/20020)	 elapsed time : 15m 10s 	 (will finish in 742m 55s)
Epoch : 0.5 	Run train acc : 0.412 	(10303/25020)	 elapsed time : 18m 53s 	 (will finish in 736m 20s)
Epoch : 0.6 	Run train acc : 0.413 	(12405/30020)	 elapsed time : 22m 44s 	 (will finish in 735m 3s)
Epoch : 0.7 	Run train acc : 0.415 	(14529/35020)	 elapsed time : 26m 34s 	 (will finish in 732m 27s)
Epoch : 0.8 	Run train acc : 0.416 	(16641/40020)	 elapsed time : 30m 0s 	 (will finish in 719m 58s)
Epoch : 0.9 	Run train acc : 0.417 	(18770/45020)	 elapsed time : 33m 33s 	 (will finish in 711m 55s)
Ep

Epoch : 7.3 	Run train acc : 0.472 	(7096/15020)	 elapsed time : 275m 6s 	 (will finish in 478m 33s)
Epoch : 7.4 	Run train acc : 0.47 	(9403/20020)	 elapsed time : 278m 22s 	 (will finish in 473m 57s)
Epoch : 7.5 	Run train acc : 0.471 	(11793/25020)	 elapsed time : 281m 31s 	 (will finish in 469m 10s)
Epoch : 7.6 	Run train acc : 0.473 	(14187/30020)	 elapsed time : 285m 13s 	 (will finish in 465m 19s)
Epoch : 7.7 	Run train acc : 0.472 	(16515/35020)	 elapsed time : 289m 11s 	 (will finish in 461m 54s)
Epoch : 7.8 	Run train acc : 0.473 	(18936/40020)	 elapsed time : 292m 45s 	 (will finish in 457m 51s)
Epoch : 7.9 	Run train acc : 0.472 	(21229/45020)	 elapsed time : 296m 33s 	 (will finish in 454m 10s)
Epoch : 8.0 	Run train acc : 0.471 	(23553/50000)	 elapsed time : 300m 19s 	 (will finish in 450m 28s)
Test accuracy :	 0.4708
Epoch : 8.0 	Run train acc : 0.4 	(8/20)	 elapsed time : 301m 7s 	 (will finish in 451m 39s)
Epoch : 8.1 	Run train acc : 0.476 	(2390/5020)	 elapsed time :

Epoch : 14.5 	Run train acc : 0.497 	(12446/25020)	 elapsed time : 542m 44s 	 (will finish in 205m 50s)
Epoch : 14.6 	Run train acc : 0.499 	(14967/30020)	 elapsed time : 546m 30s 	 (will finish in 202m 6s)
Epoch : 14.7 	Run train acc : 0.498 	(17430/35020)	 elapsed time : 550m 10s 	 (will finish in 198m 20s)
Epoch : 14.8 	Run train acc : 0.499 	(19966/40020)	 elapsed time : 553m 58s 	 (will finish in 194m 37s)
Epoch : 14.9 	Run train acc : 0.499 	(22453/45020)	 elapsed time : 557m 11s 	 (will finish in 190m 41s)
Epoch : 15.0 	Run train acc : 0.5 	(24997/50000)	 elapsed time : 560m 52s 	 (will finish in 186m 57s)
Test accuracy :	 0.5002
Epoch : 15.0 	Run train acc : 0.4 	(8/20)	 elapsed time : 561m 39s 	 (will finish in 187m 12s)
Epoch : 15.1 	Run train acc : 0.502 	(2521/5020)	 elapsed time : 565m 15s 	 (will finish in 183m 24s)
Epoch : 15.2 	Run train acc : 0.504 	(5055/10020)	 elapsed time : 569m 1s 	 (will finish in 179m 40s)
Epoch : 15.3 	Run train acc : 0.503 	(7560/15020)	 elaps

In [108]:
model = model.to(device)
train(model, optimizer, train_loader, test_loader, args.T1, args.T2, args.betas, device, 1, criterion, alg=args.alg, 
                 random_sign=True, check_thm=False, save=True, path=path, checkpoint=None, 
                 thirdphase=True, scheduler=None, cep_debug=False)

Epoch : 0.0 	Run train acc : 0.65 	(13/20)	 elapsed time : 0m 1s 	 (will finish in 54m 27s)


KeyboardInterrupt: 

In [115]:
scheduler = None
train_acc = 0.52
test_acc = 0.518
best = 0.518
epoch_sofar = 41
epoch = 1
save_dic = {'model_state_dict': model.state_dict(), 'opt': optimizer.state_dict(),
                'train_acc': train_acc, 'test_acc': test_acc, 
                'best': best, 'epoch': epoch_sofar+epoch+1}
# save_dic['angles'] = angles
save_dic['scheduler'] = scheduler.state_dict() if scheduler is not None else None
torch.save(save_dic,  path + '/checkpoint.tar')
torch.save(model, path + '/model.pt')

## save this version of the model before adding to it for comparison

# add lateral connections + attention heads to trained model

## split into multiple attention heads
inhibitory connections creating a sparse layer may require 10x more neurons in that layer (like 10 attention heads)

In [None]:
class Lat_MH_CNN(P_CNN): # lateral-connectivity, mulit-head CNN
    def __init__(self, lat_heads, in_size, channels, kernels, strides, fc, pools, paddings, activation=hard_sigmoid, softmax=False):
        # initialize default P_CNN structure, without the final fully connected layer
        super(Lat_MH_CNN, self).__init__(in_size, channels, kernels, strides, [], pools, paddings, activation=activation, softmax=softmax)
        
        self.lat_heads = lat_heads
        
        size = in_size
        
        or idx in range(len(channels)-1): 
            # layers have already been added by super(), just compute the size
#             self.synapses.append(torch.nn.Conv2d(channels[idx], channels[idx+1], kernels[idx], 
#                                                  stride=strides[idx], padding=paddings[idx], bias=True))
                
            size = int( (size + 2*paddings[idx] - kernels[idx])/strides[idx] + 1 )          # size after conv
            if self.pools[idx].__class__.__name__.find('Pool')!=-1:
                size = int( (size - pools[idx].kernel_size)/pools[idx].stride + 1 )   # size after Pool

        size = size * size * channels[-1]        
        
        
        # split into multiple linear projections, use lateral connections to form attention
        self.head_encoders = torch.nn.ModuleList()
        self.head_hopfield = torch.nn.ModuleList()
#         self.head_decodes = torch.nn.ModuleList()
        for idx in range(len(lat_heads)):
            # projects from the last convolutional layer, uses lateral connection, then projects back
            self.head_encoders.append(torch.nn.Linear(size, lat_heads[idx], False))
            self.head_hopfield.append(torch.nn.Linear(lat_heads[idx], lat_heads[idx], bias=False))
            # should head encoder and hopfield have bias?
        size = torch.sum(lat_heads)
        
        # fully connect it back down to output dimension
        fc_layers = [size] + fc
        for idx in range(len(fc)):
            self.synapses.append(torch.nn.Linear(fc_layers[idx], fc_layers[idx+1], bias=True))
        
            
    def Phi(self, x, y, neurons, head_neurons, beta, criterion):

        mbs = x.size(0)       
        conv_len = len(self.kernels)
        heads = len(self.lat_heads)
        tot_len = len(self.synapses)

        layers = [x] + neurons        
        phi = 0.0

        #Phi computation changes depending on softmax == True or not
        for idx in range(conv_len):    
            phi += torch.sum( self.pools[idx](self.synapses[idx](layers[idx])) * layers[idx+1], dim=(1,2,3)).squeeze()  
        
        for j in range(heads):
            phi += self.head_encoders[j](layers[conv_len]) * layers[conv_len+j]
            phi += self.head_hopfield[j](layers[conv_len+j]) * layers[conv_len+j]
            
        if not self.softmax:
            for idx in range(conv_len, tot_len):
                layeridx = idx + heads
                phi += torch.sum( self.synapses[idx](layers[layeridx].view(mbs,-1)) * layers[layeridx+1], dim=1).squeeze()
             
            if beta!=0.0:
                if criterion.__class__.__name__.find('MSE')!=-1:
                    y = F.one_hot(y, num_classes=self.nc)
                    L = 0.5*criterion(layers[-1].float(), y.float()).sum(dim=1).squeeze()   
                else:
                    L = criterion(layers[-1].float(), y).squeeze()             
                phi -= beta*L

        else:
            # the output layer used for the prediction is no longer part of the system ! Summing until len(self.synapses) - 1 only
            for idx in range(conv_len, tot_len-1):
                layeridx = idx + heads
                phi += torch.sum( self.synapses[idx](layers[layeridx].view(mbs,-1)) * layers[layeridx+1], dim=1).squeeze()
             
            # the prediction is made with softmax[last weights[penultimate layer]]
            if beta!=0.0:
                L = criterion(self.synapses[-1](layers[-1].view(mbs,-1)).float(), y).squeeze()             
                phi -= beta*L            
        
        return phi
    
#     def init_neurons(self, mbs, device):
        
#         neurons = []
#         append = neurons.append
#         size = self.in_size
#         for idx in range(len(self.channels)-1): 
#             size = int( (size + 2*self.paddings[idx] - self.kernels[idx])/self.strides[idx] + 1 )   # size after conv
#             if self.pools[idx].__class__.__name__.find('Pool')!=-1:
#                 size = int( (size - self.pools[idx].kernel_size)/self.pools[idx].stride + 1 )  # size after Pool
#             append(torch.zeros((mbs, self.channels[idx+1], size, size), requires_grad=True, device=device))

#         size = size * size * self.channels[-1]
        
# #         head_neurons = []
#         for j in range(len(self.lat_heads)):
#             neurons.append(torch.zeros((mbs, self.lat_heads[j])), requires_grad=True, device=device)
        
#         if not self.softmax:
#             for idx in range(len(self.fc)):
#                 append(torch.zeros((mbs, self.fc[idx]), requires_grad=True, device=device))
#         else:
#             # we *REMOVE* the output layer from the system
#             for idx in range(len(self.fc) - 1):
#                 append(torch.zeros((mbs, self.fc[idx]), requires_grad=True, device=device))            
            
#         return neurons
    
    def compute_syn_grads(self, x, y, neurons_1, neurons_2, betas, criterion, check_thm=False):
        
        beta_1, beta_2 = betas
        
        self.zero_grad()            # p.grad is zero
        if not(check_thm):
            phi_1 = self.Phi(x, y, neurons_1, beta_1, criterion)
        else:
            phi_1 = self.Phi(x, y, neurons_1, beta_2, criterion)
        phi_1 = phi_1.mean()
        
        phi_2 = self.Phi(x, y, neurons_2, beta_2, criterion)
        phi_2 = phi_2.mean()
        
        delta_phi = (phi_2 - phi_1)/(beta_1 - beta_2)        
        delta_phi.backward() # p.grad = -(d_Phi_2/dp - d_Phi_1/dp)/(beta_2 - beta_1) ----> dL/dp  by the theorem
        
        # force hopfield / lateral connections to be symmetric (backwards and forward weights the same)
        for j in range(len(self.lat_heads)):
            self.head_hopfield[j] = 0.5 * (self.head_hopfield[j] + self.head_hopfield[j].T)

## initiate with zeros

## initiate with inhibitory connections to enforce sparsity

## copy of the original model, with multiplied heads

# train these three models

# compare pre-lateral connections, initially zero lateral connections, and inhibitory lateral connection model

# visualize learned features

## clamp feature neuron to 1, display fixed point input layer