In [16]:
## Imports ##
#------------------------------------------------#
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.autograd as autograd
#------------------------------------------------#
import matplotlib.pyplot as plt
%matplotlib inline
from scipy import ndimage, misc
import numpy as np
from time import time
import math
import os
import sys
from tqdm import tqdm_notebook as tqdm
#------------------------------------------------#
from data import Dataset
#------------------------------------------------#

In [50]:
## Hyperparamters ##
batch_size  = 256			# Number of images to load at a time
epochs 		= 100			# Number of interations throough training data
train_gpu 	= True
lr 		    = 1e-4
train 		= __name__ == "__main__"
## Setting up ##
torch.manual_seed(1)
use_cuda = torch.cuda.is_available() and train_gpu
device = torch.device("cuda" if use_cuda else "cpu")
print('Device mode: ', device)

Device mode:  cuda


In [51]:
# ==================Definition Start======================
class SiameseNetwork(nn.Module):
	def __init__(self):
		super(SiameseNetwork, self).__init__()
		self.cnn1 = nn.Sequential(
			nn.Conv2d(4, 8, kernel_size=3, padding=1),
			nn.ReLU(inplace=True),
			nn.BatchNorm2d(8),
			
			nn.Conv2d(8, 8, kernel_size=3, padding=1),
			nn.ReLU(inplace=True),
			nn.BatchNorm2d(8),


			nn.Conv2d(8, 8, kernel_size=3, padding=1),
			nn.ReLU(inplace=True),
			nn.BatchNorm2d(8),
		)
		self.fc1 = nn.Sequential(
			nn.Linear(2*8*15*15, 500),
			nn.LeakyReLU(inplace=True),

			nn.Linear(500, 250),
			nn.LeakyReLU(inplace=True),

			nn.Linear(250, 250),
			nn.LeakyReLU(inplace=True),

			nn.Linear(250, 1),
			nn.Sigmoid())

	def forward_once(self, x):
		output = self.cnn1(x)
		return output

	def forward(self, input1, input2):
		output1 = self.forward_once(input1)
		output2 = self.forward_once(input2)
		output1 = output1.view(output1.size()[0], -1)
		output2 = output2.view(output2.size()[0], -1)
		output = torch.cat([output1, output2], dim=1)
		output = self.fc1(output)
		return output
class ContrastiveLoss(torch.nn.Module):
	def __init__(self, margin=2.0):
		super(ContrastiveLoss, self).__init__()
		self.margin = margin

	def forward(self, output1, output2, label):
		euclidean_distance = F.pairwise_distance(output1, output2)
		loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
									  (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
		return loss_contrastive

In [52]:
if train == True:
    ## Training Data ##
    trainset = Dataset(train=True, split_size=.8)
    trainloader = DataLoader(dataset=trainset, shuffle=True, batch_size=batch_size)

    model = SiameseNetwork().to(device)
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.5, 0.9))
    criterion = nn.BCELoss()

    for epoch in range(epochs):
        avg_loss = []
        for i, (a, b, c) in tqdm(enumerate(trainloader)):
            ap = torch.FloatTensor(ndimage.sobel(a[:,:,:,-1])).unsqueeze(3).to(device).type(torch.float32)
            bp = torch.FloatTensor(ndimage.sobel(b[:,:,:,-1])).unsqueeze(3).to(device).type(torch.float32)
            a, b, c = (a.permute(0,3,1,2).to(device).type(torch.float32), 
                       b.permute(0,3,1,2).to(device).type(torch.float32), c.to(device).type(torch.float)) #for contrastive loss it is 0 for matching pairs	
            model.zero_grad()
            ap, bp = ap.permute(0,3,1,2), bp.permute(0,3,1,2)
            a = torch.cat([a[:,:3,:,:], ap], dim=1)
            b = torch.cat([b[:,:3,:,:], bp], dim=1)
            pred_c = model(a, b)
            loss = criterion(pred_c, c.unsqueeze(1))
            loss.backward()
            optimizer.step()
            avg_loss.append(loss.item())
        print('epoch', epoch, "loss", np.mean(np.array(avg_loss)))

        torch.save(model, 'model_isola_graddepth')

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 0 loss 0.691396176815033


  "type " + obj.__name__ + ". It won't be checked "


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 1 loss 0.689012086391449


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 2 loss 0.6812306642532349


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 3 loss 0.667948305606842


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 4 loss 0.6505082130432129


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 5 loss 0.6358033180236816


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 6 loss 0.6312944889068604


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 7 loss 0.6091789603233337


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 8 loss 0.6204028844833374


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 9 loss 0.5938253164291382


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 10 loss 0.6093579292297363


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 11 loss 0.5908881664276123


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 12 loss 0.5967685580253601


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 13 loss 0.5848298668861389


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 14 loss 0.6197853446006775


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 15 loss 0.5984126091003418


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 16 loss 0.5832797408103942


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 17 loss 0.5878599882125854


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 18 loss 0.570594847202301


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 19 loss 0.5725236535072327


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 20 loss 0.578019642829895


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 21 loss 0.572899603843689


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 22 loss 0.5616136908531189


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 23 loss 0.5748984694480896


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 24 loss 0.5557126522064209


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 25 loss 0.5471686720848083


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 26 loss 0.5670375227928162


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 27 loss 0.5826513767242432


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 28 loss 0.5548216700553894


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 29 loss 0.5556536555290222


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 30 loss 0.5511350452899932


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 31 loss 0.5841058373451233


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 32 loss 0.585324227809906


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 33 loss 0.5423151254653931


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 34 loss 0.5374227106571198


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 35 loss 0.5490811586380004


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 36 loss 0.5514900803565979


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 37 loss 0.5486272692680358


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 38 loss 0.536355197429657


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 39 loss 0.5390798091888428


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 40 loss 0.5306796848773956


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 41 loss 0.5467957854270935


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 42 loss 0.5432241201400757


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 43 loss 0.5430342078208923


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 44 loss 0.5391807198524475


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 45 loss 0.540646243095398


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 46 loss 0.5394632458686829


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 47 loss 0.5445841073989868


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 48 loss 0.5332911849021912


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 49 loss 0.536809754371643


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 50 loss 0.5251121759414673


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 51 loss 0.53048255443573


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 52 loss 0.5125130712985992


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 53 loss 0.532764196395874


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 54 loss 0.5205612182617188


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 55 loss 0.5389031648635865


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 56 loss 0.5136647760868073


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 57 loss 0.5352842926979064


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 58 loss 0.5250120222568512


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 59 loss 0.5194798827171325


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 60 loss 0.5156126320362091


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 61 loss 0.5220927000045776


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 62 loss 0.5189361572265625


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 63 loss 0.5009596049785614


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 64 loss 0.5109074950218201


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 65 loss 0.4991436779499054


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 66 loss 0.4893259644508362


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 67 loss 0.5247340083122254


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 68 loss 0.5280372977256775


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 69 loss 0.5094399094581604


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 70 loss 0.48632999658584597


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 71 loss 0.519056636095047


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 72 loss 0.5046375632286072


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 73 loss 0.527327299118042


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 74 loss 0.5175553798675537


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 75 loss 0.5125125765800476


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 76 loss 0.4977217197418213


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 77 loss 0.5145191431045533


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 78 loss 0.5244516253471374


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 79 loss 0.49879767894744875


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 80 loss 0.5183002591133118


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 81 loss 0.5125202596187591


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 82 loss 0.5353392124176025


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 83 loss 0.5246192216873169


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 84 loss 0.5226037502288818


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 85 loss 0.5076332211494445


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 86 loss 0.5210785567760468


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 87 loss 0.5247372150421142


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 88 loss 0.5009110510349274


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 89 loss 0.5120339930057526


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 90 loss 0.5357869744300843


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 91 loss 0.5069378793239594


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 92 loss 0.5101182043552399


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 93 loss 0.5188575387001038


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 94 loss 0.512148916721344


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 95 loss 0.515039587020874


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 96 loss 0.5075254678726197


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 97 loss 0.5023817300796509


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 98 loss 0.5025190472602844


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch 99 loss 0.5010765790939331
