# IMPORT REQUIRED LIBRARIES

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
import numpy as np
from torch.utils.data import Dataset
from torch.optim import lr_scheduler
import torch.optim as optim
from PIL import Image
import time

# LOAD DATASET AND NORMALIZE

In [3]:
mean, std = 0.1307, 0.3081

train_dataset = MNIST('./Datasets/MNIST', train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((mean,), (std,))
                             ]))
test_dataset = MNIST('./Datasets/MNIST', train=False, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((mean,), (std,))
                            ]))

# TRIPLET LOSS FUNCTION

In [4]:
class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin
    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()


# CUSTOMIZE MNIST TO BE TRIPLETS

In [5]:
class TripletMNIST(Dataset):
    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.train_labels = self.mnist_dataset.train_labels
            self.train_data = self.mnist_dataset.train_data
            self.labels_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.labels_set}

        else:
            self.test_labels = self.mnist_dataset.test_labels
            self.test_data = self.mnist_dataset.test_data
            # generate fixed triplets for testing
            self.labels_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            triplets = [[i,
                         random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                         random_state.choice(self.label_to_indices[
                                                 np.random.choice(
                                                     list(self.labels_set - set([self.test_labels[i].item()]))
                                                 )
                                             ])
                         ]
                        for i in range(len(self.test_data))]
            self.test_triplets = triplets

    def __getitem__(self, index):
        if self.train:
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            positive_index = index
            while positive_index == index:
                positive_index = np.random.choice(self.label_to_indices[label1])
            negative_label = np.random.choice(list(self.labels_set - set([label1])))
            negative_index = np.random.choice(self.label_to_indices[negative_label])
            img2 = self.train_data[positive_index]
            img3 = self.train_data[negative_index]
        else:
            img1 = self.test_data[self.test_triplets[index][0]]
            img2 = self.test_data[self.test_triplets[index][1]]
            img3 = self.test_data[self.test_triplets[index][2]]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        img3 = Image.fromarray(img3.numpy(), mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
        return (img1, img2, img3), []

    def __len__(self):
        return len(self.mnist_dataset)


# NEWORK ARCHITECTURE

In [6]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.convnet = nn.Sequential(nn.Conv2d(1, 32, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2),
                                     nn.Conv2d(32, 64, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2))

        self.fc = nn.Sequential(nn.Linear(64 * 4 * 4, 256),
                                nn.PReLU(),
                                nn.Linear(256, 256),
                                nn.PReLU(),
                                nn.Linear(256, 2)
                                )

    def forward(self, x):
        output = self.convnet(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

    def get_embedding(self, x):
        return self.forward(x)

# TRIPLET WRAPPER

In [7]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2=None, x3=None):
        if x2 is None and x3 is None:
            return self.embedding_net(x1)
        return self.embedding_net(x1),self.embedding_net(x2),self.embedding_net(x3)

    def get_embedding(self, x):
        return self.embedding_net(x)

# SETUPS

In [8]:
triplet_train_dataset = TripletMNIST(train_dataset) 
triplet_test_dataset = TripletMNIST(test_dataset)
batch_size = 128
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
triplet_train_loader = torch.utils.data.DataLoader(triplet_train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   **kwargs)
triplet_test_loader = torch.utils.data.DataLoader(triplet_test_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  **kwargs)
margin = 1.
embedding_net = EmbeddingNet()
model = TripletNet(embedding_net)
if cuda:
    model.cuda()
loss_fn = TripletLoss(margin)
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 20
log_interval = 100



# TRAIN

In [13]:
def fit(train_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval, metrics=[],
        start_epoch=0):
    for epoch in range(0, start_epoch):
        scheduler.step()

    for epoch in range(start_epoch, n_epochs):
        scheduler.step()
        train_loss, metrics = train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics)
        message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(epoch + 1, n_epochs, train_loss)
        for metric in metrics:
            message += '\t{}: {}'.format(metric.name(), metric.value())
        print(message)


def train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics):
    for metric in metrics:
        metric.reset()

    model.train()
    losses = []
    total_loss = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        target = target if len(target) > 0 else None
        if not type(data) in (tuple, list):
            data = (data,)
        if cuda:
            data = tuple(d.cuda() for d in data)
            if target is not None:
                target = target.cuda()


        optimizer.zero_grad()
        outputs = model(*data)

        if type(outputs) not in (tuple, list):
            outputs = (outputs,)

        loss_inputs = outputs
        if target is not None:
            target = (target,)
            loss_inputs += target

        loss_outputs = loss_fn(*loss_inputs)
        loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
        losses.append(loss.item())
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

        for metric in metrics:
            metric(outputs, target, loss_outputs)

        if batch_idx % log_interval == 0:
            message = 'Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                batch_idx * len(data[0]), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), np.mean(losses))
            for metric in metrics:
                message += '\t{}: {}'.format(metric.name(), metric.value())

            print(message)
            losses = []

    total_loss /= (batch_idx + 1)
    return total_loss, metrics

In [14]:
fit(triplet_train_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval)

Epoch: 1/20. Train set: Average loss: 0.1284
Epoch: 2/20. Train set: Average loss: 0.0508
Epoch: 3/20. Train set: Average loss: 0.0315
Epoch: 4/20. Train set: Average loss: 0.0239
Epoch: 5/20. Train set: Average loss: 0.0221
Epoch: 6/20. Train set: Average loss: 0.0166
Epoch: 7/20. Train set: Average loss: 0.0088
Epoch: 8/20. Train set: Average loss: 0.0057
Epoch: 9/20. Train set: Average loss: 0.0049
Epoch: 10/20. Train set: Average loss: 0.0043
Epoch: 11/20. Train set: Average loss: 0.0033
Epoch: 12/20. Train set: Average loss: 0.0035
Epoch: 13/20. Train set: Average loss: 0.0035
Epoch: 14/20. Train set: Average loss: 0.0027
Epoch: 15/20. Train set: Average loss: 0.0024
Epoch: 16/20. Train set: Average loss: 0.0021
Epoch: 17/20. Train set: Average loss: 0.0022
Epoch: 18/20. Train set: Average loss: 0.0019
Epoch: 19/20. Train set: Average loss: 0.0020
Epoch: 20/20. Train set: Average loss: 0.0026


# VALIDATE

In [16]:
torch.save(model,"Models/tripletMNIST.pt")

In [17]:
model_loaded = torch.load("Models/tripletMNIST.pt")

In [18]:
import torch
import torch.nn.functional as F

def evaluate_model(model, triplet_test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        start = time.time()
        for (anchor, positive, negative),[] in triplet_test_loader:
            anchor_embedding,positive_embedding,negative_embedding = model(anchor,positive,negative)
            positive_distance = F.pairwise_distance(anchor_embedding, positive_embedding)
            negative_distance = F.pairwise_distance(anchor_embedding, negative_embedding)
            correct += torch.sum(positive_distance < negative_distance).item()
            total += anchor.size(0)
    accuracy = correct / total
    print('Accuracy : {:.2f}%\nTime : {:.2f} SECONDS'.format(accuracy * 100,time.time()-start))

In [19]:
evaluate_model(model_loaded,triplet_test_loader)

Accuracy : 99.67%
Time : 7.93 SECONDS


# WITH FACEBOOK AI SIMILARITY SEARCH

In [20]:
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=True, **kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, **kwargs)

In [25]:
import faiss

In [26]:
embs1 = None
labels1 = []
for idx,i in enumerate(train_loader):
    if idx==10000: break
    I, L = i
    labels1.append(L)
    emb = model_loaded(I) # Assuming `model_loaded(I)` returns a PyTorch tensor
    emb = emb.detach()
    if embs1 is None:
        embs1 = emb
    else:
        embs1 = torch.cat((embs1, emb), dim=0)

In [27]:
embs2 = None
labels2 = []
for i in test_loader:
    I, L = i
    labels2.append(L)
    emb = model_loaded(I)
    if embs2 is None:
        embs2 = emb
    else:
        embs2 = torch.cat((embs2, emb), dim=0)

In [29]:
embs = embs1

In [44]:
index1 = faiss.IndexFlatL2(embs.shape[1])  # Assuming embs.shape[1] represents the dimensionality of the embeddings
index1.add(embs)

nlist = 100  # Number of cells/buckets
quantizer = faiss.IndexFlatL2(embs.shape[1])  # Quantizer index (same as IndexFlatL2)
index2 = faiss.IndexIVFFlat(quantizer, embs.shape[1], nlist)
index2.train(embs)
index2.add(embs)

index3 = faiss.IndexHNSWFlat(embs.shape[1], 32)  # M = 32 for the HNSW index
index3.add(embs)

nbits = 8  # Number of bits for the LSH hash
index4 = faiss.IndexLSH(embs.shape[1], nbits)
index4.add(embs)


# EVALUATE WITH FAISS

In [31]:
def evaluatewithfaiss(embs,index):
    TOTAL = len(embs)
    CORRECT = 0
    start = time.time()
    for idx,emb in enumerate(embs):
        label = index.search(emb.detach().reshape(1,-1),1)[1][0][0]
        CORRECT += labels1[label]==labels2[idx]
    return (CORRECT/TOTAL*100).item(),f'{time.time()-start} SECONDS'
        

In [32]:
print(f'IndexFlatL2 : {evaluatewithfaiss(embs2,index1)}')
print(f'IndexIVFFlat : {evaluatewithfaiss(embs2,index2)}')
print(f'IndexHNSWFlat : {evaluatewithfaiss(embs2,index3)}')
print(f'IndexLSH : {evaluatewithfaiss(embs2,index4)}')

IndexFlatL2 : (98.8499984741211, '0.8810124397277832 SECONDS')
IndexIVFFlat : (98.81999969482422, '0.41861629486083984 SECONDS')
IndexHNSWFlat : (98.8499984741211, '0.5943887233734131 SECONDS')
IndexLSH : (87.05000305175781, '0.948662281036377 SECONDS')
