# 25 August - DDP

# Analysis of Batch Normalization Methods

1. Function Breakdown:

```python
def forward(self, input, pad_mask=None):
    # ... (existing preprocessing) ...
    
    need_sync = self.training and torch.distributed.is_initialized()
    if need_sync:
        process_group = torch.distributed.group.WORLD
        if self.process_group:
            process_group = self.process_group
        world_size = torch.distributed.get_world_size(process_group)
        need_sync = world_size > 1

    if need_sync:
        # Implement synchronization logic here, similar to SyncBatchNorm
        # This would involve gathering statistics from all GPUs and computing global statistics
        pass
    else:
        # Existing PowerNorm logic
        pass
```

Purpose:
This function determines whether synchronization across GPUs is necessary and prepares for it if needed.

Key steps:
a. Check if synchronization is needed:
   - The model is in training mode (`self.training`)
   - Distributed training is initialized (`torch.distributed.is_initialized()`)

b. Set up the process group:
   - Use the default world group or a custom group if specified

c. Get the world size (number of processes/GPUs)

d. Determine if synchronization is actually needed (more than one GPU)

e. If synchronization is needed:
   - Implement logic to gather statistics from all GPUs
   - Compute global statistics
   - Apply these global statistics in the normalization process

f. If synchronization is not needed:
   - Proceed with the standard PowerNorm logic

2. Comparison: F.batch_norm vs sync_batch_norm.apply()

a. F.batch_norm:
   - Standard batch normalization function
   - Operates independently on each GPU in a multi-GPU setup
   - Computes mean and variance using only the local batch on each GPU
   - Faster for single-GPU or small-scale multi-GPU setups
   - May lead to inconsistent statistics across GPUs in large-scale distributed training

Example:
```python
output = F.batch_norm(input, running_mean, running_var, weight, bias,
                      training, momentum, eps)
```

b. sync_batch_norm.apply():
   - Synchronized version of batch normalization
   - Coordinates computation across all GPUs in a distributed setup
   - Computes global mean and variance by aggregating statistics from all GPUs
   - Ensures consistent normalization across the entire model, regardless of data distribution across GPUs
   - More computationally expensive due to inter-GPU communication
   - Crucial for maintaining model consistency in large-scale distributed training

Example:
```python
output = sync_batch_norm.apply(input, weight, bias, running_mean, running_var,
                               eps, momentum, process_group, world_size)
```

Key Differences:
1. Consistency: sync_batch_norm ensures consistent statistics across all GPUs, while F.batch_norm does not.
2. Communication: sync_batch_norm involves inter-GPU communication, F.batch_norm does not.
3. Computational cost: sync_batch_norm is more expensive due to synchronization overhead.
4. Scale of distribution: sync_batch_norm is more suitable for large-scale distributed training.

When to use which:
- Use F.batch_norm for single-GPU training or when batch statistics on each GPU are representative of the whole dataset.
- Use sync_batch_norm.apply() for large-scale distributed training where maintaining consistent statistics across GPUs is crucial for model stability and performance.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm

class GroupScaling1D(nn.Module):
    def __init__(self, eps=1e-5, group_num=4):
        super(GroupScaling1D, self).__init__()
        self.eps = eps
        self.group_num = group_num

    def forward(self, input):
        T, B, C = input.shape
        Cg = C // self.group_num
        gn_input = input.contiguous().reshape(T, B, self.group_num, Cg)
        moment2 = torch.repeat_interleave(torch.mean(gn_input * gn_input, dim=3, keepdim=True),
            repeats=Cg, dim=-1).contiguous().reshape(T, B, C)
        return input / torch.sqrt(moment2 + self.eps)

class SyncPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True, process_group=None,
                 alpha_fwd=0.9, alpha_bkw=0.9, warmup_iters=10000, group_num=1):
        super(SyncPowerNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        self.process_group = process_group

        self.alpha_fwd = alpha_fwd
        self.alpha_bkw = alpha_bkw
        self.warmup_iters = warmup_iters

        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_features))
            self.bias = nn.Parameter(torch.zeros(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        if self.track_running_stats:
            self.register_buffer('running_phi', torch.ones(1, num_features, 1, 1))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_phi', None)
            self.register_parameter('num_batches_tracked', None)

        self.register_buffer('ema_gz', torch.zeros(1, num_features, 1, 1))
        self.gp = GroupScaling1D(group_num=group_num)

    def forward(self, input, pad_mask=None):
        if input.dim() < 2:
            raise ValueError(f"expected at least 2D input (got {input.dim()}D input)")

        # Handle the case where input is (N, C) instead of (N, C, *)
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)

        T, B, C = input.shape
        input = self.gp(input)
        input = input.permute(1, 2, 0).contiguous()  # B x C x T
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)  # B x C x T x 1

        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        need_sync = self.training and torch.distributed.is_available() and torch.distributed.is_initialized()
        if need_sync:
            process_group = torch.distributed.group.WORLD
            if self.process_group:
                process_group = self.process_group
            world_size = torch.distributed.get_world_size(process_group)
            need_sync = world_size > 1

        if not need_sync:
            x2 = (input * input).mean(dim=(0, 2, 3), keepdim=True)
            if self.training:
                if self.num_batches_tracked <= self.warmup_iters:
                    z = input / (x2 + self.eps).sqrt()
                else:
                    z = input / (self.running_phi + self.eps).sqrt()
                self.running_phi = self.alpha_fwd * self.running_phi + (1 - self.alpha_fwd) * x2
            else:
                z = input / (self.running_phi + self.eps).sqrt()
        else:
            x2 = (input * input).mean(dim=(2, 3), keepdim=True)  # B x C x 1 x 1
            x2_all = torch.empty(world_size, B, C, 1, 1, dtype=x2.dtype, device=x2.device)
            torch.distributed.all_gather_into_tensor(x2_all, x2, group=process_group)
            x2 = x2_all.mean(dim=0)  # Average across all processes

            if self.training:
                if self.num_batches_tracked <= self.warmup_iters:
                    z = input / (x2 + self.eps).sqrt()
                else:
                    z = input / (self.running_phi + self.eps).sqrt()
                self.running_phi = self.alpha_fwd * self.running_phi + (1 - self.alpha_fwd) * x2
            else:
                z = input / (self.running_phi + self.eps).sqrt()

        if self.affine:
            z = self.weight.view(1, C, 1, 1) * z + self.bias.view(1, C, 1, 1)

        output = z.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()  # T x B x C

        if shaped_input:
            output = output.squeeze(0)

        return output

    @staticmethod
    def convert_sync_powernorm(module, process_group=None):
        module_output = module
        if isinstance(module, MaskPowerNorm):
            module_output = SyncPowerNorm(module.num_features, module.eps, module.afwd,
                                          module.affine, module.track_running_stats,
                                          process_group, module.alpha_fwd, module.alpha_bkw,
                                          module.warmup_iters, module.group_num)
            if module.affine:
                module_output.weight = module.weight
                module_output.bias = module.bias
            module_output.running_phi = module.running_phi
            module_output.num_batches_tracked = module.num_batches_tracked
        for name, child in module.named_children():
            module_output.add_module(name, SyncPowerNorm.convert_sync_powernorm(child, process_group))
        del module
        return module_output

# sync batch norm implementation

In [1]:
import torch
import torch.distributed as dist
from torch.autograd.function import Function


class SyncBatchNorm(Function):
    @staticmethod
    def forward(
        self,
        input,
        weight,
        bias,
        running_mean,
        running_var,
        eps,
        momentum,
        process_group,
        world_size,
    ):
        # contiguity check <-- ensures input and weight tensors are in a contiguous format
        if not (
            input.is_contiguous(memory_format=torch.channels_last)
            or input.is_contiguous(memory_format=torch.channels_last_3d)
        ):
            input = input.contiguous()
        if weight is not None:
            weight = weight.contiguous()

        # ensures sufficient data points per channel for meaningful normalization. Prevent statistical anomalies due to insufficient data.
        size = int(input.numel() // input.size(1))
        if size == 1 and world_size < 2:
            raise ValueError(
                f"Expected more than 1 value per channel when training, got input size {size}"
            )

        num_channels = input.shape[1]
        if input.numel() > 0:
            # calculate mean/invstd for input. <-- COMPUTES THE MEAN AND INVERSE STD FOR INPUT TENSOR, these stats are concatenated along with count of elements per channel, forming a combined tensor which is synchronised across all processesß
            mean, invstd = torch.batch_norm_stats(input, eps) # <-- BATCH_NORM_STATS, reverse engineering required

            count = torch.full(
                (1,),
                input.numel() // input.size(1),
                dtype=mean.dtype,
                device=mean.device,
            )

            # C, C, 1 -> (2C + 1)
            combined = torch.cat([mean, invstd, count], dim=0)
        else:
            # for empty input, set stats and the count to zero. The stats with
            # zero count will be filtered out later when computing global mean
            # & invstd, but they still needs to participate the all_gather
            # collective communication to unblock other peer processes.
            combined = torch.zeros(
                2 * num_channels + 1, dtype=input.dtype, device=input.device
            )

        # Use allgather instead of allreduce because count could be different across
        # ranks, simple all reduce op can not give correct results.
        # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
        # all gathered mean, invstd and count.
        # for nccl backend, use the optimized version of all gather.
        # The Gloo backend does not support `all_gather_into_tensor`.
        if process_group._get_backend_name() != "gloo":
            # world_size * (2C + 1)
            combined_size = combined.numel()
            combined_flat = torch.empty(
                1,
                combined_size * world_size,
                dtype=combined.dtype,
                device=combined.device,
            )
            dist.all_gather_into_tensor(
                combined_flat, combined, process_group, async_op=False
            )
            combined = torch.reshape(combined_flat, (world_size, combined_size))
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
        else:
            # world_size * (2C + 1)
            combined_list = [torch.empty_like(combined) for _ in range(world_size)]
            dist.all_gather(combined_list, combined, process_group, async_op=False)
            combined = torch.stack(combined_list, dim=0)
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)

        if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()):
            # The lines below force a synchronization between CUDA and CPU, because
            # the shape of the result count_all depends on the values in mask tensor.
            # Such synchronizations break CUDA Graph capturing.
            # See https://github.com/pytorch/pytorch/issues/78549
            # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
            # a better longer-term solution.

            # remove stats from empty inputs
            mask = count_all.squeeze(-1) >= 1
            count_all = count_all[mask]
            mean_all = mean_all[mask]
            invstd_all = invstd_all[mask]

        # calculate global mean & invstd
        counts = count_all.view(-1)
        if running_mean is not None and counts.dtype != running_mean.dtype:
            counts = counts.to(running_mean.dtype)
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input,
            mean_all,
            invstd_all,
            running_mean,
            running_var,
            momentum,
            eps,
            counts,
        )

        self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
        self.process_group = process_group

        # apply element-wise normalization
        if input.numel() > 0:
            return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
        else:
            return torch.empty_like(input)

    @staticmethod
    def backward(self, grad_output):
        if not (
            grad_output.is_contiguous(memory_format=torch.channels_last)
            or grad_output.is_contiguous(memory_format=torch.channels_last_3d)
        ):
            grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
        grad_input = grad_weight = grad_bias = None
        process_group = self.process_group

        if saved_input.numel() > 0:
            # calculate local stats as well as grad_weight / grad_bias
            (
                sum_dy,
                sum_dy_xmu,
                grad_weight,
                grad_bias,
            ) = torch.batch_norm_backward_reduce(
                grad_output,
                saved_input,
                mean,
                invstd,
                weight,
                self.needs_input_grad[0],
                self.needs_input_grad[1],
                self.needs_input_grad[2],
            )

            if self.needs_input_grad[0]:
                # synchronizing stats used to calculate input gradient.
                num_channels = sum_dy.shape[0]
                combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
                torch.distributed.all_reduce(
                    combined,
                    torch.distributed.ReduceOp.SUM,
                    process_group,
                    async_op=False,
                )
                sum_dy, sum_dy_xmu = torch.split(combined, num_channels)

                # backward pass for gradient calculation
                if weight is not None and weight.dtype != mean.dtype:
                    weight = weight.to(mean.dtype)
                grad_input = torch.batch_norm_backward_elemt(
                    grad_output,
                    saved_input,
                    mean,
                    invstd,
                    weight,
                    sum_dy,
                    sum_dy_xmu,
                    count_tensor,
                )
            # synchronizing of grad_weight / grad_bias is not needed as distributed
            # training would handle all reduce.
            if weight is None or not self.needs_input_grad[1]:
                grad_weight = None

            if weight is None or not self.needs_input_grad[2]:
                grad_bias = None
        else:
            # This process got an empty input tensor in the forward pass.
            # Although this process can directly set grad_input as an empty
            # tensor of zeros, it still needs to participate in the collective
            # communication to unblock its peers, as other peer processes might
            # have received non-empty inputs.
            num_channels = saved_input.shape[1]
            if self.needs_input_grad[0]:
                # launch all_reduce to unblock other peer processes
                combined = torch.zeros(
                    2 * num_channels, dtype=saved_input.dtype, device=saved_input.device
                )
                torch.distributed.all_reduce(
                    combined,
                    torch.distributed.ReduceOp.SUM,
                    process_group,
                    async_op=False,
                )

            # Leave grad_input, grad_weight and grad_bias as None, which will be
            # interpreted by the autograd engine as Tensors full of zeros.

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
