In [7]:
import torch 
from torchvision import datasets, transforms
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

import os
import random
from PIL import Image
import numpy as np
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt

In [8]:
class MNISTpairs(data.Dataset):
    """
    Load the MNIST dataset in pairs of similar(positive)/non-similar(negative) pairs
    """

    def __init__(self, mnist_dataset):    
        self.mnist_dataset = mnist_dataset
        self.transform = self.mnist_dataset.transform
        self.labels = self.mnist_dataset.train_labels
        self.data = self.mnist_dataset.train_data                
        # indices of images for each class
        self.class_idx = [np.where(self.labels==x)[0] for x in range(0,10)]
    
    def __getitem__(self, index):
        
        # anchor image
        img1, label1 = self.data[index], self.labels[index].item()
        # draw another positive (1) or negative (0) image
        pair_label = np.random.randint(0, 2)
        
        if pair_label == 1:
            # choose an image with the same label as the anchor - avoid itself
            index2 = index
            while index2 == index:
                index2 = np.random.choice(self.class_idx[label1])
            img2 = self.data[index2]
        else:
            # choose an image with the different label than the anchor 
            img2 = self.data[np.random.choice(self.class_idx[ np.random.choice(np.setdiff1d(range(0,10), label1))])]
            
        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        
        return (img1, img2), pair_label
    
    def __len__(self):
        return len(self.data)
        
    
# mnist dataset structure - test part
mnist_dataset_test = datasets.MNIST('vs3ex1data/mnist_data', train=False, transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
               ]))
    
# mnist dataset structure - train part
mnist_dataset_train = datasets.MNIST('vs3ex1data/mnist_data', train=True, transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
               ]))
            
# mnist dataset in positive/negative pairs structure 
mnist_dataset_train_pairs = MNISTpairs(mnist_dataset_train)



In [9]:
def l2n(x, eps=1e-6):
    """
    Vector L2 normalization 
    """
    return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x)

class MnistNetEmb(nn.Module):
    """
    Liteweight network architecture for the Mnist dataset (digit) to extract descriptors/embeddings
    """

    def __init__(self):
        super(MnistNetEmb, self).__init__()

        # fully convolutional part
        self.features = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(4, 4, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(inplace=True)            
        )
        
        # embedding network, FC layers
        self.embedder = nn.Sequential(
            nn.Linear(16*4,16)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.embedder(x.view(-1,x.size(-3)*x.size(-2)*x.size(-1)))
        return l2n(x)
    

In [10]:
def train2(model, train_loader, optimizer, margin = 0.9):
    """
    Training of an epoch with Contrastive loss and training pairs
    model: network
    train_loader: train_loader loading pairs of positive/negative images and pair-label in batches. 
    optimizer: optimizer to use in the training
    margin: loss margin
    """

    model.train()
    all_neg_dist = torch.Tensor()
    all_pos_dist = torch.Tensor()
    total_loss = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        
        # extract descriptor for anchor and the corresponding pos/neg images
        v1, v2 = model(data[0]), model(data[1])
        
        # compute the contrastive loss
        distances = (v2 - v1).pow(2).sum(1).sqrt()        
        loss = 0.5 * (target.float() * distances.pow(2) + (1-target).float()*F.relu(margin - distances).pow(2))
        
        loss.sum().backward()
        optimizer.step()

        all_neg_dist = torch.cat((all_neg_dist, distances[torch.nonzero(1-target)].view(-1))) # for statistics
        all_pos_dist = torch.cat((all_pos_dist, distances[torch.nonzero(target)].view(-1))) # for statistics
        total_loss = total_loss + loss.sum().data.numpy()
        
        if batch_idx % 100 == 0:
            print('[{}/{} ({:.0f}%)]\tBatch loss: {:.6f}'.format(
                batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.mean()))
    print('Epoch average loss {:.6f}'.format(total_loss/len(train_loader.dataset)))

    plt.hist(all_pos_dist.data.numpy(), 20, alpha = 0.5, label = 'pos')
    plt.hist(all_neg_dist.data.numpy(), 20, alpha = 0.5, label = 'neg')
    plt.title('Distribution of distances for positive and negative pairs')
    plt.legend(loc='upper right')
    plt.show()

    
def test(model, test_loader, step = 100):
    """
    Compute accuracy on the test set
    model: network
    test_loader: test_loader loading images and labels in batches
    step: step to iterate over images (for faster evaluation)
    """

    model.eval()
    des = torch.Tensor()
    labels = torch.LongTensor()
    for batch_idx, (data, target) in enumerate(test_loader):
        des = torch.cat((des, model(data)))
        labels = torch.cat((labels, target))
    
    # compute all pair-wise distances
    cdistances = cdist(des.data.numpy(), des.data.numpy(), 'euclidean')
        
    # find rank of closest positive image (using each descriptor as a query)
    minrank_positive = []
    for i in range(0,len(cdistances),step):
        idx = np.argsort(cdistances[i])
        minrank_positive.append( np.min([j for (j,x) in enumerate(labels[idx[1:-1]]) if x==labels[i]]) )
    
    print('Validation: At-least-1-pos@1 {:.3f}'.format((np.array(minrank_positive) <1).mean()))
    print('Validation: At-least-1-pos@3 {:.3f}'.format((np.array(minrank_positive) <3).mean()))


In [11]:
# loader of the training set in pairs
train_loader_pairs = torch.utils.data.DataLoader(mnist_dataset_train_pairs,batch_size=64, shuffle=True)
#loader of the test set (no pairs here)
test_loader = torch.utils.data.DataLoader(mnist_dataset_test,batch_size=512, shuffle=False)

model = MnistNetEmb() # initialize the network

print('Validating the randomly initialized network')
test(model, test_loader) # test the randomly initialized network


Validating the randomly initialized network


KeyboardInterrupt: 

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print('Training with Contrastive loss and training pairs')
# train with contrastive loss
contrastive_margin = 0.9
for epoch in range(1, 10 + 1):
        print('Epoch {}'.format(epoch))
        train2(model, train_loader_pairs, optimizer, contrastive_margin)
        test(model, test_loader)