In [1]:
# siamese network for image similarity
# https://hackernoon.com/facial-similarity-with-siamese-networks-in-pytorch-9642aa9db2f7
# dense 5, divide image by 255

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import pandas as pd
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode

In [3]:
print(torch.__version__)
print(torchvision.__version__)
print(np.__version__)
print(pd.__version__)

0.4.1.post2
0.2.1
1.14.3
0.22.0


In [4]:
print(os.listdir('../data'))

['test.zip', 'train.zip', 'train_combined.csv', 'train_combined', 'cropping.txt', 'rotate.txt', 'train_play.zip', 'sample_submission.csv', 'train_play', 'test', 'train', 'train_play.csv', 'train.csv', 'exclude.txt']


In [5]:
from torchsummary import summary

In [6]:
model = models.resnet18(pretrained=True)
model = model.cuda()
summary(model, (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [7]:
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [8]:
list(model.children())[:-2]

[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace),
 MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
 Sequential(
   (0): BasicBlock(
     (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
   (1): BasicBlock(
     (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace)
     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bi

In [9]:
class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=None):
        super().__init__()
        sz = sz or (1,1)
        self.ap = nn.AdaptiveAvgPool2d(sz)
        self.mp = nn.AdaptiveMaxPool2d(sz)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

class Lambda(nn.Module):
    def __init__(self, f): super().__init__(); self.f=f
    def forward(self, x): return self.f(x)

class Flatten(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, x): return x.view(x.size(0), -1)


In [10]:
# model = models.resnet18(pretrained=True)
# features = list(model.children())[:-2]
# features += [AdaptiveConcatPool2d(), Flatten()]
# features += [nn.BatchNorm1d(num_features=1024), nn.Dropout(p=0.25)]
# features += [nn.Linear(in_features=1024, out_features=512, bias=True), nn.ReLU()]
# features += [nn.BatchNorm1d(num_features=512), nn.Dropout(p=0.25)]
# features += [nn.Linear(in_features=512, out_features=32, bias=True), nn.ReLU()]

In [11]:
# print(features)

In [12]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        model = models.resnet18(pretrained=True)
        features = list(model.children())[:-2]
        features += [AdaptiveConcatPool2d(), Flatten()]
        self.cnn1 = nn.Sequential(*features)
        self.fc1 = nn.Sequential(
                nn.BatchNorm1d(num_features=1024), 
                nn.Dropout(p=0.25),
                nn.Linear(in_features=1024, out_features=512, bias=True), 
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(num_features=512), 
                nn.Dropout(p=0.25),
                nn.Linear(in_features=512, out_features=5, bias=True)
            )
    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

In [13]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
#         print(output1, output2, label, euclidean_distance)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

In [14]:
exclude = pd.read_csv('../data/exclude.txt', header=None)
exclude.columns = ['excl']

In [15]:
exclude = list(exclude.excl.values)
exclude

['0b1e39ff.jpg',
 '0c11fa0c.jpg',
 '1b089ea6.jpg',
 '2a2ecd4b.jpg',
 '2c824757.jpg',
 '3e550c8a.jpg',
 '56893b19.jpg',
 '613539b4.jpg',
 '6530809b.jpg',
 '6b753246.jpg',
 '6b9f5632.jpg',
 '75c94986.jpg',
 '7f048f21.jpg',
 '7f7702dc.jpg',
 '806cf583.jpg',
 '95226283.jpg',
 'a3e9070d.jpg',
 'ade8176b.jpg',
 'b1cfda8a.jpg',
 'b24c8170.jpg',
 'b7ea8be4.jpg',
 'b9315c19.jpg',
 'b985ae1e.jpg',
 'baf56258.jpg',
 'c4ad67d8.jpg',
 'c5da34e7.jpg',
 'c5e3df74.jpg',
 'ced4a25c.jpg',
 'd14f0126.jpg',
 'e0b00a14.jpg',
 'e6ce415f.jpg',
 'e9bd2e9c.jpg',
 'f4063698.jpg',
 'f9ba7040.jpg']

In [16]:
rotate = pd.read_csv('../data/rotate.txt', header=None)
rotate.columns = ['rot']
rotate = list(rotate.rot.values)
rotate

['0a5216ef5.jpg',
 '0afe8c93a.jpg',
 '0b1e39ff.jpg',
 '0b39dab59.jpg',
 '0ca67fe7.jpg',
 '0dee817aa.jpg',
 '0df714e16.jpg',
 '2b792814.jpg',
 '2bc459eb.jpg',
 '3401bafe.jpg',
 '56fafc52.jpg',
 'a492ab72.jpg',
 'd1502267.jpg',
 'e53d2b96.jpg',
 'ed4f0cd5.jpg',
 'f2ec136c.jpg',
 'f966c073.jpg']

In [17]:
import pickle
if os.path.isfile('../cache/p2h.pickle'):
    with open('../cache/p2h.pickle', 'rb') as f:
        p2h = pickle.load(f)

In [18]:
tagged = dict([(p,w) for _,p,w in pd.read_csv('../data/train_combined.csv').to_records()])

In [19]:
# Find all the whales associated with an image id. It can be ambiguous as duplicate images may have different whale ids.
h2ws = {}
new_whale = 'new_whale'
for p,w in tagged.items():
    if w != new_whale: # Use only identified whales
        h = p2h[p]
        if h not in h2ws: h2ws[h] = []
        if w not in h2ws[h]: h2ws[h].append(w)
for h,ws in h2ws.items():
    if len(ws) > 1:
        h2ws[h] = sorted(ws)
len(h2ws)


19553

In [20]:
# For each image id, determine the list of pictures
h2ps = {}
for p,h in p2h.items():
    if h not in h2ps: h2ps[h] = []
    if p not in h2ps[h]: h2ps[h].append(p)
# Notice how 33321 images use only 33317 distinct image ids.
len(h2ps),list(h2ps.items())[:5]

(35790,
 [('d26698c3271c757c', ['0000e88ab.jpg']),
  ('ba8cc231ad489b77', ['0001f9222.jpg']),
  ('bbcad234a52d0f0b', ['00029d126.jpg', '28142a91.jpg']),
  ('c09ae7dc09f33a29', ['00050a15a.jpg']),
  ('d02f65ba9f74a08a', ['0005c1ef8.jpg'])])

In [21]:
submit = [p for _,p,_ in pd.read_csv('../data/sample_submission.csv').to_records()]
join   = list(tagged.keys()) + submit
len(tagged),len(submit),len(join),list(tagged.items())[:5],submit[:5]

(35211,
 7960,
 43171,
 [('0000e88ab.jpg', 'w_f48451c'),
  ('0001f9222.jpg', 'w_c3d896a'),
  ('00029d126.jpg', 'w_20df2c5'),
  ('00050a15a.jpg', 'new_whale'),
  ('0005c1ef8.jpg', 'new_whale')],
 ['00028a005.jpg',
  '000dcf7d8.jpg',
  '000e7c7df.jpg',
  '0019c34f4.jpg',
  '001a4d292.jpg'])

In [22]:
def expand_path(p):
    if os.path.isfile('../data/train_combined/' + p): return '../data/train_combined/' + p
    if os.path.isfile('../data/test/' + p): return '../data/test/' + p
    return p

In [23]:

from tqdm import tqdm_notebook
from PIL import Image as pil_image
# p2size = {}
# for p in tqdm_notebook(join):
# #     print(p)
#     size      = pil_image.open(expand_path(p)).size
#     p2size[p] = size
# len(p2size), list(p2size.items())[:5]

In [24]:
# with open('../cache/p2size.pickle', 'wb') as f:
#     pickle.dump(p2size, f)

In [25]:
with open('../cache/p2size.pickle', 'rb') as f:
    p2size = pickle.load(f)
len(p2size), list(p2size.items())[:5]

(43171,
 [('0000e88ab.jpg', (1050, 700)),
  ('0001f9222.jpg', (758, 325)),
  ('00029d126.jpg', (1050, 497)),
  ('00050a15a.jpg', (1050, 525)),
  ('0005c1ef8.jpg', (1050, 525))])

In [26]:
# For each images id, select the prefered image
def prefer(ps):
    if len(ps) == 1: return ps[0]
    best_p = ps[0]
    best_s = p2size[best_p]
    for i in range(1, len(ps)):
        p = ps[i]
        s = p2size[p]
        if s[0]*s[1] > best_s[0]*best_s[1]: # Select the image with highest resolution
            best_p = p
            best_s = s
    return best_p

h2p = {}
for h,ps in h2ps.items(): h2p[h] = prefer(ps)
len(h2p),list(h2p.items())[:5]

(35790,
 [('d26698c3271c757c', '0000e88ab.jpg'),
  ('ba8cc231ad489b77', '0001f9222.jpg'),
  ('bbcad234a52d0f0b', '00029d126.jpg'),
  ('c09ae7dc09f33a29', '00050a15a.jpg'),
  ('d02f65ba9f74a08a', '0005c1ef8.jpg')])

In [27]:
# For each whale, find the unambiguous images ids.
w2hs = {}
for h,ws in h2ws.items():
    if len(ws) == 1: # Use only unambiguous pictures
        if h2p[h] in exclude:
            print(h) # Skip excluded images
        else:
            w = ws[0]
            if w not in w2hs: w2hs[w] = []
            if h not in w2hs[w]: w2hs[w].append(h)
for w,hs in w2hs.items():
    if len(hs) > 1:
        w2hs[w] = sorted(hs)
len(w2hs)

ebf094854a2bb1d6
f86bcf9487653848
807b19b6766d09ce
bc984f67a31b48a6
c5313ec3c0343bcf
9dc39bb4833cc3c8
afa994d4416b2ab6
c0352f2194b7fcca
c4cc196f46bc8cce
c0c0753e9fcf4368
d8dc91b13fae8a18
f3ad8c8cb2b38c8c
f18c966fb836c90c
f8908e4ee223f758
a2d5d5eae64f2a01
96c949b632e90d3e
e8d1960f60b13fca
d08f9d61729e8d61
90376cc843f6b9a3
e92d90d2616f0f9c
c96d1296e96b16b4
c7882c3359ec7327
f43183739a8f53c4


6818

In [28]:
w2hs

{'w_f48451c': ['96b1699e46397c92',
  '96d9ac9bc62dc1c2',
  'a2ddce3324c8bc0f',
  'afcbd0302fcd9033',
  'ba9899aa9a597a26',
  'c1363f8d27b8986c',
  'c79a69c31e638c3c',
  'd0be5cb26790c83d',
  'd26698c3271c757c'],
 'w_c3d896a': ['b33cced372343131',
  'ba8cc231ad489b77',
  'bacac4b53d484a73',
  'ead4952568da7634'],
 'w_20df2c5': ['bbcad234a52d0f0b', 'fa17c07078391e97'],
 'w_581ba42': ['a99e87a1789c9743',
  'aac791788dd32791',
  'aecad035812c8dfb',
  'af91f02f79184f82',
  'b848cbf23625f929',
  'bad48020fd2dde36',
  'bbced2b4ad300c8b',
  'be92c161785896e7',
  'da85e07a1f8760f1',
  'db85a07a5f85a07a',
  'ea178560bc6d598e',
  'ea9e8160fc6d5896',
  'ea9e85e0781f0f4a'],
 'w_cb622a2': ['b038cec76361599e', 'e7949266b148dcd3', 'eb0b94f0290edee1'],
 'w_8cad422': ['acc9c3b63c4dcb12',
  'adc2d03d0b52f635',
  'e9ded220dc0fc3c2',
  'ec1b85f030ec1d4f'],
 'w_3de579a': ['809be6111b6cf2af',
  '81702ccf93f4b17c',
  '836c7cc75c0fa338',
  '86d965f892a439bc',
  '8f50a572a5b849af',
  '922669591ec7f171',
  '9315

In [29]:
sum([1 if i != 'new_whale' else 0 for i in list(tagged.values())])

24737

In [30]:
sum([1 if i != 'new_whale' else 0 for i in set(tagged.values())])

9254

In [31]:
sum([1,0,1,0,1])

3

In [32]:
# Find the list of training images, keep only whales with at least two images.
train = [] # A list of training image ids
for hs in w2hs.values():
    if len(hs) > 1:
        train += hs
np.random.shuffle(train)
train_set = set(train)

w2ts = {} # Associate the image ids from train to each whale id.
for w,hs in w2hs.items():
    for h in hs:
        if h in train_set:
            if w not in w2ts: w2ts[w] = []
            if h not in w2ts[w]: w2ts[w].append(h)
for w,ts in w2ts.items(): w2ts[w] = np.array(ts)
    
t2i = {} # The position in train of each training image id
for i,t in enumerate(train): t2i[t] = i

len(train),len(w2ts)

(10795, 2645)

In [33]:
w2ts

{'w_f48451c': array(['96b1699e46397c92', '96d9ac9bc62dc1c2', 'a2ddce3324c8bc0f',
        'afcbd0302fcd9033', 'ba9899aa9a597a26', 'c1363f8d27b8986c',
        'c79a69c31e638c3c', 'd0be5cb26790c83d', 'd26698c3271c757c'],
       dtype='<U16'),
 'w_c3d896a': array(['b33cced372343131', 'ba8cc231ad489b77', 'bacac4b53d484a73',
        'ead4952568da7634'], dtype='<U16'),
 'w_20df2c5': array(['bbcad234a52d0f0b', 'fa17c07078391e97'], dtype='<U16'),
 'w_581ba42': array(['a99e87a1789c9743', 'aac791788dd32791', 'aecad035812c8dfb',
        'af91f02f79184f82', 'b848cbf23625f929', 'bad48020fd2dde36',
        'bbced2b4ad300c8b', 'be92c161785896e7', 'da85e07a1f8760f1',
        'db85a07a5f85a07a', 'ea178560bc6d598e', 'ea9e8160fc6d5896',
        'ea9e85e0781f0f4a'], dtype='<U16'),
 'w_cb622a2': array(['b038cec76361599e', 'e7949266b148dcd3', 'eb0b94f0290edee1'],
       dtype='<U16'),
 'w_8cad422': array(['acc9c3b63c4dcb12', 'adc2d03d0b52f635', 'e9ded220dc0fc3c2',
        'ec1b85f030ec1d4f'], dtype='<U16'),


In [34]:
h2ps['96b1699e46397c92']

['6f7abb1be.jpg']

In [35]:
h2ps

{'d26698c3271c757c': ['0000e88ab.jpg'],
 'ba8cc231ad489b77': ['0001f9222.jpg'],
 'bbcad234a52d0f0b': ['00029d126.jpg', '28142a91.jpg'],
 'c09ae7dc09f33a29': ['00050a15a.jpg'],
 'd02f65ba9f74a08a': ['0005c1ef8.jpg'],
 'be0fc2b0e6b4a98c': ['0006e997e.jpg'],
 'e91b8cc2f270723d': ['000a6daec.jpg', 'b35cb9b8.jpg'],
 'ab9ab5e1701e2d19': ['000f0f2bf.jpg'],
 'e99a96243f89711e': ['0016b897a.jpg', '78d229fb.jpg'],
 'a508dbe7709d0d47': ['001c1ac5f.jpg', '483f0c58.jpg'],
 'aecad035812c8dfb': ['001cae55b.jpg'],
 'f289d1a6790ac7a9': ['001d7450c.jpg'],
 'f6d99ad60625e521': ['00200e115.jpg'],
 'ba9ad5236a9d2c0e': ['00245a598.jpg'],
 'afd09425ca1ab17e': ['002b4615d.jpg', '7d8b6794.jpg'],
 'f6ccc99336a4c50d': ['002f99f01.jpg'],
 'e7949266b148dcd3': ['00355ff28.jpg'],
 'a6194952d9d67c27': ['00357e37a.jpg', '47df88a5.jpg'],
 'fb1e81e0b665b8a1': ['003795857.jpg'],
 'b136cc4fbf3e6080': ['0041880bf.jpg'],
 'a71c5d32c13c1f53': ['0043da555.jpg'],
 'e9ded220dc0fc3c2': ['00442c882.jpg'],
 'efc8d0250cdaf721': ['0

In [36]:
w2ts_list = list(w2ts.keys())

In [37]:
len(w2ts_list), w2ts_list

(2645,
 ['w_f48451c',
  'w_c3d896a',
  'w_20df2c5',
  'w_581ba42',
  'w_cb622a2',
  'w_8cad422',
  'w_3de579a',
  'w_1d0830e',
  'w_8dddbee',
  'w_2365d55',
  'w_9c506f6',
  'w_c0d11da',
  'w_3881f28',
  'w_cee684e',
  'w_41d24c6',
  'w_8a235b6',
  'w_2e231f4',
  'w_6822dbc',
  'w_df86a42',
  'w_700ebb4',
  'w_d892cd9',
  'w_bc285a6',
  'w_f3252ff',
  'w_c7cda47',
  'w_c6c89db',
  'w_1531bf5',
  'w_5650932',
  'w_dd944b7',
  'w_6f0cbe3',
  'w_6cfa650',
  'w_6e209a8',
  'w_914b110',
  'w_0369a5c',
  'w_3ff114c',
  'w_a9304b9',
  'w_b44e89d',
  'w_b7e3a9f',
  'w_691b684',
  'w_4516d31',
  'w_396c12b',
  'w_03270e3',
  'w_fbc7895',
  'w_8a71ca4',
  'w_a8c2847',
  'w_9739508',
  'w_ac97190',
  'w_ddb62c2',
  'w_5a9abcf',
  'w_40a8585',
  'w_f74a89e',
  'w_688590b',
  'w_77150d0',
  'w_08630fd',
  'w_89f521e',
  'w_0ee9ae5',
  'w_c8bbb43',
  'w_191bce3',
  'w_f793b1f',
  'w_e6450bc',
  'w_4846b27',
  'w_3e4b155',
  'w_7e56d66',
  'w_dd88965',
  'w_91b96c1',
  'w_d9a3c6e',
  'w_d0bfef3',
  '

In [38]:
train

['af5f91e42e1a1ac4',
 'a1c5df1e4cf180a7',
 'fa0f8370a530c97e',
 'bf8e8070685e97a3',
 'f248d8bfc3b2a5a0',
 'ce83a03cd93e635c',
 'dbc8c6b5adb68141',
 'f0dec3a038cd9336',
 'a7d9d2340dabc9c2',
 'aac7c0b5e9385276',
 '879998a6e5493a9e',
 'a01bde242d63f3d8',
 'ba9f85605069d8f6',
 'e9959662f11cfc22',
 'f30cd0b18e523ecd',
 'e89b87607a17e19a',
 'ea9dd5327887a158',
 'b68dcae23c4e83b1',
 'bcc3c33409d837c7',
 'ea8d82e27d99836a',
 'fdcfc0340f93e0a8',
 'ed9d9687e36018d8',
 'eb1a81ea96e7e0a4',
 'bbc4deb1a1384c1b',
 'eb9395e472354b0a',
 'a7cd207ec7a1384b',
 'abc6f02e87b06c39',
 'fd9f87228069da62',
 'e6dccbc36193c322',
 'bf4ec0b035095ade',
 'eb9d926078b4599c',
 'ed9ad2a663380cc7',
 'fa8d85706e9fc121',
 'bcc9d81ec3b226b1',
 'ab9ec2a1f42c0f9a',
 'd4a1dab1c49f6352',
 'b96ccfd3b0342529',
 'fb9685781e8748a5',
 'b038cec76361599e',
 'b4ed8b12b62c4db1',
 'ea96843c0fc1f23e',
 '9a83e4789eb725e0',
 'ea9e8562f13d9e02',
 'b11ccff324c83c0f',
 'b38c8cb38cce715c',
 'e81d96e2b4a8e9f0',
 'bf22c2b807f8390f',
 'bfc0d0b5285

In [39]:
import random, cv2
class SiameseNetworkDataset(Dataset):
    def __init__(self, h2ps, w2hs, train, transform = None):
        self.h2ps = h2ps
        self.w2hs = w2hs
        self.train = train
        self.transform = transform
    def __getitem__(self,index):
        h = self.train[index]
        r = index
        p = self.h2ps[h]
        should_get_same_class = random.randint(0,1)
        img0=0
        img1=0
#         print(should_get_same_class)
        if should_get_same_class:
            img0 = np.random.choice(p, 1)
            img1 = np.random.choice(p, 1)
            while img0 == img1:
                img1 = np.random.choice(p, 1)
#             print(imgs[0], imgs[1])
            img0 = pil_image.open(expand_path(img0))
            img1 = pil_image.open(expand_path(img1))
        else:
            while True:
                r = random.randint(0, len(self.train)-1)
                if r != index: break
            h1 = self.train[r]
            p1 = self.h2ps[h1]
            img0 = np.random.choice(p, 1)
#             print(imgs[0])
            img0 = pil_image.open(expand_path(img0))
            img1 = np.random.choice(p1, 1)
#             print(imgs[0])
            img1 = pil_image.open(expand_path(img1))
        img0=img0.convert('L')
        img1=img1.convert('L')
        
        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
            
        img0 = np.asarray(img0).reshape(100,100)/255.
        img1 = np.asarray(img1).reshape(100,100)/255.
#         print(img0.shape)
        img0 = np.stack([img0, img0, img0], axis=0)
        img1 = np.stack([img1, img1, img1], axis=0)
        
        
        return img0, img1, torch.from_numpy(np.array([1. if r==index else 0.], dtype=np.float32))
    def __len__(self):
        return len(train)

In [40]:
siamese_dataset = SiameseNetworkDataset(
    h2ps,w2hs,train,
    transform=transforms.Compose(
        [transforms.Resize((100,100)),
         transforms.ToTensor()
        ])
)

train_dataloader = DataLoader(siamese_dataset,
                        shuffle=True,
                        num_workers=8,
                        batch_size=32)


In [41]:
def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

In [42]:
print(len(train))

10795


In [43]:
10795//32

337

In [44]:
net = SiameseNetwork().cuda()
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.005 )

counter = []
loss_history = [] 
iteration_number= 0

for epoch in range(2000):
    for i, data in enumerate(train_dataloader, 0):
        img0, img1 , label = data
#         print(img0.shape, label)
        img0, img1 , label = Variable(img0).cuda(), Variable(img1).cuda() , Variable(label).cuda()
        output1,output2 = net(img0,img1)
        optimizer.zero_grad()
        loss_contrastive = criterion(output1,output2,label)
        loss_contrastive.backward()
        optimizer.step()
        if (i != 0) & (i % 330 == 0) :
            print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.data[0]))
            iteration_number +=330
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.data[0])
show_plot(counter,loss_history)



Epoch number 0
 Current loss 1.0994707345962524

Epoch number 1
 Current loss 1.1632108688354492

Epoch number 2
 Current loss 1.1189537048339844

Epoch number 3
 Current loss 1.1624059677124023

Epoch number 4
 Current loss 1.0727860927581787

Epoch number 5
 Current loss 1.065853238105774

Epoch number 6
 Current loss 1.1064931154251099

Epoch number 7
 Current loss 1.1335704326629639

Epoch number 8
 Current loss 1.0442146062850952

Epoch number 9
 Current loss 1.1505390405654907

Epoch number 10
 Current loss 1.0315697193145752



Process Process-92:
Process Process-95:
Process Process-91:
Process Process-89:
Process Process-94:
Process Process-90:
Process Process-96:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/multiprocessing/process.py", line 93, in 

Traceback (most recent call last):
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-44-96c244bcb4b3>", line 10, in <module>
    for i, data in enumerate(train_dataloader, 0):
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 330, in __next__
    idx, batch = self._get_batch()
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 309, in _get_batch
    return self.data_queue.get()
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/multiprocessing/co

  File "<ipython-input-39-fbd8a929ae10>", line 34, in __getitem__
    img1=img1.convert('L')


KeyboardInterrupt: 

  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/site-packages/PIL/ImageFile.py", line 236, in load
    n, err_code = decoder.decode(b)
KeyboardInterrupt
KeyboardInterrupt
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/site-packages/PIL/Image.py", line 877, in convert
    self.load()
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/site-packages/PIL/ImageFile.py", line 236, in load
    n, err_code = decoder.decode(b)
KeyboardInterrupt
  File "/home/watts/anaconda3/envs/wcat/lib/python3.6/site-packages/PIL/ImageFile.py", line 236, in load
    n, err_code = decoder.decode(b)
KeyboardInterrupt
KeyboardInterrupt
