# toy example of using two stream metric learning

In [4]:
# The testing module requires faiss
# So if you don't have that, then this import will break 
from pytorch_metric_learning import losses, miners, samplers, trainers, testers
import pytorch_metric_learning.utils.logging_presets as logging_presets
import numpy as np
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import logging
from PIL import Image
import pytorch_metric_learning

logging.getLogger().setLevel(logging.INFO)
logging.info("VERSION %s"%pytorch_metric_learning.__version__)


INFO:root:VERSION 0.9.84


In [5]:
class CIFAR100TwoStreamDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, anchor_transform, posneg_transform):
        lengths = [int(len(dataset)*0.8), int(len(dataset)*0.2)]
        self.anchors, self.posnegs = torch.utils.data.random_split(dataset, lengths)
        
        self.anchor_transform = anchor_transform
        self.posneg_transform = posneg_transform

    def __len__(self):
        return len(self.anchors)
        
    def __getitem__(self, index):            
        anchor, target = self.anchors[index]
        if self.anchor_transform is not None:
            anchor = self.anchor_transform(anchor)
        
        # now pair this up with an image from the same class in the second stream
        A = np.where( np.array(self.posnegs.dataset.targets)==target )[0]
        posneg_idx = np.random.choice(A[np.in1d(A, self.posnegs.indices)])
        posneg, target = self.posnegs[np.where(self.posnegs.indices==posneg_idx)[0][0]]
        
        if self.posneg_transform is not None:
            posneg = self.posneg_transform(posneg)
        return anchor, posneg, target


In [6]:
# This is a basic multilayer perceptron
# This code is from https://github.com/KevinMusgrave/powerful_benchmarker
class MLP(nn.Module):
    # layer_sizes[0] is the dimension of the input
    # layer_sizes[-1] is the dimension of the output
    def __init__(self, layer_sizes, final_relu=False):
        super().__init__()
        layer_list = []
        layer_sizes = [int(x) for x in layer_sizes]
        num_layers = len(layer_sizes) - 1
        final_relu_layer = num_layers if final_relu else num_layers - 1
        for i in range(len(layer_sizes) - 1):
            input_size = layer_sizes[i]
            curr_size = layer_sizes[i + 1]
            if i < final_relu_layer:
                layer_list.append(nn.ReLU(inplace=True))
            layer_list.append(nn.Linear(input_size, curr_size))
        self.net = nn.Sequential(*layer_list)
        self.last_linear = self.net[-1]

    def forward(self, x):
        return self.net(x)


# This is for replacing the last layer of a pretrained network.
# This code is from https://github.com/KevinMusgrave/powerful_benchmarker
class Identity(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

# This code is from https://github.com/KevinMusgrave/powerful_benchmarker
class ListOfModels(nn.Module):
    def __init__(self, list_of_models, input_sizes=None, operation_before_concat=None):
        super().__init__()
        self.list_of_models = nn.ModuleList(list_of_models)
        self.input_sizes = input_sizes
        self.operation_before_concat = (lambda x: x) if not operation_before_concat else operation_before_concat
        for k in ["mean", "std", "input_space", "input_range"]:
            setattr(self, k, getattr(list_of_models[0], k, None))

    def forward(self, x):
        outputs = []
        if self.input_sizes is None:
            for m in self.list_of_models:
                curr_output = self.operation_before_concat(m(x))
                outputs.append(curr_output)
        else:
            s = 0
            for i, y in enumerate(self.input_sizes):
                curr_input = x[:, s : s + y]
                curr_output = self.operation_before_concat(self.list_of_models[i](curr_input))
                outputs.append(curr_output)
                s += y
        return torch.cat(outputs, dim=-1)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set trunk model and replace the softmax layer with an identity function
trunk = models.resnet18(pretrained=True)
trunk_output_size = trunk.fc.in_features
trunk.fc = Identity()
trunk = torch.nn.DataParallel(trunk.to(device))

# Set embedder model. This takes in the output of the trunk and outputs 64 dimensional embeddings
embedder = torch.nn.DataParallel(MLP([trunk_output_size, 128]).to(device))

# Set optimizers
trunk_optimizer = torch.optim.Adam(trunk.parameters(), lr=0.00004, weight_decay=0.00005)
embedder_optimizer = torch.optim.Adam(embedder.parameters(), lr=0.00004, weight_decay=0.00005)


In [9]:
posneg_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)

anchor_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)

# Set the datasets
original_train = datasets.CIFAR100(root="../CIFAR100_Dataset", train=True, transform=None, download=True)
original_val = datasets.CIFAR100(root="../CIFAR100_Dataset", train=False, transform=None, download=True)

# splits CIFAR100 into two streams
# 20% of the images will be used as a stream for positives and negatives
# the remaining images are used as anchor images

train_dataset = CIFAR100TwoStreamDataset(original_train, anchor_transform=anchor_transform, posneg_transform=posneg_transform)
val_dataset = CIFAR100TwoStreamDataset(original_val, anchor_transform=anchor_transform, posneg_transform=posneg_transform)


Files already downloaded and verified
Files already downloaded and verified


In [12]:
# Set the loss function
loss = losses.TripletMarginLoss(margin=0.2)

# Set the mining function
miner = miners.TripletMarginMiner(margin=0.2)

# Set the dataloader sampler
sampler = samplers.MPerClassSampler(original_train.classes, m=1)

# Set other training parameters
batch_size = 128
num_epochs = 10
iterations_per_epoch = 50

# Package the above stuff into dictionaries.
models = {"trunk": trunk, "embedder": embedder}
optimizers = {"trunk_optimizer": trunk_optimizer, "embedder_optimizer": embedder_optimizer}
loss_funcs = {"metric_loss": loss}
mining_funcs = {"tuple_miner": miner}

record_keeper, _, _ = logging_presets.get_record_keeper("example_logs", "example_tensorboard")
hooks = logging_presets.get_hook_container(record_keeper)
dataset_dict = {"val": val_dataset}
model_folder = "example_saved_models"

# Create the tester
tester = testers.GlobalTwoStreamEmbeddingSpaceTester(end_of_testing_hook=hooks.end_of_testing_hook, dataloader_num_workers=2)
end_of_epoch_hook = hooks.end_of_epoch_hook(tester, dataset_dict, model_folder)

trainer = trainers.TwoStreamMetricLoss(models=models,
                                optimizers=optimizers,
                                batch_size=batch_size,
                                loss_funcs=loss_funcs,
                                mining_funcs=mining_funcs,
                                iterations_per_epoch=iterations_per_epoch,
                                dataset=train_dataset,
                                sampler=sampler,
                                dataloader_num_workers=2,
                                end_of_iteration_hook=hooks.end_of_iteration_hook,
                                end_of_epoch_hook=end_of_epoch_hook
                                )

trainer.train(num_epochs=num_epochs)

INFO:root:Initializing dataloader
INFO:root:Initializing dataloader iterator
INFO:root:Done creating dataloader iterator
INFO:root:TRAINING EPOCH 1
total_loss=0.17232: 100%|██████████| 50/50 [00:30<00:00,  1.66it/s]
INFO:root:Evaluating epoch 1
INFO:root:Getting embeddings for the val split
100%|██████████| 250/250 [00:11<00:00, 22.65it/s]
INFO:root:Computing accuracy for the val split
INFO:root:running k-nn with k=88
INFO:root:embedding dimensionality is 128
INFO:root:running k-means clustering with k=100
INFO:root:embedding dimensionality is 128
INFO:root:New best accuracy!
INFO:root:TRAINING EPOCH 2
total_loss=0.16593: 100%|██████████| 50/50 [00:30<00:00,  1.68it/s]
INFO:root:Evaluating epoch 2
INFO:root:Getting embeddings for the val split
100%|██████████| 250/250 [00:11<00:00, 21.31it/s]
INFO:root:Computing accuracy for the val split
INFO:root:running k-nn with k=88
INFO:root:embedding dimensionality is 128
INFO:root:running k-means clustering with k=100
INFO:root:embedding dimens

INFO:root:New best accuracy!
INFO:root:TRAINING EPOCH 9
total_loss=0.14181: 100%|██████████| 50/50 [00:31<00:00,  1.48it/s]
INFO:root:Evaluating epoch 9
INFO:root:Getting embeddings for the val split
100%|██████████| 250/250 [00:13<00:00, 18.32it/s]
INFO:root:Computing accuracy for the val split
INFO:root:running k-nn with k=88
INFO:root:embedding dimensionality is 128
INFO:root:running k-means clustering with k=100
INFO:root:embedding dimensionality is 128
INFO:root:New best accuracy!
INFO:root:TRAINING EPOCH 10
total_loss=0.13159: 100%|██████████| 50/50 [00:31<00:00,  1.36it/s]
INFO:root:Evaluating epoch 10
INFO:root:Getting embeddings for the val split
100%|██████████| 250/250 [00:17<00:00, 14.58it/s]
INFO:root:Computing accuracy for the val split
INFO:root:running k-nn with k=88
INFO:root:embedding dimensionality is 128
INFO:root:running k-means clustering with k=100
INFO:root:embedding dimensionality is 128
INFO:root:New best accuracy!
