# IMPORT LIBRARIES

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Dataset, random_split
from torch.nn.functional import normalize
from transformers import ViTImageProcessor, ViTModel
from torchvision import transforms as T,transforms
import time
import torch.nn.functional as func
import torchvision.transforms.functional as F
import numpy as np

# MODEL

In [11]:
model_name='google/vit-base-patch16-224-in21k'
extractor = ViTImageProcessor.from_pretrained(model_name)
model = ViTModel.from_pretrained(model_name)

In [12]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.feature_extractor = extractor
        self.model = model

    def forward(self, x):
        x = (x + 1) / 2
        inputs = self.feature_extractor(images=x, return_tensors='pt', do_normalize=True, image_mean=0.5, image_std=0.5)
        with torch.no_grad():
            outputs = self.model(**inputs).last_hidden_state[:, 0].cpu().requires_grad_()
        return outputs

    def get_embedding(self, x):
        return self.forward(x)

class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net
        
    def forward(self, x1, x2=None, x3=None):
        if x2 is None and x3 is None:
            return self.embedding_net(x1)
        return self.embedding_net(x1),self.embedding_net(x2),self.embedding_net(x3)

    def get_embedding(self, x):
        return self.embedding_net(x)

# CUSTOMIZE MNIST INTO TRIPLETS

In [13]:
class TripletMNISTDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __getitem__(self, index):
        anchor = self.data[index]
        anchor_target = self.targets[index]
        positive_indices = torch.where(self.targets == anchor_target)[0]
        positive_indices = positive_indices[positive_indices != index]
        positive_index = torch.randint(0, len(positive_indices), (1,))
        positive = self.data[positive_indices[positive_index]]

        negative_indices = torch.where(self.targets != anchor_target)[0]
        negative_index = torch.randint(0, len(negative_indices), (1,))
        negative = self.data[negative_indices[negative_index]]

        return anchor, positive, negative

    def __len__(self):
        return len(self.data)

dataset = MNIST(root="./Datasets/MNIST", train=True, download=True)
train_ratio = 0.9
val_ratio = 1 - train_ratio
train_size = int(train_ratio * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
test_dataset = MNIST(root="./Datasets/MNIST", train=False, download=True)
train_triplet_dataset = TripletMNISTDataset(train_dataset.dataset.data, train_dataset.dataset.targets)
val_triplet_dataset = TripletMNISTDataset(val_dataset.dataset.data, train_dataset.dataset.targets)
test_triplet_dataset = TripletMNISTDataset(test_dataset.data, test_dataset.targets)

# MODEL SPECIFIC TRANSFORMATIONS

In [14]:
transformation_chain = T.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize(int((256 / 224) * extractor.size["height"])),
        transforms.Lambda(lambda x: x.convert("RGB")),  # Convert grayscale to RGB
        transforms.ToTensor(),
        transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std),
    ]

)

def collate(batch):
    anchor_transformed = [transformation_chain(img) for img in batch[0]]
    pos_transformed = [transformation_chain(img) for img in batch[1]]
    neg_transformed = [transformation_chain(img) for img in batch[2]]
    anchor_tensors = torch.stack(anchor_transformed)
    positive_tensors = torch.stack(pos_transformed)
    negative_tensors = torch.stack(neg_transformed)
    return anchor_tensors, positive_tensors, negative_tensors

# TRIPLET LOSS

In [15]:
class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin
    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = func.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

# PARAMETERS

In [16]:
bs = 1280
train_loader = DataLoader(train_triplet_dataset, batch_size=bs,collate_fn=collate)
val_loader = DataLoader(val_triplet_dataset, batch_size=bs,collate_fn=collate)
triplet_loss = TripletLoss(margin=0.2)
num_epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model1 = EmbeddingNet()
model = TripletNet(model1)
model.requires_grad_(True)
lr = 0.001
optimizer = optim.Adam(model.parameters(), lr=lr)
model.to(device)

TripletNet(
  (embedding_net): EmbeddingNet(
    (model): ViTModel(
      (embeddings): ViTEmbeddings(
        (patch_embeddings): ViTPatchEmbeddings(
          (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): ViTEncoder(
        (layer): ModuleList(
          (0-11): 12 x ViTLayer(
            (attention): ViTAttention(
              (attention): ViTSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): ViTSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
            )
            (interme

# TRAIN

In [None]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for idx,batch in enumerate(train_loader):
        anchor, positive, negative = batch
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)
        anchor_embedding, positive_embedding, negative_embedding = model(anchor,positive,negative)
        anchor_embedding.requires_grad_(True)
        positive_embedding.requires_grad_(True)
        negative_embedding.requires_grad_(True)
        optimizer.zero_grad()
        loss = triplet_loss(anchor_embedding, positive_embedding, negative_embedding)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if idx%10==0:
            print(f'({idx}).  LOSS : {loss.item()}  SEEN : {bs*(idx+1)}/{60000}')
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}")
    evaluate_model(model,val_loader)

In [308]:
torch.save(model,'Models/tripletMNISTrans.pt')

In [18]:
model_loaded = torch.load('Models/tripletMNISTrans.pt')

# EVALUATE

In [24]:
def evaluate_model(model, triplet_test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        start = time.time()
        for (anchor, positive, negative) in triplet_test_loader:
            anchor_embedding, positive_embedding, negative_embedding = model(anchor, positive, negative)
            positive_distance = func.pairwise_distance(anchor_embedding, positive_embedding)
            negative_distance = func.pairwise_distance(anchor_embedding, negative_embedding)
            correct += torch.sum(positive_distance < negative_distance).item()
            total += anchor.size(0)  # Increment by the batch size
    accuracy = correct / total
    print('Validation Set Accuracy: {:.2f}% |||| Time: {:.2f} SECONDS'.format(accuracy * 100, time.time() - start))


In [182]:
test_loader = DataLoader(test_triplet_dataset, batch_size=5000, collate_fn=collate)
evaluate_model(model,test_loader)

Accuracy: 33.33%
Time: 4.11 SECONDS


In [306]:
evaluate_model(model,train_loader)

Accuracy : 66.67%
Time : 37.82 SECONDS


# WITH FAISS

In [247]:
import faiss

In [254]:
def collate_(batch):
    return torch.stack([transformation_chain(img) for img in batch[0][0]]),batch[0][1]

test_loader_ = torch.utils.data.DataLoader(test_dataset, shuffle=True,collate_fn=collate_)
train_loader_ = torch.utils.data.DataLoader(train_dataset, shuffle=True,collate_fn = collate_)

In [256]:
embs1 = None
labels1 = []
for idx,i in enumerate(train_loader_):
    if idx==100: break
    I, L = i
    labels1.append(L)
    print(idx)
    emb = model_loaded(I) # Assuming `model_loaded(I)` returns a PyTorch tensor
    emb = emb.detach()
    if embs1 is None:
        embs1 = emb
    else:
        embs1 = torch.cat((embs1, emb), dim=0)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [257]:
embs2 = None
labels2 = []
for idx,i in enumerate(test_loader_):
    if idx==100: break
    I, L = i
    labels2.append(L)
    print(idx)
    emb = model_loaded(I) # Assuming `model_loaded(I)` returns a PyTorch tensor
    emb = emb.detach()
    if embs2 is None:
        embs2 = emb
    else:
        embs2 = torch.cat((embs2, emb), dim=0)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [286]:
import faiss
embs = embs2

index1 = faiss.IndexFlatL2(embs.shape[1])  # Assuming embs.shape[1] represents the dimensionality of the embeddings
index1.add(embs)

nlist = 100  # Number of cells/buckets
quantizer = faiss.IndexFlatL2(embs.shape[1])  # Quantizer index (same as IndexFlatL2)
index2 = faiss.IndexIVFFlat(quantizer, embs.shape[1], nlist)
index2.train(embs)
index2.add(embs)

index3 = faiss.IndexHNSWFlat(embs.shape[1], 32)  # M = 32 for the HNSW index
index3.add(embs)

nbits = 8  # Number of bits for the LSH hash
index4 = faiss.IndexLSH(embs.shape[1], nbits)
index4.add(embs)



In [261]:
labels2[index1.search(embs1[0].detach().reshape(1,-1),1)[1][0][0]]


5

In [287]:
def evaluatewithfaiss(embs,index):
    TOTAL = len(embs)
    CORRECT = 0
    start = time.time()
    for idx,emb in enumerate(embs):
        label = index.search(emb.detach().reshape(1,-1),2)[1][0][0]
        CORRECT += labels1[label]==labels1[idx]
    return (CORRECT/TOTAL*100),f'{time.time()-start} SECONDS'
        

In [None]:
print(f'IndexFlatL2 : {evaluatewithfaiss(embs2,index1)}')
print(f'IndexIVFFlat : {evaluatewithfaiss(embs2,index2)}')
print(f'IndexHNSWFlat : {evaluatewithfaiss(embs2,index3)}')
print(f'IndexLSH : {evaluatewithfaiss(embs2,index4)}')

In [None]:
evaluate_model(model,triplet_train_loader)