# Train-and-Permute-MNIST-MLP

In [5]:
import os
import sys

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import autocast
import torchvision
import torchvision.transforms as T

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

# setup

In [7]:
def save_model(model, i):
    sd = model.state_dict()
    torch.save(model.state_dict(), '%s.pt' % i)

def load_model(model, i):
    sd = torch.load('%s.pt' % i)
    model.load_state_dict(sd)

In [182]:
MNIST_MEAN = [33.3184]*3
MNIST_STD = [78.5675]*3

## fast FFCV data loaders
device = 'cuda:0' 
label_pipeline = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
pre_p = [SimpleRGBImageDecoder()]
post_p = [
    ToTensor(),
    ToDevice(device, non_blocking=True),
    ToTorchImage(),
    Convert(torch.float16),
    T.Normalize(MNIST_MEAN, MNIST_STD),
]
aug_p = [
    RandomTranslate(padding=4),
]


train_loader = Loader(f'/tmp/mnist_train.beton',
                      batch_size=2000,
                      num_workers=8,
                      order=OrderOption.RANDOM,
                      drop_last=True,
                      pipelines={'image': pre_p+aug_p+post_p,
                                 'label': label_pipeline})
train_noaug_loader = Loader(f'/tmp/mnist_train.beton',
                      batch_size=2000,
                      num_workers=8,
                      order=OrderOption.SEQUENTIAL,
                      drop_last=True,
                      pipelines={'image': pre_p+post_p,
                                 'label': label_pipeline})
test_loader = Loader(f'/tmp/mnist_test.beton',
                     batch_size=2000,
                     num_workers=8,
                     order=OrderOption.SEQUENTIAL,
                     drop_last=False,
                     pipelines={'image': pre_p+post_p,
                                'label': label_pipeline})

In [221]:
MNIST_MEAN = [72.9404]*3
MNIST_STD = [90.0212]*3

## fast FFCV data loaders
device = 'cuda:0' 
label_pipeline = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
pre_p = [SimpleRGBImageDecoder()]
post_p = [
    ToTensor(),
    ToDevice(device, non_blocking=True),
    ToTorchImage(),
    Convert(torch.float16),
    T.Normalize(MNIST_MEAN, MNIST_STD),
]
aug_p = [
    RandomHorizontalFlip(),
    RandomTranslate(padding=4),
]


train_loader = Loader(f'/tmp/fashionmnist_train.beton',
                      batch_size=2000,
                      num_workers=8,
                      order=OrderOption.RANDOM,
                      drop_last=True,
                      pipelines={'image': pre_p+aug_p+post_p,
                                 'label': label_pipeline})
train_noaug_loader = Loader(f'/tmp/fashionmnist_train.beton',
                      batch_size=2000,
                      num_workers=8,
                      order=OrderOption.SEQUENTIAL,
                      drop_last=True,
                      pipelines={'image': pre_p+post_p,
                                 'label': label_pipeline})
test_loader = Loader(f'/tmp/fashionmnist_test.beton',
                     batch_size=2000,
                     num_workers=8,
                     order=OrderOption.SEQUENTIAL,
                     drop_last=False,
                     pipelines={'image': pre_p+post_p,
                                'label': label_pipeline})

In [207]:
SVHN_MEAN = [111.6095, 113.1604, 120.5646]
SVHN_STD = [50.4977, 51.2590, 50.2442]

## fast FFCV data loaders
device = 'cuda:0' 
label_pipeline = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
pre_p = [SimpleRGBImageDecoder()]
post_p = [
    ToTensor(),
    ToDevice(device, non_blocking=True),
    ToTorchImage(),
    Convert(torch.float16),
    T.Normalize(SVHN_MEAN, SVHN_STD),
]
aug_p = [
    RandomTranslate(padding=4),
]


train_loader = Loader(f'/tmp/svhn_train.beton',
                      batch_size=2000,
                      num_workers=8,
                      order=OrderOption.RANDOM,
                      drop_last=True,
                      pipelines={'image': pre_p+aug_p+post_p,
                                 'label': label_pipeline})
train_noaug_loader = Loader(f'/tmp/svhn_train.beton',
                      batch_size=2000,
                      num_workers=8,
                      order=OrderOption.SEQUENTIAL,
                      drop_last=True,
                      pipelines={'image': pre_p+post_p,
                                 'label': label_pipeline})
test_loader = Loader(f'/tmp/svhn_test.beton',
                     batch_size=2000,
                     num_workers=8,
                     order=OrderOption.SEQUENTIAL,
                     drop_last=False,
                     pipelines={'image': pre_p+post_p,
                                'label': label_pipeline})

In [202]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]

## fast FFCV data loaders
device = 'cuda:0' 
label_pipeline = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
pre_p = [SimpleRGBImageDecoder()]
post_p = [
    ToTensor(),
    ToDevice(device, non_blocking=True),
    ToTorchImage(),
    Convert(torch.float16),
    T.Normalize(CIFAR_MEAN, CIFAR_STD),
]
aug_p = [
    RandomHorizontalFlip(),
    RandomTranslate(padding=4),
]


train_loader = Loader(f'/tmp/cifar_train.beton',
                      batch_size=2000,
                      num_workers=8,
                      order=OrderOption.RANDOM,
                      drop_last=True,
                      pipelines={'image': pre_p+aug_p+post_p,
                                 'label': label_pipeline})
train_noaug_loader = Loader(f'/tmp/cifar_train.beton',
                      batch_size=2000,
                      num_workers=8,
                      order=OrderOption.SEQUENTIAL,
                      drop_last=True,
                      pipelines={'image': pre_p+post_p,
                                 'label': label_pipeline})
test_loader = Loader(f'/tmp/cifar_test.beton',
                     batch_size=2000,
                     num_workers=8,
                     order=OrderOption.SEQUENTIAL,
                     drop_last=False,
                     pipelines={'image': pre_p+post_p,
                                'label': label_pipeline})

In [184]:
# evaluates accuracy
def evaluate(model):
    model.eval()
    correct = 0
    with torch.no_grad(), autocast():
        for inputs, labels in test_loader:
            outputs = model(inputs.cuda())
            pred = outputs.argmax(dim=1)
            correct += (labels.cuda() == pred).sum().item()
    return correct

# evaluates acc and loss
def evaluate2(model, loader=test_loader):
    model.eval()
    losses = []
    correct = 0
    total = 0
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.cuda())
            pred = outputs.argmax(dim=1)
            correct += (labels.cuda() == pred).sum().item()
            total += len(labels)
            loss = F.cross_entropy(outputs, labels.cuda())
            losses.append(loss.item())
    return correct / total, np.array(losses).mean()

def full_eval1(model):
    tr_acc, tr_loss = evaluate2(model, loader=train_noaug_loader)
    te_acc, te_loss = evaluate2(model, loader=test_loader)
    return '%.2f, %.3f, %.2f, %.3f' % (100*tr_acc, tr_loss, 100*te_acc, te_loss)

def full_eval(model):
    tr_acc, tr_loss = evaluate2(model, loader=train_noaug_loader)
    te_acc, te_loss = evaluate2(model, loader=test_loader)
    return (100*tr_acc, tr_loss, 100*te_acc, te_loss)

In [153]:
import torch.nn as nn
import torch.nn.functional as F
    
class MLP(nn.Module):
    def __init__(self, h=128, layers=3, dset='mnist'):
        super().__init__()
        self.dset = dset
        self.grayscale = (dset in ['mnist', 'fashionmnist'])
        dim1 = 28*28 if self.grayscale else 3*32*32
        self.fc1 = nn.Linear(dim1, h)
        mid_layers = []
        for _ in range(layers):
            mid_layers.extend([nn.Linear(h, h), nn.ReLU()])
        self.layers = nn.Sequential(*mid_layers)
        self.fc2 = nn.Linear(h, 10)

    def forward(self, x):
        if self.grayscale:
            if x.size(1) == 3:
                x = x.mean(1, keepdim=True)
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.layers(x)
        x = self.fc2(x)
        return x

### matching code

In [130]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1):
    n = len(train_loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for i, (images, _) in enumerate(tqdm(train_loader)):
            img_t = images.float().cuda()
            out0 = net0(img_t)
            out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
            out0 = out0.reshape(-1, out0.shape[2]).double()

            out1 = net1(img_t)
            out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
            out1 = out1.reshape(-1, out1.shape[2]).double()

            mean0_b = out0.mean(dim=0)
            mean1_b = out1.mean(dim=0)
            std0_b = out0.std(dim=0)
            std1_b = out1.std(dim=0)
            outer_b = (out0.T @ out1) / out0.shape[0]

            if i == 0:
                mean0 = torch.zeros_like(mean0_b)
                mean1 = torch.zeros_like(mean1_b)
                std0 = torch.zeros_like(std0_b)
                std1 = torch.zeros_like(std1_b)
                outer = torch.zeros_like(outer_b)
            mean0 += mean0_b / n
            mean1 += mean1_b / n
            std0 += std0_b / n
            std1 += std1_b / n
            outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    corr = cov / (torch.outer(std0, std1) + 1e-4)
    return corr

In [131]:
def get_layer_perm1(corr_mtx):
    corr_mtx_a = corr_mtx.cpu().numpy()
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
    assert (row_ind == np.arange(len(corr_mtx_a))).all()
    perm_map = torch.tensor(col_ind).long()
    return perm_map

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_layer_perm1(corr_mtx)

# Find neuron-permutation for each layer

In [205]:
h = 4096
dset = 'cifar10'
layers = 10

model0 = MLP(h, layers, dset).cuda()
model1 = MLP(h, layers, dset).cuda()
k0 = 'mlps/%s_e%d_l%d_h%d_v1' % (dset, 100, layers, h)
k1 = 'mlps/%s_e%d_l%d_h%d_v2' % (dset, 100, layers, h)
load_model(model0, k0)
load_model(model1, k1)
print(evaluate(model0), evaluate(model1))

class Subnet(nn.Module):
    def __init__(self, model, k):
        super().__init__()
        self.grayscale = model.grayscale
        self.fc1 = model.fc1
        self.layers = model.layers[:2*k]
    def forward(self, x):
        if self.grayscale:
            if x.size(1) == 3:
                x = x.mean(1, keepdim=True)
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.layers(x)
        return x

perm_map = get_layer_perm(Subnet(model0, 0), Subnet(model1, 0))
fc = model1.fc1
w_list = [fc.weight, fc.bias]
for w in w_list:
    w.data = w[perm_map]
w = model1.layers[0].weight
w.data = w.data[:, perm_map]

########

for k in range(1, 10+1):
    perm_map = get_layer_perm(Subnet(model0, k), Subnet(model1, k))
    fc = model1.layers[2*k-2]
    w_list = [fc.weight, fc.bias]
    for w in w_list:
        w.data = w[perm_map]
    if k < layers:
        w = model1.layers[2*k].weight
        w.data = w[:, perm_map]
    else:
        w = model1.fc2.weight
        w.data = w[:, perm_map]

save_model(model1, 'mlps/%s_e%d_l%d_h%d_v2_perm1' % (dset, 100, layers, h))

6111 5911


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 79.26it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 80.93it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 81.05it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 79.47it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 74.30it/s]
100%|

## Evaluate the interpolated network

In [138]:
def mix_weights(model, alpha, key0, key1):
    sd0 = torch.load('%s.pt' % key0)
    sd1 = torch.load('%s.pt' % key1)
    sd_alpha = {k: (1 - alpha) * sd0[k].cuda() + alpha * sd1[k].cuda()
                for k in sd0.keys()}
    model.load_state_dict(sd_alpha)

# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=2):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if isinstance(m, nn.modules.batchnorm._BatchNorm):
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in train_loader:
                output = model(images.cuda())
    model.eval()

In [154]:
class ResetLayer(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer
        self.bn = nn.BatchNorm1d(len(layer.weight))
        self.rescale = False
        
    def set_stats(self, goal_mean, goal_var):
        self.bn.bias.data = goal_mean
        goal_std = (goal_var + 1e-7).sqrt()
        self.bn.weight.data = goal_std
        
    def forward(self, x):
        x = self.layer(x)
        x1 = self.bn(x)
        return x1 if self.rescale else x

def make_tracked_net(net):
    net1 = MLP(layers=len(net.layers)//2, h=net.fc1.out_features, dset=net.dset)
    net1.load_state_dict(net.state_dict())
    net1.fc1 = ResetLayer(net1.fc1)
    for i, layer in enumerate(net.layers):
        if isinstance(layer, nn.Linear):
            net1.layers[i] = ResetLayer(layer)
#     net1.fc2 = ResetLayer(net1.fc2)
    return net1.cuda().eval()

def forward_pass_correction(wrap_a, model0, model1, alpha=0.5):
    ## calculate the statistics of every hidden unit in the endpoint networks
    ## this is done practically using PyTorch BatchNorm2d layers.
    wrap0 = make_tracked_net(model0)
    wrap1 = make_tracked_net(model1)
    reset_bn_stats(wrap0)
    reset_bn_stats(wrap1)
    
    ## set the goal mean/std in added bns of interpolated network, and turn batch renormalization on
    for m_a, m0, m1 in zip(wrap_a.modules(), wrap0.modules(), wrap1.modules()):
        if not isinstance(m0, ResetLayer):
            continue
        # get goal statistics -- interpolate the mean and std of parent networks
        mu0 = m0.bn.running_mean
        mu1 = m1.bn.running_mean
        goal_mean = (1 - alpha) * mu0 + alpha * mu1
        var0 = m0.bn.running_var
        var1 = m1.bn.running_var
        goal_var = ((1 - alpha) * var0.sqrt() + alpha * var1.sqrt()).square()
        # set these in the interpolated bn controller
        m_a.set_stats(goal_mean, goal_var)
        # turn rescaling on
        m_a.rescale = True
    reset_bn_stats(wrap_a, epochs=3)

In [145]:
stats = {}

In [222]:
dset = 'fashionmnist'
layers = 10

for h in tqdm([32, 64, 128, 256, 512, 1024, 2048, 4096][-3:]):
    s = {}

    k0 = 'mlps/%s_e%d_l%d_h%d_v1' % (dset, 100, layers, h)
    k1 = 'mlps/%s_e%d_l%d_h%d_v2' % (dset, 100, layers, h)
    model0 = MLP(h, layers, dset).cuda()
    model1 = MLP(h, layers, dset).cuda()
    load_model(model0, k0)
    load_model(model1, k1)
    s['v1'] = full_eval(model0)
    s['v2'] = full_eval(model1)

    model_a = MLP(h, layers, dset).cuda()
    mix_weights(model_a, 0.5, k0, k1)
    s['vanilla'] = full_eval(model_a)
    
    k1 = 'mlps/%s_e%d_l%d_h%d_v2_perm1' % (dset, 100, layers, h)
    model_a = MLP(h, layers, dset).cuda()
    mix_weights(model_a, 0.5, k0, k1)
    s['permute'] = full_eval(model_a)

    wrap_a = make_tracked_net(model_a)
    forward_pass_correction(wrap_a, model0, model1, 0.5)
    s['permute_correct'] = full_eval(wrap_a)
    
    stats['%s_e%d_l%d_h%d' % (dset, 100, layers, h)] = s

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:22<00:00,  7.48s/it]


In [223]:
torch.save(stats, 'figure_objects/mlp_barriers_datasets.pt')