In [1]:
!pip install pytorch-metric-learning
!pip install faiss-gpu

Collecting pytorch-metric-learning
[?25l  Downloading https://files.pythonhosted.org/packages/90/3b/c44b7dc2743270b0e5536d685d3ec6d0437e1af22bee940c3560dd21207b/pytorch_metric_learning-0.9.94-py3-none-any.whl (96kB)
[K     |███▍                            | 10kB 21.3MB/s eta 0:00:01[K     |██████▊                         | 20kB 15.5MB/s eta 0:00:01[K     |██████████▏                     | 30kB 13.3MB/s eta 0:00:01[K     |█████████████▌                  | 40kB 12.3MB/s eta 0:00:01[K     |█████████████████               | 51kB 9.2MB/s eta 0:00:01[K     |████████████████████▎           | 61kB 8.3MB/s eta 0:00:01[K     |███████████████████████▊        | 71kB 9.3MB/s eta 0:00:01[K     |███████████████████████████     | 81kB 10.3MB/s eta 0:00:01[K     |██████████████████████████████▌ | 92kB 9.4MB/s eta 0:00:01[K     |████████████████████████████████| 102kB 6.1MB/s 
Installing collected packages: pytorch-metric-learning
Successfully installed pytorch-metric-learning-0.9.9

In [2]:
from pytorch_metric_learning import losses, miners, distances, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
from torchvision import datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print("Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(epoch, batch_idx, loss, mining_func.num_triplets))

### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(test_embeddings, 
                                                train_embeddings,
                                                np.squeeze(test_labels),
                                                np.squeeze(train_labels),
                                                False)
    print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))

device = torch.device("cuda")

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

batch_size = 256

dataset1 = datasets.MNIST('.', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('.', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=256)

model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 1


### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low = 0)
loss_func = losses.TripletMarginLoss(margin = 0.2, distance = distance, reducer = reducer)
mining_func = miners.TripletMarginMiner(margin = 0.2, distance = distance, type_of_triplets = "semihard")
accuracy_calculator = AccuracyCalculator(include = ("precision_at_1",), k = 1)
### pytorch-metric-learning stuff ###


for epoch in range(1, num_epochs+1):
    train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)
    test(dataset1, dataset2, model, accuracy_calculator)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Epoch 1 Iteration 0: Loss = 0.1037939116358757, Number of mined triplets = 776687
Epoch 1 Iteration 20: Loss = 0.09549278020858765, Number of mined triplets = 123048
Epoch 1 Iteration 40: Loss = 0.08915364742279053, Number of mined triplets = 89530
Epoch 1 Iteration 60: Loss = 0.08617275953292847, Number of mined triplets = 50026
Epoch 1 Iteration 80: Loss = 0.08428513258695602, Number of mined triplets = 42704
Epoch 1 Iteration 100: Loss = 0.08503272384405136, Number of mined triplets = 44915
Epoch 1 Iteration 120: Loss = 0.08499237149953842, Number of mined triplets = 34030
Epoch 1 Iteration 140: Loss = 0.08733508735895157, Number of mined triplets = 45407
Epoch 1 Iteration 160: Loss = 0.08337089419364929, Number of mined triplets = 32317
Epoch 1 Iteration 180: Loss = 0.08752584457397461, Number of mined triplets = 40344
Epoch 1 Iteration 200: Loss = 0.08363424241542816, Number of mined triplets = 27893
Epoch 1 Iteration 220: Loss = 0.08256012946367264, Number of mined triplets = 191

100%|██████████| 1875/1875 [00:15<00:00, 119.05it/s]
100%|██████████| 313/313 [00:03<00:00, 78.37it/s] 


Computing accuracy
Test set accuracy (Precision@1) = 0.9789
