In [70]:
import random
from PIL import Image
import PIL.ImageOps
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms, datasets, models
import torch.nn.functional as F
from collections import OrderedDict

%matplotlib inline

In [71]:
mnist = datasets.MNIST(root='./data',
                       train=True,
                       transform = transforms.ToTensor(),
                       download=True)

mnist_test = datasets.MNIST(root='./data',
                         train=False,
                         transform = transforms.ToTensor(),
                         download=True)

In [48]:
batch_size = 16
trainloader = torch.utils.data.DataLoader(dataset = mnist, batch_size = batch_size, shuffle=True)
validloader = torch.utils.data.DataLoader(dataset = mnist_test, batch_size = batch_size, shuffle=False)

In [49]:
iterator = iter(trainloader)

In [50]:
img = next(iterator)

In [51]:
img[0].shape

torch.Size([16, 1, 28, 28])

In [73]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        
        self.conv1 = nn.Conv2d(1, 8, 3,padding=1) 
        self.conv2 = nn.Conv2d(8,16, 3, padding=1) 
        self.conv3 = nn.Conv2d(16,32,3, padding=1)
        
        self.pool = nn.MaxPool2d(2, stride=2)

        self.fc1 = nn.Linear(32*7*7, 10)
        
    def flatten(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x))) 
        x = self.pool(F.relu(self.conv3(x)))
        x = self.flatten(x)
        x = self.fc1(x)
        return x

In [90]:
cnn = Net()
print(cnn)

Net(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1568, out_features=10, bias=True)
)


In [75]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr = 0.001)

In [76]:
img = next(iter(trainloader))
img[0].shape

torch.Size([16, 1, 28, 28])

In [77]:
valid_loss_min = np.Inf
trainloss = []
validloss = []
acc_list = []

for e in range(5):
    train_loss = 0.0
    valid_loss = 0.0
    
    cnn.train()
    
    for images,labels in trainloader:
        optimizer.zero_grad()
        out = cnn(images)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*images.size(0)
            
    cnn.eval()
    correct = 0.0
    total = 0.0
    
    for images, labels in validloader:
        out = cnn(images)
        loss = criterion(out, labels)
        valid_loss += loss.item()*images.size(0)
        _, pred = torch.max(out.data, 1)
        correct += (pred == labels).sum()
        total += labels.size(0)
        
    
    train_loss = train_loss/len(trainloader.sampler)
    valid_loss = valid_loss/len(validloader.sampler)
    accuracy = 100.00 * float(correct)/float(total)
    
    trainloss.append(train_loss)
    validloss.append(valid_loss)
    acc_list.append(accuracy)
        
    print("Epoch : {}, Train Loss : {} Test Loss : {}, Accuracy : {}".format(e, train_loss, valid_loss, accuracy))
    
    if valid_loss<valid_loss_min:
        torch.save(cnn.state_dict(), 'cnn_mnist.pt')
        valid_loss_min = valid_loss
        print("Model Saved")

Epoch : 0, Train Loss : 0.14263708613465229 Test Loss : 0.04635742404162884, Accuracy : 98.5
Model Saved
Epoch : 1, Train Loss : 0.05039366958836714 Test Loss : 0.04814526725038886, Accuracy : 98.44
Epoch : 2, Train Loss : 0.03681088288128376 Test Loss : 0.031637274439260364, Accuracy : 99.01
Model Saved
Epoch : 3, Train Loss : 0.028308930354130766 Test Loss : 0.029638546760380267, Accuracy : 99.08
Model Saved
Epoch : 4, Train Loss : 0.022287394273281098 Test Loss : 0.03178395670354366, Accuracy : 98.97


In [91]:
cnn.load_state_dict(torch.load('cnn_mnist.pt'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [94]:
cnn.fc1 = nn.Linear(32*26*26,256)

In [95]:
print(cnn)

Net(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=21632, out_features=256, bias=True)
)


In [96]:
class SiameseNetwork(nn.Module):
    def __init__(self, model1):
        super(SiameseNetwork, self).__init__()
        
        self.model1 = model1
        self.fc = nn.Linear(256,16)
    
    def forward_once(self, x):
        out = self.model1(x)
        out = F.relu(self.fc(out))
        return out
    
    def forward(self, input1, input2):
        out1 = self.forward_once(input1)
        out2 = self.forward_once(input2)
        return out1,out2

In [97]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    From - https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese-networks-medium.ipynb
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

In [98]:
class SiameseNetworkDataset(Dataset):
    '''
    Generates paires of images and a similarity label
    From - https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese-networks-medium.ipynb
    '''
    def __init__(self,imageFolderDataset,transform=None,should_invert=True):
        self.imageFolderDataset = imageFolderDataset    
        self.transform = transform
        self.should_invert = should_invert
        
    def __getitem__(self,index):
        img0_tuple = random.choice(self.imageFolderDataset.imgs)
        #we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0,1) 
        if should_get_same_class:
            while True:
                #keep looping till the same class image is found
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1]==img1_tuple[1]:
                    break
        else:
            img1_tuple = random.choice(self.imageFolderDataset.imgs)

        img0 = Image.open(img0_tuple[0])
        img1 = Image.open(img1_tuple[0])
        img0 = img0.convert("L")
        img1 = img1.convert("L")
        
        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
        
        return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)

In [99]:
traindatafolder = datasets.ImageFolder(root='./omniglot/python/images_background/')
testdatafolder = datasets.ImageFolder(root='./omniglot/python/images_evaluation/')

In [100]:
omniglot_background = SiameseNetworkDataset(imageFolderDataset = traindatafolder, 
                                            transform = transforms.ToTensor(), 
                                            should_invert=False)

In [101]:
train_loader = torch.utils.data.DataLoader(omniglot_background,
                                         shuffle=True,
                                         batch_size = 16)

In [102]:
img = next(iter(train_loader))

In [103]:
img[0].shape

torch.Size([16, 1, 105, 105])

In [104]:
Net = SiameseNetwork(cnn)
criterion = ContrastiveLoss()
optimizer = optim.Adam(Net.parameters(),lr = 0.001)

In [115]:
#the model has already been trained for 10 epochs

In [106]:
losses = []

for e in range(20):
    train_loss = 0.0
    
    for img1, img2, label in train_loader:
        optimizer.zero_grad()
        out1, out2 = Net(img1, img2)
        CL = criterion(out1, out2, label)
        CL.backward()
        optimizer.step()
        losses.append(CL.item())
    
    print("Epoch : {}, Loss : {}".format(e+1, CL.item()))

Epoch : 1, Loss : 1.0098159313201904
Epoch : 2, Loss : 1.0020575523376465
Epoch : 3, Loss : 1.033799648284912
Epoch : 4, Loss : 1.1034233570098877
Epoch : 5, Loss : 1.0151464939117432
Epoch : 6, Loss : 1.020507574081421
Epoch : 7, Loss : 1.0159090757369995
Epoch : 8, Loss : 1.0273383855819702
Epoch : 9, Loss : 0.8139676451683044
Epoch : 10, Loss : 1.0430740118026733
Epoch : 11, Loss : 0.9940325617790222
Epoch : 12, Loss : 1.0299056768417358
Epoch : 13, Loss : 1.0219473838806152
Epoch : 14, Loss : 1.0035247802734375
Epoch : 15, Loss : 1.0327041149139404
Epoch : 16, Loss : 1.0033955574035645
Epoch : 17, Loss : 0.9590462446212769
Epoch : 18, Loss : 1.0147136449813843
Epoch : 19, Loss : 0.9390965700149536
Epoch : 20, Loss : 1.096144199371338


In [107]:
torch.save(Net.state_dict(), 'siamese_network.pt')

In [108]:
Net.load_state_dict(torch.load('siamese_network.pt'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [109]:
omniglot_eval = SiameseNetworkDataset(imageFolderDataset=testdatafolder,
                                        transform=transforms.ToTensor()
                                       ,should_invert=False)

testloader = torch.utils.data.DataLoader(omniglot_eval,
                                         batch_size=16,
                                         shuffle=False)

In [110]:
img = next(iter(testloader))

In [111]:
img[0].shape

torch.Size([16, 1, 105, 105])

In [112]:
test_losses = []
for img1, img2, label in testloader:
        out1, out2 = Net(img1, img2)
        CL = criterion(out1, out2, label)
        test_losses.append(CL.item())

In [113]:
print(sum(test_losses)/len(test_losses))

1.026609688924933
