In [12]:
#importing libraries
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ['MASTER_ADDR']='localhost'
    os.environt['MASTER_PORT']= '12355'
    
    #intialising the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()
    


#creating toymodel

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10,10)
        self.relu = nn.ReLu()
        self.net2 = nn.Linear(10,5)
    
    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)
    
    #creatimg a model and moving it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    
    loss_fn = nn.MSELoss()
    optimizer = optim.SSD(ddp_model.parameters(), lr=0.001)
    
    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20,10))
    labels = torch.randn(20,5).to(rank)
    loss_fn(outputs, labels).backward()
    optimiser.step()
    
    cleanup()

def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)



#save and load checkpoints

def demo_checkpoint(rank, world_size):
    print(f"Running DDP checkpoint example on rank {rankj}.")
    setup(rank, world_size)
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    
    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
    if rank == 0:
        #all processes should see same parameters as they all start from same
        #random parameters and gradients are synchronized in backward passes
        #Therefore, saving in in one process is sufficient
        
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
        
        #Use barrier() to make sure that process 1 loads the model after process
        #0 saves it
        dist.barrier()
        #configure map_location properly
        map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
        ddp_model.load_state_dict(
            torch.load(CHECKPOINT_PATH, map_location=map_location))
        
        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
        
        optimizer.zero_grad()
        outputs= ddp_model(torch.randn(20,10))
        labels = torch.randn(20,5).to(rank)
        
        loss_fn(outputs, labels).backward()
        optimizer.step()
        
        #not necessary to use a dist.barrier)() to guard the file deletion 
        #as the allReduce ops in the backward pass of DDP already served as a synchronization
        
        if rank == 0:
            os.remove(CHECKPOINT_PATH)
        
        cleanup()


#Combing DDP with Model Parallelism

class ToyMpModel(nn.Module):
    def __init__(self, dev0, dev1):
        super(TopMpModel, self). __init__()
        self.dev0 = dev0
        self.dev1 = dev1
        self.net1 = torch.nn.Linear(10,10).to(dev0)
        self.relu = torch.nn.Linear(10,5).to(dev1)
    
    def forward(self,x):
        x = x.to(self.dev0)
        x = self.relu(self.net1(x))
        x = x.to(self.dev1)
        return self.net2(x)
    
    
def demo_model_parallel(rank, world_size):
    print(f"Running DDP with model parallel example on rank {rank}.")
    setup(rank, world_size)

    # setup mp_model and devices for this process
    dev0 = (rank * 2) % world_size
    dev1 = (rank * 2 + 1) % world_size
    mp_model = ToyMpModel(dev0, dev1)
    ddp_mp_model = DDP(mp_model)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    # outputs will be on dev1
    outputs = ddp_mp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(dev1)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_basic, world_size)
    run_demo(demo_checkpoint, world_size)
    run_demo(demo_model_parallel, world_size)

ProcessExitedException: process 0 terminated with exit code 1