In [3]:
!pip install pytorch-metric-learning
!pip install faiss-gpu



In [8]:
from pytorch_metric_learning import losses, miners, distances, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
from torchvision import datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print("Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(epoch, batch_idx, loss, mining_func.num_triplets))

### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(dataset, model, accuracy_calculator):
    embeddings, labels = get_all_embeddings(dataset, model)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(embeddings, 
                                                embeddings,
                                                np.squeeze(labels),
                                                np.squeeze(labels),
                                                True)
    print("Test set accuracy (MAP@10) = {}".format(accuracies["mean_average_precision_at_r"]))

device = torch.device("cuda")

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

batch_size = 256

dataset1 = datasets.MNIST('.', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('.', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=256)

model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 1


### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low = 0)
loss_func = losses.TripletMarginLoss(margin = 0.2, distance = distance, reducer = reducer)
mining_func = miners.TripletMarginMiner(margin = 0.2, distance = distance, type_of_triplets = "semihard")
accuracy_calculator = AccuracyCalculator(include = ("mean_average_precision_at_r",), k = 10)
### pytorch-metric-learning stuff ###


for epoch in range(1, num_epochs+1):
    train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)
    test(dataset2, model, accuracy_calculator)



Epoch 1 Iteration 0: Loss = 0.10368159413337708, Number of mined triplets = 828800
Epoch 1 Iteration 20: Loss = 0.09451039880514145, Number of mined triplets = 117172
Epoch 1 Iteration 40: Loss = 0.0870940089225769, Number of mined triplets = 75505
Epoch 1 Iteration 60: Loss = 0.08687727153301239, Number of mined triplets = 63229
Epoch 1 Iteration 80: Loss = 0.08603870868682861, Number of mined triplets = 39559
Epoch 1 Iteration 100: Loss = 0.0871090441942215, Number of mined triplets = 49315
Epoch 1 Iteration 120: Loss = 0.08383794873952866, Number of mined triplets = 38464
Epoch 1 Iteration 140: Loss = 0.08661019802093506, Number of mined triplets = 28360
Epoch 1 Iteration 160: Loss = 0.07955317199230194, Number of mined triplets = 17407
Epoch 1 Iteration 180: Loss = 0.0819648876786232, Number of mined triplets = 26285
Epoch 1 Iteration 200: Loss = 0.08189083635807037, Number of mined triplets = 22461
Epoch 1 Iteration 220: Loss = 0.0827690064907074, Number of mined triplets = 23761


100%|██████████| 313/313 [00:03<00:00, 88.29it/s]

Computing accuracy
Test set accuracy (MAP@10) = 0.9738574404761905



