In [1]:
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import torch.nn as nn
import torch.nn.functional as F
from torch.multiprocessing import Process
from torch.autograd import Variable
from torchvision import datasets, transforms
from math import ceil
from random import Random
from torch.nn.parallel import DistributedDataParallel
import os

def init_process_group(rank, size, backend='nccl'):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    assert (size==dist.get_world_size()),f"{size} and {dist.get_world_size()} does not match"
    return dist

def all_reduce(tensor,op):
  dist.all_reduce(tensor,op)

def reduce(tensor,op):
  dist.reduce(tensor,op)

def average_grads_allreduce(model):
  for param in model.parameters():
    reduce(param.grad.data,op=dist.ReduceOp.SUM)
    param.grad.data/=float(dist.get_world_size())

def set_seed(val):
  torch.manual_seed(val)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(val)

def set_cuda_device(dist):
  if torch.cuda.is_available():
    torch.cuda.set_device(dist.get_rank())


dist=init_process_group(0,1)
print(f"Initialized process group with {dist.get_world_size()} ranks")


class Net(nn.Module):
    """ Network architecture. """

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)


class Partition(object):
    def __init__(self, data, index):
        self.data = data
        self.index = index

    def __len__(self):
        return len(self.index)

    def __getitem__(self, index):
        data_idx = self.index[index]
        return self.data[data_idx]

class DataPartitioner(object):

    def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234):
        self.data = data
        self.partitions = []
        rng = Random()
        rng.seed(seed)

        data_len = len(data)
        indexes = [x for x in range(0, data_len)]
        rng.shuffle(indexes)

        for frac in sizes:
            part_len = int(frac * data_len)
            self.partitions.append(indexes[0:part_len])
            indexes = indexes[part_len:]

    def use(self, partition):
        return Partition(self.data, self.partitions[partition])
def partition_dataset():
    dataset = datasets.MNIST(
        './data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ]))
    size = dist.get_world_size()
    bsz = 128 // size
    partition_sizes = [1.0 / size for _ in range(size)]
    partition = DataPartitioner(dataset, partition_sizes)
    partition = partition.use(dist.get_rank())
    train_set = torch.utils.data.DataLoader(
        partition, batch_size=bsz, shuffle=True)
    return train_set, bsz

def run(dist):
  set_seed(1234)
  train_set, bsz = partition_dataset()
  model = Net()
  if torch.cuda.is_available():
      set_cuda_device(dist)
  if dist.get_world_size()>1:
      model= DistributedDataParallel(model,device_ids=dist.get_rank(),output_device=dist.get_rank())
  else:
      model=model.to(device='cuda:0')
  optimizer = torch.optim.SGD(model.parameters(),
                        lr=0.01, momentum=0.5)

  num_batches = ceil(len(train_set.dataset) / float(bsz))
  for epoch in range(10):
      epoch_loss = 0.0
      for data, target in train_set:
          data, target= data.to(device='cuda'),target.to(device='cuda')
          optimizer.zero_grad()
          output = model(data)
          loss = F.nll_loss(output, target)
          epoch_loss += loss.item()
          loss.backward()
          average_grads_allreduce(model)
          optimizer.step()
      print('Rank ', dist.get_rank(), ', epoch ',
            epoch, ': ', epoch_loss / num_batches)
run(dist)

Initialized process group with 1 ranks


  return F.log_softmax(x)


Rank  0 , epoch  0 :  1.3033309768257872
Rank  0 , epoch  1 :  0.5454973846610421
Rank  0 , epoch  2 :  0.4244839232931259
Rank  0 , epoch  3 :  0.3598236723113924
Rank  0 , epoch  4 :  0.3233929915405286
Rank  0 , epoch  5 :  0.2905611474948651
Rank  0 , epoch  6 :  0.26836184588576684
Rank  0 , epoch  7 :  0.250192261819265
Rank  0 , epoch  8 :  0.23438266896680474
Rank  0 , epoch  9 :  0.22443408579396795
