In [1]:
!pip install pytorch-metric-learning
!pip install faiss-gpu

Collecting pytorch-metric-learning
[?25l  Downloading https://files.pythonhosted.org/packages/5d/4c/cf04389670cc5168ef8edffd09d90ce61754ed438c07304e0ed5d6647616/pytorch_metric_learning-0.9.90-py3-none-any.whl (90kB)
[K     |███▋                            | 10kB 29.5MB/s eta 0:00:01[K     |███████▎                        | 20kB 2.5MB/s eta 0:00:01[K     |███████████                     | 30kB 3.0MB/s eta 0:00:01[K     |██████████████▌                 | 40kB 3.3MB/s eta 0:00:01[K     |██████████████████▏             | 51kB 3.1MB/s eta 0:00:01[K     |█████████████████████▉          | 61kB 3.5MB/s eta 0:00:01[K     |█████████████████████████▌      | 71kB 3.8MB/s eta 0:00:01[K     |█████████████████████████████   | 81kB 4.2MB/s eta 0:00:01[K     |████████████████████████████████| 92kB 3.4MB/s 
Installing collected packages: pytorch-metric-learning
Successfully installed pytorch-metric-learning-0.9.90
Collecting faiss-gpu
[?25l  Downloading https://files.pythonhosted.org

In [2]:
from pytorch_metric_learning import losses, miners, distances, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
from torchvision import datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print("Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(epoch, batch_idx, loss, mining_func.num_triplets))

### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(dataset, model, accuracy_calculator):
    embeddings, labels = get_all_embeddings(dataset, model)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(embeddings, 
                                                embeddings,
                                                np.squeeze(labels),
                                                np.squeeze(labels),
                                                True)
    print("Test set accuracy (Adjusted Mutual Information) = {}".format(accuracies["AMI"]))

device = torch.device("cuda")

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

batch_size = 256

dataset1 = datasets.MNIST('.', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('.', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=256)

model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 2


### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low = 0)
loss_func = losses.TripletMarginLoss(margin = 0.2, distance = distance, reducer = reducer)
mining_func = miners.TripletMarginMiner(margin = 0.2, distance = distance, type_of_triplets = "semihard")
accuracy_calculator = AccuracyCalculator(include = ("AMI",))
### pytorch-metric-learning stuff ###


for epoch in range(1, num_epochs+1):
    train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)
    test(dataset2, model, accuracy_calculator)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)




Epoch 1 Iteration 0: Loss = 0.19946202635765076, Number of mined triplets = 782808
Epoch 1 Iteration 20: Loss = 0.10993602126836777, Number of mined triplets = 414856
Epoch 1 Iteration 40: Loss = 0.10145941376686096, Number of mined triplets = 251300
Epoch 1 Iteration 60: Loss = 0.09394049644470215, Number of mined triplets = 180018
Epoch 1 Iteration 80: Loss = 0.0911107212305069, Number of mined triplets = 134150
Epoch 1 Iteration 100: Loss = 0.08349557220935822, Number of mined triplets = 76062
Epoch 1 Iteration 120: Loss = 0.08468102663755417, Number of mined triplets = 52069
Epoch 1 Iteration 140: Loss = 0.08442310243844986, Number of mined triplets = 71048
Epoch 1 Iteration 160: Loss = 0.08457087725400925, Number of mined triplets = 66220
Epoch 1 Iteration 180: Loss = 0.08869604766368866, Number of mined triplets = 50983
Epoch 1 Iteration 200: Loss = 0.08236870914697647, Number of mined triplets = 50825
Epoch 1 Iteration 220: Loss = 0.08612465113401413, Number of mined triplets 

100%|██████████| 313/313 [00:03<00:00, 94.31it/s]


Computing accuracy
Test set accuracy (Adjusted Mutual Information) = 0.856050684506773
Epoch 2 Iteration 0: Loss = 0.07726449519395828, Number of mined triplets = 18995
Epoch 2 Iteration 20: Loss = 0.08266301453113556, Number of mined triplets = 42861
Epoch 2 Iteration 40: Loss = 0.09535423666238785, Number of mined triplets = 57466
Epoch 2 Iteration 60: Loss = 0.08160889893770218, Number of mined triplets = 31130
Epoch 2 Iteration 80: Loss = 0.08814632892608643, Number of mined triplets = 36959
Epoch 2 Iteration 100: Loss = 0.07569822669029236, Number of mined triplets = 19386
Epoch 2 Iteration 120: Loss = 0.1012580394744873, Number of mined triplets = 21309
Epoch 2 Iteration 140: Loss = 0.09831924736499786, Number of mined triplets = 44847
Epoch 2 Iteration 160: Loss = 0.08502580970525742, Number of mined triplets = 32722
Epoch 2 Iteration 180: Loss = 0.09057745337486267, Number of mined triplets = 41723
Epoch 2 Iteration 200: Loss = 0.07309754937887192, Number of mined triplets = 22

100%|██████████| 313/313 [00:03<00:00, 89.31it/s]

Computing accuracy
Test set accuracy (Adjusted Mutual Information) = 0.9111903949345276



