In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev)  

In [None]:
batch_size_train = 24
batch_size_test = 1000
learning_rate = 0.01
log_interval = 10

num_vectors = 4
len_vectors = 10
img_height = 28
img_width = 28
batch_size = batch_size_train
win_size = 3
epsilon = .7
epochs = 500
steps = 15

## Data Functions

In [None]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [None]:
def data_to_state(example_data):
    temp_example_data = torch.reshape(example_data,(batch_size,img_height,img_width))
    temp_inp = [temp_example_data for i in range(10)]
    temp_inp_data = torch.stack((temp_inp),dim = 3)
    return temp_inp_data.to(device)

In [None]:
def targets_to_state(example_targets):
    temp_out_state = torch.nn.functional.one_hot(example_targets,num_classes=10).repeat(1,img_height*img_width)
    temp_out_state = temp_out_state.view((batch_size,img_height,img_width,10))
    return temp_out_state.float()

## CA Functions

In [None]:
def init_state(batch_size,img_height,img_width,num_vectors,len_vectors):
    state = torch.rand((batch_size,img_height,img_width,num_vectors,len_vectors))*.1
    return state.to(device)

In [None]:
class model(nn.Module):
    def __init__(self, num_inp,num_out):
        super(model, self).__init__()
        self.Q1 = nn.Linear(num_inp, num_out)
        self.K1 = nn.Linear(num_inp, num_out)
        self.V1 = nn.Linear(num_inp, num_out)
        
        self.act = nn.LeakyReLU()
        self.act1 = nn.Sigmoid()
    def forward(self,x):
        
        Q = self.act(self.Q1(x))
        K = self.act(self.K1(x))
        V = self.act1(self.V1(x))
        
        relevance = F.softmax((Q*K),dim=1)
        out = relevance*V
        return out

In [None]:
def compute_all(bottom_up_model_list,top_down_model_list,layer_att_model_list,state,len_vectors,num_vectors,batch_size):
    #shift state to in 9 directions along the x and y plane
    roll1 = torch.roll(state, shifts=(-1,-1), dims=(1,2)).to(device)
    roll2 = torch.roll(state, shifts=(-1,0), dims=(1,2)).to(device)
    roll3 = torch.roll(state, shifts=(-1,1), dims=(1,2)).to(device)
    roll4 = torch.roll(state, shifts=(0,-1), dims=(1,2)).to(device)
    roll5 = torch.roll(state, shifts=(0,0), dims=(1,2)).to(device)
    roll6 = torch.roll(state, shifts=(0,1), dims=(1,2)).to(device)
    roll7 = torch.roll(state, shifts=(1,-1), dims=(1,2)).to(device)
    roll8 = torch.roll(state, shifts=(1,0), dims=(1,2)).to(device)
    roll9 = torch.roll(state, shifts=(1,1), dims=(1,2)).to(device)
    roll_list = [roll1,roll2,roll3,roll4,roll5,roll6,roll7,roll8,roll9]
    
    #concatenate vectors so that att_list contains the state and every adjacent vector on the same vector level
    att_list = torch.cat(roll_list,dim=4)
    
    #feed layers to models:
    #top-down models don't get first two layers as input and don't add to 1st & last layer
    #bot-up models don't get last layer as input and don't add to first layer 
    #adjacent models don't add to first layer
    delta = [torch.zeros((batch_size*img_height*img_width,len_vectors)).to(device) for i in range(num_vectors)]
    for i in range(num_vectors):
        if(i<num_vectors-2):
            top_down_temp = top_down_model_list[i](torch.reshape(att_list[:,:,:,i+2,:],(-1,len_vectors*9)))
            delta[i+1] = delta[i+1] + top_down_temp
        if(i<num_vectors-1):
            bottom_up_temp = bottom_up_model_list[i](torch.reshape(att_list[:,:,:,i,:],(-1,len_vectors*9)))
            att_layer_temp = layer_att_model_list[i](torch.reshape(att_list[:,:,:,i+1,:],(-1,len_vectors*9)))
            delta[i+1] = delta[i+1] + bottom_up_temp + att_layer_temp
    
    #format delta so that delta and state can be added together
    delta = torch.stack(delta,dim=1)
    delta = torch.reshape(delta,(batch_size,img_height,img_width,num_vectors,len_vectors))
    return delta

## Model Initializations

In [None]:
bottom_up_model_list = [model(9*len_vectors,len_vectors).cuda() for i in range(num_vectors-1)]
top_down_model_list= [model(9*len_vectors,len_vectors).cuda() for i in range(num_vectors-2)]
layer_att_model_list = [model(9*len_vectors,len_vectors).cuda() for i in range(num_vectors-1)]

## Optimizers

In [None]:
param_list = []
for i in range(num_vectors):
    if(i<num_vectors-2):
        param_list =  param_list + list(top_down_model_list[i].parameters())
    if(i<num_vectors-1):
        param_list =  param_list + list(bottom_up_model_list[i].parameters())
        param_list =  param_list + list(layer_att_model_list[i].parameters())
        
optimizer = torch.optim.Adadelta(param_list, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
mse = nn.MSELoss()

## Train

In [None]:
examples = enumerate(train_loader)
for epoch in range(epochs):
    optimizer.zero_grad()
    
    #get batch of images
    batch_idx, (example_data, example_targets) = next(examples)
    
    #initialize state
    state = init_state(batch_size,img_height,img_width,num_vectors,len_vectors)
    
    #put current batches into state
    state[:,:,:,0,:] = data_to_state(example_data)
    for step in range(steps):

        delta = compute_all(bottom_up_model_list,top_down_model_list,layer_att_model_list,state,len_vectors,num_vectors,batch_size)
        
        state = state + delta
        
    #get loss
    pred_out = state[:,:,:,-1]
    targ_out = targets_to_state(example_targets).to(device)
    loss = mse(pred_out,targ_out)
    loss.backward()
    optimizer.step()
    print("Epoch: {}/{}  Loss: {}".format(epoch,epochs,loss))

## Test

In [None]:
tot_corr = 0
tot_batches = 0
for data, target in test_loader:
    tot_batches+=batch_size
    optimizer.zero_grad()
    
    #get batch of images
    batch_idx, (data, target) = next(examples)
    
    #initialize state
    state = init_state(batch_size,img_height,img_width,num_vectors,len_vectors)
    
    #put current batches into state
    state[:,:,:,0,:] = data_to_state(example_data)
    state = state.to(device)
    for step in range(steps):
        delta = compute_all(bottom_up_model_list,top_down_model_list,layer_att_model_list,state,len_vectors,num_vectors,batch_size)
        
        #update state
        state = state + delta
        
    for batch in range(batch_size):
        temp = torch.zeros((10))
        for height in range(img_height):
            for width in range(img_width):
                ind = torch.argmax(state[batch,height,width,-1])
                temp[ind]+=1
        
        if(target[batch] == torch.argmax(temp)):
            tot_corr+=1
    print("Acc: {}".format(tot_corr/tot_batches))