In [1]:
!pip install fairscale

Collecting fairscale
  Downloading fairscale-0.4.6.tar.gz (248 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m248.2/248.2 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: fairscale
  Building wheel for fairscale (pyproject.toml) ... [?25ldone
[?25h  Created wheel for fairscale: filename=fairscale-0.4.6-py3-none-any.whl size=307224 sha256=a3d7fe85da2bb39f237abdba3033dfd42ee12dcb2177778b1200f5afc5d14cfa
  Stored in directory: /root/.cache/pip/wheels/0b/8c/fa/a9e102632bcb86e919561cf25ca1e0dd2ec67476f3a5544653
Successfully built fairscale
Installing collected packages: fairscale
Successfully installed fairscale-0.4.6
[0m

#### 1. OSS (Optimizer state sharding)

In [5]:
%%writefile oss.py

import os
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
import torch.distributed as dist
import torchvision.transforms as transforms
import torch.multiprocessing as mp

from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP

from functools import partial
from tqdm import tqdm

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_loader():
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    batch_size = 64
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
    dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                              shuffle=False, num_workers=2)
    return dataloader

def init_process(rank, size, epochs, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, epochs)
    
def train(
    rank: int,
    world_size: int,
    epochs: int):
    
    dataloader = get_loader()
    
    device = torch.device(f'cuda:{rank}')
    model = Net().to(device)
    loss_fn = torch.nn.CrossEntropyLoss()

    # optimizer specific arguments e.g. LR, momentum, etc...
    base_optimizer_arguments = {"lr": 1e-4}

    # Wrap a base optimizer into OSS
    base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
    optimizer = OSS(
        params=model.parameters(),
        optim=base_optimizer,
        **base_optimizer_arguments)

    # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
    model = ShardedDDP(model, optimizer)

    print('-' * 50)
    print(f'Optim params in rank: {rank}')
    for elem in optimizer.partition_parameters()[rank]:
        for param in elem['params']:
            print(f'shape: {param.shape}')
    print('-' * 50)
    
    model.train()
    for e in range(epochs):
        for idx, (data, target) in enumerate(tqdm(dataloader)):
            data, target = data.to(device), target.to(device)
            # new
            model.zero_grad()
            outputs = model(data)
            loss = loss_fn(outputs, target)
            loss.backward()
            optimizer.step()
                                
size, epochs = 2, 1

if __name__ == '__main__':
    fn = partial(init_process, size=size, epochs=epochs, fn=train, backend='nccl')
    mp.spawn(
            fn,
            nprocs=size,
            join=True
        )

Overwriting oss.py


#### We can see that the optimizer has split the parameters between workers.

In [4]:
# self.conv1 = nn.Conv2d(3, 6, 5)
# self.pool = nn.MaxPool2d(2, 2)
# self.conv2 = nn.Conv2d(6, 16, 5)
# self.fc1 = nn.Linear(400, 120)
# self.fc2 = nn.Linear(120, 84)
# self.fc3 = nn.Linear(84, 10)
!python3 oss.py

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
100%|███████████████████████| 170498071/170498071 [00:03<00:00, 45046121.72it/s]
 89%|████████████████████▍  | 151683072/170498071 [00:04<00:00, 48235582.26it/s]Extracting ./data/cifar-10-python.tar.gz to ./data
100%|███████████████████████| 170498071/170498071 [00:04<00:00, 38430142.82it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
--------------------------------------------------
Optim params in rank :0
shape: torch.Size([6, 3, 5, 5])
shape: torch.Size([16])
shape: torch.Size([120, 400])
--------------------------------------------------
  0%|                                                   | 0/157 [00:00<?, ?it/s]--------------------------------------------------
Optim params in rank :1
shape: torch.Size([6])
shape: torch.Size([16, 6, 5, 5])
shape: torch.Size([120])
shape:

#### 2. FSDP (Fully Sharded Data Parallel)
Interesting params:
* mixed_precision
* move_params_to_cpu
* move_grads_to_cpu

You can use mixed_precision, but with special scaler ShardedGradScaler.

In [52]:
%%writefile fsdp.py

import os
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
import torch.distributed as dist
import torchvision.transforms as transforms
import torch.multiprocessing as mp

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.optim.grad_scaler import ShardedGradScaler


from functools import partial
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_loader():
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    batch_size = 64
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
    dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                              shuffle=False, num_workers=2)
    return dataloader

def init_process(rank, size, epochs, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '30000'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, epochs)
    
def train(
    rank: int,
    world_size: int,
    epochs: int):
    
    torch.cuda.set_device(rank)
    
    dataloader = get_loader()
    model = Net()
    base_optimizer_arguments = {"lr": 1e-4}
    
    model = FSDP(
        model,
        mixed_precision=True,
        reshard_after_forward=True,
        move_params_to_cpu=True,
        move_grads_to_cpu=True
    )
    
    optimizer = torch.optim.SGD(
        params=model.parameters(),
        **base_optimizer_arguments
    )
    
    loss_fn = torch.nn.CrossEntropyLoss()
    scaler = ShardedGradScaler()
    
    # uncomment if move_params_to_cpu=False
    # model = model.to(rank)
    
    model.train()
    for e in range(epochs):
        for idx, (data, target) in enumerate(tqdm(dataloader)):
            data, target = data.to(rank), target.to(rank)
            model.zero_grad(set_to_none=True)
            
            with torch.autocast(device_type='cuda'):
                outputs = model(data)
                loss = loss_fn(outputs, target)
                
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            if idx == 0:
                if rank == 1:
                    dist.barrier()
                print('-' * 50)
                print(f'rank: {rank}')
                print(f'Param after backward')
                for param in optimizer.param_groups[0]['params']:
                    print(f'Shape: {param.shape}')
                    print(param)
                print('-' * 50)
                if rank == 0:
                    dist.barrier()
                                
size, epochs = 2, 1

if __name__ == '__main__':
    fn = partial(init_process, size=size, epochs=epochs, fn=train, backend='nccl')
    mp.spawn(
            fn,
            nprocs=size,
            join=True
        )

Overwriting fsdp.py


#### Number of parameters: 62006, at each gpu: 31003

In [53]:
!python3 fsdp.py

Files already downloaded and verified
Files already downloaded and verified
  0%|                                                   | 0/157 [00:00<?, ?it/s]--------------------------------------------------
rank: 0
Param after backward
Shape: torch.Size([31003])
Parameter containing:
Parameter(FlatParameter([ 0.0477, -0.0293,  0.0357,  ..., -0.0013,  0.0012,  0.0489],
              requires_grad=True))
--------------------------------------------------
--------------------------------------------------
rank: 1
Param after backward
Shape: torch.Size([31003])
  1%|▎                                          | 1/157 [00:11<30:40, 11.80s/it]Parameter containing:
Parameter(FlatParameter([-0.0184,  0.0203,  0.0362,  ..., -0.0378, -0.0753, -0.1072],
              requires_grad=True))
--------------------------------------------------
100%|█████████████████████████████████████████| 157/157 [00:18<00:00,  8.27it/s]
100%|█████████████████████████████████████████| 157/157 [00:18<00:00,  8.26it/s]


#### 3. Wrap individual modules

In [81]:
%%writefile wrap.py

import os
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
import torch.distributed as dist
import torchvision.transforms as transforms
import torch.multiprocessing as mp

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.wrap import wrap, enable_wrap, auto_wrap

from functools import partial
from tqdm import tqdm

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_loader():
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    batch_size = 64
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
    dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                              shuffle=False, num_workers=2)
    return dataloader

def init_process(rank, size, epochs, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, epochs)
    
def train(
    rank: int,
    world_size: int,
    epochs: int):

    torch.cuda.set_device(rank)
    
    model = Net()
    model = model.to(rank)
    
    model.train()
    model = model.to(rank)
    fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
    with enable_wrap(**fsdp_params):
        
        # auto_wrap_policy=functools.partial(default_auto_wrap_policy, 
        # min_num_params=1e3)
        
        model.fc3 = wrap(model.fc3)
        if rank == 1:
            dist.barrier()
        print('-' * 50)
        print(f'RANK: {rank}')
        print(f'Wrapped: {model.fc3}')
        print(f'Unwrapped: {model.conv2}')
        print('-' * 50)
        if rank == 0:
            dist.barrier()
                                
size, epochs = 2, 1

if __name__ == '__main__':
    fn = partial(init_process, size=size, epochs=epochs, fn=train, backend='nccl')
    mp.spawn(
            fn,
            nprocs=size,
            join=True
        )

Overwriting wrap.py


In [82]:
!python3 wrap.py

--------------------------------------------------
RANK: 0
Wrapped: FullyShardedDataParallel(
  world_size=2, flatten_parameters=True, mixed_precision=True, 
  (_fsdp_wrapped_module): FlattenParamsWrapper(
    (_fpw_module): Linear(in_features=84, out_features=10, bias=True)
  )
)
Unwrapped: Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
--------------------------------------------------
--------------------------------------------------
RANK: 1
Wrapped: FullyShardedDataParallel(
  world_size=2, flatten_parameters=True, mixed_precision=True, 
  (_fsdp_wrapped_module): FlattenParamsWrapper(
    (_fpw_module): Linear(in_features=84, out_features=10, bias=True)
  )
)
Unwrapped: Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
--------------------------------------------------


4. #### Slowmo DDP

SlowMo Distributed Data Parallel reduces the communication between different nodes while performing data parallel training.

In [92]:
%%writefile slowmo.py

import os
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
import torch.distributed as dist
import torchvision.transforms as transforms
import torch.multiprocessing as mp

from fairscale.experimental.nn.data_parallel \
        import SlowMoDistributedDataParallel as SlowMoDDP

from functools import partial
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_loader():
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    batch_size = 64
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
    dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                              shuffle=False, num_workers=2)
    return dataloader

def init_process(rank, size, epochs, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, epochs)
    
def train(
    rank: int,
    world_size: int,
    epochs: int):
    
    torch.cuda.set_device(rank)
    
    dataloader = get_loader()
    
    model = Net().to(rank)
    model = SlowMoDDP(model, slowmo_momentum=0.5, nprocs_per_node=2)
    
    base_optimizer_arguments = {"lr": 1e-4}
    optimizer = torch.optim.SGD(
        params=model.parameters(),
        **base_optimizer_arguments
    )
    
    loss_fn = torch.nn.CrossEntropyLoss()
    
    model.train()
    for e in range(epochs):
        for idx, (data, target) in enumerate(tqdm(dataloader)):
            data, target = data.to(rank), target.to(rank)
            # new
            model.zero_grad(set_to_none=True)
            outputs = model(data)
            loss = loss_fn(outputs, target)
            loss.backward()
            optimizer.step()
            
            model.perform_slowmo(optimizer)
                                
size, epochs = 2, 1

if __name__ == '__main__':
    fn = partial(init_process, size=size, epochs=epochs, fn=train, backend='nccl')
    mp.spawn(
            fn,
            nprocs=size,
            join=True
        )

Overwriting slowmo.py


In [93]:
!python3 slowmo.py

Files already downloaded and verified
Files already downloaded and verified
100%|█████████████████████████████████████████| 157/157 [00:15<00:00, 10.08it/s]
100%|█████████████████████████████████████████| 157/157 [00:15<00:00,  9.97it/s]
