
Connecting to Colab and Importing libaries

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [None]:
from os import walk
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
from PIL import Image
import cv2
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
import random
import pickle

Creating Dictionary with Keys as Classes

In [None]:
data_dir='/content/drive/My Drive/Colab Notebooks/WaDaBa/'
a_file = open(data_dir+"category.pkl", "rb")
categories = pickle.load(a_file)
#categories.pop(7)
categories.keys()

dict_keys([1, 5, 2, 6, 7])


Function that returns Images from the Dataset

In [None]:
class WadabaDataset(Dataset):
    def __init__(self, setSize, transform=None):
        self.transform = transform
        self.setSize = setSize
    def __len__(self):
        return self.setSize
    def __getitem__(self, idx):
        img1 = None
        img2 = None
        img3 = None
        category1 = random.choice([k for k in categories.keys()])
        anchor_img = random.choice(categories[category1])
        pos_img = random.choice(categories[category1])
        while anchor_img == pos_img:
            pos_img = random.choice(categories[category1])
        category2 = random.choice([k for k in categories.keys()])
        while category1 == category2:
            category2 = random.choice([k for k in categories.keys()])
        neg_img = random.choice(categories[category2])

        anchor_img = Image.open(data_dir + anchor_img)
        pos_img = Image.open(data_dir + pos_img)
        neg_img = Image.open(data_dir + neg_img)

        if self.transform:
            anchor_img = self.transform(anchor_img)
            pos_img = self.transform(pos_img)
            neg_img = self.transform(neg_img)
        return anchor_img, pos_img,neg_img

Function for N-way evaluation of the dataset

In [None]:
class NWayOneShotEvalSet(Dataset):
    def __init__(self, setSize,transform=None):
        self.setSize = setSize
        self.transform = transform
    def __len__(self):
        return self.setSize
    def __getitem__(self, idx):
        # find one main image
        category = random.choice([k for k in categories.keys()])
        imgName = random.choice(categories[category])
        mainImg = Image.open(data_dir + imgName)
        # print(imgDir + '/' + imgName)
        if self.transform:
            mainImg = self.transform(mainImg)
        
        # find n numbers of distinct images, 1 in the same set as the main
        testSet = []
        label = 0
        for i,j in enumerate([k for k in categories.keys()]):
            testImgName = ''
            if j == category:
              label = i
            testImgName = random.choice(categories[j])
            testImg = Image.open(data_dir + testImgName)
            if self.transform:
                testImg = self.transform(testImg)
            testSet.append(testImg)
        # plt.imshow()
        return mainImg, testSet, torch.from_numpy(np.array([label], dtype = int))

Custom Loss function for triplet loss

In [None]:
def loss_fn(anchor_emb,pos_emb,neg_emb,margin):
    #pos_dist = (anchor_emb - pos_emb).pow(2).sum(1)
    pos_dist = torch.pow((anchor_emb - pos_emb),2).sum(1)
    #neg_dist = (anchor_emb - neg_emb.pow(2)).sum(1)
    neg_dist = torch.pow((anchor_emb - neg_emb),2).sum(1)
    loss = torch.relu(pos_dist - neg_dist + margin)
    return loss.mean()

In [None]:
import torch.nn as nn
import torch.nn.functional as F


Creating Network

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # Koch et al.
        # Conv2d(input_channels, output_channels, kernel_size)
        self.conv1 = nn.Conv2d(3, 64, 10) 
        self.conv2 = nn.Conv2d(64, 128, 7)  
        self.conv3 = nn.Conv2d(128, 128, 4)
        self.conv4 = nn.Conv2d(128, 256, 4)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(256 * 6 * 6, 4096)
        self.fc2 = nn.Linear(4096,128)
        #self.fcOut = nn.Linear(4096, 1)
        self.sigmoid = nn.Sigmoid()
    
    def convs(self, x):

        # Koch et al.
        # out_dim = in_dim - kernel_size + 1  
        #1, 105, 105
        x = F.relu(self.bn1(self.conv1(x)))
        # 64, 96, 96
        x = F.max_pool2d(x, (2,2))
        # 64, 48, 48
        x = F.relu(self.bn2(self.conv2(x)))
        # 128, 42, 42
        x = F.max_pool2d(x, (2,2))
        # 128, 21, 21
        x = F.relu(self.bn3(self.conv3(x)))
        # 128, 18, 18
        x = F.max_pool2d(x, (2,2))
        # 128, 9, 9
        x = F.relu(self.bn4(self.conv4(x)))
        # 256, 6, 6
        return x

    def forward(self, x1):
        x1 = self.convs(x1)

        # Koch et al.
        x1 = x1.view(-1, 256 * 6 * 6)
        x1 = self.sigmoid(self.fc1(x1))
        x1 = self.sigmoid(self.fc2(x1))
        
        return x1


creating the network and couting the paramenters

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
Triplet_Baseline = Net()
Triplet_Baseline = Triplet_Baseline.to(device)

def count_parameters(model):
    temp = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'The model architecture:\n\n', model)
    print(f'\nThe model has {temp:,} trainable parameters')
    
count_parameters(Triplet_Baseline)

cuda:0
The model architecture:

 Net(
  (conv1): Conv2d(3, 64, kernel_size=(10, 10), stride=(1, 1))
  (conv2): Conv2d(64, 128, kernel_size=(7, 7), stride=(1, 1))
  (conv3): Conv2d(128, 128, kernel_size=(4, 4), stride=(1, 1))
  (conv4): Conv2d(128, 256, kernel_size=(4, 4), stride=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=128, bias=True)
  (sigmoid): Sigmoid()
)

The model has 39,486,016 trainable parameters



saving and loading checkpoint mechanisms

In [None]:
def save_checkpoint(save_path, model, optimizer, val_loss):
    if save_path==None:
        return
    save_path = save_path 
    state_dict = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'val_loss': val_loss}

    torch.save(state_dict, save_path)

    print(f'Model saved to ==> {save_path}')

def load_checkpoint(model, optimizer):
    save_path = data_dir + 'Weights/temp-TripletNet2-batchnorm50.pt'
    state_dict = None
    if torch.cuda.is_available():
        state_dict = torch.load(save_path)
    else:
        state_dict = torch.load(save_path,map_location=torch.device('cpu')) 
    model.load_state_dict(state_dict['model_state_dict'])
    optimizer.load_state_dict(state_dict['optimizer_state_dict'])
    val_loss = state_dict['val_loss']
    print(f'Model loaded from <== {save_path}')
    
    return val_loss

Initializing Train and validation sets

In [None]:
# choose a training dataset size and further divide it into train and validation set 80:20
dataSize = 4000 # self-defined dataset size
TRAIN_PCT = 0.8 # percentage of entire dataset for training
train_size = int(dataSize * TRAIN_PCT)
val_size = dataSize - train_size

transformations = transforms.Compose([
        transforms.Resize((105,105)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

wadabadataset = WadabaDataset(dataSize, transformations)
train_set, val_set = random_split(wadabadataset, [train_size, val_size])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, num_workers=16)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=20, num_workers=16, shuffle=True)

Training and Validation after every epoch

In [None]:
# training and validation after every epoch
import time

def train(model, train_loader, val_loader, num_epochs, save_name):
    best_val_loss = float("Inf") 
    #best_val_loss = 0.00254
    train_losses = []
    val_losses = []
    cur_step = 0
    for epoch in range(num_epochs):
        start_time = time.time()
        running_loss = 0.0
        model.train()
        print("Starting epoch " + str(epoch+1))
        for anchor_img, pos_img, neg_img in train_loader:
            
            # Forward
            anchor_img = anchor_img.to(device)
            pos_img = pos_img.to(device)
            neg_img = neg_img.to(device)
            anchor_emb = model(anchor_img)
            pos_emb = model(pos_img)
            neg_emb = model(neg_img)
            loss = loss_fn(anchor_emb, pos_emb, neg_emb,0.8)
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        val_running_loss = 0.0
        with torch.no_grad():
            model.eval()
            for anchor_img, pos_img, neg_img in val_loader:
                anchor_img = anchor_img.to(device)
                pos_img = pos_img.to(device)
                neg_img = neg_img.to(device)
                anchor_emb = model(anchor_img)
                pos_emb = model(pos_img)
                neg_emb = model(neg_img)
                loss = loss_fn(anchor_emb, pos_emb, neg_emb,0.8)
                val_running_loss += loss.item()
        avg_val_loss = val_running_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        print('Epoch [{}/{}],Train Loss: {:.4f}, Valid Loss: {:.8f}'
            .format(epoch+1, num_epochs, avg_train_loss, avg_val_loss))
        train_loss.append(avg_train_loss)
        validation_loss.append(avg_val_loss)
        print("Time taken for epoch = ",(time.time()-start_time))
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            save_checkpoint(save_name, model, optimizer, best_val_loss)
    
    print("Finished Training")  
    return train_losses, val_losses  

Training the network

In [None]:
import torch.optim as optim

num_epochs = 30
save_path = data_dir+'Weights/temp-TripletNet2-batchnorm50.pt'
#optimizer = optim.SGD(Triplet_Baseline.parameters(), lr=0.01, momentum=0.9)
#train_losses, val_losses = train(Triplet_Baseline, train_loader, val_loader, num_epochs, save_path)
load_model = Net().to(device)
optimizer = optim.SGD(load_model.parameters(), lr=0.001,momentum=0.9)
load_checkpoint(load_model,optimizer)
train_losses, val_losses = train(load_model, train_loader, val_loader, num_epochs,save_path)

Model loaded from <== /content/drive/My Drive/Colab Notebooks/WaDaBa/Weights/temp-TripletNet2-batchnorm50.pt
Starting epoch 1
Epoch [1/30],Train Loss: 0.0062, Valid Loss: 0.00267426
Time taken for epoch =  579.0906186103821
Model saved to ==> /content/drive/My Drive/Colab Notebooks/WaDaBa/Weights/temp-TripletNet2-batchnorm50.pt
Starting epoch 2
Epoch [2/30],Train Loss: 0.0085, Valid Loss: 0.00582511
Time taken for epoch =  543.7220079898834
Starting epoch 3
Epoch [3/30],Train Loss: 0.0080, Valid Loss: 0.00500970
Time taken for epoch =  536.4187712669373
Starting epoch 4
Epoch [4/30],Train Loss: 0.0107, Valid Loss: 0.02155743
Time taken for epoch =  534.584451675415
Starting epoch 5
Epoch [5/30],Train Loss: 0.0066, Valid Loss: 0.00427334
Time taken for epoch =  533.8446831703186
Starting epoch 6
Epoch [6/30],Train Loss: 0.0072, Valid Loss: 0.01187497
Time taken for epoch =  535.5806725025177
Starting epoch 7
Epoch [7/30],Train Loss: 0.0077, Valid Loss: 0.00648856
Time taken for epoch = 

Evaluating the Model

In [None]:
# evaluation metrics
def eval(model, test_loader):
    with torch.no_grad():
        model.eval()
        correct = 0
        print('Starting Iteration')
        count = 0
        #acc_category = {1:0,2:0,5:0,6:0,7:0}
        for mainImg, imgSets, label in test_loader:
            mainImg = mainImg.to(device)
            predVal = float('inf')
            pred = -1
            for i, testImg in enumerate(imgSets):
                testImg = testImg.to(device)
                output = torch.abs(model(mainImg) - model(testImg))
                output = torch.pow((model(mainImg) - model(testImg)),2).sum(1)
                if output < predVal:
                    pred = i
                    predVal = output
            #print(label)
            label = label.to(device)
            if pred == label:
                correct += 1
                #acc_category[category.numpy()[0]] += 1
            count += 1
            if count % 20 == 0:
                print("Current Count is: {}".format(count))
                print('Accuracy on n way: {}'.format(correct/count))

In [None]:
testSize = 200
test_set = NWayOneShotEvalSet(testSize,transformations)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = 1, num_workers = 2, shuffle=True)

In [None]:
import torch.optim as optim
load_model = Net().to(device)
load_optimizer = optim.SGD(load_model.parameters(), lr=0.0005)


#num_epochs = 10
#eval_every = 1000
#total_step = len(train_loader)*num_epochs
best_val_loss = load_checkpoint(load_model, load_optimizer)

print(best_val_loss)
eval(load_model, test_loader)