# IMPORT REQUIRED LIBRARIES

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
import numpy as np
from torch.utils.data import Dataset
from torch.optim import lr_scheduler
import torch.optim as optim
from PIL import Image
import time

# LOAD DATASET AND NORMALIZE

In [2]:
mean, std = 0.1307, 0.3081

train_dataset = MNIST('./Datasets/MNIST', train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((mean,), (std,))
                             ]))
test_dataset = MNIST('./Datasets/MNIST', train=False, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((mean,), (std,))
                            ]))

# TRIPLET LOSS FUNCTION

In [3]:
class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin
    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()


# CUSTOMIZE MNIST TO BE TRIPLETS

In [4]:
class TripletMNIST(Dataset):
    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.train_labels = self.mnist_dataset.train_labels
            self.train_data = self.mnist_dataset.train_data
            self.labels_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.labels_set}

        else:
            self.test_labels = self.mnist_dataset.test_labels
            self.test_data = self.mnist_dataset.test_data
            # generate fixed triplets for testing
            self.labels_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            triplets = [[i,
                         random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                         random_state.choice(self.label_to_indices[
                                                 np.random.choice(
                                                     list(self.labels_set - set([self.test_labels[i].item()]))
                                                 )
                                             ])
                         ]
                        for i in range(len(self.test_data))]
            self.test_triplets = triplets

    def __getitem__(self, index):
        if self.train:
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            positive_index = index
            while positive_index == index:
                positive_index = np.random.choice(self.label_to_indices[label1])
            negative_label = np.random.choice(list(self.labels_set - set([label1])))
            negative_index = np.random.choice(self.label_to_indices[negative_label])
            img2 = self.train_data[positive_index]
            img3 = self.train_data[negative_index]
        else:
            img1 = self.test_data[self.test_triplets[index][0]]
            img2 = self.test_data[self.test_triplets[index][1]]
            img3 = self.test_data[self.test_triplets[index][2]]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        img3 = Image.fromarray(img3.numpy(), mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
        return (img1, img2, img3), []

    def __len__(self):
        return len(self.mnist_dataset)


# NEWORK ARCHITECTURE

In [5]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.convnet = nn.Sequential(nn.Conv2d(1, 32, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2),
                                     nn.Conv2d(32, 64, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2))

        self.fc = nn.Sequential(nn.Linear(64 * 4 * 4, 256),
                                nn.PReLU(),
                                nn.Linear(256, 256),
                                nn.PReLU(),
                                nn.Linear(256, 2)
                                )

    def forward(self, x):
        output = self.convnet(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

    def get_embedding(self, x):
        return self.forward(x)

# TRIPLET WRAPPER

In [6]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2=None, x3=None):
        if x2 is None and x3 is None:
            return self.embedding_net(x1)
        return self.embedding_net(x1),self.embedding_net(x2),self.embedding_net(x3)

    def get_embedding(self, x):
        return self.embedding_net(x)

# SETUPS

In [8]:
cuda = torch.cuda.is_available()
triplet_train_dataset = TripletMNIST(train_dataset) 
triplet_test_dataset = TripletMNIST(test_dataset)
batch_size = 128
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
triplet_train_loader = torch.utils.data.DataLoader(triplet_train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   **kwargs)
triplet_test_loader = torch.utils.data.DataLoader(triplet_test_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  **kwargs)
margin = 1.
embedding_net = EmbeddingNet()
model = TripletNet(embedding_net)
if cuda:
    model.cuda()
loss_fn = TripletLoss(margin)
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 20
log_interval = 100

# TRAIN

In [10]:
def fit(train_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval, metrics=[],
        start_epoch=0):
    for epoch in range(0, start_epoch):
        scheduler.step()

    for epoch in range(start_epoch, n_epochs):
        scheduler.step()
        train_loss, metrics = train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics)
        message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(epoch + 1, n_epochs, train_loss)
        for metric in metrics:
            message += '\t{}: {}'.format(metric.name(), metric.value())
        print(message)


def train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics):
    for metric in metrics:
        metric.reset()

    model.train()
    losses = []
    total_loss = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        target = target if len(target) > 0 else None
        if not type(data) in (tuple, list):
            data = (data,)
        if cuda:
            data = tuple(d.cuda() for d in data)
            if target is not None:
                target = target.cuda()


        optimizer.zero_grad()
        outputs = model(*data)
        print(outputs)
        if type(outputs) not in (tuple, list):
            outputs = (outputs,)

        loss_inputs = outputs
        if target is not None:
            target = (target,)
            loss_inputs += target

        loss_outputs = loss_fn(*loss_inputs)
        loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
        losses.append(loss.item())
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

        for metric in metrics:
            metric(outputs, target, loss_outputs)

        if batch_idx % log_interval == 0:
            message = 'Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                batch_idx * len(data[0]), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), np.mean(losses))
            for metric in metrics:
                message += '\t{}: {}'.format(metric.name(), metric.value())

            print(message)
            losses = []

    total_loss /= (batch_idx + 1)
    return total_loss, metrics

In [11]:
fit(triplet_train_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval)



(tensor([[ 1.0531e-01,  3.4335e-03],
        [ 1.3939e-01, -2.1750e-02],
        [ 1.2765e-01, -7.6256e-03],
        [ 1.2166e-01,  2.2495e-02],
        [ 1.2138e-01, -2.0838e-02],
        [ 1.1317e-01, -4.8914e-03],
        [ 1.6028e-01,  1.8080e-02],
        [ 1.6362e-01, -3.2913e-04],
        [ 1.1883e-01,  7.8362e-03],
        [ 1.4764e-01, -8.3395e-03],
        [ 1.4234e-01, -2.3863e-02],
        [ 1.3170e-01, -1.4658e-02],
        [ 9.3261e-02, -2.7425e-02],
        [ 1.4065e-01,  1.9992e-02],
        [ 1.4056e-01, -1.2661e-02],
        [ 1.0236e-01, -7.7702e-03],
        [ 1.4352e-01, -6.7864e-03],
        [ 1.2343e-01,  1.9241e-03],
        [ 1.4363e-01, -6.5379e-03],
        [ 1.4508e-01, -1.8486e-02],
        [ 1.3632e-01, -2.0281e-02],
        [ 1.1683e-01, -2.2008e-02],
        [ 1.2362e-01, -7.3279e-03],
        [ 1.3669e-01, -1.5354e-02],
        [ 1.4627e-01,  2.1785e-02],
        [ 1.1324e-01,  2.2321e-03],
        [ 9.5048e-02, -1.6395e-02],
        [ 1.0324e-01, -1.45

(tensor([[ 0.4219,  0.1137],
        [ 0.1739,  0.0478],
        [ 0.9358, -0.5873],
        [ 0.3841,  0.2284],
        [ 0.9660, -0.6408],
        [ 0.7597, -0.6004],
        [ 0.9041, -0.7323],
        [ 0.4141,  0.2167],
        [ 0.1409,  0.3451],
        [ 0.5795,  0.0287],
        [ 0.1423,  0.2668],
        [ 0.6176, -0.2888],
        [ 0.6977, -0.5762],
        [ 0.4338, -0.2806],
        [ 0.3588,  0.0344],
        [ 0.7611, -0.1310],
        [ 0.4937, -0.0872],
        [ 0.9056, -0.8861],
        [ 0.5704, -0.1692],
        [ 0.3514, -0.0852],
        [ 0.5443, -0.0234],
        [ 0.4358,  0.1272],
        [ 0.7088, -0.5388],
        [ 0.4707,  0.0578],
        [ 0.4676,  0.0759],
        [ 0.5175,  0.1918],
        [ 0.9420, -0.4344],
        [ 0.5474,  0.1042],
        [ 0.4104,  0.2449],
        [ 0.3332,  0.1584],
        [ 0.2800,  0.0736],
        [ 0.8570, -0.7227],
        [ 0.2259, -0.0114],
        [ 0.1603,  0.0777],
        [ 0.7839, -0.2970],
        [ 0.3271,  

(tensor([[ 1.6842, -0.4428],
        [ 1.1760, -0.8451],
        [ 1.0182, -0.1995],
        [ 0.9178, -0.3345],
        [ 2.5919, -2.2376],
        [ 1.0611,  0.0719],
        [ 1.2087,  0.0727],
        [ 1.4727, -0.4693],
        [ 1.0140, -0.0666],
        [ 0.9339,  0.5172],
        [ 2.3276, -1.7123],
        [ 0.9057, -0.0308],
        [ 1.7843, -0.2610],
        [ 1.5741,  0.0357],
        [ 1.6130, -0.8861],
        [ 0.7157,  0.5916],
        [ 1.6190,  0.1802],
        [ 0.9479, -0.5785],
        [ 0.7080,  1.0160],
        [ 1.5141, -0.4500],
        [ 1.3960, -0.5766],
        [ 1.7631, -0.4539],
        [ 1.5943, -0.8757],
        [ 0.8331,  0.0503],
        [ 0.7943,  0.1530],
        [ 1.2576, -0.1999],
        [ 1.4271, -0.3063],
        [ 1.5109, -0.0616],
        [ 0.8560, -0.1363],
        [ 1.0282,  0.3048],
        [ 2.2441, -2.0940],
        [ 1.4895,  0.0102],
        [ 2.1205, -1.1175],
        [ 2.2115, -1.0056],
        [ 1.2050,  0.2276],
        [ 1.6072, -

(tensor([[ 1.0521e+00,  7.7424e-02],
        [ 2.1949e+00, -8.5956e-02],
        [ 2.3716e+00,  5.3217e-01],
        [ 1.6955e+00,  7.7503e-01],
        [ 2.5021e+00, -3.5015e-01],
        [ 2.1106e+00,  4.7449e-01],
        [ 1.6977e+00,  9.6784e-02],
        [ 2.4783e+00,  1.5747e-01],
        [ 9.5484e-01,  4.0203e-01],
        [ 1.9530e+00,  2.2946e-01],
        [ 1.0820e+00,  7.2085e-01],
        [ 2.1703e+00, -7.9241e-01],
        [ 1.5686e+00,  6.2575e-01],
        [ 2.7283e+00, -2.2294e-01],
        [ 1.6808e+00,  3.0098e-01],
        [ 2.6647e+00, -3.1108e-01],
        [ 1.5536e+00,  7.5742e-01],
        [ 2.2050e+00, -2.7334e-01],
        [ 2.4540e+00,  5.2346e-01],
        [ 1.2718e+00,  4.0579e-01],
        [ 1.2412e+00,  7.0682e-01],
        [ 2.4589e+00, -3.0087e-01],
        [ 1.4972e+00,  3.4243e-01],
        [ 2.4492e+00,  6.8682e-01],
        [ 1.1075e+00,  1.1419e+00],
        [ 1.0491e+00,  8.5051e-01],
        [ 2.2798e+00, -1.3726e-01],
        [ 1.8875e+00,  1.07

(tensor([[ 2.0756,  1.6792],
        [ 4.3001,  1.0085],
        [ 3.5782,  0.3328],
        [ 3.6480,  0.3457],
        [ 2.0100,  1.3731],
        [ 3.2726,  0.7211],
        [ 1.8765,  2.0961],
        [ 3.9337,  0.9713],
        [ 2.2278,  0.6726],
        [ 4.3181,  0.7359],
        [ 2.3720,  1.0366],
        [ 1.9920,  0.5713],
        [ 2.8561,  0.8398],
        [ 3.6885,  0.4396],
        [ 3.2919,  0.7752],
        [ 3.7837,  1.7323],
        [ 2.5451,  1.2537],
        [ 2.3155,  0.9277],
        [ 3.9601,  0.9269],
        [ 3.9784,  1.2342],
        [ 2.3726,  2.2505],
        [ 3.3685,  0.7838],
        [ 2.0280,  0.6718],
        [ 1.4389,  0.0503],
        [ 3.9818,  1.7662],
        [ 3.9600, -0.0784],
        [ 3.4865,  0.3577],
        [ 3.2393, -0.1631],
        [ 3.4714, -0.4125],
        [ 4.1146,  0.7518],
        [ 1.2256, -0.1379],
        [ 3.8885,  1.0407],
        [ 2.3133,  1.3499],
        [ 4.5005,  1.1215],
        [ 3.5210, -0.2452],
        [ 3.9463, -

(tensor([[ 5.6951e+00,  1.4124e+00],
        [ 3.6366e+00,  2.9411e+00],
        [ 3.0811e+00,  2.3624e+00],
        [ 2.9600e+00,  2.0473e+00],
        [ 4.8081e+00,  9.4403e-01],
        [ 5.6448e+00,  2.1451e+00],
        [ 5.1367e+00,  2.3489e+00],
        [ 3.3640e+00,  1.4003e-01],
        [ 4.0194e+00,  5.2490e-01],
        [ 5.0994e+00,  2.2743e+00],
        [ 3.5667e+00,  2.3763e+00],
        [ 3.1303e+00,  1.3445e+00],
        [ 5.3049e+00,  3.0367e+00],
        [ 4.4606e+00,  3.1484e-01],
        [ 3.1420e+00,  1.1586e+00],
        [ 5.3763e+00,  1.6375e+00],
        [ 2.4698e+00,  8.4123e-01],
        [ 4.9573e+00,  5.3271e-01],
        [ 3.3466e+00,  1.7762e+00],
        [ 4.4877e+00,  1.5913e+00],
        [ 5.8681e+00,  1.8688e+00],
        [ 3.3219e+00,  7.3058e-01],
        [ 2.5127e+00,  6.6574e-01],
        [ 4.3700e+00,  1.1671e+00],
        [ 3.6395e+00,  5.1203e-01],
        [ 4.2728e+00,  2.2216e+00],
        [ 3.7373e+00,  1.6478e+00],
        [ 3.2911e+00,  1.20

(tensor([[ 4.7529,  0.5019],
        [ 2.0263,  0.2024],
        [ 2.0333, -0.6352],
        [ 2.7725, -0.6603],
        [ 3.0230,  0.2723],
        [ 4.7322,  1.0282],
        [ 3.1204,  2.3469],
        [ 2.2555,  0.2147],
        [ 3.7537, -0.4404],
        [ 2.8102,  0.5400],
        [ 2.4929,  0.6525],
        [ 5.0296,  0.4796],
        [ 3.1493,  0.2453],
        [ 4.4331,  0.8396],
        [ 4.1154,  1.5654],
        [ 4.2104,  0.6166],
        [ 2.0331,  0.4127],
        [ 3.3787,  0.4287],
        [ 2.5652,  1.2420],
        [ 3.9881,  0.4630],
        [ 3.9089, -0.7265],
        [ 3.2641,  0.5883],
        [ 3.8796,  1.4433],
        [ 1.8866, -0.2538],
        [ 3.1549,  1.3168],
        [ 1.8900,  0.3079],
        [ 2.7979, -0.3730],
        [ 2.6430,  1.0402],
        [ 3.2111, -0.7018],
        [ 2.8784,  1.0205],
        [ 3.4692, -0.3092],
        [ 2.2049,  0.5177],
        [ 3.5009,  0.1690],
        [ 5.0721,  1.7204],
        [ 3.6531,  0.6205],
        [ 4.2164,  

(tensor([[ 3.5645, -0.6240],
        [ 2.8318,  1.1191],
        [ 3.0951, -1.0801],
        [ 3.4403, -0.1423],
        [ 3.9394, -0.8257],
        [ 1.7809,  0.0847],
        [ 2.5373, -0.9233],
        [ 3.8168, -0.5196],
        [ 3.1905, -0.6660],
        [ 3.9209,  0.4927],
        [ 3.1947, -0.3367],
        [ 3.3142, -1.3071],
        [ 3.0932,  0.3850],
        [ 3.6931, -0.6784],
        [ 3.3932, -0.6105],
        [ 3.8300, -0.2008],
        [ 3.8876, -0.4356],
        [ 2.5173,  0.1389],
        [ 2.6557,  0.5606],
        [ 3.4135, -0.9802],
        [ 3.6040, -1.3170],
        [ 3.2151, -0.7048],
        [ 2.9253, -0.3556],
        [ 2.0336,  0.2379],
        [ 4.0062, -0.1253],
        [ 1.9735,  0.1809],
        [ 4.0661, -0.2064],
        [ 2.9665,  1.4572],
        [ 3.1431, -0.6488],
        [ 3.9727,  0.8612],
        [ 4.0443, -0.1665],
        [ 2.5670,  0.2792],
        [ 4.1398, -0.5663],
        [ 4.7551, -0.1908],
        [ 4.1691, -0.3520],
        [ 2.1561, -

(tensor([[ 1.7598e+00,  5.4398e-01],
        [ 3.0882e+00, -4.2061e-01],
        [ 2.6660e+00, -6.4831e-01],
        [ 3.1064e+00, -1.0815e+00],
        [ 3.3068e+00, -1.4078e+00],
        [ 3.4147e+00, -9.9312e-01],
        [ 2.2224e+00,  5.6695e-01],
        [ 3.1338e+00, -1.7033e+00],
        [ 1.7988e+00, -4.1713e-01],
        [ 3.4872e+00, -1.5196e+00],
        [ 3.8294e+00, -1.1612e-01],
        [ 3.8180e+00, -7.4138e-01],
        [ 1.6940e+00, -1.1810e+00],
        [ 2.9208e+00, -9.2526e-01],
        [ 2.3601e+00,  1.7978e-01],
        [ 4.6640e+00, -6.4115e-01],
        [ 3.6438e+00,  7.3244e-01],
        [ 1.6813e+00, -1.3613e+00],
        [ 2.7863e+00, -1.3122e+00],
        [ 3.4453e+00, -9.2731e-01],
        [ 3.6225e+00, -9.3938e-01],
        [ 3.1552e+00, -4.7798e-01],
        [ 3.5087e+00, -1.1899e+00],
        [ 1.8402e+00,  5.8151e-01],
        [ 1.4448e+00, -1.4960e+00],
        [ 1.5745e+00, -1.0550e-01],
        [ 4.2432e+00, -5.3562e-01],
        [ 3.6892e+00, -4.40

(tensor([[ 3.3919e+00, -2.7001e+00],
        [ 2.3856e+00, -2.4748e+00],
        [ 9.9154e-01, -2.1219e+00],
        [ 1.3433e+00, -6.7020e-01],
        [ 2.1545e+00, -3.2335e-01],
        [ 3.0170e+00, -1.5022e+00],
        [ 3.3308e+00, -8.8497e-01],
        [ 3.6516e+00, -9.4413e-01],
        [ 1.7700e+00, -6.8559e-01],
        [ 3.4961e+00, -1.4402e+00],
        [ 2.7393e+00, -4.8829e-03],
        [ 1.3002e+00, -4.0810e-01],
        [ 3.0867e+00, -1.5289e+00],
        [ 3.3594e+00, -1.8652e+00],
        [ 3.5540e+00, -1.6282e+00],
        [ 3.8998e+00, -1.0818e+00],
        [ 1.6590e+00, -2.4221e-01],
        [ 3.5958e+00, -1.8220e+00],
        [ 1.2881e+00, -5.1472e-01],
        [ 2.9255e+00, -1.2675e+00],
        [ 3.2502e+00, -7.3605e-01],
        [ 1.2680e+00, -2.1169e+00],
        [ 1.2028e+00, -2.1956e+00],
        [ 3.5382e+00, -1.2672e+00],
        [ 3.1117e+00, -1.8632e-01],
        [ 1.8859e+00, -1.5115e+00],
        [ 2.8967e+00, -1.3282e+00],
        [ 2.8778e+00, -2.48

(tensor([[ 0.9819, -2.4980],
        [ 1.4780, -1.5197],
        [ 1.6548, -0.4351],
        [ 1.1318, -0.5025],
        [ 1.5329, -2.8845],
        [ 2.9231, -1.7118],
        [ 1.5471, -2.6856],
        [ 2.7951,  0.3365],
        [ 3.3740, -1.6584],
        [ 1.8217, -1.5045],
        [ 2.9395, -1.2362],
        [ 2.7946, -2.5216],
        [ 1.7388, -1.2440],
        [ 3.1995, -1.9923],
        [ 2.9878,  0.2511],
        [ 2.4383, -1.9581],
        [ 2.9464, -2.2017],
        [ 2.7064,  0.5753],
        [ 1.8065, -2.0111],
        [ 2.9328, -1.2987],
        [ 2.9363,  0.1071],
        [ 2.2974, -3.4364],
        [ 0.9581, -0.7885],
        [ 3.6555, -2.3680],
        [ 1.5939, -1.9176],
        [ 3.0583,  0.2784],
        [ 1.0766, -2.2311],
        [ 2.2021, -2.8039],
        [ 3.6226, -1.2744],
        [ 0.9513, -0.5875],
        [ 1.1579, -2.8427],
        [ 2.9337, -1.4569],
        [ 2.0379, -1.5386],
        [ 2.4201, -0.9728],
        [ 2.7930, -0.3732],
        [ 3.3201, -

(tensor([[ 1.4787, -0.3460],
        [ 1.4215, -1.6596],
        [ 1.4608, -0.7222],
        [ 1.6095, -1.9871],
        [ 1.3100, -0.7343],
        [ 2.3879, -1.5201],
        [ 1.1910, -2.2026],
        [ 1.1000, -2.8987],
        [ 1.1319, -1.0067],
        [ 2.5178, -2.3097],
        [ 1.2157, -2.6905],
        [ 2.7057, -3.6529],
        [ 2.9194, -2.6605],
        [ 1.5469, -1.0781],
        [ 1.6865, -2.2403],
        [ 2.5685, -3.2191],
        [ 1.0485, -2.5661],
        [ 1.3055, -0.5442],
        [ 1.3891, -0.9023],
        [ 2.7898, -2.4877],
        [ 2.4775, -2.3305],
        [ 2.9033, -3.4798],
        [ 2.5202, -2.8881],
        [ 2.3769, -2.2930],
        [ 1.6910, -2.1947],
        [ 3.2451, -1.9515],
        [ 1.3781, -1.2314],
        [ 1.6909, -2.9758],
        [ 1.7883, -1.1868],
        [ 2.4953, -1.5253],
        [ 2.7443, -0.8547],
        [ 2.1694, -2.0164],
        [ 2.3457, -2.4161],
        [ 2.6838, -1.7888],
        [ 2.2309, -2.8707],
        [ 2.2423, -

(tensor([[ 1.0660, -2.5377],
        [ 2.0588, -3.0192],
        [ 1.1627, -1.5440],
        [ 2.0566, -1.1079],
        [ 2.6551, -2.1728],
        [ 2.0292, -1.9998],
        [ 1.4084, -0.5085],
        [ 1.2669, -2.4269],
        [ 2.8491, -2.2168],
        [ 1.1866, -0.4643],
        [ 2.8095, -1.5597],
        [ 1.3226, -2.0849],
        [ 2.5075, -2.3461],
        [ 0.7614, -2.3726],
        [ 2.4551, -3.3253],
        [ 2.2699, -2.9554],
        [ 2.2845, -2.6515],
        [ 2.5653, -2.6346],
        [ 1.0755, -0.7230],
        [ 0.8026, -1.9826],
        [ 2.8413, -1.4483],
        [ 1.0863, -2.4202],
        [ 2.2481, -2.0919],
        [ 1.1100, -2.1304],
        [ 2.0936, -1.0094],
        [ 2.3968, -2.6872],
        [ 0.7280, -1.8592],
        [ 2.3643, -1.3519],
        [ 2.9316, -1.8040],
        [ 0.7452, -0.6607],
        [ 0.7780, -2.4035],
        [ 2.0716, -2.4706],
        [ 0.9053, -0.5979],
        [ 1.9323, -1.7726],
        [ 1.6348, -1.8964],
        [ 0.9993, -

(tensor([[ 1.6968e+00, -2.1798e+00],
        [ 1.8448e+00, -1.9671e+00],
        [ 2.2342e+00, -1.2476e+00],
        [ 1.9150e+00, -1.3516e+00],
        [ 1.8682e+00, -1.6837e+00],
        [ 5.7498e-01, -1.5410e+00],
        [ 8.3213e-01, -1.4838e-01],
        [ 1.9047e+00, -1.1732e+00],
        [ 1.8025e+00, -1.3741e+00],
        [ 1.4406e+00, -7.8370e-01],
        [ 1.5484e-01, -1.6150e+00],
        [ 5.1191e-01, -3.9763e-01],
        [ 4.7506e-01, -1.0889e+00],
        [ 5.5154e-02, -2.0440e-01],
        [ 5.8719e-01, -1.8397e+00],
        [ 9.6932e-02, -4.9774e-01],
        [ 2.3321e+00, -1.1315e+00],
        [ 1.1017e-02, -9.1249e-02],
        [ 4.1538e-01, -1.3154e+00],
        [ 2.2209e+00, -1.8289e+00],
        [ 7.8025e-01, -1.1340e+00],
        [ 1.9503e+00, -1.1959e+00],
        [ 3.7343e-01, -1.8280e+00],
        [ 1.0641e+00, -2.7229e-02],
        [ 1.8626e+00, -2.2633e+00],
        [ 1.1789e+00, -1.4850e+00],
        [ 5.8370e-01, -1.7143e+00],
        [ 5.4422e-01, -1.15

(tensor([[ 2.3574, -0.3813],
        [ 1.4028, -1.2630],
        [ 2.3609,  0.4770],
        [ 1.5547, -1.3866],
        [ 0.4953, -1.1564],
        [ 1.3705, -1.8698],
        [ 1.1825, -0.7161],
        [ 0.4466, -0.7742],
        [ 1.6240, -1.2101],
        [ 0.9446,  0.0379],
        [ 0.2804, -0.8726],
        [ 1.8828, -1.4247],
        [ 2.0307,  0.1046],
        [ 0.8328, -0.0352],
        [ 1.8440, -1.5573],
        [ 2.0583, -0.0625],
        [ 1.9413, -1.3427],
        [ 1.3785, -0.9673],
        [ 1.8424, -1.6451],
        [ 1.8739, -1.3339],
        [ 2.2218, -0.0190],
        [ 1.4300, -1.6041],
        [ 0.2286,  0.1409],
        [ 2.2319,  0.4859],
        [ 1.7106, -0.5925],
        [ 0.7316, -1.0466],
        [-0.3842, -1.3514],
        [ 0.3287, -0.2856],
        [ 0.9518, -1.2394],
        [ 1.2756,  0.0790],
        [ 0.3087, -1.0091],
        [ 1.8049, -0.4145],
        [ 0.4390,  0.1126],
        [ 2.6933, -0.6079],
        [ 1.0169, -0.1009],
        [ 0.8751, -

(tensor([[ 6.8019e-01,  1.6902e-01],
        [ 1.8506e+00,  1.1358e-01],
        [-2.3361e-01, -1.4301e-01],
        [ 7.0239e-01,  3.5294e-01],
        [ 1.5942e+00, -1.6036e+00],
        [ 1.9643e+00, -1.3666e+00],
        [ 2.5509e+00, -3.8919e-01],
        [ 4.3707e-01, -4.6340e-01],
        [ 8.3271e-03, -1.6359e+00],
        [ 1.7739e+00, -9.8943e-01],
        [ 2.8813e-01, -6.7075e-01],
        [ 1.9556e+00, -7.8542e-01],
        [ 1.4473e+00, -1.9485e+00],
        [ 2.2276e+00, -3.1265e-01],
        [ 1.3529e+00, -1.3799e+00],
        [-1.4985e-02, -1.4975e+00],
        [ 1.8071e+00, -1.9594e+00],
        [ 2.2463e-01, -7.2298e-01],
        [ 1.7102e+00, -9.6507e-01],
        [ 1.6291e+00, -1.9305e+00],
        [ 1.7292e+00, -8.1051e-01],
        [ 1.3209e+00, -4.3591e-01],
        [ 1.6109e+00, -1.3060e+00],
        [ 9.8361e-01, -2.4966e+00],
        [ 1.2899e-01, -1.2672e+00],
        [ 8.7970e-01,  2.9510e-01],
        [-3.2972e-02, -6.8591e-01],
        [ 1.4624e+00, -1.62

(tensor([[ 2.3043,  1.0075],
        [ 1.0290, -2.3840],
        [-0.1464, -0.6436],
        [ 0.4285,  0.1718],
        [-0.2293, -1.6947],
        [ 0.1420, -1.4414],
        [ 1.3333, -1.7958],
        [ 1.1912, -1.1661],
        [ 1.3551, -1.7814],
        [ 0.3225, -0.8095],
        [ 1.7474, -0.6803],
        [-0.2711, -1.4477],
        [ 1.8245, -1.9401],
        [ 0.9105, -1.5116],
        [ 0.7704,  0.3960],
        [ 0.2306, -0.2175],
        [ 0.3473, -1.5731],
        [ 1.2323, -1.5730],
        [ 1.6805, -1.9590],
        [ 0.9069, -0.6365],
        [ 1.5095,  1.2121],
        [ 1.0275, -0.3201],
        [ 1.0902, -1.2233],
        [ 1.4328, -1.6168],
        [ 1.6282, -1.0722],
        [ 2.0660,  0.8417],
        [ 0.6477,  0.4969],
        [-0.5892, -0.5622],
        [ 1.4469, -1.5931],
        [ 2.4721, -1.5157],
        [ 0.9929, -2.4045],
        [ 1.2272, -0.9854],
        [ 0.2869,  0.0840],
        [ 1.5592, -1.4528],
        [-0.3570, -1.3460],
        [ 1.8876, -

(tensor([[ 0.3977, -0.2344],
        [ 0.6150,  0.7454],
        [-0.0572, -1.8378],
        [ 1.5923, -0.3597],
        [ 0.3479, -1.2971],
        [ 0.9241, -0.7674],
        [ 1.7849,  1.3076],
        [-0.6224, -0.0873],
        [ 0.9623, -0.7888],
        [ 2.0887,  1.2820],
        [ 0.4333,  0.5338],
        [ 1.3345, -0.7315],
        [ 1.9351, -2.1474],
        [-0.3060, -0.8339],
        [ 1.7315, -2.0696],
        [ 2.1699,  0.2890],
        [ 0.0074, -0.9124],
        [ 0.7806,  0.9080],
        [ 2.0056, -1.7428],
        [ 1.3265, -1.5154],
        [ 1.2034, -1.4687],
        [ 1.1034, -2.1592],
        [ 2.0568,  0.0503],
        [ 1.6548, -1.3688],
        [ 1.4870,  1.1313],
        [ 1.8255, -0.4183],
        [ 2.5511, -1.3646],
        [ 1.2648, -0.9134],
        [ 1.2002,  0.9800],
        [ 2.0936, -0.7459],
        [ 1.0993, -1.1107],
        [ 2.0237,  0.3416],
        [ 1.9347, -0.8933],
        [ 1.0503, -1.3525],
        [ 2.0767,  0.9711],
        [ 0.4721,  

(tensor([[ 2.2425e+00, -4.3799e-01],
        [ 1.9981e+00,  1.0355e+00],
        [ 5.3301e-01,  7.8859e-01],
        [ 1.4701e+00, -2.1494e+00],
        [ 1.6096e+00,  8.0964e-01],
        [ 1.2758e+00, -1.3004e+00],
        [-8.5370e-01,  7.8188e-02],
        [-1.6991e-01, -1.0342e+00],
        [ 1.9776e+00, -8.9803e-01],
        [-5.6204e-01,  4.0321e-01],
        [ 2.2614e+00, -1.2301e+00],
        [-5.2128e-01,  5.9416e-01],
        [ 1.9712e+00, -1.4322e+00],
        [ 1.0821e-01, -7.9821e-01],
        [-4.9353e-01, -3.8455e-01],
        [ 2.0712e+00,  7.8431e-01],
        [ 2.0542e+00, -1.5009e+00],
        [ 1.2976e+00, -1.6028e+00],
        [ 2.3972e+00, -9.4976e-01],
        [ 6.6922e-02,  6.5988e-01],
        [ 1.3098e+00, -8.9754e-01],
        [-7.4117e-01, -1.4304e-01],
        [ 2.3564e+00,  6.8734e-01],
        [ 1.2137e+00,  4.0939e-02],
        [-7.7562e-01, -3.9350e-01],
        [ 1.5269e+00, -2.0444e+00],
        [ 1.2437e+00, -5.1926e-01],
        [-4.0220e-01, -9.02

(tensor([[-0.3355,  0.0823],
        [ 0.2408,  1.2201],
        [ 0.4723, -0.6423],
        [ 1.1809, -0.1854],
        [ 0.5190,  0.9322],
        [ 1.7602, -0.5395],
        [-0.3323, -0.1856],
        [ 2.0867, -0.5527],
        [ 2.5440,  0.3365],
        [ 2.3366,  1.0321],
        [ 0.7393, -0.2269],
        [ 1.7596,  1.0342],
        [ 0.5880, -0.5875],
        [ 1.3554, -1.7825],
        [ 1.9043, -1.7380],
        [ 1.6905,  0.3364],
        [ 1.7295, -1.3783],
        [ 2.0225, -0.2616],
        [ 2.0963,  0.9044],
        [ 0.4066,  0.6271],
        [ 1.9959, -0.2634],
        [-0.6818, -0.3900],
        [-0.0461, -0.9392],
        [ 0.2813, -1.5575],
        [ 0.1739,  0.8461],
        [ 1.9848, -0.2885],
        [ 2.3693, -1.9294],
        [ 0.0693,  0.9493],
        [ 0.5078,  1.1499],
        [ 2.2238, -1.5105],
        [ 0.0162, -1.2592],
        [ 1.6422, -0.8027],
        [ 0.2551,  0.5479],
        [ 2.1310, -1.6244],
        [ 2.5917, -1.2022],
        [ 1.1069, -

(tensor([[ 3.7054e-01, -8.2926e-01],
        [ 1.7904e+00, -2.7079e-01],
        [ 2.2984e+00, -6.9704e-01],
        [ 1.8186e+00, -1.2638e+00],
        [ 2.2955e+00, -1.1611e-01],
        [ 2.0329e+00,  1.1865e+00],
        [ 4.1030e-01, -1.2887e-02],
        [ 2.1820e-01,  7.9352e-01],
        [ 2.6144e+00, -1.2025e+00],
        [-3.9703e-01,  9.0147e-01],
        [ 5.0096e-01,  7.8050e-01],
        [ 2.5495e+00,  1.1467e-01],
        [ 1.9205e+00, -7.6130e-01],
        [ 1.8207e+00,  9.4704e-01],
        [ 2.7362e-01,  9.5988e-01],
        [ 2.4999e+00, -1.1858e+00],
        [ 2.3049e+00,  4.1624e-01],
        [ 8.4617e-01,  2.1809e-03],
        [ 2.3801e+00, -9.8960e-01],
        [ 4.2806e-01,  1.0414e+00],
        [ 1.8704e+00,  3.4143e-01],
        [ 1.6476e+00, -7.5741e-02],
        [ 2.4113e+00, -1.0450e+00],
        [ 2.2814e+00, -1.3127e+00],
        [ 1.6142e-01,  9.4283e-01],
        [ 1.8117e+00, -1.7509e+00],
        [ 1.1190e+00,  9.9519e-01],
        [ 4.4909e-01, -1.28

(tensor([[ 1.6971, -0.0721],
        [ 2.2320,  0.8501],
        [ 2.6541, -1.0534],
        [ 1.8965, -0.1247],
        [-0.6255,  0.4286],
        [ 1.8044, -2.4048],
        [ 0.2972,  0.0568],
        [ 1.9777,  0.1151],
        [ 1.8089, -0.1948],
        [ 2.1723,  0.2008],
        [ 1.1444,  0.6734],
        [ 2.6027, -0.7140],
        [ 0.5910, -1.0931],
        [ 1.7342, -1.1855],
        [ 0.7235, -0.8300],
        [ 0.6754,  0.4390],
        [ 2.1456, -1.0722],
        [ 2.4118,  0.1242],
        [ 2.6827, -1.3795],
        [ 2.0163, -0.1078],
        [ 2.4254, -1.5769],
        [ 0.8783, -1.1269],
        [ 2.4956, -0.4822],
        [ 0.1911, -0.4756],
        [ 0.3026, -0.5858],
        [ 2.1960, -0.3037],
        [ 2.4895, -0.6839],
        [ 2.2286, -1.7249],
        [ 1.9675, -1.5174],
        [-0.4644,  0.5046],
        [ 2.6032,  0.4807],
        [ 1.7200, -2.1687],
        [ 1.7097, -2.3376],
        [ 1.3616, -0.9247],
        [ 2.4109, -1.5288],
        [ 0.0455, -

(tensor([[ 2.0864e-01,  1.1860e+00],
        [ 2.5154e+00, -1.0901e+00],
        [ 5.5314e-01, -1.1331e+00],
        [ 6.2220e-01, -4.7782e-01],
        [ 2.4490e+00, -1.6719e+00],
        [ 2.6363e+00,  1.4262e-01],
        [ 3.3772e-01, -5.9532e-01],
        [ 2.4280e+00, -1.1827e+00],
        [ 3.2330e-01, -4.0318e-01],
        [ 7.6061e-01, -1.0700e+00],
        [ 2.5772e-01,  8.8459e-01],
        [ 1.8897e+00, -9.4944e-01],
        [ 1.7240e+00,  1.2446e+00],
        [ 1.8516e+00, -8.8086e-01],
        [ 4.7876e-01, -3.7107e-01],
        [ 9.3151e-02, -4.7989e-01],
        [ 2.5067e+00, -2.6099e-01],
        [ 1.7390e+00, -2.1434e+00],
        [ 1.7617e+00, -1.3191e+00],
        [ 1.6914e+00, -2.3089e+00],
        [ 2.2084e+00, -1.4715e+00],
        [ 8.2293e-01, -8.8050e-01],
        [ 1.2314e-01,  1.0440e+00],
        [ 1.9531e+00,  1.1680e+00],
        [ 2.0909e+00, -1.5344e+00],
        [ 1.8964e+00, -1.1710e+00],
        [ 5.2993e-01, -4.5527e-02],
        [ 1.5281e+00, -1.88

(tensor([[ 1.5923,  1.1107],
        [ 1.1063, -0.8102],
        [ 1.6751,  0.2726],
        [ 1.7482, -0.1573],
        [ 2.1313,  0.4980],
        [ 1.4255,  0.3448],
        [ 2.2033, -0.0848],
        [ 2.0447,  0.2037],
        [ 1.9428, -1.0426],
        [ 2.3481, -1.0788],
        [ 0.4875, -1.2744],
        [ 0.4508, -1.1236],
        [ 1.1631, -1.5351],
        [ 2.1945, -1.0636],
        [ 0.6182,  0.0146],
        [ 0.5432, -1.1648],
        [ 2.2395, -0.9402],
        [ 1.8718,  0.7931],
        [ 0.8252, -1.1362],
        [ 0.9852, -1.4580],
        [ 0.3727,  0.7791],
        [ 0.6166, -1.3070],
        [ 1.3668, -1.5729],
        [ 1.8191, -0.2461],
        [ 0.4827,  0.5796],
        [ 1.8896, -0.3817],
        [ 1.5008, -0.2994],
        [ 1.5290, -0.5935],
        [ 0.1291,  0.2440],
        [ 1.1522, -1.7716],
        [ 2.1499, -0.7839],
        [ 1.3268, -1.8429],
        [ 2.2907, -0.6346],
        [ 2.1890,  0.4590],
        [ 2.0830,  0.8959],
        [ 0.4477,  

(tensor([[ 0.4190,  0.5936],
        [ 0.5502,  0.9676],
        [ 1.1329, -1.5026],
        [ 0.2449, -1.3379],
        [ 1.0879,  0.1586],
        [ 1.8706,  0.2621],
        [ 2.1509, -0.9721],
        [ 0.3675, -0.9906],
        [ 0.2409,  0.6701],
        [ 1.5494, -1.4348],
        [ 1.7126, -0.2364],
        [ 0.3911, -0.8008],
        [ 0.2135, -0.6917],
        [ 1.4200, -1.7369],
        [ 1.9180, -1.0675],
        [ 1.0235, -1.3192],
        [ 0.4460, -1.1820],
        [ 1.5335, -0.7309],
        [ 2.0171, -1.3008],
        [ 0.4200,  0.4123],
        [ 0.4844,  0.8419],
        [ 0.2478, -1.0268],
        [ 1.6328, -1.4942],
        [ 1.3708, -1.7056],
        [ 0.5262, -0.8167],
        [ 1.8643,  0.3623],
        [ 0.4961, -0.4019],
        [ 0.5655, -1.1073],
        [ 0.2841,  0.9940],
        [ 0.6079,  0.4449],
        [ 2.1121, -1.2817],
        [ 0.6617, -1.2405],
        [ 1.2959,  0.6549],
        [ 2.0817, -0.2863],
        [ 1.3820, -1.3781],
        [ 1.2966, -

(tensor([[ 1.9534, -1.6639],
        [ 0.0435, -1.4404],
        [ 0.4034, -1.1369],
        [ 0.4568,  1.0272],
        [ 1.4799, -1.5235],
        [ 1.8295, -0.1124],
        [ 1.0356,  1.1701],
        [ 0.9484,  0.5103],
        [-0.1372, -0.5681],
        [ 0.3729, -0.8871],
        [ 1.4782, -0.5571],
        [ 2.1403, -0.9153],
        [ 0.3287, -1.3836],
        [ 2.2803,  0.1523],
        [ 1.9397, -1.3376],
        [ 1.6979,  1.2130],
        [ 0.5158, -0.1694],
        [ 1.8725, -0.0640],
        [ 0.5017,  0.9977],
        [-0.3668,  0.2258],
        [ 1.9849,  0.0549],
        [ 0.4962, -1.0107],
        [ 0.6338,  1.0144],
        [ 0.1850, -0.9955],
        [ 2.0982,  1.1607],
        [ 1.9554, -0.6668],
        [ 1.5867, -0.6433],
        [ 2.0594, -1.2422],
        [ 2.5544, -0.2544],
        [ 2.0634,  0.5239],
        [ 0.4484,  1.1018],
        [ 0.9184,  1.0567],
        [ 2.1501, -0.0236],
        [ 0.1818, -0.9307],
        [ 2.0386, -1.6787],
        [ 0.6114, -

(tensor([[ 0.5339, -0.3637],
        [ 0.3203, -1.4952],
        [ 1.2020, -1.8033],
        [ 2.7260, -0.4378],
        [ 2.1839, -0.9705],
        [ 0.6322,  1.0494],
        [ 2.0671, -0.1986],
        [ 2.5053, -1.1869],
        [ 2.2214, -0.7990],
        [ 1.6093, -0.0587],
        [ 1.8878,  0.8744],
        [ 1.3885,  0.1457],
        [ 0.5603, -0.0065],
        [ 2.3400, -0.0785],
        [ 1.6314,  0.5770],
        [ 0.3578, -1.1418],
        [ 2.0870,  1.2265],
        [ 1.9000, -1.2762],
        [ 2.7636, -1.3947],
        [ 2.0415,  0.7290],
        [ 2.2330, -1.5954],
        [ 0.4777, -1.1057],
        [ 0.8087, -0.9689],
        [ 2.0485, -1.5161],
        [ 0.5990, -0.6399],
        [ 2.4363, -1.3988],
        [-0.3009,  0.6130],
        [ 0.2574, -0.7231],
        [ 0.2077, -1.4301],
        [ 2.2587,  1.2189],
        [ 1.3948, -0.3669],
        [ 1.2441, -1.3235],
        [ 2.3040, -0.6692],
        [ 2.5536, -0.9780],
        [ 2.4008, -1.2531],
        [ 2.6575, -

(tensor([[ 2.4580e+00, -1.1639e+00],
        [ 3.3780e+00, -5.3666e-01],
        [ 1.0682e+00, -1.8969e-01],
        [ 2.7238e+00,  1.3522e-01],
        [ 2.3026e+00,  8.3053e-01],
        [ 2.2919e+00,  6.0558e-01],
        [-1.4835e-01,  5.3925e-01],
        [ 3.2674e+00, -1.0907e+00],
        [ 1.0262e+00, -1.8002e-02],
        [ 5.0992e-01, -1.3256e+00],
        [ 3.2809e+00,  1.9750e-01],
        [ 1.0310e+00,  1.6827e+00],
        [ 2.6538e+00, -8.3259e-01],
        [ 1.7627e+00,  7.7613e-01],
        [ 6.8809e-01,  1.6612e+00],
        [ 2.6687e+00, -1.7382e+00],
        [ 1.3021e+00,  1.3574e-01],
        [ 2.9341e+00, -1.4967e+00],
        [ 6.1596e-01, -9.1449e-01],
        [ 9.3570e-01, -1.4867e+00],
        [ 1.9495e+00, -2.3573e+00],
        [ 6.4194e-01, -8.6588e-01],
        [ 3.1602e+00, -1.5296e+00],
        [ 2.2047e+00, -1.9275e+00],
        [ 2.3847e+00, -2.6862e-01],
        [ 1.8566e+00, -2.2723e+00],
        [ 1.1623e+00, -7.8665e-02],
        [ 3.0985e+00, -8.10

(tensor([[ 2.1011, -1.7867],
        [ 2.0053,  0.4502],
        [ 0.4450,  1.4733],
        [ 0.3726,  1.3988],
        [ 1.5348, -0.0269],
        [ 1.9382, -1.3673],
        [ 2.7852, -0.3479],
        [ 0.6624,  1.4720],
        [ 2.6057,  0.9637],
        [ 0.5213,  1.4556],
        [ 3.1896, -1.2383],
        [ 0.3394,  1.0999],
        [ 2.5479,  1.2244],
        [ 1.7972,  1.1757],
        [ 1.2359,  0.9229],
        [ 2.4302,  0.3122],
        [ 2.6119,  0.5973],
        [ 1.9174,  0.0037],
        [ 2.9346,  0.1589],
        [ 2.3494,  1.1937],
        [ 0.6500,  1.4863],
        [ 0.9800, -0.0977],
        [ 2.3885,  0.6949],
        [ 0.5793,  0.8093],
        [ 0.4605, -0.4131],
        [ 1.8748,  1.1987],
        [ 0.3325, -0.5202],
        [ 2.1880,  0.1087],
        [ 0.5186, -0.5590],
        [ 2.9620, -0.4595],
        [ 2.5305, -1.4700],
        [ 0.6416,  1.5043],
        [ 2.9788, -1.0980],
        [ 0.9821, -0.5050],
        [ 0.5748, -1.1809],
        [ 0.2827, -

(tensor([[ 2.6825, -0.1832],
        [ 2.4162, -0.6671],
        [ 1.3640, -0.5449],
        [ 0.2128, -0.5601],
        [ 0.8208,  1.2865],
        [ 0.0261,  0.9113],
        [ 0.3933,  0.8321],
        [ 1.6793, -1.4587],
        [ 0.6741,  1.0117],
        [ 2.1211, -0.8637],
        [ 2.8360,  0.4537],
        [-0.2421,  0.7878],
        [ 0.1962, -0.3326],
        [ 2.1985,  1.4189],
        [-0.1611,  0.4353],
        [ 0.7845,  1.4339],
        [ 0.4790, -0.8145],
        [ 0.5286,  1.1863],
        [ 1.8453, -0.6832],
        [ 2.6148, -0.8206],
        [ 0.5951,  1.1276],
        [ 0.6339,  1.4122],
        [ 0.5884,  0.9645],
        [ 1.6697,  1.2452],
        [ 0.4006,  1.1574],
        [ 1.5860,  1.2959],
        [ 1.2912, -1.1673],
        [ 2.9273, -0.5599],
        [ 1.5286,  0.1128],
        [ 2.1645, -1.3250],
        [ 0.3553,  1.4220],
        [ 0.8647,  1.1829],
        [ 1.3352, -0.5993],
        [ 2.4323, -0.0993],
        [ 1.7774, -0.0039],
        [ 2.1548, -

(tensor([[ 1.9958,  0.5165],
        [ 1.4787,  0.9307],
        [ 0.8305,  1.8505],
        [ 1.6286, -0.2247],
        [-0.3127,  0.9216],
        [ 0.2561, -0.1968],
        [ 1.5423,  0.1441],
        [ 1.8580,  1.0258],
        [ 1.1446, -0.8938],
        [ 1.1948,  0.9262],
        [ 0.1216, -0.8681],
        [ 2.2236, -0.2727],
        [ 1.6205,  1.0603],
        [ 0.0751,  0.8015],
        [ 0.8629, -1.1598],
        [ 1.7923,  1.5127],
        [ 1.3631,  1.0869],
        [ 1.9397, -0.0921],
        [ 0.8843,  1.6809],
        [ 1.2667, -0.5400],
        [ 0.2437,  0.0138],
        [ 1.0847, -1.0187],
        [ 1.6504,  1.4107],
        [ 2.0963, -0.3995],
        [ 1.6208,  0.1869],
        [ 2.2596, -0.3253],
        [ 0.0286, -0.1177],
        [ 1.9082,  1.0436],
        [-0.0262, -0.8623],
        [ 2.1765, -0.6169],
        [ 0.2982,  0.3073],
        [ 2.0715,  1.6814],
        [-0.0489, -0.5177],
        [ 0.3772,  1.0469],
        [ 0.4493,  1.4106],
        [ 0.3116,  

(tensor([[ 2.4263e+00, -2.8039e-01],
        [ 8.3783e-02,  9.7112e-01],
        [-1.8726e-01, -3.8331e-01],
        [ 2.0244e+00,  8.5435e-01],
        [ 3.0017e+00,  4.0628e-01],
        [ 1.8828e+00,  8.6124e-01],
        [ 1.9577e+00,  1.0716e+00],
        [ 1.7981e+00,  1.6016e+00],
        [ 3.6372e-02, -1.9039e-01],
        [-6.7455e-02,  3.9639e-01],
        [-1.4353e-01, -6.3640e-01],
        [ 1.2994e+00, -5.0503e-01],
        [ 2.3291e+00,  5.8271e-01],
        [ 4.2463e-01,  1.2022e+00],
        [ 1.2349e-01, -6.6169e-01],
        [ 1.5844e+00,  4.6509e-01],
        [ 1.8961e+00,  7.1297e-01],
        [ 1.9954e+00, -2.5529e-01],
        [ 1.1866e+00,  1.6789e-01],
        [ 1.3446e+00,  1.3766e+00],
        [ 2.0921e+00,  1.1569e+00],
        [ 6.5689e-01,  1.8756e+00],
        [ 2.1083e+00,  1.4734e+00],
        [ 2.6092e-01, -3.2184e-01],
        [ 2.0244e+00, -2.3268e-02],
        [ 1.3310e+00,  1.0162e+00],
        [ 1.2958e+00,  3.6268e-01],
        [ 1.4078e+00,  2.16

(tensor([[-0.0414,  0.9221],
        [ 0.2461, -0.0084],
        [ 0.1714,  1.4997],
        [ 1.7162,  1.0479],
        [-0.4121,  1.0611],
        [-0.0485, -0.5274],
        [ 1.1911,  0.7421],
        [ 1.8598,  1.0834],
        [ 0.9820,  1.1537],
        [-0.0171, -0.7650],
        [ 2.0487,  1.2036],
        [ 1.0167, -0.8623],
        [ 0.3291,  1.6315],
        [ 1.5775,  0.7781],
        [ 1.8849,  0.6413],
        [ 2.0988,  1.9183],
        [ 0.1580, -0.2836],
        [ 2.3501, -0.4012],
        [ 1.5670, -0.5297],
        [ 1.9806,  0.4589],
        [ 1.9702,  0.8658],
        [ 1.0109,  1.3847],
        [ 0.6711,  0.8958],
        [ 1.8156,  0.4682],
        [ 1.8651,  0.4117],
        [ 0.0902, -0.5676],
        [-0.1127,  0.6727],
        [ 1.3995,  1.2762],
        [ 2.6486,  0.8737],
        [ 1.9664, -0.2468],
        [-0.4598,  0.5332],
        [ 0.2656, -0.5154],
        [ 0.1703,  0.1897],
        [ 0.3086, -0.2110],
        [ 1.7213,  0.6958],
        [ 1.6154,  

(tensor([[ 1.0302e-01,  1.8668e+00],
        [ 1.8138e+00,  1.3197e+00],
        [ 1.8346e+00, -2.9727e-02],
        [ 1.6159e+00,  1.4331e+00],
        [ 2.2608e+00,  6.9882e-01],
        [ 1.4748e+00,  1.0449e+00],
        [ 1.1994e+00, -9.5435e-01],
        [-4.2451e-02, -2.8970e-01],
        [ 1.6919e+00,  2.1615e-01],
        [ 3.4426e-01,  1.2866e+00],
        [ 1.7109e+00, -1.3080e-01],
        [ 1.4151e+00, -8.7081e-01],
        [-5.2855e-01,  7.7745e-01],
        [ 2.3068e+00, -7.6796e-01],
        [ 1.5089e+00, -6.2531e-01],
        [ 9.5240e-01, -7.1316e-01],
        [ 3.0807e-01,  1.7237e+00],
        [ 1.4979e+00,  3.5975e-01],
        [ 1.0096e+00, -8.8180e-02],
        [ 5.1257e-01, -2.5606e-01],
        [ 2.1641e+00,  1.1827e+00],
        [-1.1920e-01, -5.8705e-01],
        [ 2.1286e+00,  2.6879e-01],
        [ 2.6215e-02,  1.8702e-03],
        [ 1.6272e+00,  8.6978e-01],
        [ 9.8873e-01,  1.1417e+00],
        [ 2.4681e+00, -4.8886e-01],
        [ 1.0736e-02,  1.85

(tensor([[ 1.4003e+00,  4.0661e-01],
        [ 4.7262e-02, -1.1069e+00],
        [ 1.8193e+00, -6.5243e-02],
        [ 1.2101e+00,  5.8365e-01],
        [ 5.4504e-01, -7.9223e-01],
        [ 9.0970e-01, -6.7662e-01],
        [-1.6758e-01,  1.3198e+00],
        [-2.2457e-01,  1.3667e+00],
        [ 5.2266e-01,  1.5608e+00],
        [ 1.1222e+00, -1.2942e+00],
        [ 1.8369e+00,  6.4123e-03],
        [ 4.1618e-01,  1.7080e+00],
        [ 1.4236e+00,  5.9172e-01],
        [ 1.7444e+00,  1.2944e+00],
        [ 2.8839e-01, -2.4603e-01],
        [ 2.2643e-01,  8.4461e-01],
        [ 1.6243e+00,  1.5367e+00],
        [ 1.1688e+00,  6.7557e-01],
        [ 1.9989e+00,  5.1620e-01],
        [ 8.4201e-01,  2.6607e-01],
        [ 1.7498e+00, -1.8763e-02],
        [-9.7491e-02,  1.7348e+00],
        [ 1.8334e+00,  1.6655e-02],
        [ 4.5977e-01,  1.4963e+00],
        [-1.2035e-01,  9.3127e-01],
        [ 5.0060e-03, -9.2007e-01],
        [ 1.5012e+00,  4.7330e-01],
        [ 1.9056e+00, -8.70

(tensor([[ 0.0233,  1.0811],
        [-0.2558,  1.7010],
        [-0.2604,  1.5783],
        [ 1.4038,  0.5785],
        [ 0.7118, -1.3997],
        [ 0.9829,  1.6960],
        [-0.1252,  1.7875],
        [-0.6377,  1.6717],
        [ 1.3069, -0.7336],
        [ 1.2582, -1.1119],
        [ 1.0780,  1.8667],
        [-0.1118,  0.9011],
        [ 1.3406,  0.5843],
        [-0.8716,  1.1090],
        [-0.2247,  2.0351],
        [ 0.2024, -1.3534],
        [-0.3683, -0.7832],
        [-0.2285,  0.6499],
        [ 0.1940,  0.5142],
        [-0.4167,  1.1419],
        [ 1.3639,  0.4624],
        [-0.0899, -0.1974],
        [-0.3457,  1.6488],
        [-0.1612,  1.8618],
        [-0.7190,  1.3734],
        [-0.9577,  1.5314],
        [-0.3403,  0.0995],
        [ 0.5699, -1.0557],
        [ 1.1582,  0.6232],
        [-0.7701,  1.6438],
        [-0.3904,  1.5165],
        [ 0.6811,  0.7019],
        [ 1.6746,  0.6362],
        [ 0.8747,  0.6946],
        [-0.6779,  1.4187],
        [ 1.9643,  

KeyboardInterrupt: 

# VALIDATE

In [16]:
torch.save(model,"Models/tripletMNIST.pt")

In [17]:
model_loaded = torch.load("Models/tripletMNIST.pt")

In [18]:
import torch
import torch.nn.functional as F

def evaluate_model(model, triplet_test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        start = time.time()
        for (anchor, positive, negative),[] in triplet_test_loader:
            anchor_embedding,positive_embedding,negative_embedding = model(anchor,positive,negative)
            positive_distance = F.pairwise_distance(anchor_embedding, positive_embedding)
            negative_distance = F.pairwise_distance(anchor_embedding, negative_embedding)
            correct += torch.sum(positive_distance < negative_distance).item()
            total += anchor.size(0)
    accuracy = correct / total
    print('Accuracy : {:.2f}%\nTime : {:.2f} SECONDS'.format(accuracy * 100,time.time()-start))

In [19]:
evaluate_model(model_loaded,triplet_test_loader)

Accuracy : 99.67%
Time : 7.93 SECONDS


# WITH FACEBOOK AI SIMILARITY SEARCH

In [20]:
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=True, **kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, **kwargs)

In [25]:
import faiss

In [26]:
embs1 = None
labels1 = []
for idx,i in enumerate(train_loader):
    if idx==10000: break
    I, L = i
    labels1.append(L)
    emb = model_loaded(I) # Assuming `model_loaded(I)` returns a PyTorch tensor
    emb = emb.detach()
    if embs1 is None:
        embs1 = emb
    else:
        embs1 = torch.cat((embs1, emb), dim=0)

In [27]:
embs2 = None
labels2 = []
for i in test_loader:
    I, L = i
    labels2.append(L)
    emb = model_loaded(I)
    if embs2 is None:
        embs2 = emb
    else:
        embs2 = torch.cat((embs2, emb), dim=0)

In [29]:
embs = embs1

In [44]:
index1 = faiss.IndexFlatL2(embs.shape[1])  # Assuming embs.shape[1] represents the dimensionality of the embeddings
index1.add(embs)

nlist = 100  # Number of cells/buckets
quantizer = faiss.IndexFlatL2(embs.shape[1])  # Quantizer index (same as IndexFlatL2)
index2 = faiss.IndexIVFFlat(quantizer, embs.shape[1], nlist)
index2.train(embs)
index2.add(embs)

index3 = faiss.IndexHNSWFlat(embs.shape[1], 32)  # M = 32 for the HNSW index
index3.add(embs)

nbits = 8  # Number of bits for the LSH hash
index4 = faiss.IndexLSH(embs.shape[1], nbits)
index4.add(embs)


# EVALUATE WITH FAISS

In [31]:
def evaluatewithfaiss(embs,index):
    TOTAL = len(embs)
    CORRECT = 0
    start = time.time()
    for idx,emb in enumerate(embs):
        label = index.search(emb.detach().reshape(1,-1),1)[1][0][0]
        CORRECT += labels1[label]==labels2[idx]
    return (CORRECT/TOTAL*100).item(),f'{time.time()-start} SECONDS'
        

In [32]:
print(f'IndexFlatL2 : {evaluatewithfaiss(embs2,index1)}')
print(f'IndexIVFFlat : {evaluatewithfaiss(embs2,index2)}')
print(f'IndexHNSWFlat : {evaluatewithfaiss(embs2,index3)}')
print(f'IndexLSH : {evaluatewithfaiss(embs2,index4)}')

IndexFlatL2 : (98.8499984741211, '0.8810124397277832 SECONDS')
IndexIVFFlat : (98.81999969482422, '0.41861629486083984 SECONDS')
IndexHNSWFlat : (98.8499984741211, '0.5943887233734131 SECONDS')
IndexLSH : (87.05000305175781, '0.948662281036377 SECONDS')


# WITHOUT FAISS
Accuracy : 99.67%   ||     Time : 7.93 SECONDS