In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq self-supervised

In [None]:
#default_exp dist

# Dist

> Utilities for distributed training.

In [None]:
#export
import torch
import torch.distributed as dist

An all_gather layer with backward, useful for collecting model output embeddings from multiple gpus to allow large batch size loss calculation, e.g. for InfoNCE (SimCRL, CLIP).


In [None]:
#export
class GatherLayer(torch.autograd.Function):
    '''Gather tensors from all process, supporting backward propagation.
    https://github.com/open-mmlab/OpenSelfSup/blob/696d04950e55d504cf33bc83cfadbb4ece10fbae/openselfsup/models/utils/gather_layer.py
    '''
        
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) \
            for _ in range(dist.get_world_size())]
        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        input, = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        grad_out[:] = grads[dist.get_rank()]
        return grad_out

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 01-augmentations.ipynb.
Converted 02-layers.ipynb.
Converted 03-distributed.ipynb.
Converted 10-simclr.ipynb.
Converted 11-byol.ipynb.
Converted 12-swav.ipynb.
Converted 13-moco.ipynb.
Converted index.ipynb.
