In [None]:
# Author: Arjun Viswanathan
# Date Created: 5/26/22
# Creating a Siamese Neural Network (SNN) architecture using PyTorch. The jupyter notebook version.
# Network takes in 2 sets of inputs, and trains on them to give 2 sets of outputs.
# These outputs are then used to compute a distance, and this is passed into a Dense layer to give the output of the SNN

'''
Log:
5/24: Created file and started the SNN construction. Added in data loading and training to test out with MNIST
database but there were errors in setting up the training.
5/26: After consulting some sample code from a friend, fixed the training code and data loaders. Tested it and it
works, just not very well. Current train accuracy is 10.5% and test accuracy is 11.37%. Will need to adjust parameters.
5/31: Changed up the SNN model and made some big changes to how the MNIST data is made. Currently unable to set up the 
DataLoader, which I will do later. Once that is working, training can be done. 
6/1: Copied from jupyter notebook to google collab, and ran the code and got the data to be made. Now the model has a 
problem in the Convolutional filters. Need to fix.
6/2: Updated the SNN model and removed the MNIST dataloading. Tried making optimizer from paper, but hit a roadblock.
Currently just have the model template and with training and testing. Once data is available, can import and test.
TODO: Update the optimizer, import in data and test everything out to see how well it works. 
'''

In [None]:
import torch
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import random
import math
from sklearn.model_selection import train_test_split
from torch.nn import Module, Conv2d, MaxPool2d, Linear
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [None]:
# SNN class with model
class SiameseNeuralNetwork(Module):
    def __init__(self, start_features):
        super(SiameseNeuralNetwork, self).__init__()
        
        self.fc1 = Linear(in_features=start_features, out_features=16)
        self.fc2 = Linear(in_features=16, out_features=32)
        self.fc3 = Linear(in_features=32, out_features=64)
        
        # create the feature vector for the input data
        self.fc4 = Linear(in_features=64, out_features=256)
        
        # final stage where prediction is made
        self.fc5 = Linear(in_features=256, out_features=1)

    def forward_on_input(self, x):
        return torch.sigmoid(self.fc4(F.relu(self.fc3(F.relu(self.fc2(F.relu(self.fc1(x))))))))

    def forward(self, x1, x2):
        y1 = self.forward_on_input(x1)
        y2 = self.forward_on_input(x2)
        d = torch.abs(y1 - y2)           # L1 siamese distance metric
        p = self.fc5(d)
        return p

In [None]:
# define parameters for training and device
starting_features = 4
batchsize = 16
num_epochs = 10
# If a GPU is available, then send it to that GPU rather than train on CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
device

In [None]:
class CreateSiameseDataset(DataLoader):
    def __init__(self):
        # TODO: Get the data and set the variables accordingly
        self.dataset = None
        self.data = None
        self.labels = None
    
    def __getitem__(self, index):
        coinflip = random.randint(0, 1) # 0: same image, 1: different image
        set1choice = random.randint(0, len(self.labels))
        if coinflip:
            while(True): # keep looping until we find the image we want that either matches or differentiates the original image
                set2choice = random.randint(0, len(self.labels))
                if(self.labels[set1choice] != self.labels[set2choice] and set1choice != set2choice):
                    x1 = self.data[set1choice]
                    x2 = self.data[set2choice]
                    label = torch.tensor(1)
                    break # load them and break for next coin flip
        else:
            while(True):
                set2choice = random.randint(0, len(self.labels))
                if(self.labels[set1choice] == self.labels[set2choice] and set1choice != set2choice):
                    x1 = self.data[set1choice]
                    x2 = self.data[set2choice]
                    label = torch.tensor(0)
                    break
        return x1, x2, label
    
    def __len__(self):
        return len(self.data)

In [None]:
siamesedata = CreateSiameseDataset()
siamesedataloader = DataLoader(siamesedata, shuffle=True, batch_size=batchsize)

In [None]:
model = SiameseNeuralNetwork(starting_features).to(device)
print(model)

In [None]:
# currently using Adam, but will need to change according to the paper on SNNs
opt = torch.optim.Adam(model.parameters(), lr=0.005)

In [None]:
# Compute the accuracy of the model at each epoch
def accuracy(output, target, batch_size):
    corrects = (torch.max(output, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects / batch_size
    return accuracy.item()

In [None]:
# training
for epoch in range(num_epochs):
    train_running_loss = 0.0
    train_accuracy = 0.0
    model = model.train()

    # training step: iterate through the batch and get the images and labels at each x
    for x, (x1, x2, l) in enumerate(siamesedataloader):        
        # sending images and labels to device (GPU or CPU)
        x1 = x1.to(device)
        x2 = x2.to(device)
        l = l.to(device)

        # pass 2 sets of inputs into the snn and gives p, the output
        output = model(x1, x2)
        print(output)
        loss = l*math.log(output) + (1-l)*math.log(1-output)

        opt.zero_grad()
        loss.backward()
        opt.step()

        train_running_loss += loss.detach().item()
        train_accuracy += accuracy(output, labels, batchsize)

    model.eval()
    print('Epoch %d | Loss: %.4f | Train Accuracy: %.2f'%(epoch+1, train_running_loss / x, train_accuracy / x))

In [None]:
test_accuracy = 0.0
for y, (y1, y2, l) in enumerate(siamesedataloader):
    y1 = y1.to(device)
    y2 = y2.to(device)
    l = l.to(device)

    outputs = model(y1, y2)
    test_accuracy += accuracy(outputs, l, batchsize)
print('Test Accuracy: %.2f'%(test_accuracy / y))