In [1]:
# The testing module requires faiss
# So if you don't have that, then this import will break
from pytorch_metric_learning import losses, miners, samplers, trainers, testers
import pytorch_metric_learning.utils.logging_presets as logging_presets
import numpy as np
from torchvision import datasets, models, transforms
import torch
import logging
from utils_for_examples import MLP, Identity
logging.getLogger().setLevel(logging.INFO)

import pytorch_metric_learning
logging.info("VERSION %s"%pytorch_metric_learning.__version__)

INFO:root:VERSION 0.9.76


In [2]:
##############################
########## Training ##########
##############################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set trunk model and replace the softmax layer with an identity function
trunk = models.resnet18(pretrained=True)
trunk_output_size = trunk.fc.in_features
trunk.fc = Identity()
trunk = torch.nn.DataParallel(trunk.to(device))

# Set embedder model. This takes in the output of the trunk and outputs 64 dimensional embeddings
embedder = torch.nn.DataParallel(MLP([trunk_output_size, 64]).to(device))

# Set the classifier. The classifier will take the embeddings and output a 100 dimensional vector.
# (There are 100 classes in CIFAR100, which is the dataset we'll use in this example.)
# We'll specify the classification loss further down in the code.
classifier = torch.nn.DataParallel(MLP([64, 100])).to(device)

# Set optimizers
trunk_optimizer = torch.optim.Adam(trunk.parameters(), lr=0.00001, weight_decay=0.00005)
embedder_optimizer = torch.optim.Adam(embedder.parameters(), lr=0.00001, weight_decay=0.00005)
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=0.00001, weight_decay=0.00005)

# Set the image transforms
train_transform = transforms.Compose([transforms.Resize(256),
                                    transforms.RandomResizedCrop(scale=(0.16, 1), ratio=(0.75, 1.33), size=227),
                                    transforms.RandomHorizontalFlip(0.5),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

val_transform = transforms.Compose([transforms.Resize(256),
                                    transforms.CenterCrop(227),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# Set the datasets
train_dataset = datasets.CIFAR100(root="CIFAR100_Dataset", train=True, transform=train_transform, download=True)
val_dataset = datasets.CIFAR100(root="CIFAR100_Dataset", train=False, transform=val_transform, download=True)

# Set the loss function
loss = losses.TripletMarginLoss(margin=0.01)

# Set the classification loss:
classification_loss = torch.nn.CrossEntropyLoss()

# Set the mining function
miner = miners.MultiSimilarityMiner(epsilon=0.1)

# Set the dataloader sampler
sampler = samplers.MPerClassSampler(train_dataset.targets, m=4)

# Set other training parameters
batch_size = 32
num_epochs = 2
iterations_per_epoch = 100

# Package the above stuff into dictionaries.
models = {"trunk": trunk, "embedder": embedder, "classifier": classifier}
optimizers = {"trunk_optimizer": trunk_optimizer, "embedder_optimizer": embedder_optimizer, "classifier_optimizer": classifier_optimizer}
loss_funcs = {"metric_loss": loss, "classifier_loss": classification_loss}
mining_funcs = {"post_gradient_miner": miner}

# We can specify loss weights if we want to. This is optional
loss_weights = {"metric_loss": 1, "classifier_loss": 0.5}

record_keeper, _, _ = logging_presets.get_record_keeper("example_logs", "example_tensorboard")
hooks = logging_presets.get_hook_container(record_keeper)
dataset_dict = {"val": val_dataset}
model_folder = "example_saved_models"

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Create the tester
tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook=hooks.end_of_testing_hook)
end_of_epoch_hook = hooks.end_of_epoch_hook(tester, dataset_dict, model_folder)
trainer = trainers.TrainWithClassifier(models,
                                optimizers,
                                batch_size,
                                loss_funcs,
                                mining_funcs,
                                iterations_per_epoch,
                                train_dataset,
                                loss_weights=loss_weights,
                                sampler=sampler,
                                end_of_iteration_hook=hooks.end_of_iteration_hook,
                                end_of_epoch_hook=end_of_epoch_hook)

trainer.train(num_epochs=num_epochs)

INFO:root:Initializing dataloader
INFO:root:Initializing dataloader iterator
INFO:root:Done creating dataloader iterator
INFO:root:TRAINING EPOCH 1
total_loss=2.44327: 100%|██████████| 100/100 [00:09<00:00, 11.08it/s]
INFO:root:Evaluating epoch 1
INFO:root:Getting embeddings for the val split
100%|██████████| 313/313 [00:08<00:00, 36.24it/s]
INFO:root:Computing accuracy for the val split
INFO:root:running k-nn with k=100
INFO:root:embedding dimensionality is 64
INFO:root:New best accuracy!
INFO:root:TRAINING EPOCH 2
total_loss=2.33122: 100%|██████████| 100/100 [00:07<00:00, 14.07it/s]
INFO:root:Evaluating epoch 2
INFO:root:Getting embeddings for the val split
100%|██████████| 313/313 [00:08<00:00, 35.95it/s]
INFO:root:Computing accuracy for the val split
INFO:root:running k-nn with k=100
INFO:root:embedding dimensionality is 64
INFO:root:New best accuracy!
