# Permute-CIFAR10-ResNet20

In [1]:
import os

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import autocast
import torchvision.transforms as T

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

# setup

In [3]:
def save_model(model, i):
    sd = model.state_dict()
    torch.save(model.state_dict(), '%s.pt' % i)

def load_model(model, i):
    sd = torch.load('%s.pt' % i)
    model.load_state_dict(sd)

In [4]:
from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze

CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]

## fast FFCV data loaders
device = 'cuda:0' 
label_pipeline = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
pre_p = [SimpleRGBImageDecoder()]
post_p = [
    ToTensor(),
    ToDevice(device, non_blocking=True),
    ToTorchImage(),
    Convert(torch.float16),
    T.Normalize(CIFAR_MEAN, CIFAR_STD),
]
aug_p = [
    RandomHorizontalFlip(),
    RandomTranslate(padding=4),
    Cutout(12, tuple(map(int, CIFAR_MEAN))),
]


train_aug_loader = Loader(f'/tmp/cifar_train.beton',
                      batch_size=500,
                      num_workers=8,
                      order=OrderOption.RANDOM,
                      drop_last=True,
                      pipelines={'image': pre_p+aug_p+post_p,
                                 'label': label_pipeline})
train_noaug_loader = Loader(f'/tmp/cifar_train.beton',
                     batch_size=1000,
                     num_workers=8,
                     order=OrderOption.SEQUENTIAL,
                     drop_last=False,
                     pipelines={'image': pre_p+post_p,
                                'label': label_pipeline})
test_loader = Loader(f'/tmp/cifar_test.beton',
                     batch_size=1000,
                     num_workers=8,
                     order=OrderOption.SEQUENTIAL,
                     drop_last=False,
                     pipelines={'image': pre_p+post_p,
                                'label': label_pipeline})

In [5]:
# evaluates accuracy
def evaluate(model, loader=test_loader):
    model.eval()
    correct = 0
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.cuda())
            pred = outputs.argmax(dim=1)
            correct += (labels.cuda() == pred).sum().item()
    return correct

# evaluates acc and loss
def evaluate2(model, loader=test_loader):
    model.eval()
    losses = []
    correct = 0
    total = 0
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.cuda())
            pred = outputs.argmax(dim=1)
            correct += (labels.cuda() == pred).sum().item()
            total += len(labels)
            loss = F.cross_entropy(outputs, labels.cuda())
            losses.append(loss.item())
    return correct / total, np.array(losses).mean()

def full_eval(model):
    tr_acc, tr_loss = evaluate2(model, loader=train_noaug_loader)
    te_acc, te_loss = evaluate2(model, loader=test_loader)
    return (100*tr_acc, tr_loss, 100*te_acc, te_loss)

In [6]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def _weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, w=1, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = w*16

        self.conv1 = nn.Conv2d(3, w*16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(w*16)
        self.layer1 = self._make_layer(block, w*16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, w*32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, w*64, num_blocks[2], stride=2)
        self.linear = nn.Linear(w*64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def resnet20(w=1):
    return ResNet(BasicBlock, [3, 3, 3], w).cuda().eval()

def get_blocks(net):
    return nn.Sequential(nn.Sequential(net.conv1, net.bn1, nn.ReLU()),
                         *net.layer1, *net.layer2, *net.layer3)

### matching code

In [7]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, loader=train_aug_loader):
    n = len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for i, (images, _) in enumerate(tqdm(loader)):
            img_t = images.float().cuda()
            out0 = net0(img_t)
            out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
            out0 = out0.reshape(-1, out0.shape[2]).double()

            out1 = net1(img_t)
            out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
            out1 = out1.reshape(-1, out1.shape[2]).double()

            mean0_b = out0.mean(dim=0)
            mean1_b = out1.mean(dim=0)
            std0_b = out0.std(dim=0)
            std1_b = out1.std(dim=0)
            outer_b = (out0.T @ out1) / out0.shape[0]

            if i == 0:
                mean0 = torch.zeros_like(mean0_b)
                mean1 = torch.zeros_like(mean1_b)
                std0 = torch.zeros_like(std0_b)
                std1 = torch.zeros_like(std1_b)
                outer = torch.zeros_like(outer_b)
            mean0 += mean0_b / n
            mean1 += mean1_b / n
            std0 += std0_b / n
            std1 += std1_b / n
            outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    corr = cov / (torch.outer(std0, std1) + 1e-4)
    return corr

In [8]:
def get_layer_perm1(corr_mtx):
    corr_mtx_a = corr_mtx.cpu().numpy()
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
    assert (row_ind == np.arange(len(corr_mtx_a))).all()
    perm_map = torch.tensor(col_ind).long()
    return perm_map

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_layer_perm1(corr_mtx)

In [9]:
# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, conv, bn=None):
    pre_weights = [conv.weight]
    if conv.bias is not None:
        pre_weights.append(conv.bias)
    if bn is not None:
        pre_weights.extend([bn.weight, bn.bias, bn.running_mean, bn.running_var])
    for w in pre_weights:
        w.data = w[perm_map]

# modifies the weight matrix of a convolution layer for a given
# permutation of the input channels
def permute_input(perm_map, conv):
    w = conv.weight
    w.data = w[:, perm_map, :, :]

# Find neuron-permutation for each layer

In [10]:
def load3(model, w, key):
    d = '/persist/kjordan/Network-Permutations/train/resnet-batchnorm-cifar/checkpoints/'
    p = d + 'batchnorm_resnet20x%d_e250_%s.pt' % (w, key)
    sd = torch.load(p)
    model.load_state_dict(sd)

w_map = {
    1: ('3dbaad51-0b9c-48a4-a7dc-f676b503e352', '7b552086-160c-4862-92f2-28b7545b8178'),
    2: ('08f5216e-09c0-469a-9fbb-14197861c15b', '67535203-d0f4-4198-aac1-3c7f601a3936'),
    4: ('9c0810b8-9c57-4be0-887c-02bc69ae041f', 'a8f9dc58-9539-4070-8fe5-b466c1f6f4e3'),
    8: ('385580db-1317-4594-a10a-cddbb4661cfb', '65576a5b-eb87-4a51-baf1-bb4fa1434746'),
    16: ('e52e0379-f115-48b6-8086-deb016abd0c8', '0432d15c-7888-4276-b9a7-17d60045a54c'),
    32: ('376a3ce1-4c9b-4537-8a7e-2e038e461947', '2ea547fe-0459-4b7f-8a1d-eae08b9a4257'),
}

w = 1
model0 = resnet20(w)
model1 = resnet20(w)
load3(model0, w, w_map[w][0])
load3(model1, w, w_map[w][1])

save_model(model0, 'batchnorm/resnet20x%d_v1' % w)
save_model(model1, 'batchnorm/resnet20x%d_v2' % w)

evaluate(model0), evaluate(model1)

(9347, 9369)

### intrablock

In [11]:
blocks0 = get_blocks(model0)
blocks1 = get_blocks(model1)

for k in range(1, len(blocks1)):
    block0 = blocks0[k]
    block1 = blocks1[k]
    subnet0 = nn.Sequential(blocks0[:k], block0.conv1, block0.bn1, nn.ReLU())
    subnet1 = nn.Sequential(blocks1[:k], block1.conv1, block1.bn1, nn.ReLU())
    perm_map = get_layer_perm(subnet0, subnet1)
    permute_output(perm_map, block1.conv1, block1.bn1)
    permute_input(perm_map, block1.conv2)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 29.07it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 253.75it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 179.59it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 151.77it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 124.98it/s]
100%|

In [12]:
# save_model(model1, 'batchnorm/resnet20x%d_v2_perm1a' % w)

### interblock

In [13]:
# load_model(model1, 'batchnorm/resnet20x%d_v2_perm1a' % w)

In [14]:
kk = [3, 6, 8]

perm_map = get_layer_perm(blocks0[:kk[0]+1], blocks1[:kk[0]+1])
permute_output(perm_map, model1.conv1, model1.bn1)
for block in model1.layer1:
    permute_input(perm_map, block.conv1)
    permute_output(perm_map, block.conv2, block.bn2)
block = model1.layer2[0]
permute_input(perm_map, block.conv1)
permute_input(perm_map, block.shortcut[0])

perm_map = get_layer_perm(blocks0[:kk[1]+1], blocks1[:kk[1]+1])
for i, block in enumerate(model1.layer2):
    if i > 0:
        permute_input(perm_map, block.conv1)
    else:
        permute_output(perm_map, block.shortcut[0], block.shortcut[1])
    permute_output(perm_map, block.conv2, block.bn2)
block = model1.layer3[0]
permute_input(perm_map, block.conv1)
permute_input(perm_map, block.shortcut[0])

perm_map = get_layer_perm(blocks0[:kk[2]+1], blocks1[:kk[2]+1])
for i, block in enumerate(model1.layer3):
    if i > 0:
        permute_input(perm_map, block.conv1)
    else:
        permute_output(perm_map, block.shortcut[0], block.shortcut[1])
    permute_output(perm_map, block.conv2, block.bn2)
model1.linear.weight.data = model1.linear.weight[:, perm_map]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 159.99it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 115.19it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 108.59it/s]


### done

In [15]:
evaluate(model1)

9369

In [16]:
save_model(model1, 'batchnorm/resnet20x%d_v2_perm1' % w)

## Evaluate the interpolated network

In [17]:
def mix_weights(model, alpha, key0, key1):
    sd0 = torch.load('%s.pt' % key0)
    sd1 = torch.load('%s.pt' % key1)
    sd_alpha = {k: (1 - alpha) * sd0[k].cuda() + alpha * sd1[k].cuda()
                for k in sd0.keys()}
    model.load_state_dict(sd_alpha)

# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    with torch.no_grad(), autocast():
        for images, _ in loader:
            output = model(images.cuda())

In [18]:
stats = {}

In [19]:
ss = {}

w = 1
model0 = resnet20(w)
model1 = resnet20(w)

k0 = 'batchnorm/resnet20x%d_v1' % w
k1 = 'batchnorm/resnet20x%d_v2_perm1' % w

load_model(model0, k0)
load_model(model1, k1)
ss['model_v1'] = full_eval(model0)
ss['model_v2'] = full_eval(model1)
print(ss['model_v1'])
print(ss['model_v2'])

model_a = resnet20(w)
mix_weights(model_a, 0.5, k0, k1)
ss['permute'] = full_eval(model_a)
print(ss['permute'])

reset_bn_stats(model_a)
ss['permute_renorm'] = full_eval(model_a)
print(ss['permute_renorm'])
stats['resnet20x%d' % w] = ss

(99.688, 0.012446842743083835, 93.47, 0.22479863911867143)
(99.616, 0.012966994484886528, 93.69, 0.22910738438367845)
(12.293999999999999, 2.8969798851013184, 12.53, 2.889247918128967)
(68.726, 1.1290484881401062, 67.0, 1.2428919911384582)


In [20]:
# torch.save(stats, 'batchnorm_resnet20_barriers.pt')