In [1]:
!pip install pytorch-metric-learning==0.9.92.dev1
!pip install faiss-cpu # we're using cpu for this example

Collecting pytorch-metric-learning==0.9.92.dev1
[?25l  Downloading https://files.pythonhosted.org/packages/97/20/24474216408287b645b33caf52125b6e10b548e9b8345a0f5146105c801d/pytorch_metric_learning-0.9.92.dev1-py3-none-any.whl (92kB)
[K     |███▌                            | 10kB 16.9MB/s eta 0:00:01[K     |███████                         | 20kB 2.0MB/s eta 0:00:01[K     |██████████▋                     | 30kB 2.7MB/s eta 0:00:01[K     |██████████████▏                 | 40kB 3.0MB/s eta 0:00:01[K     |█████████████████▋              | 51kB 2.4MB/s eta 0:00:01[K     |█████████████████████▏          | 61kB 2.6MB/s eta 0:00:01[K     |████████████████████████▊       | 71kB 2.9MB/s eta 0:00:01[K     |████████████████████████████▎   | 81kB 3.2MB/s eta 0:00:01[K     |███████████████████████████████▉| 92kB 3.5MB/s eta 0:00:01[K     |████████████████████████████████| 102kB 3.0MB/s 
Installing collected packages: pytorch-metric-learning
Successfully installed pytorch-metric-

In [2]:
######################################################################################################################
### This script is modified from the guide on pytorch distributed training https://github.com/seba-1511/dist_tuto.pth/
### https://pytorch.org/tutorials/intermediate/dist_tuto.html
######################################################################################################################
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from math import ceil
from random import Random
from torch.multiprocessing import Process
from torchvision import datasets, transforms

from pytorch_metric_learning.utils import distributed as pml_dist
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning import losses, miners, testers
import numpy as np
import logging
logging.getLogger().setLevel(logging.INFO)


class Partition(object):
    """ Dataset-like object, but only access a subset of it. """

    def __init__(self, data, index):
        self.data = data
        self.index = index

    def __len__(self):
        return len(self.index)

    def __getitem__(self, index):
        data_idx = self.index[index]
        return self.data[data_idx]


class DataPartitioner(object):
    """ Partitions a dataset into different chuncks. """

    def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234):
        self.data = data
        self.partitions = []
        rng = Random()
        rng.seed(seed)
        data_len = len(data)
        indexes = [x for x in range(0, data_len)]
        rng.shuffle(indexes)

        for frac in sizes:
            part_len = int(frac * data_len)
            self.partitions.append(indexes[0:part_len])
            indexes = indexes[part_len:]

    def use(self, partition):
        return Partition(self.data, self.partitions[partition])


class Net(nn.Module):
    """ Network architecture. """

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        return self.fc1(x)

def get_MNIST(train):
    return datasets.MNIST(
            './data',
            train=train,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]))
    

def partition_dataset(dataset):
    """ Partitioning MNIST """
    size = dist.get_world_size()
    bsz = 512 // size
    partition_sizes = [1.0 / size for _ in range(size)]
    partition = DataPartitioner(dataset, partition_sizes)
    partition = partition.use(dist.get_rank())
    train_set = torch.utils.data.DataLoader(
        partition, batch_size=bsz, shuffle=True)
    return train_set, bsz


### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model, data_device):
    # dataloader_num_workers has to be 0 to avoid pid error
    # This only happens when within multiprocessing
    tester = testers.BaseTester(dataloader_num_workers=0, data_device=data_device)
    return tester.get_all_embeddings(dataset, model)


### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(dataset, model, accuracy_calculator, data_device):
    embeddings, labels = get_all_embeddings(dataset, model, data_device)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(embeddings, 
                                                embeddings,
                                                np.squeeze(labels),
                                                np.squeeze(labels),
                                                True)
    print("Validation set accuracy (MAP@10) = {}".format(accuracies["mean_average_precision_at_r"]))

def test_model(rank, dataset, model, epoch, data_device):
    if rank == 0:
        print("Computing validation set accuracy for epoch {}".format(epoch))
        accuracy_calculator = AccuracyCalculator(include = ("mean_average_precision_at_r",), k = 10)
        test(dataset, model, accuracy_calculator, data_device)
    dist.barrier()


def run(rank, size, train_dataset, val_dataset):
    """ Distributed Synchronous SGD Example """
    print("Rank {} entering the 'run' function".format(rank))
    torch.manual_seed(1234)
    train_set, bsz = partition_dataset(train_dataset)
    dist.barrier()
    ### use this if you have multiple GPUs ###
    # device = torch.device("cuda:{}".format(rank))
    device = torch.device("cpu")
    model = Net()
    ### if you have multiple GPUs, set this to DDP(model.to(device), device_ids=[rank])
    model = DDP(model.to(device))
    test_model(rank, val_dataset, model, "untrained", device)

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    #####################################
    ### pytorch-metric-learning stuff ###
    loss_fn = losses.TripletMarginLoss()
    ### if you have multiple GPUs, set this to pml_dist.DistributedLossWrapper(loss=loss_fn, device_ids=[rank])
    loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn)
    miner = miners.MultiSimilarityMiner()
    miner = pml_dist.DistributedMinerWrapper(miner=miner)
    ### pytorch-metric-learning stuff ###
    #####################################

    num_batches = ceil(len(train_set.dataset) / float(bsz))
    for epoch in range(1):
        epoch_loss = 0.0
        epoch_average_gradient = 0.0
        print("Rank {} starting epoch {}".format(rank, epoch))
        for i, (data, target) in enumerate(train_set):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            hard_pairs = miner(output, target)
            loss = loss_fn(output, target, hard_pairs)
            epoch_loss += loss.item()
            loss.backward()
            ##########################################################
            ### these two lines are just for illustration purposes ###
            average_gradient = np.mean([torch.mean(p.grad).item() for p in model.parameters()])
            epoch_average_gradient += average_gradient
            ##########################################################
            optimizer.step()
            ####################################################
            ### The loss value is the same for each process. ###
            ### But if you are using multiple GPUs, then the ###
            ### gradients will be different for each process ###
            ####################################################
            if i % 10 == 0:
                print('Rank {}, iteration {}, loss {}, num pos pairs {}, num neg pairs {}, average gradient {}'.\
                      format(rank, i, loss.item(), miner.miner.num_pos_pairs, miner.miner.num_neg_pairs, average_gradient))
            dist.barrier()

        print('Rank {}, epoch {}, average loss {}, average gradient {}'.format(rank, epoch, epoch_loss/num_batches, epoch_average_gradient/num_batches))
        test_model(rank, val_dataset, model, epoch, device)



#######################################
### Set backend='nccl' if using GPU ###
#######################################
def init_processes(rank, size, fn, train_dataset, val_dataset, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, train_dataset, val_dataset)


if __name__ == "__main__":
    train_dataset = get_MNIST(True)
    val_dataset = get_MNIST(False)

    size = 4
    processes = []
    for rank in range(size):
        p = Process(target=init_processes, args=(rank, size, run, train_dataset, val_dataset))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Rank 1 entering the 'run' function
Rank 0 entering the 'run' function
Rank 2 entering the 'run' function
Rank 3 entering the 'run' function








Computing validation set accuracy for epoch untrained


100%|██████████| 313/313 [00:03<00:00, 89.57it/s]


Computing accuracy


INFO:root:running k-nn with k=10
INFO:root:embedding dimensionality is 50


Validation set accuracy (MAP@10) = 0.7912241984126984
Rank 3 starting epoch 0
Rank 1 starting epoch 0
Rank 0 starting epoch 0
Rank 2 starting epoch 0
Rank 2, iteration 0, loss 0.22585304081439972, num pos pairs 26208, num neg pairs 234298, average gradient 0.00043545858325918135
Rank 3, iteration 0, loss 0.22585304081439972, num pos pairs 26208, num neg pairs 234298, average gradient 0.00043545858325918135
Rank 1, iteration 0, loss 0.22585304081439972, num pos pairs 26208, num neg pairs 234298, average gradient 0.00043545858325918135
Rank 0, iteration 0, loss 0.22585304081439972, num pos pairs 26208, num neg pairs 234298, average gradient 0.00043545858325918135








Rank 3, iteration 10, loss 0.20190657675266266, num pos pairs 26239, num neg pairs 233600, average gradient 0.0005712817248119487
Rank 1, iteration 10, loss 0.20190657675266266, num pos pairs 26239, num neg pairs 233600, average gradient 0.0005712817248119487
Rank 2, iteration 10, loss 0.20190657675266266, num pos pairs 

100%|██████████| 313/313 [00:03<00:00, 92.84it/s]


Computing accuracy


INFO:root:running k-nn with k=10
INFO:root:embedding dimensionality is 50


Validation set accuracy (MAP@10) = 0.8194116587301588
