<a href="https://colab.research.google.com/github/arnavvats/pytorch-cnns/blob/master/siamese_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [41]:
import os
os.environ['KAGGLE_USERNAME'] = "your-kaggle-username"
os.environ['KAGGLE_KEY'] = "your-kaggle-key"
from google.colab import drive
drive.mount('/drive', force_remount = True)

Mounted at /drive


In [0]:
# !kaggle datasets download -d watesoyan/omniglot
# !unzip -q omniglot.zip
# !ls

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib.image import imread
import random

In [3]:
drive_dir = './drive/My Drive'
train_im_dir = './images_background'
test_im_dir = './images_evaluation'
train_categories = [[os.path.join(train_im_dir,i,j,k) for k in os.listdir(os.path.join(train_im_dir, i ,j))] 
                    for i in os.listdir(train_im_dir)
                    for j in os.listdir(os.path.join(train_im_dir, i))]
test_categories = [[os.path.join(test_im_dir,i,j,k) for k in os.listdir(os.path.join(test_im_dir, i ,j))] 
                    for i in os.listdir(test_im_dir)
                    for j in os.listdir(os.path.join(test_im_dir, i))]
train_num_cat = len(train_categories)
test_num_cat = len(test_categories)
train_num_cat

964

In [0]:
def get_batch(batch_size, step_no):
  input_list = np.empty((batch_size,2,105,105))
  label_list = np.ones((batch_size,1))
  for i, cat in enumerate(train_categories[step_no * batch_size: (step_no + 1) * batch_size]):
    cat_len = len(cat)
    rand_idx = np.random.randint(cat_len)
    im_1 = imread(cat[rand_idx]).reshape((1,105,105))
    op = 1
    if i % 2:
      
      rand_idx2 = (i + np.random.randint(train_num_cat)) % train_num_cat
      im_2 = imread(train_categories[rand_idx2][rand_idx]).reshape((1,105,105))
      label_list[i] = 0
    else:
      rand_idx = (rand_idx + np.random.randint(cat_len)) % cat_len
      im_2 = imread(cat[rand_idx]).reshape((1,105,105))
    input_list[i] = np.concatenate([im_1, im_2], axis = 0)
  return torch.from_numpy(input_list).float().cuda(), torch.from_numpy(label_list).float().cuda()

def get_test_batch(batch_size):
  input_list = np.empty((batch_size,2,105,105))
  label_list = np.ones((batch_size,1))
  start = np.random.randint(test_num_cat - batch_size)
  for i, cat in enumerate(test_categories[start: start + batch_size]):
    cat_len = len(cat)
    rand_idx = np.random.randint(cat_len)
    im_1 = imread(cat[rand_idx]).reshape((1,105,105))
    op = 1
    if i % 2:
      
      rand_idx2 = (i + np.random.randint(test_num_cat)) % test_num_cat
      im_2 = imread(test_categories[rand_idx2][rand_idx]).reshape((1,105,105))
      label_list[i] = 0
    else:
      rand_idx = (rand_idx + np.random.randint(cat_len)) % cat_len
      im_2 = imread(cat[rand_idx]).reshape((1,105,105))
    input_list[i] = np.concatenate([im_1, im_2], axis = 0)
  return torch.from_numpy(input_list).float().cuda(), torch.from_numpy(label_list).float().cuda()

In [0]:
class SiameseTwin(nn.Module):
  def __init__(self):
    super(SiameseTwin, self).__init__()
    self.conv_1 = nn.Conv2d(1,64,10)
    self.bn_1 = nn.BatchNorm2d(64)
    self.conv_2 = nn.Conv2d(64,128,7)
    self.bn_2 = nn.BatchNorm2d(128)
    self.conv_3 = nn.Conv2d(128, 128, 4)
    self.bn_3 = nn.BatchNorm2d(128)
    self.conv_4 = nn.Conv2d(128, 256, 4)
    self.bn_4 = nn.BatchNorm2d(256)
    self.fc_6 = nn.Linear(9216,4096)
    
  def forward(self, x):
    x = F.relu(self.conv_1(x))
    x = self.bn_1(x)
    x = F.max_pool2d(x, 2, 2)
    x = F.relu(self.conv_2(x))
    x = self.bn_2(x)
    x = F.max_pool2d(x, 2, 2 )
    x = F.relu(self.conv_3(x))
    x = self.bn_3(x)
    x = F.max_pool2d(x,2,2)
    x = F.relu(self.conv_4(x))
    x = self.bn_4(x)
    x = x.view(x.shape[0],-1)
    x = F.sigmoid(self.fc_6(x))
    return x

class SiameseNet(nn.Module):
  def __init__(self):
    super(SiameseNet, self).__init__()
    self.siamese_twin = SiameseTwin()
    self.fc_layer = nn.Linear(4096, 1)
    
  def forward(self, twin_ims):
    x_1, x_2 = twin_ims[:,0:1,:,:], twin_ims[:,1:2,:,:]
    x_1 = self.siamese_twin(x_1)
    x_2 = self.siamese_twin(x_2)
    x = torch.abs(x_1 - x_2)
    x = F.sigmoid(self.fc_layer(x))
    return x

In [10]:
snet = SiameseNet().cuda()
optimizer = optim.Adam(snet.parameters(), lr = 0.0001)
criterion = nn.BCELoss()
epochs = 1000
batch_size = 480
steps = 2
loss_threshold = 0.05
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn_c=1
        for s in list(p.size()):
            nn_c = nn_c*s
        pp += nn_c
    return pp
get_n_params(snet)

38952897

In [0]:
# x = np.random.randn(40,2,105,105)
# x = torch.from_numpy(x).float().cuda()
# y =torch.from_numpy(np.random.rand(40)).float().cuda()
# snet.train()
# z = snet(x)
# l_oss = criterion(z,y)
# l_oss.backward()
# optimizer.step()
# optimizer.zero_grad()

In [31]:
%%time
for epoch in range(epochs):
  random.seed(epoch)
  random.shuffle(train_categories)
  snet.train()
  epoch_loss = 0
  acc = 0
  for step in range(steps):
    optimizer.zero_grad()
    ip_list, op_list = get_batch(batch_size, step)
    op_pred = snet(ip_list)
    loss = criterion(op_pred.view(-1), op_list.view(-1))
    epoch_loss += loss.item()
    loss.backward()
    optimizer.step()
  ip_test, op_test = get_test_batch(batch_size)
  with torch.no_grad():
    op_pred = snet(ip_test)
    op_pred = (op_pred >= 0.5).float()
    acc = (op_pred == op_test).sum().item() * 100 / batch_size
  epoch_avg_loss = epoch_loss / steps
  print('Epoch - {} , loss - {:.6f}, acc - {:.6f}'.format(epoch + 1 - 700, epoch_avg_loss, acc))
  if (acc >= acc_threshold and epoch_avg_loss < loss_threshold):
    break



Epoch - 401 , loss - 0.053400, acc - 87.916667
Epoch - 402 , loss - 0.041312, acc - 89.375000
Epoch - 403 , loss - 0.047595, acc - 86.250000
Epoch - 404 , loss - 0.051490, acc - 90.000000
Epoch - 405 , loss - 0.038701, acc - 87.291667
Epoch - 406 , loss - 0.039975, acc - 92.708333
CPU times: user 11.2 s, sys: 6.19 s, total: 17.4 s
Wall time: 17.5 s


In [43]:
torch.save(snet,os.path.join(drive_dir,'siamese_network.pt'))

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [44]:
torch.save(snet, 'network.pt')

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
