In [3]:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

class SimplifiedSyncBatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))

    def forward(self, input):
        if not input.is_contiguous():
            input = input.contiguous()

        # Calculate local mean and variance
        batch_size = input.size(0)
        mean = input.mean([0, 2, 3])
        var = input.var([0, 2, 3], unbiased=False)

        # Synchronize mean and variance across processes
        if dist.is_initialized() and self.training:
            mean_var = torch.stack([mean, var])
            dist.all_reduce(mean_var, op=dist.ReduceOp.SUM)
            mean, var = mean_var / dist.get_world_size()

        # Update running statistics
        if self.training:
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var

        # Normalize
        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        
        # Apply affine transform
        return self.weight[None, :, None, None] * input + self.bias[None, :, None, None]

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class TestModel(nn.Module):
    def __init__(self, norm_layer):
        super(TestModel, self).__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn = norm_layer(64)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

def run_test(rank, world_size):
    print(f"Running on rank {rank}.")
    setup(rank, world_size)

    torch.manual_seed(42 + rank)
    torch.cuda.set_device(rank)
    
    # Create models
    model_builtin = TestModel(nn.SyncBatchNorm).to(rank)
    model_custom = TestModel(SimplifiedSyncBatchNorm).to(rank)

    # Wrap models with DDP
    model_builtin = DDP(model_builtin, device_ids=[rank])
    model_custom = DDP(model_custom, device_ids=[rank])

    # Create some dummy data
    batch_size = 32
    data = torch.randn(batch_size, 3, 64, 64).to(rank)

    # Forward pass
    out_builtin = model_builtin(data)
    out_custom = model_custom(data)

    # Compare outputs
    diff = (out_builtin - out_custom).abs().mean().item()
    print(f"Rank {rank}: Mean absolute difference between outputs: {diff}")

    # Backward pass
    loss_builtin = out_builtin.sum()
    loss_custom = out_custom.sum()

    loss_builtin.backward()
    loss_custom.backward()

    # Compare gradients
    for (name_b, param_b), (name_c, param_c) in zip(model_builtin.named_parameters(), model_custom.named_parameters()):
        if param_b.grad is not None and param_c.grad is not None:
            grad_diff = (param_b.grad - param_c.grad).abs().mean().item()
            print(f"Rank {rank}: Mean absolute difference in gradients for {name_b}: {grad_diff}")

    cleanup()

def main():
    world_size = torch.cuda.device_count()
    print(f"Using {world_size} GPUs")
    mp.spawn(run_test, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()

Using 4 GPUs


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'run_test' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'run_test' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/usr/lib/p

ProcessExitedException: process 0 terminated with exit code 1

In [4]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

CUDA available: True
Number of GPUs: 4


CUDA available: True
Number of GPUs: 4


KeyError: 'LOCAL_RANK'