# Yang et al., 'Mean Field Theory of Batch Norm', ICLR 2019
Reproduces experiment in Figure 2/6. 
https://openreview.net/pdf?id=SyMDXnCcF7

### Slicing explained

In [None]:
import numpy as np
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

# to normalize without for-loop
from torch import FloatTensor
from torch.nn.functional import normalize


N=int(1e8)
d=3
save_dir = './'

# generate datapoints distributed uniformly on a (d-1)-sphere with radius 1
x = np.random.multivariate_normal(np.zeros(d), np.eye(d), N)
x_normed = normalize(FloatTensor(x), dim=1, eps=1e-16).numpy()

A (d-1)-sphere is described by the equation of d coordinates and the radius R:
`\sum\limits_{i=1}^d x_i^2 = R^2`

A slice of a (d-1)-sphere is described by 2 coordinates constrained as:
`x_1^2 + x_2^2 = R^2 - \sum\limits_{i=3}^d x_i^2`.

Thus, the points from the sphere corresponding to a slice in the x_1-x_2-plane obey 
`\sum\limits_{i=3}^d x_i^2 = const`.

In [None]:
eff=1e-6
set_const = 0.5
mask=(abs(np.sum(x_normed[:,2:]**2, axis=1)-set_const)<eff)

fig, ax = plt.subplots(1, 1, figsize=(6,6))
ax.scatter(x_normed[mask,0], x_normed[mask,1])
fig.savefig(save_dir+'slice.png')

In [None]:
import os
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

# for mutual information
#from mnist.ft_utils import mi
from numpy.linalg import det, inv

# for advex
from torch.autograd import Variable

# for 3D viz of decision boundary
from mpl_toolkits.mplot3d import Axes3D

# for linear model
import torch
import torch.nn as nn
import torch.utils.data

from sklearn.decomposition import PCA

In [None]:
R = 1.3
LABEL_OUTER_SPHERE = 1


def generate_clean_batch(batch_size, d):
    """
    Returns a mini-batch of synthetic spheres in online setting 
    @param batch_size of mini-batch
    @param d: dimensionality of the spheres
    """
    x = np.random.multivariate_normal(np.zeros(d), np.eye(d), batch_size)

    x_train = np.zeros((batch_size, d))
    y_train = np.zeros((batch_size, 1))

    x_shuff = np.zeros((batch_size, d))
    y_shuff = np.zeros((batch_size, 1))

    euclidean_norm = np.linalg.norm(x, axis=1)

    # outer sphere, radius R
    for i in range(batch_size // 2):
        x_train[i, :] = R * x[i, :] / euclidean_norm[i]

    # inner sphere, radius 1
    for i in range(batch_size // 2, batch_size):
        x_train[i, :] = x[i, :] / euclidean_norm[i]

    # assign training labels
    y_train[:batch_size // 2] = LABEL_OUTER_SPHERE

    return (x_train, y_train)

In [None]:
N=10000 # train on 10k pts
test_N=100000 # eval on 1M pts
d=30
train_x, train_y = generate_clean_batch(N, d)
val_x, val_y = generate_clean_batch(N, d)
test_x, test_y = generate_clean_batch(test_N, d)

In [None]:
train_y = train_y.reshape(-1)
val_y = val_y.reshape(-1)
test_y = test_y.reshape(-1)

In [None]:
np.argmin(train_y)

In particular, when we sampled 10,000 random points on the inner sphere, the nearest pair was distance 1.25 away
from each other.

In [None]:
np.linalg.norm(train_x[train_y==0], 2, 1)

In [None]:
#i = 0
idx = N // 4
print(np.linalg.norm(val_x[val_y==0][:idx] - val_x[val_y==0][idx:], 2, axis=1).min())
idx = test_N // 4
print(np.linalg.norm(test_x[test_y==0][:idx] - test_x[test_y==0][idx:], 2, axis=1).min())
#train_x[train_y==0][i + 1]

In [None]:
#idx = N // 4
#np.linalg.norm(X_val_inner[:idx] - X_val_inner[idx:], 2, axis=1).min()

In [None]:
#adv_x[adv_y==0].shape
#test_N // 100

Visualize dataset

In [None]:
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 3))
colors_txt = ['red', 'blue']
ax1.scatter(train_x[:, 0], train_x[:, 1], c=train_y, alpha=0.5, 
            cmap=mpl.colors.ListedColormap(colors_txt))
ax1.set_title("Train")
ax2.scatter(test_x[:, 0], test_x[:, 1], c=test_y, alpha=0.5, 
            cmap=mpl.colors.ListedColormap(colors_txt))
ax2.set_title("Test")
"""
#fig.savefig('AdversarialSpheres.png')

In [None]:
softmax = torch.nn.Softmax(dim=0)
batch_softmax = torch.nn.Softmax(dim=1)

seed = 1234
torch.manual_seed(seed)

In [None]:
# Train
X_train = torch.tensor(train_x, dtype=torch.float)
Y_train = torch.tensor(train_y, dtype=torch.long)

train_batch_size = 50
train_dataset = torch.utils.data.TensorDataset(X_train, Y_train)
train_loader_noshuffle = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=1000, shuffle=False)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=train_batch_size, shuffle=True)

# Val
X_val = torch.tensor(val_x, dtype=torch.float)
Y_val = torch.tensor(val_y, dtype=torch.long)

val_dataset = torch.utils.data.TensorDataset(X_val, Y_val)
val_loader_noshuffle = torch.utils.data.DataLoader(
    dataset=val_dataset, batch_size=1000, shuffle=False)

# Val inner sphere only
X_val_inner = torch.tensor(val_x[N // 2:], dtype=torch.float)
Y_val_inner = torch.tensor(val_y[N // 2:], dtype=torch.long)
val_inner_dataset = torch.utils.data.TensorDataset(X_val_inner, Y_val_inner)
val_inner_loader_noshuffle = torch.utils.data.DataLoader(
    dataset=val_inner_dataset, batch_size=N // 2, shuffle=False)

# Test
X_test = torch.tensor(test_x, dtype=torch.float)
Y_test = torch.tensor(test_y, dtype=torch.long)

# create datasets
test_dataset = torch.utils.data.TensorDataset(X_test, Y_test)
test_loader_noshuffle = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=N // 10, shuffle=False) # 100k / 10 = 10k

# Tst inner sphere only
X_tst_inner = torch.tensor(test_x[test_N // 2:], dtype=torch.float)
Y_tst_inner = torch.tensor(test_y[test_N // 2:], dtype=torch.long)
tst_inner_dataset = torch.utils.data.TensorDataset(X_tst_inner, Y_tst_inner)
tst_inner_loader_noshuffle = torch.utils.data.DataLoader(
    dataset=tst_inner_dataset, batch_size=N // 20, shuffle=False)

In [None]:
activations = []
def hook(module, input, output):
    activations.append(output)

In [None]:
class ReLUNetwork(nn.Module):
    def __init__(self, input_size, num_layers, num_units, num_classes, do_batch_norm=False):
        super(ReLUNetwork, self).__init__()
        self.do_batch_norm = do_batch_norm
        self.input_layer = nn.Linear(input_size, num_units, bias=True)
        self.features = self._make_layers(num_layers, num_units)
        self.classifier = nn.Linear(num_units, num_classes, bias=True)

    def forward(self, inputs):
        """Forward pass, returns outputs of each layer. Use last out (final) for backprop!"""
        out = self.input_layer(inputs)
        out = self.features(out)
        out = self.classifier(out)
        return out
    
    def _make_layers(self, num_layers, num_units):
        layers = []
        for i in range(num_layers):
            if self.do_batch_norm:
                layers += [nn.BatchNorm1d(num_units, momentum=None),
                           nn.Linear(num_units, num_units, bias=True),
                           nn.ReLU()]
            else:
                layers += [nn.Linear(num_units, num_units, bias=True),
                           nn.ReLU()]
        return nn.Sequential(*layers)

In [None]:
num_units = 1000
num_layers = 60
do_batch_norm = True
N_CLASSES = 2
#device = torch.device('cpu')
device = torch.device('cuda:0')
model = ReLUNetwork(train_x.shape[1], num_layers, num_units, N_CLASSES, do_batch_norm=do_batch_norm).to(device)
#model = LinearNetwork(train_x.shape[1], num_layers, num_units, N_CLASSES, do_batch_norm=do_batch_norm).to(device)
loss_fnct = nn.CrossEntropyLoss()

In [None]:
model

In [None]:
j = 0
L = num_layers * 3 # ReLUNetwork
register_idx = 0

for i in range(L):
    if i == register_idx:
        j += 1
        register_idx += 3
        print('%d %s' % (j, model.features[i]))
        model.features[i].register_forward_hook(hook)

In [None]:
activations=[]
with torch.no_grad():
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        out = model(inputs)
        break
print(len(activations))

In [None]:
inputs.norm(p=2, dim=1)

In [None]:
#activations[0].shape
#for i in range(num_layers):
#    print(i, activations[i].norm(p=2, dim=1).var())
plt.hist(activations[38].norm(p=2, dim=1).detach().cpu().numpy())
#print(activations[0].norm(p=2, dim=1).max())

# Train

In [None]:
#optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

In [None]:
dataset_type = ['trn', 'tst']
lss = {}    # cross-entropy loss
acc = {}    # prediction accuracy
lss['trnmb'] = [] # one entry per minibatch

for dst in dataset_type:
    lss[dst] = [] # one entry per epoch
    acc[dst] = [] # one entry per epoch
        
total_step = len(train_loader)

max_epochs = 20
for epoch in range(max_epochs):
    model.train()
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = loss_fnct(outputs, labels)
        lss['trnmb'].append(loss.item()) # record training loss
        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # record train loss (avg value over mini-batches)
    lss['trn'].append(np.mean(lss['trnmb'][-total_step:]))

    model.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        for inputs, labels in train_loader_noshuffle:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        train_acc = float(correct) / total
    
    with torch.no_grad():
        total = 0
        correct = 0
        test_mb_loss = 0
        for inputs, labels in val_loader_noshuffle:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            test_mb_loss += loss_fnct(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        test_acc = float(correct) / total
        lss['tst'].append(test_mb_loss / len(val_loader_noshuffle))

        print('Epoch [{}/{}], Loss: {:.5f} (Train), {:.5f} (Val); Acc: {:.5f} (Train), {:.5f} (Val)'
              .format(epoch, max_epochs, lss['trn'][-1], lss['tst'][-1], train_acc, test_acc))

In [None]:
total = 0
correct = 0
model.eval()
for inputs, labels in test_loader_noshuffle:
    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = model(inputs)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct_idx = predicted == labels 
    correct += (correct_idx).sum().item()
test_acc = float(correct) / total
print(test_acc)

In [None]:
total = 0
correct = 0
d_err_min = 10
model.eval()
for inputs, labels in val_inner_loader_noshuffle:
    
    labels = labels.to(device)
    inputs = inputs.to(device)
    advex = inputs.clone()
    
    for i in range(100):
        advex = fgsm_l2(model, advex, labels, 0.1)
        # project back on to manifold
        advex /= torch.unsqueeze(advex.norm(p=2, dim=1), dim=1)
        advex[labels==LABEL_OUTER_SPHERE] *= R
        
    outputs = model(advex)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct_idx = predicted == labels 
    correct += (correct_idx).sum().item()
    
    # compute nearest error on inner sphere
    d_err = (inputs[correct_idx==0] - advex[correct_idx==0]).norm(p=2, dim=1).min().item()
    if d_err < d_err_min:
        d_err_min = d_err
    
test_acc = float(correct) / total
print(test_acc, d_err_min)

In [None]:
len(tst_inner_loader_noshuffle)

In [None]:
total = 0
correct = 0
model.eval()
for inputs, labels in tst_inner_loader_noshuffle:
    labels = labels.to(device)
    inputs = inputs.to(device)
    outputs = model(inputs)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct_idx = predicted == labels 
    correct += (correct_idx).sum().item()    
test_acc = float(correct) / total
print(test_acc)

In [None]:
#correct_idx[10] = 0

In [None]:
total = 0
correct = 0
d_err_min = 10
model.eval()
for i, (inputs, labels) in enumerate(val_inner_loader_noshuffle):
    
    labels = labels.to(device)
    inputs = inputs.to(device)
    advex = inputs.clone()
    
    for j in range(1000):
        advex = fgsm_l2(model, advex, labels, 0.01)
        # project back on to manifold
        advex /= torch.unsqueeze(advex.norm(p=2, dim=1), dim=1)
        #advex[labels==LABEL_OUTER_SPHERE] *= R
        
        outputs = model(advex)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct_idx = predicted == labels 
        correct += (correct_idx).sum().item()
        
        if (correct_idx == 0).sum() > 0:
            # compute nearest error on inner sphere
            d_err = (inputs[correct_idx==0] - advex[correct_idx==0]).norm(p=2, dim=1).min().item()
            if d_err < d_err_min:
                print('itr %d got d = %f' % (j, d_err))
                d_err_min = d_err
            break
    print('%d of %d' % (i, len(val_inner_loader_noshuffle)))
    
test_acc = float(correct) / total
print(test_acc, d_err_min)

In [None]:
total = 0
correct = 0
#d_err_min = 10
model.eval()
for inputs, labels in test_loader_noshuffle:
    inputs = inputs.to(device)
    labels = labels.to(device)
    for i in range(100):
        inputs = fgsm_l2(model, inputs, labels, 0.1)
        # project back on to manifold
        inputs /= torch.unsqueeze(inputs.norm(p=2, dim=1), dim=1)
        inputs[labels==LABEL_OUTER_SPHERE] *= R
    outputs = model(inputs)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
test_acc = float(correct) / total
print(test_acc)

In [None]:
for i, (cln_inputs, labels) in enumerate(test_loader_noshuffle):
    cln_inputs = cln_inputs.to(device)
    labels = labels.to(device)

## compute d(E) - distance to nearest error

In [None]:
print((cln_inputs[labels==0] - inputs[labels==0]).norm(p=2, dim=1).min().item())
print((cln_inputs[labels==1] - inputs[labels==1]).norm(p=2, dim=1).min())

In [None]:
criterion = nn.CrossEntropyLoss(reduction='sum')

def fgsm_l2(model, im, labels, eps):
    """Evaluate model on FGM."""
    x_ = Variable(im, requires_grad=True)
    red_ind = list(np.arange(1, len(x_.shape)))
    loss = criterion(model(x_), labels)
    loss.backward()
    loss_grad = x_.grad.data.clone()
    square = torch.max(torch.FloatTensor([1e-12]).to(device),  # to prevent div by zero
                       torch.sum(loss_grad**2, dim=red_ind, keepdim=True))
    normalized_loss_grad = loss_grad / torch.sqrt(square)
    adv_img = x_.detach() + (eps * normalized_loss_grad).to(device)
    return adv_img

In [None]:
print(inputs[labels==1].norm(2))
print(inputs[labels==0].norm(2))

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 3))

colors_txt = ['red', 'blue']
ax1.scatter(train_x[:, 0], train_x[:, 1], c=train_y, alpha=0.5, 
            cmap=mpl.colors.ListedColormap(colors_txt))
ax1.set_title("Train")

ax2.scatter(test_x[:, 0], test_x[:, 1], c=test_y, alpha=0.5, 
            cmap=mpl.colors.ListedColormap(colors_txt))
ax2.set_title("Test")

inputs_np = inputs.detach().cpu().numpy()
labels_np = labels.detach().cpu().numpy()

ax3.scatter(inputs_np[:, 0], inputs_np[:, 1], c=labels_np, alpha=0.5, 
            cmap=mpl.colors.ListedColormap(colors_txt))
ax3.set_title("Adversarial")

# Mean Field Theory - Figure 2.

Batch norm leads to a chaotic input-output map with increasing depth. A linear network with batch norm is shown acting on two minibatches of size 64 after random orthogonal initialization. The datapoints in the minibatch are chosen to form a 2d circle in input space, except for one datapoint that is perturbed separately in each minibatch (leftmost datapoint at input layer 0). Because the network is linear, for a given minibatch it performs an affine transformation on its inputs– a circle in input space remains an ellipse throughout the network.  However, due to batch norm the coefficients of that affine transformation change nonlinearly as the datapoints in the minibatchare changed.(a) Each pane shows a scatterplot of activations at a given layer for all datapointsin the minibatch, projected onto the top two PCA directions. PCA directions are computed using the concatenation of the two minibatches. Due to the batch norm  nonlinearity, mini-batches that are nearly identical in input space grow increasingly dissimilar with depth.  Intuitively, this chaotic input-output map can be understood as the source of exploding gradients when batch norm is applied to very deep networks, since very small changes in an input correspond to very large movements in network outputs.

In [None]:
activations=[]
with torch.no_grad():
    for inputs, labels in train_loader_noshuffle:
        inputs = inputs.to(device)
        labels = labels.to(device)
        out = model(inputs)
        break
print(len(activations))

In [None]:
j = 0
register_idx = 3
if do_batch_norm:
    L = num_layers * 3 # ReLUNetwork
    register_idx = 1
else:
    L = num_layers
for i in range(L):
    if i == register_idx:
        j += 1
        register_idx += 3
        print('%d %s' % (j, model.features[i]))
        model.features[i].register_forward_hook(hook)

In [None]:
activations = []

#model.train() #
model.eval()

for i, (inputs, labels) in enumerate(train_loader_noshuffle):
    inputs = inputs.to(device)
    labels = labels.to(device)
    if i == 0:
        x_b0 = inputs.detach().cpu().numpy()
        y_b0 = labels.detach().cpu().numpy()
        pred = model(inputs)
    elif i == 1:
        x_b1 = inputs.detach().cpu().numpy()
        y_b1 = labels.detach().cpu().numpy()
        pred = model(inputs)
    else:
        break        
print(len(activations))