# Train-and-Permute-MNIST-MLP

In [1]:
import os
import sys

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

# setup

In [3]:
os.makedirs('./mlps3', exist_ok=True)
def save_model(model, i):
    sd = model.state_dict()
    torch.save(model.state_dict(), 'mlps3/%s.pt' % i)

def load_model(model, i):
    sd = torch.load('mlps3/%s.pt' % i)
    model.load_state_dict(sd)

In [4]:
# write ffcv files (only needs to be run once)
import torchvision
from ffcv.fields import IntField, RGBImageField
from ffcv.writer import DatasetWriter

transform = lambda img: img.convert('RGB')
train_dset = torchvision.datasets.MNIST(root='/tmp', download=True, train=True, transform=transform)
test_dset = torchvision.datasets.MNIST(root='/tmp', download=True, train=False, transform=transform)

datasets = { 
    'train': train_dset,
    'test': test_dset,
}

for (name, ds) in datasets.items():
    writer = DatasetWriter(f'/tmp/mnist_{name}.beton', {
        'image': RGBImageField(),
        'label': IntField()
    })  
    writer.from_indexed_dataset(ds)

In [5]:
from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze

MNIST_MEAN = [33.32]
MNIST_STD = [78.58]
normalize = T.Normalize(np.array(MNIST_MEAN), np.array(MNIST_STD))

## fast FFCV data loaders
device = 'cuda:0' 
label_pipeline = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
pre_p = [SimpleRGBImageDecoder()]
post_p = [
    ToTensor(),
    ToDevice(device, non_blocking=True),
    ToTorchImage(),
    Convert(torch.float16),
    normalize,
]
train_loader = train_noaug_loader = Loader(f'/tmp/mnist_train.beton',
                     batch_size=1000,
                     num_workers=8,
                     order=OrderOption.RANDOM,
                     drop_last=False,
                     pipelines={'image': pre_p+post_p,
                                'label': label_pipeline})
test_loader = Loader(f'/tmp/mnist_test.beton',
                     batch_size=1000,
                     num_workers=8,
                     order=OrderOption.SEQUENTIAL,
                     drop_last=False,
                     pipelines={'image': pre_p+post_p,
                                'label': label_pipeline})

In [6]:
# evaluates accuracy
def evaluate(model, loader=test_loader):
    model.eval()
    correct = 0
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.cuda())
            pred = outputs.argmax(dim=1)
            correct += (labels.cuda() == pred).sum().item()
    return correct

# evaluates acc and loss
def evaluate2(model, loader=test_loader):
    model.eval()
    losses = []
    correct = 0
    total = 0
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.cuda())
            pred = outputs.argmax(dim=1)
            correct += (labels.cuda() == pred).sum().item()
            total += len(labels)
            loss = F.cross_entropy(outputs, labels.cuda())
            losses.append(loss.item())
    return correct / total, np.array(losses).mean()

def full_eval1(model):
    tr_acc, tr_loss = evaluate2(model, loader=train_noaug_loader)
    te_acc, te_loss = evaluate2(model, loader=test_loader)
    return '%.2f, %.3f, %.2f, %.3f' % (100*tr_acc, tr_loss, 100*te_acc, te_loss)

def full_eval(model):
    tr_acc, tr_loss = evaluate2(model, loader=train_noaug_loader)
    te_acc, te_loss = evaluate2(model, loader=test_loader)
    return (100*tr_acc, tr_loss, 100*te_acc, te_loss)

In [7]:
import torch.nn as nn
import torch.nn.functional as F

class NormMLP(nn.Module):
    def __init__(self, h=128, layers=3):
        super().__init__()
        self.fc1 = nn.Linear(28*28, h, bias=False)
        self.bn1 = nn.BatchNorm1d(h)
        mid_layers = []
        for _ in range(layers):
            mid_layers.extend([
                nn.Linear(h, h, bias=False),
                nn.BatchNorm1d(h),
                nn.ReLU(),
            ])
        self.layers = nn.Sequential(*mid_layers)
        self.fc2 = nn.Linear(h, 10)
    def forward(self, x):
        if x.size(1) == 3:
            x = x.mean(1, keepdim=True)
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.layers(x)
        x = self.fc2(x)
        return x

## Train and save two models

In [8]:
def train(save_key, layers=5, h=512):
    model = NormMLP(h=h, layers=layers).cuda()

    optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9)

    EPOCHS = 50
    ne_iters = len(train_loader)
    lr_schedule = np.interp(np.arange(1+EPOCHS*ne_iters), [0, 5*ne_iters, EPOCHS*ne_iters], [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule.__getitem__)

    scaler = GradScaler()
    loss_fn = CrossEntropyLoss()

    for _ in tqdm(range(EPOCHS)):
        model.train()
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad(set_to_none=True)
            with autocast():
                outputs = model(inputs.cuda())
                loss = loss_fn(outputs, labels.cuda())
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
    
    print(evaluate(model))
    save_model(model, save_key)

In [9]:
# for layers in range(8):
for layers in [12]:
    h = 128
    train('mlp_e50_l%d_h%d_v1' % (layers, h), layers=layers, h=h)
    train('mlp_e50_l%d_h%d_v2' % (layers, h), layers=layers, h=h)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:21<00:00,  2.28it/s]


9817


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:20<00:00,  2.50it/s]

9827





### matching code

In [10]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, epochs=1, loader=train_loader):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().cuda()
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2]).double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2]).double()

                mean0_b = out0.mean(dim=0)
                mean1_b = out1.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                outer_b = (out0.T @ out1) / out0.shape[0]

                if i == 0:
                    mean0 = torch.zeros_like(mean0_b)
                    mean1 = torch.zeros_like(mean1_b)
                    std0 = torch.zeros_like(std0_b)
                    std1 = torch.zeros_like(std1_b)
                    outer = torch.zeros_like(outer_b)
                mean0 += mean0_b / n
                mean1 += mean1_b / n
                std0 += std0_b / n
                std1 += std1_b / n
                outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    corr = cov / (torch.outer(std0, std1) + 1e-4)
    return corr

In [11]:
def get_layer_perm1(corr_mtx):
    corr_mtx_a = corr_mtx.cpu().numpy()
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
    assert (row_ind == np.arange(len(corr_mtx_a))).all()
    perm_map = torch.tensor(col_ind).long()
    return perm_map

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_layer_perm1(corr_mtx)

# Find neuron-permutation for each layer

In [12]:
v1, v2 = 1, 2
h = 128
layers = 12

model0 = NormMLP(h=h, layers=layers).cuda()
model1 = NormMLP(h=h, layers=layers).cuda()
load_model(model0, 'mlp_e50_l%d_h%d_v%d' % (layers, h, v1))
load_model(model1, 'mlp_e50_l%d_h%d_v%d'% (layers, h, v2))
print(evaluate(model0), evaluate(model1))

class Subnet(nn.Module):
    def __init__(self, model, layer_i):
        super().__init__()
        self.model = model
        self.layer_i = layer_i
    def forward(self, x):
        if x.size(1) == 3:
            x = x.mean(1, keepdim=True)
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.model.fc1(x))
        x = self.model.layers[:3*self.layer_i](x)
        return x

perm_map = get_layer_perm(Subnet(model0, layer_i=0), Subnet(model1, layer_i=0))
fc = model1.fc1
w_list = [fc.weight]
bn = model1.bn1
w_list.extend([bn.weight,
                bn.bias,
                bn.running_mean,
                bn.running_var])
for w in w_list:
    w.data = w[perm_map]
for w in [model1.layers[0].weight]:
    w.data = w.data[:, perm_map]

########

for i in range(layers):
    perm_map = get_layer_perm(Subnet(model0, layer_i=i+1), Subnet(model1, layer_i=i+1))
    fc = model1.layers[3*i]
    w_list = [fc.weight]
    bn = model1.layers[3*i+1]
    w_list.extend([bn.weight,
                    bn.bias,
                    bn.running_mean,
                    bn.running_var])
    for w in w_list:
        w.data = w[perm_map]
    if i < layers-1:
        for w in [model1.layers[3*(i+1)].weight]:
            w.data = w[:, perm_map]
w = model1.fc2.weight
w.data = w[:, perm_map]

save_model(model1, 'mlp_e50_l%d_h%d_v%d_perm%d' % (layers, h, v2, v1))

9817 9827


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 252.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 253.17it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 277.98it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 283.46it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 352.81it/s]
100%|

## Evaluate the interpolated network

In [13]:
def mix_weights(model, alpha, key0, key1, prefuse=False, premodel=None):
    sd0 = torch.load('mlps3/%s.pt' % key0)
    sd1 = torch.load('mlps3/%s.pt' % key1)
    if prefuse:
        premodel.load_state_dict(sd0)
        sd0 = fuse_mlp(premodel).state_dict()
        premodel.load_state_dict(sd1)
        sd1 = fuse_mlp(premodel).state_dict()
    sd_alpha = {k: (1 - alpha) * sd0[k].cuda() + alpha * sd1[k].cuda()
                for k in sd0.keys()}
    model.load_state_dict(sd_alpha)

# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.cuda())

In [15]:
layers = 9
pre = 'mlp_e50_l%d_h%d' % (layers, h)
v1, v2 = 1, 2

model0 = NormMLP(h=h, layers=layers).cuda()
model1 = NormMLP(h=h, layers=layers).cuda()
model_a = NormMLP(h=h, layers=layers).cuda()
load_model(model0, '%s_v%d' % (pre, v1))
load_model(model1, '%s_v%d_perm%d' % (pre, v2, v1))
mix_weights(model_a, 0.5, '%s_v%d' % (pre, v1), '%s_v%d_perm%d' % (pre, v2, v1))

print('(α=0.0)', full_eval(model0))
print('(α=1.0)', full_eval(model1))
print('(α=0.5 permuted)', full_eval(model_a))
reset_bn_stats(model_a)
print('(α=0.5 permuted+corrected)', full_eval(model_a))

(α=0.0) (100.0, 0.0001097355286522846, 98.32, 0.08815740388818086)
(α=1.0) (100.0, 0.00010460326666361652, 98.36, 0.0902980868704617)
(α=0.5 permuted) (87.33, 0.44000959893067676, 87.47, 0.4413963109254837)
(α=0.5 permuted+corrected) (95.90666666666667, 0.1948693387210369, 95.44, 0.23052519261837007)
