In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from utils import HopfieldEnergy, HopfieldUpdate
from torch.utils.data import Subset, DataLoader
import argparse
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torchvision
import torchvision.transforms as transforms
import numpy as np
import csv
import os
import time
from tabulate import tabulate

np.set_printoptions(precision=2, suppress=True)

In [2]:
# Parameters

input_size = 784
hidden1_size = 256
hidden2_size = 256
output_size = 11
free_steps = 40
nudge_steps = 5
learning_rate = 1.0
beta = 1.0
# batch_dim = 5
batch_size = 5
n_iters = 4000
mr = 0.5
lam = 1.0
print_frequency = 25
n_epochs = 20
n_steps=10000



In [3]:
def minimizeEnergy(model,steps,optimizer,x,h1,h2,y,target=None,beta=None,print_energy=False):
    energies = []  # List to store the energy values
    for step in range(steps):
        # optimizer.zero_grad()
        # energy = model(x, h1, h2, y, target=target,beta=beta)
        # energy.backward()
        # optimizer.step()

        W_x_h1,W_h1_h2,W_h2_y = model
        dh1 = -h1 + x @ W_x_h1 + h2 @ W_h1_h2.t()
        dh2 = -h2 + h1 @ W_h1_h2 + y @ W_h2_y.t()
        if beta is None:
            dy = -y + h2 @ W_h2_y
        else:
            dy = -y + h2 @ W_h2_y + beta*(target-y)

        h1 = h1 + mr*dh1
        h2 = h2 + mr*dh2
        y = y + mr*dy

        # Restrict values between 0 and 1
        h1.data = torch.clamp(h1.data, 0, 1)
        h2.data = torch.clamp(h2.data, 0, 1)
        y.data = torch.clamp(y.data, 0, 1)

        # energies.append(energy.item())  # Save the energy value

    # Save copy of the internal state variables
    h1_free = h1.detach().clone()
    h2_free = h2.detach().clone()
    y_free = y.detach().clone()

    return h1_free, h2_free, y_free, energies

In [4]:
# Data

images = torch.rand(3, input_size)
states = torch.eye(output_size)


# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(),transforms.Lambda(torch.flatten)])
trainset = torchvision.datasets.MNIST(root='~/datasets', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='~/datasets', train=False, download=True, transform=transform)

# Create split train loader
indices = [[] for label in range(10)]
for idx, target in enumerate(trainset.targets):
    indices[target].append(idx)

trainloader = [iter(DataLoader(Subset(trainset,indices[label]), batch_size=batch_size, shuffle=True, num_workers=2)) for label in range(10)]

# Create split test loader
indices = [[] for label in range(10)]
for idx, target in enumerate(testset.targets):
    indices[target].append(idx)

testloader = [iter(DataLoader(Subset(testset,indices[label]), batch_size=batch_size, shuffle=True, num_workers=2)) for label in range(10)]

# # Plot samples
# plt.figure()
# for idx,(image,label) in enumerate(trainloader[4]):
#     if idx==9:
#         break
#     plt.subplot(3,3,idx+1)
#     plt.imshow(image[0].reshape(28,28).numpy(), cmap=cm.gray)
# plt.show()




#### Experiment 1 ####

Description:

1) Clamp first image and let settle 40 steps
2) Clamp intermediate label of image for relevant half of the output layer. Mask the rest. Let settle 5 steps. Intermediate label consists of 1-hot in the "intermediate half" of the target, and 0's in the "final half".
3) EP update
4) Clamp second image and let settle for 40 steps (output free)
5) Clamp final label on second half of output layer. Mask the other half. Settle for 5 steps
6) EP update

#### Experiment 2 ####

Description:

1) Clamp first image and let settle 40 steps (e.g. image #2 )
2) Clamp intermediate label of image 1 and let settle 5 steps. Intermediate label consists of 1-hot in the "intermediate half" of the target, and 0's in the "final half". 
3) EP update
4) Clamp second image and let settle for 40 steps (output free)
5) Clamp final label, which is a 2-hot with the intermediate label spot AND final label spot as ones. Settle for 5 steps
6) EP update




In [5]:
n_classes = 3

# output_size = n_classes # 12

output_size=12

# Define the model
# model = HopfieldEnergy(input_size, hidden1_size, hidden2_size, output_size, beta=beta, lam=lam)

W_x_h1 = torch.nn.Linear(input_size, hidden1_size).weight.data.t()
W_h1_h2 = torch.nn.Linear(hidden1_size, hidden2_size).weight.data.t()
W_h2_y = torch.nn.Linear(hidden2_size, output_size).weight.data.t()
model = [W_x_h1, W_h1_h2, W_h2_y]
# W_h2_h1 = W
# W_y_h2 = torch.nn.Linear(output_size, hidden2_size)
# W_h2_h1.weight.data = W_h1_h2.weight.data.t()
# W_y_h2.weight.data = W_h2_y.weight.data.t()


# Labels
intermediate_labels = torch.zeros(n_classes,output_size)
for i in range(n_classes):
    intermediate_labels[i,i] = 1

final_labels = torch.zeros(3,3,output_size)
for i in range(3):
    for j in range(3):
        final_labels[i,j,i] = 1
        final_labels[i,j,3+3*i+2] = 1

# Initialize the internal state variables
x = torch.zeros(batch_size,input_size)
h1 = torch.zeros(batch_size, hidden1_size)
h2 = torch.zeros(batch_size, hidden2_size)
y = torch.zeros(batch_size, output_size)
optimizer = optim.SGD([h1, h2, y], lr=mr)

###############
# Training loop
###############

lr1 = 0.01
for epoch in range(n_epochs):

    print("Epoch: ",epoch)
    for itr in range(n_steps):
        w1_update, w2_update, w3_update, b1_update, b2_update, b3_update = [], [], [], [], [], []

        # Pick which input to use as the trigger
        i = itr%n_classes

        # Phase 1 learning
        x.data, _ = next(trainloader[i]) # get MNIST image from class i

        if x.data.shape[0] < batch_size:
            continue

        # Zero out state variables
        h1.data.zero_()
        h2.data.zero_()
        y.data.zero_()
        
        label = intermediate_labels[i]
        target = torch.tile(label,(batch_size,1)).clone().detach()

        # Find where state currently settles to. Nudge towards target
        h1_free, h2_free, y_free, energies = minimizeEnergy(model,40,optimizer,x,h1,h2,y,print_energy=False)
        w1_update.append(x.t()@h1_free)
        w2_update.append(h1_free.t()@h2_free)
        w3_update.append(h2_free.t()@y_free)

        # # Plot energies
        # if itr%10==0:
        #     figure = plt.figure()
        #     plt.plot(energies)
        #     plt.xlabel('Iteration')
        #     plt.ylabel('Energy')
        #     plt.title('Energy vs. Iteration')
        #     plt.show()

        # Print some info
        if itr%100==0:
            print("Iteration: ",itr)
            print("i: ",i)
            # print("target: ",target.detach().numpy())
            # print("y_free: ",y_free.detach().numpy())
            print("Error = ",torch.norm(target-y_free).item())
            print("Fraction correct: ",torch.sum(torch.argmax(y_free,dim=1)==torch.argmax(target,dim=1)).item()/batch_size)

        beta = 1.0*torch.ones(1,output_size)
        # beta = torch.zeros(1,output_size)
        # beta[:,:3] = 0.5
        h1_nudge, h2_nudge, y_nudge, energies2 = minimizeEnergy(model,5,optimizer,x,h1,h2,y,target=target,beta=beta,print_energy=False)
        w1_update.append(x.t()@h1_nudge)
        w2_update.append(h1_nudge.t()@h2_nudge)
        w3_update.append(h2_nudge.t()@y_nudge)

        # Get total update
        w1_update = lr1*(-w1_update[0] + w1_update[1]) 
        w2_update = lr1*(-w2_update[0] + w2_update[1]) 
        w3_update = lr1*(-w3_update[0] + w3_update[1]) 

        W_x_h1 += w1_update
        W_h1_h2 += w2_update
        W_h2_y += w3_update
        W_x_h1 += w1_update
        W_h1_h2 += w2_update
        W_h2_y += w3_update
        # Phase 2 learning
        if itr>1000: # Start phase 2 learning
            # assert(0)
            lr2=0.01

            w1_update, w2_update, w3_update, b1_update, b2_update, b3_update = [], [], [], [], [], []
            # Get second image index
            if i==0:
                j = np.random.choice([1,2])
            elif i==1:
                j = np.random.choice([0,2])
            else:
                j = np.random.choice([0,1])

            x.data = next(trainloader[j])[0].clone().detach()
        
            # Experiment: try continuing to learn intermediate label down phase 2
            label = final_labels[i,j].clone().detach()
            target = torch.tile(label,(batch_size,1)).clone().detach()

            # Find where state currently settles to previous state
            h1_free, h2_free, y_free, energies = minimizeEnergy(model,40,optimizer,x,h1,h2,y,print_energy=False)
            w1_update.append(x.t()@h1_free)
            w2_update.append(h1_free.t()@h2_free)
            w3_update.append(h2_free.t()@y_free)

            if itr%100==0:
                print("i,j: ",i,j)    
                print("Target: ",target.detach().numpy())
                print("y: ",y_nudge.detach().numpy())
                print("Error = ",torch.norm(target-y_free).item())
                print("Fraction correct: ",torch.sum(torch.argmax(y_free[3:],dim=1)==torch.argmax(target[3:],dim=1)).item()/batch_size)


            # Experiment: try continuing to learn intermediate label down phase 2
            beta = 0.5*torch.ones(1,output_size)
            # beta = torch.zeros(1,output_size)
            # beta[:,3:] = 0.5
    
            h1_nudge, h2_nudge, y_nudge, energies2 = minimizeEnergy(model,5,optimizer,x,h1,h2,y,target=target,beta=beta,print_energy=False)
            w1_update.append(x.t()@h1_nudge)
            w2_update.append(h1_nudge.t()@h2_nudge)
            w3_update.append(h2_nudge.t()@y_nudge)

            w1_update = lr2*(-w1_update[0] + w1_update[1]) 
            w2_update = lr2*(-w2_update[0] + w2_update[1])
            w3_update = lr2*(-w3_update[0] + w3_update[1]) 
            
            W_x_h1 += w1_update
            W_h1_h2 += w2_update
            W_h2_y += w3_update
        

    # Testing loop
    for itr in range(1000):
        # Pick which input to use as the trigger
        i = itr%n_classes

        # Phase 1 learning
        x.data, _ = next(testloader[i]) # get MNIST image from class i

        # Zero out state variables
        h1.data.zero_()
        h2.data.zero_()
        y.data.zero_()

        label = intermediate_labels[i]
        target = torch.tile(label,(batch_size,1)).clone().detach()

        # Find where state currently settles to. Nudge towards target
        h1_free, h2_free, y_free, energies = minimizeEnergy(model,40,optimizer,x,h1,h2,y,print_energy=False)

        # Print some info
        if itr%100==0:
            print("Iteration: ",itr)
            print("i: ",i)
            print("Error = ",torch.norm(target-y_free).item())
            print("Fraction correct: ",torch.sum(torch.argmax(y_free,dim=1)==torch.argmax(target,dim=1)).item()/batch_size)

        # Get second image
        if i==0:
            j = np.random.choice([1,2])
        elif i==1:
            j = np.random.choice([0,2])
        else:
            j = np.random.choice([0,1])

        x.data = next(testloader[j])[0].clone().detach()

        label = final_labels[i,j].clone().detach()
        target = torch.tile(label,(batch_size,1)).clone().detach()

        h1_free, h2_free, y_free, energies = minimizeEnergy(model,40,optimizer,x,h1,h2,y,print_energy=False)

        if itr%100==0:
            print("i,j: ",i,j)    
            print("Target: ",target.detach().numpy())
            print("y: ",y_nudge.detach().numpy())
            print("Error = ",torch.norm(target-y_free).item())
            print("Fraction correct: ",torch.sum(torch.argmax(y_free[3:],dim=1)==torch.argmax(target[3:],dim=1)).item()/batch_size)




Iteration:  0
i:  0
Error =  2.2454142570495605
Fraction correct:  0.0
Iteration:  100
i:  1
Error =  1.2729988098144531
Fraction correct:  1.0
Iteration:  200
i:  2
Error =  3.1630144119262695
Fraction correct:  0.0
Iteration:  300
i:  0
Error =  0.6522217988967896
Fraction correct:  1.0
Iteration:  400
i:  1
Error =  1.118826985359192
Fraction correct:  0.8
Iteration:  500
i:  2
Error =  1.166476845741272
Fraction correct:  1.0
Iteration:  600
i:  0
Error =  1.4142135381698608
Fraction correct:  0.8
Iteration:  700
i:  1
Error =  1.7353230714797974
Fraction correct:  0.6
Iteration:  800
i:  2
Error =  1.414214015007019
Fraction correct:  0.8
Iteration:  900
i:  0
Error =  2.0
Fraction correct:  0.6
Iteration:  1000
i:  1
Error =  1.4142135381698608
Fraction correct:  0.8
Iteration:  1100
i:  2
Error =  2.0380115509033203
Fraction correct:  0.6
i,j:  2 1
Target:  [[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0

RuntimeError: The size of tensor a (5) must match the size of tensor b (3) at non-singleton dimension 0