# Image matching with Siamese networks
## Dependencies

In [1]:
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch import optim
from torchvision import transforms
from torch.utils.data import DataLoader

import utilities as utils 
from siameseDataloader import readDataFolder,dataSplits,SiameseDataset
from siameseModel import SimpleSiameseNetwork
from siameseLosses import ContrastiveLoss

In [2]:
config = utils.load_config()
datadir = config['DATASET']['root']
numclasses = config['DATASET']['numclasses']
sameprob = config['DATASET']['sameprob']

seed = config['seed']
np.random.seed(seed)

## Dataset

In [3]:
dataset = readDataFolder(datadir,numclasses)
train_split,val_split,test_split = dataSplits(dataset,0.7,0.2,0.1,)


In [4]:
train_transforms = transforms.Compose([
    transforms.Grayscale(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
    ])
test_transforms = train_transforms = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor()
    ])

train_dataset = SiameseDataset(train_split,train_transforms,sameprob)
val_dataset = SiameseDataset(val_split,test_transforms,0.5)
test_dataset = SiameseDataset(test_split,test_transforms,0.5)

datasets = {'train':train_dataset,'val':val_dataset,'test':test_dataset}

In [5]:
#dataloaders
dataloaders = {
    'train': DataLoader(train_dataset,shuffle=True,
                        num_workers=config['TRAIN']['numworkers'],
                        batch_size=config['TRAIN']['batchsize']),
    'val': DataLoader(val_dataset,shuffle=False,
                        num_workers=config['TRAIN']['numworkers'],
                        batch_size=config['TRAIN']['batchsize']),
    'test': DataLoader(test_dataset,shuffle=False,
                        num_workers=config['TRAIN']['numworkers'],
                        batch_size=config['TRAIN']['batchsize'])
}

## Model

In [6]:
if config['MODEL']['model'] == 'simple':
    model = SimpleSiameseNetwork(distance_type=config['MODEL']['distance'],pretrained=config['MODEL']['pretrained'])
else:
    model = SimpleSiameseNetwork()

In [7]:
data = train_dataset[0]
out = model(data[0].reshape(1,1,64,64),data[1].reshape(1,1,64,64))
print(out)
print(data[2])


tensor([1.0000], grad_fn=<DivBackward0>)
tensor([0], dtype=torch.int32)


## Train

In [8]:
criterion = ContrastiveLoss(margin=config['TRAIN']['lossmargin'])
optimizer = optim.Adam(model.parameters(),lr = config['TRAIN']['lr'] )

In [9]:
device = config['TRAIN']['device']
model = model.to(device)

In [10]:
loss_history ={'train':[],'val':[]}

In [11]:
for epoch in range(config['TRAIN']['numepochs']):
    for mode in ['train','val']:
        loss_epoch=0
        count=0
        for i,(img1,img2,label) in enumerate(dataloaders[mode]):
            img1 = img1.to(device)
            img2 = img2.to(device)
            label = label.to(device)
            
            if mode=='train':
                model.train()
                
                optimizer.zero_grad()
                out = model(img1,img2)
                loss = criterion(out,label)
                loss.backward()
                optimizer.step()
                
            else:
                model.eval()
                with torch.no_grad():
                    out = model(img1,img2)
                    loss = criterion(out,label)
            
            # track total loss
            loss_epoch = loss_epoch+loss
            count = count + len(label)
        
        loss_history[mode].append(loss_epoch.item()/count)
    print('Epoch: {},\t train loss:{:.5},\t val loss:{:.5}'.format(epoch,loss_history['train'][-1],loss_history['val'][-1]))

            
            

Epoch: 0,	 train loss:2.8219,	 val loss:2.8589
Epoch: 1,	 train loss:2.8826,	 val loss:2.9691
Epoch: 2,	 train loss:2.8348,	 val loss:2.9444
Epoch: 3,	 train loss:2.9302,	 val loss:3.0141
Epoch: 4,	 train loss:2.7921,	 val loss:2.8622
Epoch: 5,	 train loss:2.7991,	 val loss:2.9275
Epoch: 6,	 train loss:2.8441,	 val loss:2.8127
Epoch: 7,	 train loss:2.767,	 val loss:2.8791
Epoch: 8,	 train loss:2.8818,	 val loss:2.968
Epoch: 9,	 train loss:2.8238,	 val loss:3.004
Epoch: 10,	 train loss:2.811,	 val loss:2.7666
Epoch: 11,	 train loss:2.8279,	 val loss:2.9714
Epoch: 12,	 train loss:2.8642,	 val loss:2.9444
Epoch: 13,	 train loss:2.8433,	 val loss:2.8724
Epoch: 14,	 train loss:2.8035,	 val loss:2.9151
Epoch: 15,	 train loss:2.8269,	 val loss:2.7599
Epoch: 16,	 train loss:2.9116,	 val loss:2.8083
Epoch: 17,	 train loss:2.8536,	 val loss:2.7711
Epoch: 18,	 train loss:2.9029,	 val loss:2.7565
Epoch: 19,	 train loss:2.7793,	 val loss:2.6429
Epoch: 20,	 train loss:2.9587,	 val loss:3.0411
Epoch:

KeyboardInterrupt: 

In [None]:
plt.figure()
plt.plot(loss_history['train'],'b-')
plt.plot(loss_history['val'],'r-')