In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms

from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

In [2]:
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

In [3]:
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ###
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print(
                "Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(
                    epoch, batch_idx, loss, mining_func.num_triplets
                )
            )

In [4]:
### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

In [5]:
### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator, accuracy_str: str = "precision_at_1"):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    print("Computing accuracy")
    
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, test_labels, train_embeddings, train_labels, False
    )
    print(accuracies)
    
    for k, v in accuracies.items():
        print(f"Test set  = {k}: {v}")

In [6]:
# device = torch.device("cuda")
device = torch.device("cpu")

In [7]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

batch_size = 256
num_epochs = 1

In [8]:
dataset1 = datasets.MNIST(".", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST(".", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(
    dataset1, batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=batch_size)

In [9]:
print(f"Train set size: {dataset1.data.shape}")
print(f"Test set size: {dataset2.data.shape}")

Train set size: torch.Size([60000, 28, 28])
Test set size: torch.Size([10000, 28, 28])


In [10]:
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [11]:
### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
loss_func = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)
mining_func = miners.TripletMarginMiner(
    margin=0.2, distance=distance, type_of_triplets="semihard"
)
accuracy_calculator_precision_at_1 = AccuracyCalculator(include=("precision_at_1",), k=1)

In [18]:
# Training loop
for epoch in range(1, num_epochs + 1):
    train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)
    test(dataset1, dataset2, model, accuracy_calculator_precision_at_1)

Epoch 1 Iteration 0: Loss = 0.10245107114315033, Number of mined triplets = 736921
Epoch 1 Iteration 20: Loss = 0.09349484741687775, Number of mined triplets = 143794
Epoch 1 Iteration 40: Loss = 0.09178034961223602, Number of mined triplets = 89165
Epoch 1 Iteration 60: Loss = 0.08623408526182175, Number of mined triplets = 62666
Epoch 1 Iteration 80: Loss = 0.08743933588266373, Number of mined triplets = 60605
Epoch 1 Iteration 100: Loss = 0.08344611525535583, Number of mined triplets = 31143
Epoch 1 Iteration 120: Loss = 0.08250976353883743, Number of mined triplets = 37088
Epoch 1 Iteration 140: Loss = 0.0842210203409195, Number of mined triplets = 39999
Epoch 1 Iteration 160: Loss = 0.08521860837936401, Number of mined triplets = 32405
Epoch 1 Iteration 180: Loss = 0.08421389013528824, Number of mined triplets = 28084
Epoch 1 Iteration 200: Loss = 0.08335979282855988, Number of mined triplets = 24940
Epoch 1 Iteration 220: Loss = 0.08223602920770645, Number of mined triplets = 272

100%|███████████████████████████████| 1875/1875 [00:21<00:00, 85.23it/s]
100%|█████████████████████████████████| 313/313 [00:04<00:00, 74.77it/s]


Computing accuracy
{'precision_at_1': 0.9816}
Test set  = precision_at_1: 0.9816


In [12]:
from pytorch_metric_learning.utils import accuracy_calculator

# TODO: improve and DRY
class MyCalculator(accuracy_calculator.AccuracyCalculator):
    # def calculate_precision_at_1(self, knn_labels, query_labels, **kwargs):
    #     return accuracy_calculator.precision_at_k(knn_labels, query_labels[:, None], 1)
        
    def calculate_precision_at_2(self, knn_labels, query_labels, **kwargs):
        gt_labels = query_labels[:, None]
        avg_of_avgs=None
        return_per_class = None
        label_comparison_fn = torch.eq
        
        return accuracy_calculator.precision_at_k(
            knn_labels=knn_labels, 
            gt_labels=gt_labels, 
            k=2, 
            avg_of_avgs=avg_of_avgs,
            return_per_class=return_per_class,
            label_comparison_fn=label_comparison_fn
        )

    def calculate_precision_at_3(self, knn_labels, query_labels, **kwargs):
        gt_labels = query_labels[:, None]
        avg_of_avgs=None
        return_per_class = None
        label_comparison_fn = torch.eq
        
        return accuracy_calculator.precision_at_k(
            knn_labels=knn_labels, 
            gt_labels=gt_labels, 
            k=3, 
            avg_of_avgs=avg_of_avgs,
            return_per_class=return_per_class,
            label_comparison_fn=label_comparison_fn
        )

    def calculate_precision_at_5(self, knn_labels, query_labels, **kwargs):
        gt_labels = query_labels[:, None]
        avg_of_avgs=None
        return_per_class = None
        label_comparison_fn = torch.eq
        
        return accuracy_calculator.precision_at_k(
            knn_labels=knn_labels, 
            gt_labels=gt_labels, 
            k=5, 
            avg_of_avgs=avg_of_avgs,
            return_per_class=return_per_class,
            label_comparison_fn=label_comparison_fn
        )

    def calculate_precision_at_10(self, knn_labels, query_labels, **kwargs):
        gt_labels = query_labels[:, None]
        avg_of_avgs=None
        return_per_class = None
        label_comparison_fn = torch.eq
        
        return accuracy_calculator.precision_at_k(
            knn_labels=knn_labels, 
            gt_labels=gt_labels, 
            k=10, 
            avg_of_avgs=avg_of_avgs,
            return_per_class=return_per_class,
            label_comparison_fn=label_comparison_fn
        )
        
    def requires_knn(self):
        return super().requires_knn() + ["precision_at_2"] + ["precision_at_5"] + ["precision_at_3"] + ["precision_at_10"]

In [13]:
def sample_embeddings(embeddings, sample_size: int = 100):
    sample_idx = np.random.randint(embeddings[0].shape[0], size=sample_size)
    return embeddings[0][sample_idx, :], embeddings[1][sample_idx, :]

In [14]:
tester = testers.BaseTester()
embs_test = tester.get_all_embeddings(dataset2, model)
embs_train = tester.get_all_embeddings(dataset1, model)

100%|█████████████████████████████████| 313/313 [00:04<00:00, 78.21it/s]
100%|███████████████████████████████| 1875/1875 [00:22<00:00, 84.66it/s]


In [15]:
embs_test_sampled = sample_embeddings(embs_test, sample_size=1000)
embs_train_sampled = sample_embeddings(embs_train, sample_size=10000)

train_embeddings, train_labels = embs_train_sampled
test_embeddings, test_labels = embs_test_sampled
train_labels = train_labels.squeeze(1)
test_labels = test_labels.squeeze(1)

In [16]:
acc_calc = AccuracyCalculator(include=("precision_at_1",), k=1)

accuracies = acc_calc.get_accuracy(
    test_embeddings, test_labels, train_embeddings, train_labels, False
)

accuracies

  x.storage().data_ptr() + x.storage_offset() * 4)


{'precision_at_1': 0.952}

In [17]:
acc_calc =  MyCalculator()

accuracies = acc_calc.get_accuracy(
    test_embeddings, test_labels, train_embeddings, train_labels, False
)

accuracies

{'AMI': 0.5091641785263621,
 'NMI': 0.518064520237727,
 'mean_average_precision': 0.4474036158878568,
 'mean_average_precision_at_r': 0.31849579869519007,
 'mean_reciprocal_rank': 0.9676808714866638,
 'precision_at_1': 0.952,
 'precision_at_10': 0.9078999999999999,
 'precision_at_2': 0.9425,
 'precision_at_3': 0.9356666666666666,
 'precision_at_5': 0.9266,
 'r_precision': 0.43060780334557375}

```
{'AMI': 0.5091641785263621,
 'NMI': 0.518064520237727,
 'mean_average_precision': 0.4474036158878568,
 'mean_average_precision_at_r': 0.31849579869519007,
 'mean_reciprocal_rank': 0.9676808714866638,
 'precision_at_1': 0.952,
 'precision_at_10': 0.9078999999999999,
 'precision_at_2': 0.9425,
 'precision_at_3': 0.9356666666666666,
 'precision_at_5': 0.9266,
 'r_precision': 0.43060780334557375}
```