In [63]:
### MOFA: Maxim's Once For All ###
### This version is using the Kernel Transition Matrix by default. ###

In [1]:
import copy
import random
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt

from datasets.cifar100 import cifar100_get_datasets


# <ins>Classes</ins>

In [2]:
class Clamp(nn.Module):
    """
    Post-Activation Clamping Module
    Clamp the output to the given range (typically, [-128, +127])
    """
    def __init__(self, min_val=None, max_val=None):
        super().__init__()
        self.min_val = min_val
        self.max_val = max_val

    def forward(self, x):  # pylint: disable=arguments-differ
        """Forward prop"""
        return x.clamp(min=self.min_val, max=self.max_val)

In [3]:
class MOFAnet(nn.Module):
    # Maxim OFA Net
    def __init__(self, param_dict):
        super(MOFAnet, self).__init__()
        self.param_dict = param_dict
        self.in_ch = param_dict['in_ch']
        self.out_class = param_dict['out_class']
        self.n_units = param_dict['n_units']
        self.width_list = param_dict['width_list']
        self.kernel_list = param_dict['kernel_list']
        self.bias_list = param_dict['bias_list']
        self.bn = param_dict['bn']
        self.last_width = self.in_ch
        self.units = nn.ModuleList([])
        for i in range(n_units):
            self.units.append(Unit(len(self.kernel_list[i]), 
                                   self.kernel_list[i],
                                   self.width_list[i], 
                                   self.last_width, 
                                   self.bias_list[i],
                                   self.bn))
            self.last_width = self.width_list[i][-1]
        self.flatten = nn.Flatten()
        self.max_pool = nn.MaxPool2d(kernel_size=2)

        self.classifier = nn.Linear(1024, self.out_class) 
    def forward(self, x):
        for i, unit in enumerate(self.units[:-1]):
            x = unit(x)
            x = self.max_pool(x)
        x = self.units[-1](x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [4]:
class Unit(nn.Module):
    def __init__(self, depth, kernel_list, 
                 width_list, init_width, bias_list, bn=True):
        super(Unit, self).__init__()
        self.depth = depth
        self.kernel_list = kernel_list
        self.width_list = width_list
        self.bias_list = bias_list
        self.bn = bn
        self._width_list = [init_width] + width_list
        self.layers = nn.ModuleList([])
        for i in range(depth):
            self.layers.append(
                FusedConv2dReLU(self._width_list[i],
                                self._width_list[i+1],
                                self.kernel_list[i],
                                self.bias_list[i],
                                self.bn))
    def forward(self, x):
        for i in range(self.depth):
            x = self.layers[i](x)
        return x

In [5]:
class FusedConv2dReLU(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            bias=True,
            bn=True):
        super(FusedConv2dReLU, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
            
        ktm_core = torch.zeros((9, 1))
        ktm_core[4] = 1
        self.ktm = nn.Parameter(data=ktm_core, requires_grad=True)
        
        if kernel_size == 1:
            self.pad = 0
        elif kernel_size == 3:
            self.pad = 1
        else:
            raise ValueError
        self.func = F.conv2d
        self.conv2d = nn.Conv2d(in_channels, out_channels,
                                kernel_size=3, stride=1,
                                padding=1, bias=bias)
        self.bn = bn
        if self.bn:
            self.batchnorm = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU()
        self.clamp = Clamp(min_val=-1, max_val=1)
    def forward(self, x):        
        weight = self.conv2d.weight
        bias = self.conv2d.bias
        if self.kernel_size == 1:
            flattened_weight = weight.view(weight.size(0), weight.size(1), -1, 9)
            weight = flattened_weight.to(device) @ self.ktm.to(device)
                    
        x = self.func(x, weight, bias, self.conv2d.stride, self.pad)
        if self.bn:
            x = self.batchnorm(x)
        x = self.activation(x)
#         x = self.clamp(x)
        return x

# <ins>Functions</ins>

## Batchnorm Related 

In [6]:
def make_bn_stats_false(model):
    for u_ind, unit in enumerate(model.units):
        for l_ind, layer in enumerate(unit.layers):
            model.units[u_ind].layers[l_ind].batchnorm.track_running_stats = False

    return model

def make_bn_stats_true(model):
    for u_ind, unit in enumerate(model.units):
        for l_ind, layer in enumerate(unit.layers):
            model.units[u_ind].layers[l_ind].batchnorm.track_running_stats = True

    return model

def fuse_bn(conv, bn):
    w = conv.weight
    mean = bn.running_mean
    var_sqrt = torch.sqrt(bn.running_var + bn.eps)
    beta = bn.weight
    gamma = bn.bias
    if conv.bias is not None:
        b = conv.bias
    else:
        b = mean.new_zeros(mean.shape)
    w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
    b = (b - mean) / var_sqrt * beta + gamma
    fused_conv = nn.Conv2d(conv.in_channels,
                         conv.out_channels,
                         conv.kernel_size,
                         conv.stride,
                         conv.padding,
                         bias=True)
    fused_conv.weight = nn.Parameter(w)
    fused_conv.bias = nn.Parameter(b)
    return fused_conv


def fuse_bn_mofa(mofa_net):
    param_dict = copy.deepcopy(mofa_net.param_dict)
    param_dict['bn'] = False
    fused_model = MOFAnet(param_dict)
    with torch.no_grad():
        fused_model.classifier.weight.copy_(mofa_net.classifier.weight)
        fused_model.classifier.bias.copy_(mofa_net.classifier.bias)
    for u_ind, unit in enumerate(mofa_net.units):
        for l_ind, layer in enumerate(unit.layers):
            fused_conv = fuse_bn(layer.conv2d, layer.batchnorm)
            fused_conv = fused_conv.to(device)
            with torch.no_grad():
                fused_model.units[u_ind].layers[l_ind].conv2d.weight.copy_(fused_conv.weight)
                fused_model.units[u_ind].layers[l_ind].conv2d.bias.copy_(fused_conv.bias)
    return fused_model

## Elastic Kernel - Depth -Width 

In [7]:
def sample_subnet_kernel(mofa):
    param_dict = copy.deepcopy(mofa.param_dict)
    for u_ind, unit in enumerate(mofa.units):
        for l_ind, layer in enumerate(unit.layers):
            param_dict['kernel_list'][u_ind][l_ind] = random.choice([1, 3])
#     param_dict['kernel_list'][0][0] = random.choice([1, 3])
#     param_dict['kernel_list'][0][1] = random.choice([1, 3])
#     param_dict['kernel_list'][0][2] = random.choice([1, 3])
    param_dict['bn'] = False
    subnet = MOFAnet(param_dict)
    with torch.no_grad():
        subnet.classifier.weight.copy_(mofa.classifier.weight)
        subnet.classifier.bias.copy_(mofa.classifier.bias)
        for u_ind, unit in enumerate(mofa.units):
            for l_ind, layer in enumerate(unit.layers):
                subnet.units[u_ind].layers[l_ind].conv2d.weight.copy_(mofa.units[u_ind].layers[l_ind].conv2d.weight)
                subnet.units[u_ind].layers[l_ind].ktm.copy_(mofa.units[u_ind].layers[l_ind].ktm)
                if mofa.bias_list[u_ind][l_ind] is True:
                    subnet.units[u_ind].layers[l_ind].conv2d.bias.copy_(mofa.units[u_ind].layers[l_ind].conv2d.bias)
    return subnet

def update_mofa_from_subnet_kernel(mofa, subnet):
    with torch.no_grad():
        mofa.classifier.weight.copy_(subnet.classifier.weight)
        mofa.classifier.bias.copy_(subnet.classifier.bias)
        for u_ind, unit in enumerate(mofa.units):
            for l_ind, layer in enumerate(unit.layers):
                mofa.units[u_ind].layers[l_ind].conv2d.weight.copy_(subnet.units[u_ind].layers[l_ind].conv2d.weight)
                mofa.units[u_ind].layers[l_ind].ktm.copy_(subnet.units[u_ind].layers[l_ind].ktm)
                if mofa.bias_list[u_ind][l_ind] is True:
                    mofa.units[u_ind].layers[l_ind].conv2d.bias.copy_(subnet.units[u_ind].layers[l_ind].conv2d.bias)
    return mofa


def sample_subnet_depth(mofa, sample_kernel=True):
    param_dict = copy.deepcopy(mofa.param_dict)
    depth_list = []
    for u_ind in range(len(param_dict['width_list'])):
        max_depth = len(param_dict['width_list'][u_ind])
        min_depth = 1
        depth_list.append(random.randint(min_depth, max_depth))
    
    if sample_kernel:
        subnet = sample_subnet_kernel(mofa) # This is confirmed by Ji
    else:
        subnet = copy.deepcopy(mofa)
    
    param_dict = copy.deepcopy(subnet.param_dict)
    param_dict['bn'] = False
    param_dict['width_list'] = [lst[:depth_list[ind]] for ind, lst in enumerate(subnet.param_dict['width_list'])]
    param_dict['kernel_list'] = [lst[:depth_list[ind]] for ind, lst in enumerate(subnet.param_dict['kernel_list'])]
    param_dict['bias_list'] = [lst[:depth_list[ind]] for ind, lst in enumerate(subnet.param_dict['bias_list'])]
    
    subnet2 = MOFAnet(param_dict)
    
    with torch.no_grad():
        subnet2.classifier.weight.copy_(subnet.classifier.weight)
        subnet2.classifier.bias.copy_(subnet.classifier.bias)
        for u_ind, unit in enumerate(subnet2.units):
            for l_ind, layer in enumerate(unit.layers):
                subnet2.units[u_ind].layers[l_ind].conv2d.weight.copy_(subnet.units[u_ind].layers[l_ind].conv2d.weight)
                subnet2.units[u_ind].layers[l_ind].conv2d.bias.copy_(subnet.units[u_ind].layers[l_ind].conv2d.bias)
        
    return subnet2, param_dict, depth_list

def update_mofa_from_subnet_depth(mofa, subnet):
    subnet_params = subnet.param_dict
    mofa_params = mofa.param_dict
    
    with torch.no_grad():
        mofa.classifier.weight.copy_(subnet.classifier.weight)
        mofa.classifier.bias.copy_(subnet.classifier.bias)
        for u_ind, unit in enumerate(subnet.units):
                for l_ind, layer in enumerate(unit.layers):
                    mofa.units[u_ind].layers[l_ind].conv2d.weight.copy_(subnet.units[u_ind].layers[l_ind].conv2d.weight)
                    mofa.units[u_ind].layers[l_ind].ktm.copy_(subnet.units[u_ind].layers[l_ind].ktm)
                    if mofa.bias_list[u_ind][l_ind] is True:
                        mofa.units[u_ind].layers[l_ind].conv2d.bias.copy_(subnet.units[u_ind].layers[l_ind].conv2d.bias)
    return mofa
                
def order_channels(mofa):
    pass
    
    
def sample_subnet_width(mofa):
    mofa = order_channels(mofa)
    return mofa

    
# def update_mofa_from_subnet_width(mofa, subnet):
 #TODO

## Others

In [8]:
def cross_entropy_loss_with_soft_target(pred, soft_target):
    logsoftmax = nn.LogSoftmax()
    return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))

# <ins> MOFA Training </ins>

In [9]:
n_units = 5
n_layers = 3

param_dict = {}
param_dict['n_units']     = n_units
param_dict['in_ch']       = 3
param_dict['out_class']   = 100
param_dict['width_list']  = [[256]*n_layers for _ in range(n_units)]
param_dict['kernel_list'] = [[3]*n_layers for _ in range(n_units)]
param_dict['bias_list']   = [[True]*n_layers for _ in range(n_units)]
param_dict['bn']          = True


In [10]:
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_cuda else "cpu")

# class Args():
#     def __init__(self):
#         super(Args, self).__init__()
#         self.truncate_testset = False
#         self.act_mode_8bit = False
        
# args = Args()
# train_dataset, test_dataset = cifar100_get_datasets(('data', args))

# trainset = DataLoader(dataset=train_dataset,
#                       batch_size=100,
#                       shuffle=True,
#                       num_workers=0)

# valset = DataLoader(dataset=test_dataset,
#                       batch_size=1000,
#                       shuffle=False,
#                       num_workers=0)

# mofa = MOFAnet(param_dict)
# mofa = mofa.to(device)

# criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(mofa.parameters(), lr=1e-4)
# scheduler = StepLR(optimizer, step_size=35, gamma=0.4)

# best_val_accuracy = 0
# max_epochs = 150
# for epoch in range(max_epochs):
#     t0 = time.time()
#     mofa.train()
#     for batch, labels in trainset:
#         batch, labels = batch.to(device), labels.to(device)
        
#         y_pred = mofa(batch)
#         loss = criterion(y_pred, labels)
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#     scheduler.step()
    
#     print(f'Epoch {epoch+1}')
#     print(f'\tTraining loss:{loss.item()}')
#     t1 = time.time()
#     print(f'\tTraining time:{t1-t0:.2f} s - {(t1-t0)/60:.2f} mins ')
#     # Validation
#     correct = 0
#     total = 0
#     mofa.eval()
#     with torch.no_grad():
#         for data in valset:
#             images, labels = data
#             images, labels = images.to(device), labels.to(device)
#             outputs = mofa(images)
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
#     val_accuracy = correct / total
#     if val_accuracy > best_val_accuracy:
#         if epoch != 0:
#             os.remove(f'mofa_models/noclamp_mofa_acc{100*best_val_accuracy:.0f}%.pth.tar')
#         torch.save(mofa, f'mofa_models/noclamp_mofa_acc{100*val_accuracy:.0f}%.pth.tar')
#         best_val_accuracy = val_accuracy
#     print('\tAccuracy of the mofa on the test images: %d %%' % (
#         100 * correct / total))
    

# <ins> Elastic Kernel Training </ins>

In [84]:
model = torch.load('mofa_models/noclamp_mofa_acc71%.pth.tar')
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [85]:
fused_model = fuse_bn_mofa(model)
mofa = copy.deepcopy(fused_model)
fused_model = fused_model.to(device)
mofa = mofa.to(device)

In [13]:
kd_ratio = 0.5

class Args():
    def __init__(self):
        super(Args, self).__init__()
        self.truncate_testset = False
        self.act_mode_8bit = False
        
args = Args()
train_dataset, test_dataset = cifar100_get_datasets(('data', args))

trainset = DataLoader(dataset=train_dataset,
                      batch_size=100,
                      shuffle=True,
                      num_workers=0)

valset = DataLoader(dataset=test_dataset,
                      batch_size=1000,
                      shuffle=False,
                      num_workers=0)

criterion = torch.nn.CrossEntropyLoss()

best_val_accuracy = 0
max_epochs = 250
for epoch in range(max_epochs):
    t0 = time.time()
    mofa.train()
    for batch, labels in trainset:
        batch, labels = batch.to(device), labels.to(device)
        
#         mofa = make_bn_stats_false(mofa)
        subnet = sample_subnet_kernel(mofa)
        subnet = subnet.to(device)
        optimizer = torch.optim.SGD(subnet.parameters(), lr=1e-3)
      
        y_pred = subnet(batch)
        
        if kd_ratio > 0:
            fused_model.train()
            with torch.no_grad():
                soft_logits = fused_model(batch).detach()
                soft_label = F.softmax(soft_logits, dim=1)
            kd_loss = cross_entropy_loss_with_soft_target(y_pred, soft_label)
            loss = kd_ratio * kd_loss + criterion(y_pred, labels)
        else:
            loss = criterion(y_pred, labels)     
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        mofa = update_mofa_from_subnet_kernel(mofa, subnet)
        
    print(f'Epoch {epoch+1}')
    print(f'\tTraining loss:{loss.item()}')
    t1 = time.time()
    print(f'\tTraining time:{t1-t0:.2f} s - {(t1-t0)/60:.2f} mins ')
    
    # Validation
    correct = 0
    total = 0
    mofa.eval()
    with torch.no_grad():
#         mofa = make_bn_stats_true(mofa)
#         mofa.train()
#         for data in valset:
#             images, labels = data
#             images, labels = images.to(device), labels.to(device)
#             outputs = mofa(images)
#         mofa.eval()
        for data in valset:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = mofa(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_accuracy = correct / total
#     if val_accuracy > best_val_accuracy:
#         if epoch is not 0:
#             os.remove(f'mofa_models/ofa_acc{100*best_val_accuracy:.0f}%.pth.tar')
#         torch.save(mofa, f'mofa_models/ofa_acc{100*val_accuracy:.0f}%.pth.tar')
#         best_val_accuracy = val_accuracy
    print('\tAccuracy of the mofa on the test images: %d %%' % (
        100 * correct / total))
    print(f'\tFirst ktm: {mofa.units[0].layers[0].ktm[4].item()}')
    print(f'\tLast ktm: {mofa.units[4].layers[2].ktm[4].item()}')
    

Files already downloaded and verified
Files already downloaded and verified


RuntimeError: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 10.76 GiB total capacity; 961.57 MiB already allocated; 11.44 MiB free; 986.00 MiB reserved in total by PyTorch)

# <ins> Testing <ins/>

**MOFA Test**

In [13]:
class Args():
    def __init__(self):
        super(Args, self).__init__()
        self.truncate_testset = False
        self.act_mode_8bit = False
        
args = Args()
train_dataset, test_dataset = cifar100_get_datasets(('data', args))

trainset = DataLoader(dataset=train_dataset,
                      batch_size=100,
                      shuffle=True,
                      num_workers=0)

valset = DataLoader(dataset=test_dataset,
                      batch_size=100,
                      shuffle=False,
                      num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


In [88]:
correct = 0
total = 0
mofa.eval()
with torch.no_grad():
    for data in valset:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = mofa(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(correct / total)

0.7051


In [86]:
def sort_channels(mofa):
    for ind in range((n_units*n_layers)-1):
        u_ind = ind // n_layers
        l_ind = ind % n_layers
        layer = mofa.units[u_ind].layers[l_ind]
        
        importance = torch.sum(torch.abs(layer.conv2d.weight.data), dim=(1, 2, 3))
        _, inds = torch.sort(importance, descending=True)
        layer.conv2d.weight.data = layer.conv2d.weight.data[inds, :, :, :]
        layer.conv2d.bias.data = layer.conv2d.bias.data[inds]
        
        ind_new = ind + 1
        u_ind = ind_new // n_layers
        l_ind = ind_new % n_layers
        mofa.units[u_ind].layers[l_ind].conv2d.weight.data = mofa.units[u_ind].layers[l_ind].conv2d.weight.data[:, inds, :, :]
        
    return mofa

    
def see_channel_importances(mofa):
    for u_ind, unit in enumerate(mofa.units):
                for l_ind, layer in enumerate(unit.layers):
                    importance = torch.sum(torch.abs(layer.conv2d.weight.data), dim=(1, 2, 3))
                    print(importance)

In [87]:
mofa = sort_channels(mofa)

In [89]:
see_channel_importances(mofa)

tensor([52.2155, 49.9688, 48.7137, 47.7936, 43.0445, 41.2706, 41.1886, 39.0327,
        37.9445, 37.4216, 36.3574, 34.3126, 32.8328, 32.7713, 31.9814, 31.5286,
        30.7265, 29.8918, 29.8626, 29.7444, 29.4663, 29.3193, 29.0832, 28.8199,
        28.7962, 28.4378, 27.9413, 27.8418, 27.0945, 27.0935, 26.6975, 26.3430,
        26.1570, 26.1327, 26.1167, 26.0389, 25.8284, 25.3495, 25.2784, 25.2092,
        25.1281, 24.8885, 24.6520, 24.3132, 24.0459, 23.9084, 23.4436, 23.2229,
        23.1052, 22.8207, 22.7445, 22.6872, 22.5672, 22.5346, 22.3246, 22.2832,
        22.2546, 22.1327, 22.0078, 21.8441, 21.8223, 21.8218, 21.4960, 21.4390,
        21.3385, 20.9553, 20.9107, 20.8229, 20.7871, 20.6867, 20.5277, 20.5038,
        20.2273, 20.0780, 20.0688, 19.8455, 19.6425, 19.6403, 19.6170, 19.5640,
        19.4884, 19.4858, 19.4147, 19.2329, 19.1154, 19.1020, 18.8236, 18.6821,
        18.6073, 18.5862, 18.5680, 18.1918, 18.1833, 18.1732, 17.9338, 17.9270,
        17.8077, 17.6107, 17.5430, 17.40

**Random Subnet Test**

In [44]:
subnet = sample_subnet_kernel(mofa)

In [45]:
print(subnet.param_dict)

{'n_units': 5,
 'in_ch': 3,
 'out_class': 100,
 'width_list': [[256, 256, 256],
  [256, 256, 256],
  [256, 256, 256],
  [256, 256, 256],
  [256, 256, 256]],
 'kernel_list': [[1, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]],
 'bias_list': [[True, True, True],
  [True, True, True],
  [True, True, True],
  [True, True, True],
  [True, True, True]],
 'bn': False}

In [46]:
correct = 0
total = 0
subnet = subnet.to(device)
subnet.eval()
with torch.no_grad():
    for data in valset:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = subnet(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        break
#     val_accuracy = correct / total
#     if val_accuracy > best_val_accuracy:
#         if epoch is not 0:
#             os.remove(f'mofa_models/ofa_acc{100*best_val_accuracy:.0f}%.pth.tar')
#         torch.save(mofa, f'mofa_models/ofa_acc{100*val_accuracy:.0f}%.pth.tar')
#         best_val_accuracy = val_accuracy
print(correct / total)

0.674
