In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms as T
import numpy as np
from torch.utils.data import Dataset
from torch.optim import lr_scheduler
import torch.optim as optim
from PIL import Image
import faiss
import time
from transformers import AutoFeatureExtractor, AutoModel

In [9]:
model_ckpt = "google/vit-base-patch16-224-in21k"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

In [10]:
mean, std = 0.1307, 0.3081

train_dataset = MNIST('./Datasets/MNIST', train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((mean,), (std,))
                             ]))
test_dataset = MNIST('./Datasets/MNIST', train=False, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((mean,), (std,))
                            ]))

In [11]:
class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin
    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

In [70]:
class TripletMNIST(Dataset):
    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.train_labels = self.mnist_dataset.train_labels
            self.train_data = self.mnist_dataset.train_data
            self.labels_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.labels_set}

        else:
            self.test_labels = self.mnist_dataset.test_labels
            self.test_data = self.mnist_dataset.test_data
            # generate fixed triplets for testing
            self.labels_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            triplets = [[i,
                         random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                         random_state.choice(self.label_to_indices[
                                                 np.random.choice(
                                                     list(self.labels_set - set([self.test_labels[i].item()]))
                                                 )
                                             ])
                         ]
                        for i in range(len(self.test_data))]
            self.test_triplets = triplets

    def __getitem__(self, index):
        label1 = None
        if self.train:
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            positive_index = index
            while positive_index == index:
                positive_index = np.random.choice(self.label_to_indices[label1])
            negative_label = np.random.choice(list(self.labels_set - set([label1])))
            negative_index = np.random.choice(self.label_to_indices[negative_label])
            img2 = self.train_data[positive_index]
            img3 = self.train_data[negative_index]
        else:
            img1 = self.test_data[self.test_triplets[index][0]]
            img2 = self.test_data[self.test_triplets[index][1]]
            img3 = self.test_data[self.test_triplets[index][2]]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        img3 = Image.fromarray(img3.numpy(), mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
        return (img1, img2, img3), []

    def __len__(self):
        return len(self.mnist_dataset)


In [134]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.model = model
    def forward(self, x):
        return self.model(x).last_hidden_state[:, 0].cpu()
    def get_embedding(self, x):
        return self.forward(x)

In [135]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net
    def forward(self, x1, x2=None, x3=None):
        if x2 is None and x3 is None:
            return self.embedding_net(x1)
        return self.embedding_net(x1),self.embedding_net(x2),self.embedding_net(x3)
    def get_embedding(self, x):
        return self.embedding_net(x)

In [136]:
transformation_chain = T.Compose([
    T.ToPILImage(),
    T.Resize((224,224)),
    T.Lambda(lambda img: img.convert('RGB')),  # Convert the image to RGB format
    T.ToTensor(),
    T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
])

In [137]:
def collate(batch):
    anchor_transformed = [transformation_chain(img) for img in batch[0][0]]
    pos_transformed = [transformation_chain(img) for img in batch[1][0]]
    neg_transformed = [transformation_chain(img) for img in batch[2][0]]
    anchor_tensors = torch.stack(anchor_transformed)
    positive_tensors = torch.stack(pos_transformed)
    negative_tensors = torch.stack(neg_transformed)

    return (anchor_tensors, positive_tensors, negative_tensors),[]

In [138]:
cuda = torch.cuda.is_available()
triplet_train_dataset = TripletMNIST(train_dataset) 
triplet_test_dataset = TripletMNIST(test_dataset)
batch_size = 128
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
margin = 1.
embedding_net = EmbeddingNet()
model = TripletNet(embedding_net)
if cuda:
    model.cuda()
loss_fn = TripletLoss(margin)
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9,0.999), eps=1e-08)
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=0.5, last_epoch=-1)
n_epochs = 20
log_interval = 100

In [139]:
triplet_test_loader = torch.utils.data.DataLoader(triplet_test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate, **kwargs)
triplet_train_loader = torch.utils.data.DataLoader(triplet_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate, **kwargs)


In [144]:
def fit(train_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval,start_epoch=0):
    for epoch in range(0, start_epoch):
        scheduler.step()
    for epoch in range(start_epoch, n_epochs):
        scheduler.step()
        train_loss, metrics = train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval)
        message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(epoch + 1, n_epochs, train_loss)
        print(message)
def train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval):
    model.train()
    losses = []
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        target = target if len(target) > 0 else None
        if not type(data) in (tuple, list):
            data = (data,)
        if cuda:
            data = tuple(d.cuda() for d in data)
            if target is not None:
                target = target.cuda()
        optimizer.zero_grad()
        outputs = model(*data)
        if type(outputs) not in (tuple, list):
            outputs = (outputs,)
        loss_inputs = outputs
        if target is not None:
            target = (target,)
            loss_inputs += target
        loss_outputs = loss_fn(*loss_inputs)
        loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
        losses.append(loss.item())
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        print(loss.item())
        if batch_idx % log_interval == 0:
            message = 'Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                batch_idx * len(data[0]), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), np.mean(losses))
            print(message)
            losses = []

    total_loss /= (batch_idx + 1)
    return total_loss, metrics

In [145]:
fit(triplet_train_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval)

1.0000826120376587
0.9995934367179871
0.9996693730354309
0.999172031879425
1.0002089738845825


KeyboardInterrupt: 