In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sys
import torchvision
import random
from torchvision import datasets, transforms

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_dataset = datasets.MNIST(root="data", download=True, 
                               train=True, transform=transforms.Compose([
    transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,)),
     transforms.Resize((14,14))
]))

In [4]:
test_dataset = datasets.MNIST(root="data", download=True,train=False, transform=transforms.Compose([
    transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Resize((14,14))
]))

In [36]:
class ffmodel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(196,40)
        self.layernorm1 = nn.LayerNorm(40, elementwise_affine=False)

        self.linear2 = nn.Linear(40,40)
        self.layernorm2 = nn.LayerNorm(40, elementwise_affine=False)
        
        self.linear3 = nn.Linear(40,40)
        self.layernorm3 = nn.LayerNorm(40, elementwise_affine=False)
        
        self.linear4 = nn.Linear(40,40)
        self.layernorm4 = nn.LayerNorm(40, elementwise_affine=False)
        

    def forward(self, X, 
                #valency #-1 or 0 or 1 depending on goodness or badness of data or the other way around idk
                ):
        X = X.flatten(start_dim=1)
        
        X = self.linear1(X)
        X = self.layernorm1(X)
        X = torch.relu(X)
        ssqa1 = torch.sum(X**2)
        X = X.detach() # Detaching X resets the tree that autograd uses to compute gradients. Thus parameter updates will not proceed further down than the local layer

        X = self.linear2(X)
        X = self.layernorm2(X)
        X = torch.relu(X)
        ssqa2 = torch.sum(X**2)
        X = X.detach()
        
        X = self.linear3(X)
        X = self.layernorm3(X)
        X = torch.relu(X)
        ssqa3 = torch.sum(X**2)
        X = X.detach()

        X = self.linear4(X)
        X = self.layernorm4(X)
        X = torch.relu(X)
        ssqa4 = torch.sum(X**2)
        
        
        return X, (ssqa1, ssqa2, ssqa3, ssqa4)

In [37]:
model = ffmodel()
model.load_state_dict(torch.load("good_weights.pt"))
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

arr = np.zeros((10,9))
for row_index in range(10):
    arr[row_index] = [i for i in range(10) if i != row_index]

for epoch in range(100):
    print(f'starting epoch {epoch}')
    eval()
    for index, (image, label) in enumerate(train_dataset):
        
        image[:,0,:10] = torch.zeros_like(image[:,0,:10])

        if(random.uniform(0,1) < 0.5): # Positive example
            valency = -1
            image[:,0,label] = 1

        else: # Negative example
            valency = 1
            image[:,0, random.choice(arr[label]).astype(np.int16)] = 1
        if(index == 10000):
            break
        
       

        optimizer.zero_grad()
        out, ssqs = model(image)
        for ssq in ssqs:
            (ssq * valency).backward()
            
        optimizer.step()
        

starting epoch 0
Error rate is 0.929070929070929
starting epoch 1
Error rate is 0.5944055944055944
starting epoch 2
Error rate is 0.3546453546453546
starting epoch 3
Error rate is 0.3356643356643357
starting epoch 4
Error rate is 0.3646353646353646
starting epoch 5
Error rate is 0.3016983016983017


KeyboardInterrupt: 

In [26]:
# Eval
def eval():
    total = 0
    incorrect = 0
    with torch.no_grad():
        model.eval()

        for index, (image, label) in enumerate(test_dataset):
            
            activities_list = []
            for c in range(10):
                image[:,0,:10] = torch.zeros_like(image[:,0,:10])
                image[:,0,c] = 1

                out, ssqs = model(image)
                activities_list.append((c, ssqs[1] + ssqs[2] + ssqs[3]))
            predicted_label = max(activities_list, key=lambda tup: tup[1])[0]
            
            total += 1
            if(predicted_label != label):
                incorrect += 1
                #print(predicted_label, label)

            if(index == 1000):
                break
    model.train()

    print(f"Error rate is {incorrect/total}") 