<a href="https://colab.research.google.com/github/Amaljayaranga/ContrastiveLoss/blob/master/HardMining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from itertools import combinations
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from argparse import ArgumentParser
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F

parser = ArgumentParser(description='Simase Network')
parser.add_argument('--learning_batch_size', type=int, default=16)
parser.add_argument('--fc_in_features', type=int, default=512)
parser.add_argument('--fc_out_features', type=int, default=64)
parser.add_argument('--constractive_loss_margin', type=float, default=1.0)
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--weight_decay', type=float, default=1e-5)
parser.add_argument('--validation_split', type=float, default=0.2)
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--device', type=str, default='cpu')


args, unknown = parser.parse_known_args()

DEVICE = args.device
if not torch.cuda.is_available():
    DEVICE = 'cpu'

train_dataset = MNIST('../data/MNIST', train=True, download=True,
                             transform = transforms.Compose([
                                 transforms.ToTensor()
                             ]))

test_dataset = MNIST('../data/MNIST', train=False, download=True,
                            transform = transforms.Compose([
                                transforms.ToTensor()
                            ]))


split = int(np.floor(args.validation_split * len(train_dataset)))
indices = list(range(len(train_dataset)))
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

dataloaders =[]
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = args.learning_batch_size,
                                                   sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(train_dataset, batch_size = args.learning_batch_size,
                                                  sampler=valid_sampler)
dataloaders.append(train_loader)
dataloaders.append(val_loader)


def pdist(vectors):
    distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
        dim=1).view(-1, 1)
    return distance_matrix


class HardNegativePairSelector():

    def __init__(self):
        super(HardNegativePairSelector, self).__init__()

    def get_pairs(self, embeddings, labels):

        distance_matrix = pdist(embeddings)

        labels = labels.cpu().data.numpy()
        all_pairs = np.array(list(combinations(range(len(labels)), 2)))
        all_pairs = torch.LongTensor(all_pairs)
        positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()]
        negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()]

        negative_distances = distance_matrix[negative_pairs[:, 0], negative_pairs[:, 1]]
        negative_distances = negative_distances.cpu().data.numpy()
        top_negatives = np.argpartition(negative_distances, len(positive_pairs))[:len(positive_pairs)]
        top_negative_pairs = negative_pairs[torch.LongTensor(top_negatives)]

        return positive_pairs, top_negative_pairs

class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.convnet = nn.Sequential(nn.Conv2d(1, 32, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2),
                                     nn.Conv2d(32, 64, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2))

        self.fc = nn.Sequential(nn.Linear(64 * 4 * 4, 256),
                                nn.PReLU(),
                                nn.Linear(256, 256),
                                nn.PReLU(),
                                nn.Linear(256, 2)
                                )

    def forward(self, x):
        output = self.convnet(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

    def get_embedding(self, x):
        return self.forward(x)


class SiameseNet(nn.Module):
    def __init__(self, embedding_net):
        super(SiameseNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2):
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)
        return output1, output2

    def get_embedding(self,x):
        return self.embedding_net(x)

class ContrastiveLoss(nn.Module):

    def __init__(self, margin):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, target):
        eq_distance = F.pairwise_distance(output1, output2)
        loss = 0.5 * (1 - target.float()) * torch.pow(eq_distance, 2) + \
               0.5 * target.float() * torch.pow(torch.clamp(self.margin - eq_distance, min=0.00), 2)
        return loss.mean()

embedding_net = EmbeddingNet()
hardS = HardNegativePairSelector()

model = SiameseNet(embedding_net)
model = model.to(DEVICE)

if args.mode == 'train':

    criterion = ContrastiveLoss(margin=args.constractive_loss_margin)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    print('training started')
    minimum_loss = float('Inf')

    training_losses = []
    validation_losses = []
    epochs = []

    for epoch in range(1, args.num_epochs + 1):
        epochs.append(epoch)

        for dataloader in enumerate(dataloaders):
            stage = ''

            if dataloader == val_loader:
                model = model.eval()
            else:
                model = model.train()

            epoch_loss = []

            for dataloader_idx, dataloader in enumerate(dataloaders):

                if dataloader == train_loader:
                    for batch_idx, batch in enumerate(dataloader):
                        images, labels = batch[0], batch[1]
                        embeddings = model.get_embedding(images)
                        positive_pairs, top_negative_pairs = hardS.get_pairs(embeddings, labels)

                        for pair in positive_pairs:
                            print(embeddings[pair[0]])
                            loss = criterion(embeddings[pair[0]], embeddings[pair[1]], 0)
                            print(loss)


               

                    #loss.backward()
                    #optimizer.step()
                    #optimizer.zero_grad()

                #epoch_loss.append(loss.to('cpu').item())

            if dataloader == train_loader:
                training_losses.append(np.mean(epoch_loss))
                stage = 'train'
            else:
                validation_losses.append(np.mean(epoch_loss))
                stage = 'eval'

            if dataloader == train_loader:
                print(
                    f'epoch: {epoch} stage: {stage} loss: {np.mean(epoch_loss)}')
            else:
                print(f'epoch: {epoch} stage: {stage} loss: {np.mean(epoch_loss)}')

            if dataloader == val_loader:
                if minimum_loss > np.mean(validation_losses):
                    minimum_loss = np.mean(validation_losses)
                    torch.save(model.to('cpu'), 'simase_best.pt')
                    model = model.to('cuda')
                    print('Model is saving')

    plt.plot(epochs, training_losses, label="train")
    plt.plot(epochs, validation_losses, label="eval")
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.title('training vs validation loss')
    plt.legend()
    plt.show()

















Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ../data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/MNIST/raw
Processing...
Done!
training started
tensor([ 0.0228, -0.0902], grad_fn=<SelectBackward>)


IndexError: ignored