In [1]:
import math
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.multiprocessing as mp

import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

<br>

# Synching batch norm
---

Do not forget to synchronize batch normalization across nodes in distributed training, or the network might learn something specific to your specific shard of data:

In [7]:
sbn = nn.SyncBatchNorm(num_features=32) # synchronized automatically
bn = nn.BatchNorm2d(num_features=32) # not synchronized

# Automatic transformation (can be applied recursively to a network)
nn.SyncBatchNorm.convert_sync_batchnorm(bn)

SyncBatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

Links:

* https://pytorch.org/docs/master/generated/torch.nn.SyncBatchNorm.html
* https://pytorch.org/docs/master/generated/torch.nn.SyncBatchNorm.html#torch.nn.SyncBatchNorm.convert_sync_batchnorm

<br>

# Calling eval()
---

Do not forget to call eval() during inference time, otherwise the batch normalization module will still adapt to the data, which is not something you likely wish outside training time:

In [23]:
bn = nn.BatchNorm1d(num_features=2)
x = torch.zeros(size=(100, 2))

# Moving values
bn.train()
for _ in range(3):
    x.normal_()
    bn(x)
    print(bn.running_mean)

# Frozen values
bn.eval()
for _ in range(2):
    x.normal_()
    bn(x)
    print(bn.running_mean)

tensor([-0.0067,  0.0124])
tensor([ 0.0052, -0.0009])
tensor([ 0.0072, -0.0063])
tensor([ 0.0072, -0.0063])
tensor([ 0.0072, -0.0063])
tensor([ 0.0072, -0.0063])
