# Siamese Network Classifier

## Project description
I needed to classify input images based on certain classes but the problem is that the number of classes might increase or decrease many times, so if we used normal classifier it will require us to retrain the model and that is very impractical, therefore I will try a new approach(Similarity network) which will allow me to add or remove classes on the fly by putting the target classes in a database then comparing the input image with each of the classes image from the database and we can add/remove images to/from database.

In [None]:
"""
Created on Thr Apr 15 9:40:21 2021

@author: Mohab Mohamed
"""

import torchvision
import torch.utils.data as utils
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import time
import copy
from torch.optim import lr_scheduler
import os
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import random
import cv2

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!ls 'drive/MyDrive/PlantVillage' #Image dataset path

In [None]:
!pwd

In [None]:
# variables that will be needed throughout the project
train_path = './drive/MyDrive/PlantVillage/train'
valid_path = './drive/MyDrive/PlantVillage/validate'
test_path = './drive/MyDrive/PlantVillage/test'

batch_size = 16
lr = 0.001
dim = 224

In [None]:
# Some utility functions
def imshow(img,text=None,should_save=False):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()    

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

In [None]:
# This class is made to handle our costum dataset.
class SiameseNetworkDataset(Dataset):
    
    def __init__(self,rootFolder,iterations,transform=None):
        self.data = []
        self.rootFolder = rootFolder
        self.iterations = iterations
        self.transform = transform
        self.make_pairwise_set()
        self.data_len = iterations * 2
    
    def make_pairwise_set(self):
        img1Index = None
        img2Index = None
        class1Index = None
        class2Index = None
        img1 = None
        img2 = None
        imgDataClasses = os.listdir(self.rootFolder)

        # Form similar pairs
        for i in range(self.iterations):
            randomClass = random.choice(imgDataClasses) # Choose a random class from classes list 
            imgPath = os.path.join(self.rootFolder, randomClass)
            imgs = os.listdir(imgPath) #images in that class
            randomImg = random.choice(imgs)  # Choose a random image from that class
            img1Index = imgs.index(randomImg) # Save the index of the random img
            img1 = os.path.join(imgPath, randomImg) # first image of that class
            while True:
                randomImg = random.choice(imgs)
                img2Index = imgs.index(randomImg)
                if(img1Index != img2Index):
                    break;
            img2 = os.path.join(imgPath, randomImg)
            self.data.append([img1, img2, 1])

        # Form different pairs
        for i in range(self.iterations):
            randomClass = random.choice(imgDataClasses) # Choose a random class from classes list 
            class1Index = imgDataClasses.index(randomClass) # save the index of that random class
            imgPath = os.path.join(self.rootFolder, randomClass)
            imgs = os.listdir(imgPath) #images in that class
            randomImg = random.choice(imgs)  # Choose a random image from that class
            img1 = os.path.join(imgPath, randomImg) # first image of that class
            while True:
                randomClass = random.choice(imgDataClasses) # Choose a random class from classes list 
                class2Index = imgDataClasses.index(randomClass) # save the index of that random class 
                if(class1Index != class2Index):
                    break;
            imgPath = os.path.join(self.rootFolder, randomClass)
            imgs = os.listdir(imgPath) #images in that class
            randomImg = random.choice(imgs)  # Choose a random image from that class
            img2 = os.path.join(imgPath, randomImg)
            self.data.append([img1, img2, 0])
        # Note the the result is a list that have consecutive similar pairs then followed by 
        # consecutive different pairs which is bad when training the model 
        # so we need to shuffle the result of this function.
    
    def __getitem__(self,index):
        
        # getting the image path
        img1Path = self.data[index][0]
        img2Path = self.data[index][1]
        label = self.data[index][2]
        
        
        # Loading the image
        img1 = cv2.imread(img1Path)
        img2 = cv2.imread(img2Path)
        
        # Apply image transformations
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        
        return img1, img2 , label
    
    def __len__(self):
        return self.data_len

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        
        # Setting up the Sequential of CNN Layers
        self.cnn1 = nn.Sequential(
            
            nn.Conv2d(3, 96, kernel_size=11,stride=1),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(5,alpha=0.0001,beta=0.75,k=2),
            nn.MaxPool2d(2, stride=2),
            
            nn.Conv2d(96, 256, kernel_size=3,stride=2,padding=1),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(5,alpha=0.0001,beta=0.75,k=2),
            nn.MaxPool2d(2, stride=2),

        )
        
        # Defining the fully connected layers
        self.fc1 = nn.Sequential(
            nn.Linear(256*27*27, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
            
            nn.Linear(1024, 128),
            nn.ReLU(inplace=True),
            
            nn.Linear(128,10))
        
  
  
    def forward_once(self, x):
        # Forward pass 
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        # forward pass of input 1
        output1 = self.forward_once(input1)
        # forward pass of input 2
        output2 = self.forward_once(input2)
        return output1, output2

In [None]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

In [None]:
if torch.cuda.is_available():
    print('Yes')

In [None]:
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [None]:
path = 'D:/NoteBook/G.P/PlantVillage/train'
train_dataset = SiameseNetworkDataset(rootFolder=train_path, iterations=200, transform=transforms.Compose([
                                                                                        transforms.ToPILImage(),
                                                                                        transforms.ToTensor()]))

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=8, num_workers=0, shuffle=True)
batch = next(iter(train_dataloader))
concatenated = torch.cat((batch[0], batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated))
print(batch[2].numpy())

In [None]:
# Declare Siamese Network
net = SiameseNetwork().to(device) #if we are using device 
# Decalre Loss Function
criterion = ContrastiveLoss()
# Declare Optimizer
optimizer = optim.RMSprop(net.parameters(), lr=1e-4, alpha=0.99, eps=1e-8, weight_decay=0.0005, momentum=0.9)

In [None]:
def train():
    counter = []
    loss_history = [] 
    iteration_number= 0
    
    for epoch in range(0, 25):
        for i, data in enumerate(train_dataloader,0):
            img0, img1 , label = data
            img0, img1 , label = img0.cuda(), img1.cuda() , label.cuda()
            optimizer.zero_grad()
            output1,output2 = net(img0,img1)
            loss_contrastive = criterion(output1,output2,label)
            loss_contrastive.backward()
            optimizer.step()
            if i %50 == 0 :
                print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.item()))
                iteration_number +=10
                counter.append(iteration_number)
                loss_history.append(loss_contrastive.item())
    return net

In [None]:
model = train()

In [None]:
torch.FloatTensor([0])