In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev)  

In [2]:
batch_size_train = 40
batch_size_test = 40
learning_rate = 0.01
log_interval = 10

num_vectors = 4
len_vectors = 10
img_height = 28
img_width = 28
win_size = 3
epsilon = .7
epochs = 4000
steps = 20

## Data Functions

In [3]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [4]:
def data_to_state(example_data,batch_size):
    #reshape MNIST input image and stack image data until the bottom state vector has the entire vector filled
    temp_example_data = torch.reshape(example_data,(batch_size,img_height,img_width))
    temp_inp = [temp_example_data for i in range(10)]
    temp_inp_data = torch.stack((temp_inp),dim = 3)
    return temp_inp_data.to(device)

In [5]:
def targets_to_state(example_targets,batch_size):
    #one hot MNIST target data and uniformly spread correct target vector across every pixel of the state per batch
    temp_out_state = torch.nn.functional.one_hot(example_targets,num_classes=10).repeat(1,img_height*img_width)
    temp_out_state = temp_out_state.view((batch_size,img_height,img_width,10))
    return temp_out_state.float().to(device)

## CA Functions

In [6]:
def init_state(batch_size,img_height,img_width,num_vectors,len_vectors):
    state = torch.rand((batch_size,img_height,img_width,num_vectors,len_vectors))*.1
    return state.to(device)

In [7]:
class model(nn.Module):
    def __init__(self, num_inp,num_out):
        super(model, self).__init__()
        self.Q1 = nn.Linear(num_inp, num_out)
        self.K1 = nn.Linear(num_out, num_out)
        self.V1 = nn.Linear(num_out, num_out)
        
        self.m = nn.Dropout(p=0.1)
        
        self.act = nn.LeakyReLU()
        self.act1 = nn.Tanh()
        self.act3 = nn.GELU()
    def forward(self,x):
        
        Q = self.act1(self.Q1(self.m(x)))
        K = self.act1(self.K1(self.m(Q)))+Q
        V = self.act1(self.V1(self.m(K)))+Q+K
        
        return V*.01

In [8]:
'''class model(nn.Module):
    def __init__(self, num_inp,num_out):
        super(model, self).__init__()
        self.Q1 = nn.Linear(num_inp, num_out)
        self.K1 = nn.Linear(num_inp, num_out)
        self.V1 = nn.Linear(num_inp, num_out)
        
        self.act = nn.LeakyReLU()
        self.act1 = nn.Sigmoid()
    def forward(self,x):
        
        Q = self.act(self.Q1(x))
        K = self.act(self.K1(x))
        V = self.act1(self.V1(x))
        
        relevance = F.softmax((Q*K),dim=1)
        out = relevance*V*.01
        return out'''

'class model(nn.Module):\n    def __init__(self, num_inp,num_out):\n        super(model, self).__init__()\n        self.Q1 = nn.Linear(num_inp, num_out)\n        self.K1 = nn.Linear(num_inp, num_out)\n        self.V1 = nn.Linear(num_inp, num_out)\n        \n        self.act = nn.LeakyReLU()\n        self.act1 = nn.Sigmoid()\n    def forward(self,x):\n        \n        Q = self.act(self.Q1(x))\n        K = self.act(self.K1(x))\n        V = self.act1(self.V1(x))\n        \n        relevance = F.softmax((Q*K),dim=1)\n        out = relevance*V*.01\n        return out'

In [9]:
def get_layer_attention(center_matrix,roll_matrix,after_tri):
    #dot product vectors to find similarity of adjacent vectors
    after_mul = torch.matmul(center_matrix,roll_matrix.permute((0,1,2,4,3)))
    after_diag = torch.diagonal(after_mul, offset=0, dim1=3, dim2=4)
    
    #multiply vectors by lambda matrix to find full attention numbers
    after_eps = torch.matmul(after_diag,after_tri)
    
    #stack full attention numbers so that each vector gets its proper attention number
    after_sim = torch.stack([after_eps for i in range(len_vectors)],dim=3).permute((0,1,2,4,3))
    
    #multiply each vector by the attention numbers to complete the attention step
    full_vec_dis = center_matrix*after_sim.detach()
    
    return full_vec_dis

In [10]:
def compute_all(bottom_up_model_list,top_down_model_list,layer_att_model_list,state,len_vectors,num_vectors,batch_size):
    #shift state to in 9 directions along the x and y plane
    roll1 = torch.roll(state, shifts=(-1,-1), dims=(1,2)).to(device)
    roll2 = torch.roll(state, shifts=(-1,0), dims=(1,2)).to(device)
    roll3 = torch.roll(state, shifts=(-1,1), dims=(1,2)).to(device)
    roll4 = torch.roll(state, shifts=(0,-1), dims=(1,2)).to(device)
    roll5 = torch.roll(state, shifts=(0,0), dims=(1,2)).to(device)
    roll6 = torch.roll(state, shifts=(0,1), dims=(1,2)).to(device)
    roll7 = torch.roll(state, shifts=(1,-1), dims=(1,2)).to(device)
    roll8 = torch.roll(state, shifts=(1,0), dims=(1,2)).to(device)
    roll9 = torch.roll(state, shifts=(1,1), dims=(1,2)).to(device)
    roll_list = [roll1,roll2,roll3,roll4,roll5,roll6,roll7,roll8,roll9]
    
    eps_matrix = epsilon**torch.arange(start = 1,end = num_vectors+1)
    try_roll = [torch.roll(eps_matrix, shifts=(i), dims=(0)) for i in range(eps_matrix.shape[0])]
    try_roll = torch.stack(try_roll)
    after_tri = torch.triu(try_roll, diagonal=0).T.to(device)
    
    att_list = []
    for roll in roll_list:
        att_list.append(get_layer_attention(roll,roll5,after_tri))
    
    #concatenate vectors so that att_list contains the state and every adjacent vector on the same vector level
    att_list = torch.cat(roll_list,dim=4)
    
    #feed layers to models:
    #top-down models don't get first two layers as input and don't add to 1st & last layer
    #bot-up models don't get last layer as input and don't add to first layer 
    #adjacent models don't add to first layer
    delta = [torch.zeros((batch_size*img_height*img_width,len_vectors)).to(device) for i in range(num_vectors)]
    for i in range(num_vectors):
        if(i<num_vectors-2):
            top_down_temp = top_down_model_list[i](torch.reshape(att_list[:,:,:,i+2,:],(-1,len_vectors*9)))
            delta[i+1] = delta[i+1] + top_down_temp
        if(i<num_vectors-1):
            bottom_up_temp = bottom_up_model_list[i](torch.reshape(att_list[:,:,:,i,:],(-1,len_vectors*9)))
            att_layer_temp = layer_att_model_list[i](torch.reshape(att_list[:,:,:,i+1,:],(-1,len_vectors*9)))
            delta[i+1] = delta[i+1] + bottom_up_temp + att_layer_temp
    
    #format delta so that delta and state can be added together
    delta = torch.stack(delta,dim=1)#.permute(0,2,1)#.permute(2,0,1)
    delta = torch.reshape(delta,(batch_size,img_height,img_width,num_vectors,len_vectors))
    return delta

## Model Initializations

In [11]:
bottom_up_model_list = [model(9*len_vectors,len_vectors).cuda() for i in range(num_vectors-1)]
top_down_model_list= [model(9*len_vectors,len_vectors).cuda() for i in range(num_vectors-2)]
layer_att_model_list = [model(9*len_vectors,len_vectors).cuda() for i in range(num_vectors-1)]

## Optimizers

In [12]:
#create parameter list of every model to feed into optimizer
param_list = []
for i in range(num_vectors):
    if(i<num_vectors-2):
        param_list =  param_list + list(top_down_model_list[i].parameters())
    if(i<num_vectors-1):
        param_list =  param_list + list(bottom_up_model_list[i].parameters())
        param_list =  param_list + list(layer_att_model_list[i].parameters())
        
optimizer = torch.optim.Adam(param_list,lr=learning_rate)
mse = nn.MSELoss()

## Train

In [13]:
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
for epoch in range(epochs):
    optimizer.zero_grad()
    
    #in case StopIteration error is raised
    try:
        batch_idx, (example_data, example_targets) = next(examples)
    except StopIteration:
        examples = enumerate(train_loader)
        batch_idx, (example_data, example_targets) = next(examples)
    
    #initialize state
    state = init_state(batch_size_train,img_height,img_width,num_vectors,len_vectors)
    
    #put current batches into state
    state[:,:,:,0,:] = data_to_state(example_data,batch_size_train)
    state1 = torch.clone(state)
    state2 = torch.clone(state)
    for step in range(steps):
        
        delta = compute_all(bottom_up_model_list,top_down_model_list,layer_att_model_list,state,len_vectors,num_vectors,batch_size_train)
        
        state = state + delta + .0001*torch.rand((state.shape)).to(device)
        
        #add first state to state in the middle of the steps (allows for RESNET type gradient backprop)
        if(step%int(steps/2)==0):
            state = state + state1*.1
            state1 = torch.clone(state)
            
        if(step%int(steps/4)==0):
            state = state + state2*.1
            state2 = torch.clone(state)
            
    state = state + state1*.1
    #get loss
    pred_out = state[:,:,:,-1]
    targ_out = targets_to_state(example_targets,batch_size_train)
    loss = mse(pred_out,targ_out)
    loss.backward()
    optimizer.step()
    print("Epoch: {}/{}  Loss: {}".format(epoch,epochs,loss))

Epoch: 0/4000  Loss: 0.10321268439292908
Epoch: 1/4000  Loss: 0.10136249661445618
Epoch: 2/4000  Loss: 0.09457285702228546
Epoch: 3/4000  Loss: 0.09660812467336655
Epoch: 4/4000  Loss: 0.09551694244146347
Epoch: 5/4000  Loss: 0.09583022445440292
Epoch: 6/4000  Loss: 0.0946187749505043
Epoch: 7/4000  Loss: 0.09446649253368378
Epoch: 8/4000  Loss: 0.09359311312437057
Epoch: 9/4000  Loss: 0.09175699949264526
Epoch: 10/4000  Loss: 0.09468197077512741
Epoch: 11/4000  Loss: 0.09278279542922974
Epoch: 12/4000  Loss: 0.09359242767095566
Epoch: 13/4000  Loss: 0.09432879090309143
Epoch: 14/4000  Loss: 0.0947674810886383
Epoch: 15/4000  Loss: 0.09292217344045639
Epoch: 16/4000  Loss: 0.09323766082525253
Epoch: 17/4000  Loss: 0.09382543712854385
Epoch: 18/4000  Loss: 0.09198030084371567
Epoch: 19/4000  Loss: 0.0932554379105568
Epoch: 20/4000  Loss: 0.09323876351118088
Epoch: 21/4000  Loss: 0.09369466453790665
Epoch: 22/4000  Loss: 0.09437286108732224
Epoch: 23/4000  Loss: 0.09256337583065033
Epoch

Epoch: 194/4000  Loss: 0.09329531341791153
Epoch: 195/4000  Loss: 0.09310983121395111
Epoch: 196/4000  Loss: 0.0919378399848938
Epoch: 197/4000  Loss: 0.09090103209018707
Epoch: 198/4000  Loss: 0.09204366058111191
Epoch: 199/4000  Loss: 0.09274604916572571
Epoch: 200/4000  Loss: 0.09165804088115692
Epoch: 201/4000  Loss: 0.09287750720977783
Epoch: 202/4000  Loss: 0.09141170233488083
Epoch: 203/4000  Loss: 0.09204485267400742
Epoch: 204/4000  Loss: 0.09228608757257462
Epoch: 205/4000  Loss: 0.09161413460969925
Epoch: 206/4000  Loss: 0.0915072038769722
Epoch: 207/4000  Loss: 0.09190233796834946
Epoch: 208/4000  Loss: 0.0918222963809967
Epoch: 209/4000  Loss: 0.09180464595556259
Epoch: 210/4000  Loss: 0.09125807881355286
Epoch: 211/4000  Loss: 0.0932232216000557
Epoch: 212/4000  Loss: 0.09154403954744339
Epoch: 213/4000  Loss: 0.09078962355852127
Epoch: 214/4000  Loss: 0.09183607995510101
Epoch: 215/4000  Loss: 0.09084717929363251
Epoch: 216/4000  Loss: 0.09262114763259888
Epoch: 217/4000

Epoch: 386/4000  Loss: 0.09088332206010818
Epoch: 387/4000  Loss: 0.09022027254104614
Epoch: 388/4000  Loss: 0.0918322429060936
Epoch: 389/4000  Loss: 0.09049263596534729
Epoch: 390/4000  Loss: 0.0902155414223671
Epoch: 391/4000  Loss: 0.09047416597604752
Epoch: 392/4000  Loss: 0.08935123682022095
Epoch: 393/4000  Loss: 0.08935222029685974
Epoch: 394/4000  Loss: 0.09106779098510742
Epoch: 395/4000  Loss: 0.09067178517580032
Epoch: 396/4000  Loss: 0.09154273569583893
Epoch: 397/4000  Loss: 0.09192746877670288
Epoch: 398/4000  Loss: 0.09026152640581131
Epoch: 399/4000  Loss: 0.0898633822798729
Epoch: 400/4000  Loss: 0.08936546742916107
Epoch: 401/4000  Loss: 0.08986923843622208
Epoch: 402/4000  Loss: 0.08981212973594666
Epoch: 403/4000  Loss: 0.09035404026508331
Epoch: 404/4000  Loss: 0.08991740643978119
Epoch: 405/4000  Loss: 0.08912413567304611
Epoch: 406/4000  Loss: 0.08910933881998062
Epoch: 407/4000  Loss: 0.09047206491231918
Epoch: 408/4000  Loss: 0.08891571313142776
Epoch: 409/400

Epoch: 577/4000  Loss: 0.08729908615350723
Epoch: 578/4000  Loss: 0.08591903746128082
Epoch: 579/4000  Loss: 0.08621564507484436
Epoch: 580/4000  Loss: 0.08882558345794678
Epoch: 581/4000  Loss: 0.08676157146692276
Epoch: 582/4000  Loss: 0.08838438987731934
Epoch: 583/4000  Loss: 0.08655644208192825
Epoch: 584/4000  Loss: 0.08537385612726212
Epoch: 585/4000  Loss: 0.08812476694583893
Epoch: 586/4000  Loss: 0.08381310850381851
Epoch: 587/4000  Loss: 0.0849883034825325
Epoch: 588/4000  Loss: 0.08695357292890549
Epoch: 589/4000  Loss: 0.08671139925718307
Epoch: 590/4000  Loss: 0.08066825568675995
Epoch: 591/4000  Loss: 0.08351542055606842
Epoch: 592/4000  Loss: 0.0781075730919838
Epoch: 593/4000  Loss: 0.08812496811151505
Epoch: 594/4000  Loss: 0.08369216322898865
Epoch: 595/4000  Loss: 0.08723274618387222
Epoch: 596/4000  Loss: 0.0885172188282013
Epoch: 597/4000  Loss: 0.08545143902301788
Epoch: 598/4000  Loss: 0.08501023799180984
Epoch: 599/4000  Loss: 0.08612746745347977
Epoch: 600/400

Epoch: 769/4000  Loss: 0.07947947084903717
Epoch: 770/4000  Loss: 0.08158915489912033
Epoch: 771/4000  Loss: 0.08436916023492813
Epoch: 772/4000  Loss: 0.08602634817361832
Epoch: 773/4000  Loss: 0.08259988576173782
Epoch: 774/4000  Loss: 0.08057041466236115
Epoch: 775/4000  Loss: 0.08187002688646317
Epoch: 776/4000  Loss: 0.08148850500583649
Epoch: 777/4000  Loss: 0.07307077199220657
Epoch: 778/4000  Loss: 0.07960382848978043
Epoch: 779/4000  Loss: 0.08045174926519394
Epoch: 780/4000  Loss: 0.0765957459807396
Epoch: 781/4000  Loss: 0.08131542056798935
Epoch: 782/4000  Loss: 0.08258761465549469
Epoch: 783/4000  Loss: 0.08476648479700089
Epoch: 784/4000  Loss: 0.0765707939863205
Epoch: 785/4000  Loss: 0.07909763604402542
Epoch: 786/4000  Loss: 0.08143530786037445
Epoch: 787/4000  Loss: 0.08322245627641678
Epoch: 788/4000  Loss: 0.08433188498020172
Epoch: 789/4000  Loss: 0.07898516952991486
Epoch: 790/4000  Loss: 0.08147455006837845
Epoch: 791/4000  Loss: 0.08719731122255325
Epoch: 792/40

Epoch: 961/4000  Loss: 0.07591858506202698
Epoch: 962/4000  Loss: 0.08046063780784607
Epoch: 963/4000  Loss: 0.07827375829219818
Epoch: 964/4000  Loss: 0.08202394098043442
Epoch: 965/4000  Loss: 0.07100838422775269
Epoch: 966/4000  Loss: 0.07819759100675583
Epoch: 967/4000  Loss: 0.07789712399244308
Epoch: 968/4000  Loss: 0.0815587043762207
Epoch: 969/4000  Loss: 0.07426569610834122
Epoch: 970/4000  Loss: 0.07975751906633377
Epoch: 971/4000  Loss: 0.07817047834396362
Epoch: 972/4000  Loss: 0.07843761146068573
Epoch: 973/4000  Loss: 0.07566546648740768
Epoch: 974/4000  Loss: 0.07878945767879486
Epoch: 975/4000  Loss: 0.08003821223974228
Epoch: 976/4000  Loss: 0.0795222818851471
Epoch: 977/4000  Loss: 0.08487842231988907
Epoch: 978/4000  Loss: 0.07276003807783127
Epoch: 979/4000  Loss: 0.07298888266086578
Epoch: 980/4000  Loss: 0.07393545657396317
Epoch: 981/4000  Loss: 0.07737789303064346
Epoch: 982/4000  Loss: 0.07727126032114029
Epoch: 983/4000  Loss: 0.06815146654844284
Epoch: 984/40

Epoch: 1149/4000  Loss: 0.06830207258462906
Epoch: 1150/4000  Loss: 0.07320572435855865
Epoch: 1151/4000  Loss: 0.0736960917711258
Epoch: 1152/4000  Loss: 0.07423801720142365
Epoch: 1153/4000  Loss: 0.0810517966747284
Epoch: 1154/4000  Loss: 0.07753532379865646
Epoch: 1155/4000  Loss: 0.071119025349617
Epoch: 1156/4000  Loss: 0.06732376664876938
Epoch: 1157/4000  Loss: 0.07231616973876953
Epoch: 1158/4000  Loss: 0.07813926041126251
Epoch: 1159/4000  Loss: 0.07423152029514313
Epoch: 1160/4000  Loss: 0.07091126590967178
Epoch: 1161/4000  Loss: 0.07855209708213806
Epoch: 1162/4000  Loss: 0.07397174090147018
Epoch: 1163/4000  Loss: 0.06978953629732132
Epoch: 1164/4000  Loss: 0.07583308964967728
Epoch: 1165/4000  Loss: 0.07918842881917953
Epoch: 1166/4000  Loss: 0.0642746090888977
Epoch: 1167/4000  Loss: 0.07656363397836685
Epoch: 1168/4000  Loss: 0.08177110552787781
Epoch: 1169/4000  Loss: 0.07761993259191513
Epoch: 1170/4000  Loss: 0.0751468688249588
Epoch: 1171/4000  Loss: 0.074629463255

Epoch: 1337/4000  Loss: 0.07901620864868164
Epoch: 1338/4000  Loss: 0.0712910145521164
Epoch: 1339/4000  Loss: 0.07628428190946579
Epoch: 1340/4000  Loss: 0.07457837462425232
Epoch: 1341/4000  Loss: 0.07676943391561508
Epoch: 1342/4000  Loss: 0.07197993248701096
Epoch: 1343/4000  Loss: 0.0772627592086792
Epoch: 1344/4000  Loss: 0.07630008459091187
Epoch: 1345/4000  Loss: 0.08005514740943909
Epoch: 1346/4000  Loss: 0.07300898432731628
Epoch: 1347/4000  Loss: 0.07179111987352371
Epoch: 1348/4000  Loss: 0.07397327572107315
Epoch: 1349/4000  Loss: 0.06996393203735352
Epoch: 1350/4000  Loss: 0.07710861414670944
Epoch: 1351/4000  Loss: 0.07734660059213638
Epoch: 1352/4000  Loss: 0.07413861155509949
Epoch: 1353/4000  Loss: 0.0757332593202591
Epoch: 1354/4000  Loss: 0.0740388035774231
Epoch: 1355/4000  Loss: 0.07011960446834564
Epoch: 1356/4000  Loss: 0.07333242893218994
Epoch: 1357/4000  Loss: 0.07041393220424652
Epoch: 1358/4000  Loss: 0.07522780448198318
Epoch: 1359/4000  Loss: 0.0730437487

Epoch: 1524/4000  Loss: 0.07020752131938934
Epoch: 1525/4000  Loss: 0.07604548335075378
Epoch: 1526/4000  Loss: 0.06451593339443207
Epoch: 1527/4000  Loss: 0.08005904406309128
Epoch: 1528/4000  Loss: 0.074042409658432
Epoch: 1529/4000  Loss: 0.06885313242673874
Epoch: 1530/4000  Loss: 0.07139123976230621
Epoch: 1531/4000  Loss: 0.07628243416547775
Epoch: 1532/4000  Loss: 0.07008298486471176
Epoch: 1533/4000  Loss: 0.06985438615083694
Epoch: 1534/4000  Loss: 0.07814318686723709
Epoch: 1535/4000  Loss: 0.06790624558925629
Epoch: 1536/4000  Loss: 0.07255350053310394
Epoch: 1537/4000  Loss: 0.07296613603830338
Epoch: 1538/4000  Loss: 0.07038705050945282
Epoch: 1539/4000  Loss: 0.0682181790471077
Epoch: 1540/4000  Loss: 0.07567550987005234
Epoch: 1541/4000  Loss: 0.07757914811372757
Epoch: 1542/4000  Loss: 0.06975187361240387
Epoch: 1543/4000  Loss: 0.07773090153932571
Epoch: 1544/4000  Loss: 0.06684625148773193
Epoch: 1545/4000  Loss: 0.07240377366542816
Epoch: 1546/4000  Loss: 0.079595655

Epoch: 1711/4000  Loss: 0.07623446732759476
Epoch: 1712/4000  Loss: 0.07662879675626755
Epoch: 1713/4000  Loss: 0.06468161940574646
Epoch: 1714/4000  Loss: 0.07142443209886551
Epoch: 1715/4000  Loss: 0.07690116763114929
Epoch: 1716/4000  Loss: 0.07363688945770264
Epoch: 1717/4000  Loss: 0.07043996453285217
Epoch: 1718/4000  Loss: 0.07578761130571365
Epoch: 1719/4000  Loss: 0.0746845155954361
Epoch: 1720/4000  Loss: 0.07346243411302567
Epoch: 1721/4000  Loss: 0.06842707097530365
Epoch: 1722/4000  Loss: 0.06777752935886383
Epoch: 1723/4000  Loss: 0.0772550106048584
Epoch: 1724/4000  Loss: 0.07782977819442749
Epoch: 1725/4000  Loss: 0.07485732436180115
Epoch: 1726/4000  Loss: 0.07090433686971664
Epoch: 1727/4000  Loss: 0.07702258229255676
Epoch: 1728/4000  Loss: 0.07795606553554535
Epoch: 1729/4000  Loss: 0.0657481923699379
Epoch: 1730/4000  Loss: 0.07650886476039886
Epoch: 1731/4000  Loss: 0.06936831027269363
Epoch: 1732/4000  Loss: 0.06768736243247986
Epoch: 1733/4000  Loss: 0.073421292

Epoch: 1898/4000  Loss: 0.07467687129974365
Epoch: 1899/4000  Loss: 0.0755477249622345
Epoch: 1900/4000  Loss: 0.07094127684831619
Epoch: 1901/4000  Loss: 0.07296208292245865
Epoch: 1902/4000  Loss: 0.07780623435974121
Epoch: 1903/4000  Loss: 0.06705640256404877
Epoch: 1904/4000  Loss: 0.06933819502592087
Epoch: 1905/4000  Loss: 0.0701659768819809
Epoch: 1906/4000  Loss: 0.07651611417531967
Epoch: 1907/4000  Loss: 0.076727494597435
Epoch: 1908/4000  Loss: 0.0722038596868515
Epoch: 1909/4000  Loss: 0.06436270475387573
Epoch: 1910/4000  Loss: 0.06669064611196518
Epoch: 1911/4000  Loss: 0.07343043386936188
Epoch: 1912/4000  Loss: 0.07058174163103104
Epoch: 1913/4000  Loss: 0.06396966427564621
Epoch: 1914/4000  Loss: 0.06893692910671234
Epoch: 1915/4000  Loss: 0.06139763072133064
Epoch: 1916/4000  Loss: 0.07422707974910736
Epoch: 1917/4000  Loss: 0.06680520623922348
Epoch: 1918/4000  Loss: 0.06593760848045349
Epoch: 1919/4000  Loss: 0.070194773375988
Epoch: 1920/4000  Loss: 0.0636430755257

Epoch: 2085/4000  Loss: 0.0727337896823883
Epoch: 2086/4000  Loss: 0.07242204248905182
Epoch: 2087/4000  Loss: 0.07311418652534485
Epoch: 2088/4000  Loss: 0.07026442885398865
Epoch: 2089/4000  Loss: 0.06846913695335388
Epoch: 2090/4000  Loss: 0.07148254662752151
Epoch: 2091/4000  Loss: 0.06386708468198776
Epoch: 2092/4000  Loss: 0.07323165237903595
Epoch: 2093/4000  Loss: 0.06715308129787445
Epoch: 2094/4000  Loss: 0.06587245315313339
Epoch: 2095/4000  Loss: 0.07224493473768234
Epoch: 2096/4000  Loss: 0.07196290045976639
Epoch: 2097/4000  Loss: 0.06458552181720734
Epoch: 2098/4000  Loss: 0.06442558765411377
Epoch: 2099/4000  Loss: 0.0742039754986763
Epoch: 2100/4000  Loss: 0.0717552974820137
Epoch: 2101/4000  Loss: 0.06641125679016113
Epoch: 2102/4000  Loss: 0.07635568082332611
Epoch: 2103/4000  Loss: 0.06940298527479172
Epoch: 2104/4000  Loss: 0.07137645781040192
Epoch: 2105/4000  Loss: 0.06233539804816246
Epoch: 2106/4000  Loss: 0.07141536474227905
Epoch: 2107/4000  Loss: 0.070391252

Epoch: 2272/4000  Loss: 0.06488823890686035
Epoch: 2273/4000  Loss: 0.08069884032011032
Epoch: 2274/4000  Loss: 0.07563440501689911
Epoch: 2275/4000  Loss: 0.06672883778810501
Epoch: 2276/4000  Loss: 0.07716654986143112
Epoch: 2277/4000  Loss: 0.0708020031452179
Epoch: 2278/4000  Loss: 0.06655461341142654
Epoch: 2279/4000  Loss: 0.06758612394332886
Epoch: 2280/4000  Loss: 0.06988424062728882
Epoch: 2281/4000  Loss: 0.06913954019546509
Epoch: 2282/4000  Loss: 0.07115544378757477
Epoch: 2283/4000  Loss: 0.07237748056650162
Epoch: 2284/4000  Loss: 0.06518969684839249
Epoch: 2285/4000  Loss: 0.0657176747918129
Epoch: 2286/4000  Loss: 0.07334768027067184
Epoch: 2287/4000  Loss: 0.06955043226480484
Epoch: 2288/4000  Loss: 0.05968274176120758
Epoch: 2289/4000  Loss: 0.07062317430973053
Epoch: 2290/4000  Loss: 0.06677638739347458
Epoch: 2291/4000  Loss: 0.06798321008682251
Epoch: 2292/4000  Loss: 0.07295054942369461
Epoch: 2293/4000  Loss: 0.06629840284585953
Epoch: 2294/4000  Loss: 0.06775254

Epoch: 2459/4000  Loss: 0.06668081134557724
Epoch: 2460/4000  Loss: 0.06810262054204941
Epoch: 2461/4000  Loss: 0.07303255051374435
Epoch: 2462/4000  Loss: 0.06809356063604355
Epoch: 2463/4000  Loss: 0.07173208147287369
Epoch: 2464/4000  Loss: 0.0719727948307991
Epoch: 2465/4000  Loss: 0.07017198950052261
Epoch: 2466/4000  Loss: 0.0641447901725769
Epoch: 2467/4000  Loss: 0.06741342693567276
Epoch: 2468/4000  Loss: 0.0654260665178299
Epoch: 2469/4000  Loss: 0.06782197207212448
Epoch: 2470/4000  Loss: 0.07469362765550613
Epoch: 2471/4000  Loss: 0.07052794098854065
Epoch: 2472/4000  Loss: 0.0671553909778595
Epoch: 2473/4000  Loss: 0.06631594151258469
Epoch: 2474/4000  Loss: 0.07022303342819214
Epoch: 2475/4000  Loss: 0.0680738165974617
Epoch: 2476/4000  Loss: 0.06734590232372284
Epoch: 2477/4000  Loss: 0.07454539090394974
Epoch: 2478/4000  Loss: 0.06590185314416885
Epoch: 2479/4000  Loss: 0.07012983411550522
Epoch: 2480/4000  Loss: 0.06685515493154526
Epoch: 2481/4000  Loss: 0.06376840919

Epoch: 2646/4000  Loss: 0.0650564506649971
Epoch: 2647/4000  Loss: 0.06974335014820099
Epoch: 2648/4000  Loss: 0.06616483628749847
Epoch: 2649/4000  Loss: 0.06635895371437073
Epoch: 2650/4000  Loss: 0.06145063787698746
Epoch: 2651/4000  Loss: 0.06950217485427856
Epoch: 2652/4000  Loss: 0.061681050807237625
Epoch: 2653/4000  Loss: 0.06335914880037308
Epoch: 2654/4000  Loss: 0.06537897884845734
Epoch: 2655/4000  Loss: 0.06565651297569275
Epoch: 2656/4000  Loss: 0.06484347581863403
Epoch: 2657/4000  Loss: 0.06164619326591492
Epoch: 2658/4000  Loss: 0.06867964565753937
Epoch: 2659/4000  Loss: 0.06564883142709732
Epoch: 2660/4000  Loss: 0.0697576031088829
Epoch: 2661/4000  Loss: 0.05849063768982887
Epoch: 2662/4000  Loss: 0.058773916214704514
Epoch: 2663/4000  Loss: 0.07069063931703568
Epoch: 2664/4000  Loss: 0.06812994182109833
Epoch: 2665/4000  Loss: 0.07094935327768326
Epoch: 2666/4000  Loss: 0.06430641561746597
Epoch: 2667/4000  Loss: 0.07187747955322266
Epoch: 2668/4000  Loss: 0.070051

Epoch: 2833/4000  Loss: 0.06164524704217911
Epoch: 2834/4000  Loss: 0.07001731544733047
Epoch: 2835/4000  Loss: 0.0696522518992424
Epoch: 2836/4000  Loss: 0.07034541666507721
Epoch: 2837/4000  Loss: 0.0679258331656456
Epoch: 2838/4000  Loss: 0.06666864454746246
Epoch: 2839/4000  Loss: 0.06696581095457077
Epoch: 2840/4000  Loss: 0.06577656418085098
Epoch: 2841/4000  Loss: 0.06249977648258209
Epoch: 2842/4000  Loss: 0.07021261751651764
Epoch: 2843/4000  Loss: 0.06913068145513535
Epoch: 2844/4000  Loss: 0.06940396130084991
Epoch: 2845/4000  Loss: 0.06741585582494736
Epoch: 2846/4000  Loss: 0.061328623443841934
Epoch: 2847/4000  Loss: 0.066268190741539
Epoch: 2848/4000  Loss: 0.06118873506784439
Epoch: 2849/4000  Loss: 0.06909457594156265
Epoch: 2850/4000  Loss: 0.06587783247232437
Epoch: 2851/4000  Loss: 0.06775001436471939
Epoch: 2852/4000  Loss: 0.07005221396684647
Epoch: 2853/4000  Loss: 0.06503928452730179
Epoch: 2854/4000  Loss: 0.060693930834531784
Epoch: 2855/4000  Loss: 0.07391534

Epoch: 3020/4000  Loss: 0.05766933411359787
Epoch: 3021/4000  Loss: 0.07201765477657318
Epoch: 3022/4000  Loss: 0.06390584260225296
Epoch: 3023/4000  Loss: 0.0655197724699974
Epoch: 3024/4000  Loss: 0.06475019454956055
Epoch: 3025/4000  Loss: 0.07402034848928452
Epoch: 3026/4000  Loss: 0.06784839928150177
Epoch: 3027/4000  Loss: 0.06316623091697693
Epoch: 3028/4000  Loss: 0.06928691267967224
Epoch: 3029/4000  Loss: 0.06987304985523224
Epoch: 3030/4000  Loss: 0.07382519543170929
Epoch: 3031/4000  Loss: 0.07249805331230164
Epoch: 3032/4000  Loss: 0.06740865856409073
Epoch: 3033/4000  Loss: 0.06657862663269043
Epoch: 3034/4000  Loss: 0.061383362859487534
Epoch: 3035/4000  Loss: 0.07554277777671814
Epoch: 3036/4000  Loss: 0.06254354864358902
Epoch: 3037/4000  Loss: 0.07312685251235962
Epoch: 3038/4000  Loss: 0.059766024351119995
Epoch: 3039/4000  Loss: 0.06521061062812805
Epoch: 3040/4000  Loss: 0.059512920677661896
Epoch: 3041/4000  Loss: 0.06607599556446075
Epoch: 3042/4000  Loss: 0.0683

Epoch: 3207/4000  Loss: 0.05927978456020355
Epoch: 3208/4000  Loss: 0.0654987171292305
Epoch: 3209/4000  Loss: 0.06410887837409973
Epoch: 3210/4000  Loss: 0.06148678436875343
Epoch: 3211/4000  Loss: 0.06989294290542603
Epoch: 3212/4000  Loss: 0.06584673374891281
Epoch: 3213/4000  Loss: 0.06404870748519897
Epoch: 3214/4000  Loss: 0.060174934566020966
Epoch: 3215/4000  Loss: 0.06443342566490173
Epoch: 3216/4000  Loss: 0.06681175529956818
Epoch: 3217/4000  Loss: 0.06195351108908653
Epoch: 3218/4000  Loss: 0.07531113922595978
Epoch: 3219/4000  Loss: 0.061600614339113235
Epoch: 3220/4000  Loss: 0.05898717790842056
Epoch: 3221/4000  Loss: 0.059912409633398056
Epoch: 3222/4000  Loss: 0.058980975300073624
Epoch: 3223/4000  Loss: 0.06303148716688156
Epoch: 3224/4000  Loss: 0.06163798272609711
Epoch: 3225/4000  Loss: 0.0626315325498581
Epoch: 3226/4000  Loss: 0.07030140608549118
Epoch: 3227/4000  Loss: 0.06475839018821716
Epoch: 3228/4000  Loss: 0.0584048330783844
Epoch: 3229/4000  Loss: 0.06238

Epoch: 3394/4000  Loss: 0.06571153551340103
Epoch: 3395/4000  Loss: 0.07116924971342087
Epoch: 3396/4000  Loss: 0.06152695044875145
Epoch: 3397/4000  Loss: 0.06889533996582031
Epoch: 3398/4000  Loss: 0.0665552169084549
Epoch: 3399/4000  Loss: 0.05888380482792854
Epoch: 3400/4000  Loss: 0.06258944422006607
Epoch: 3401/4000  Loss: 0.06571535766124725
Epoch: 3402/4000  Loss: 0.06394007802009583
Epoch: 3403/4000  Loss: 0.07340330630540848
Epoch: 3404/4000  Loss: 0.07047102600336075
Epoch: 3405/4000  Loss: 0.0708736851811409
Epoch: 3406/4000  Loss: 0.0630941241979599
Epoch: 3407/4000  Loss: 0.06288161128759384
Epoch: 3408/4000  Loss: 0.05998174473643303
Epoch: 3409/4000  Loss: 0.06126205623149872
Epoch: 3410/4000  Loss: 0.07447715103626251
Epoch: 3411/4000  Loss: 0.06330674141645432
Epoch: 3412/4000  Loss: 0.06578841805458069
Epoch: 3413/4000  Loss: 0.06899488717317581
Epoch: 3414/4000  Loss: 0.06552809476852417
Epoch: 3415/4000  Loss: 0.0657578706741333
Epoch: 3416/4000  Loss: 0.0659280493

Epoch: 3581/4000  Loss: 0.06713538616895676
Epoch: 3582/4000  Loss: 0.06799549609422684
Epoch: 3583/4000  Loss: 0.0657721534371376
Epoch: 3584/4000  Loss: 0.06163011118769646
Epoch: 3585/4000  Loss: 0.06507756561040878
Epoch: 3586/4000  Loss: 0.06868917495012283
Epoch: 3587/4000  Loss: 0.06724993884563446
Epoch: 3588/4000  Loss: 0.06481121480464935
Epoch: 3589/4000  Loss: 0.06316833198070526
Epoch: 3590/4000  Loss: 0.06689245253801346
Epoch: 3591/4000  Loss: 0.07220029085874557
Epoch: 3592/4000  Loss: 0.0712134912610054
Epoch: 3593/4000  Loss: 0.06113286316394806
Epoch: 3594/4000  Loss: 0.05968332663178444
Epoch: 3595/4000  Loss: 0.06988837569952011
Epoch: 3596/4000  Loss: 0.0746484026312828
Epoch: 3597/4000  Loss: 0.06568508595228195
Epoch: 3598/4000  Loss: 0.06588908284902573
Epoch: 3599/4000  Loss: 0.06492936611175537
Epoch: 3600/4000  Loss: 0.05942518636584282
Epoch: 3601/4000  Loss: 0.06038407236337662
Epoch: 3602/4000  Loss: 0.06814278662204742
Epoch: 3603/4000  Loss: 0.054816808

Epoch: 3768/4000  Loss: 0.06077013909816742
Epoch: 3769/4000  Loss: 0.06632696837186813
Epoch: 3770/4000  Loss: 0.06578480452299118
Epoch: 3771/4000  Loss: 0.06558914482593536
Epoch: 3772/4000  Loss: 0.06818516552448273
Epoch: 3773/4000  Loss: 0.07096659392118454
Epoch: 3774/4000  Loss: 0.06619960069656372
Epoch: 3775/4000  Loss: 0.06417538970708847
Epoch: 3776/4000  Loss: 0.06780359148979187
Epoch: 3777/4000  Loss: 0.06327719986438751
Epoch: 3778/4000  Loss: 0.06574267148971558
Epoch: 3779/4000  Loss: 0.0640116035938263
Epoch: 3780/4000  Loss: 0.0656091570854187
Epoch: 3781/4000  Loss: 0.0684729591012001
Epoch: 3782/4000  Loss: 0.054033976048231125
Epoch: 3783/4000  Loss: 0.06088325381278992
Epoch: 3784/4000  Loss: 0.06335524469614029
Epoch: 3785/4000  Loss: 0.07046666741371155
Epoch: 3786/4000  Loss: 0.06852425634860992
Epoch: 3787/4000  Loss: 0.06154949590563774
Epoch: 3788/4000  Loss: 0.06364157795906067
Epoch: 3789/4000  Loss: 0.06343720108270645
Epoch: 3790/4000  Loss: 0.06778678

Epoch: 3955/4000  Loss: 0.06845325976610184
Epoch: 3956/4000  Loss: 0.06350281834602356
Epoch: 3957/4000  Loss: 0.056838616728782654
Epoch: 3958/4000  Loss: 0.05672081932425499
Epoch: 3959/4000  Loss: 0.06076483428478241
Epoch: 3960/4000  Loss: 0.06445345282554626
Epoch: 3961/4000  Loss: 0.060202497988939285
Epoch: 3962/4000  Loss: 0.0662613958120346
Epoch: 3963/4000  Loss: 0.060260284692049026
Epoch: 3964/4000  Loss: 0.06723660230636597
Epoch: 3965/4000  Loss: 0.059104111045598984
Epoch: 3966/4000  Loss: 0.06747552752494812
Epoch: 3967/4000  Loss: 0.06714053452014923
Epoch: 3968/4000  Loss: 0.06934776157140732
Epoch: 3969/4000  Loss: 0.06489871442317963
Epoch: 3970/4000  Loss: 0.06649025529623032
Epoch: 3971/4000  Loss: 0.06655577570199966
Epoch: 3972/4000  Loss: 0.05788418650627136
Epoch: 3973/4000  Loss: 0.06642498821020126
Epoch: 3974/4000  Loss: 0.06664076447486877
Epoch: 3975/4000  Loss: 0.05738156661391258
Epoch: 3976/4000  Loss: 0.06453266739845276
Epoch: 3977/4000  Loss: 0.066

## Test

In [None]:
tot_corr = 0
tot_batches = 0
for example_data, target in test_loader:
    tot_batches+=batch_size_test
    
    #initialize state
    state = init_state(batch_size_test,img_height,img_width,num_vectors,len_vectors)
    
    #put current batches into state
    state[:,:,:,0,:] = data_to_state(example_data,batch_size_test)
    for step in range(steps):
        delta = compute_all(bottom_up_model_list,top_down_model_list,layer_att_model_list,state,len_vectors,num_vectors,batch_size_test)
        
        #update state
        state = state + delta
        
    for batch in range(batch_size_test):
        temp = torch.zeros((10))
        for height in range(img_height):
            for width in range(img_width):
                ind = torch.argmax(state[batch,height,width,-1])
                temp[ind]+=1
        
        if(target[batch] == torch.argmax(temp)):
            tot_corr+=1
    print("Acc: {}".format(tot_corr/tot_batches))
print("Final Accuracy: {}".format(tot_corr/tot_batches))

Acc: 0.425
Acc: 0.425
Acc: 0.43333333333333335
Acc: 0.4375
Acc: 0.485
Acc: 0.5041666666666667
Acc: 0.48928571428571427
Acc: 0.4625
Acc: 0.4722222222222222
Acc: 0.475
Acc: 0.4818181818181818
Acc: 0.4791666666666667
Acc: 0.4846153846153846
Acc: 0.48035714285714287
Acc: 0.4766666666666667
Acc: 0.4734375
Acc: 0.47205882352941175
Acc: 0.46944444444444444
Acc: 0.47368421052631576
Acc: 0.475
Acc: 0.4773809523809524
Acc: 0.48068181818181815
Acc: 0.48586956521739133
Acc: 0.48541666666666666
Acc: 0.488
Acc: 0.4855769230769231
Acc: 0.4824074074074074
Acc: 0.475
Acc: 0.47844827586206895
Acc: 0.47833333333333333
Acc: 0.47580645161290325
Acc: 0.47890625
Acc: 0.4772727272727273
Acc: 0.4772058823529412
Acc: 0.47714285714285715
Acc: 0.47638888888888886
Acc: 0.4756756756756757
Acc: 0.47960526315789476
Acc: 0.47628205128205126
Acc: 0.47625
Acc: 0.47621951219512193
Acc: 0.4791666666666667
Acc: 0.4808139534883721
Acc: 0.48011363636363635
Acc: 0.4811111111111111
Acc: 0.48043478260869565
Acc: 0.4803191489361

In [None]:
def save_models(bottom_up_model_list,top_down_model_list,layer_att_model_list):
    PATH = "best_models/"
    for i in range(num_vectors):
        if(i<num_vectors-2):
            torch.save(top_down_model_list[i],PATH+"top_down_model{}".format(i))
        if(i<num_vectors-1):
            torch.save(bottom_up_model_list[i],PATH+"bottom_up_model{}".format(i))
            torch.save(layer_att_model_list[i],PATH+"layer_att_model{}".format(i))

In [None]:
save_models(bottom_up_model_list,top_down_model_list,layer_att_model_list)