In [1]:
!pip install pytorch-metric-learning
!pip install faiss-gpu

Collecting pytorch-metric-learning
  Downloading pytorch_metric_learning-1.0.0-py3-none-any.whl (102 kB)
[?25l[K     |███▏                            | 10 kB 17.1 MB/s eta 0:00:01[K     |██████▍                         | 20 kB 13.1 MB/s eta 0:00:01[K     |█████████▋                      | 30 kB 9.3 MB/s eta 0:00:01[K     |████████████▊                   | 40 kB 8.6 MB/s eta 0:00:01[K     |████████████████                | 51 kB 5.4 MB/s eta 0:00:01[K     |███████████████████▏            | 61 kB 5.9 MB/s eta 0:00:01[K     |██████████████████████▎         | 71 kB 5.8 MB/s eta 0:00:01[K     |█████████████████████████▌      | 81 kB 6.4 MB/s eta 0:00:01[K     |████████████████████████████▊   | 92 kB 5.0 MB/s eta 0:00:01[K     |███████████████████████████████▉| 102 kB 5.5 MB/s eta 0:00:01[K     |████████████████████████████████| 102 kB 5.5 MB/s 
Installing collected packages: pytorch-metric-learning
Successfully installed pytorch-metric-learning-1.0.0
Collecting faiss-

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
from torchvision import datasets, transforms

from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

class SubCenterArcFaceLoss(losses.ArcFaceLoss):
    """
    Implementation of https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123560715.pdf
    """

    def __init__(self, *args, margin=28.6, scale=64, sub_centers=3, **kwargs):
        num_classes, embedding_size = args
        super().__init__(num_classes * sub_centers, embedding_size, margin=margin, scale=scale, **kwargs)
        self.sub_centers = sub_centers
        self.num_classes = num_classes
    
    def get_cosine(self, embeddings):
        cosine = self.distance(embeddings, self.W.t())
        cosine = cosine.view(-1, self.num_classes, self.sub_centers)
        cosine, _ = cosine.max(axis=2)
        return cosine

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x


### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
def train(model, loss_func, device, train_loader, optimizer, loss_optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        loss_optimizer.zero_grad()
        embeddings = model(data)
        loss = loss_func(embeddings, labels)
        loss.backward()
        optimizer.step()
        loss_optimizer.step()
        if batch_idx % 20 == 0:
            print(
                "Epoch {} Iteration {}: Loss = {}".format(
                    epoch, batch_idx, loss
                )
            )


### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)


### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, train_embeddings, test_labels, train_labels, False
    )
    print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))


device = torch.device("cuda")

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

batch_size = 64

dataset1 = datasets.MNIST(".", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST(".", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=batch_size)

model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 1


### pytorch-metric-learning stuff ###
loss_func = SubCenterArcFaceLoss(10, 128).to(device)
loss_optimizer = torch.optim.SGD(loss_func.parameters(), lr=0.01)
accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)
### pytorch-metric-learning stuff ###


for epoch in range(1, num_epochs + 1):
    train(model, loss_func, device, train_loader, optimizer, loss_optimizer, epoch)
    test(dataset1, dataset2, model, accuracy_calculator)

Epoch 1 Iteration 0: Loss = 37.22161102294922
Epoch 1 Iteration 20: Loss = 30.873565673828125
Epoch 1 Iteration 40: Loss = 25.4113826751709
Epoch 1 Iteration 60: Loss = 6.920309066772461
Epoch 1 Iteration 80: Loss = 8.881182670593262
Epoch 1 Iteration 100: Loss = 6.854680061340332
Epoch 1 Iteration 120: Loss = 7.7223615646362305
Epoch 1 Iteration 140: Loss = 5.3887529373168945
Epoch 1 Iteration 160: Loss = 3.7892096042633057
Epoch 1 Iteration 180: Loss = 5.4152913093566895
Epoch 1 Iteration 200: Loss = 3.5456528663635254
Epoch 1 Iteration 220: Loss = 7.2787957191467285
Epoch 1 Iteration 240: Loss = 5.1593732833862305
Epoch 1 Iteration 260: Loss = 5.803065776824951
Epoch 1 Iteration 280: Loss = 2.1638660430908203
Epoch 1 Iteration 300: Loss = 3.2827391624450684
Epoch 1 Iteration 320: Loss = 3.6944427490234375
Epoch 1 Iteration 340: Loss = 1.6409906148910522
Epoch 1 Iteration 360: Loss = 4.501714706420898
Epoch 1 Iteration 380: Loss = 1.9407941102981567
Epoch 1 Iteration 400: Loss = 2.01

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:03<00:00, 474.80it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:00<00:00, 437.68it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.9818
