In [9]:
### MOFA: Maxim's Once For All ###
### This version is using the Kernel Transition Matrix by default. ###

In [11]:
import copy
import random
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt

from datasets.cifar100 import cifar100_get_datasets


# <ins>Classes</ins>

In [12]:
class Clamp(nn.Module):
    """
    Post-Activation Clamping Module
    Clamp the output to the given range (typically, [-128, +127])
    """
    def __init__(self, min_val=None, max_val=None):
        super().__init__()
        self.min_val = min_val
        self.max_val = max_val

    def forward(self, x):  # pylint: disable=arguments-differ
        """Forward prop"""
        return x.clamp(min=self.min_val, max=self.max_val)

In [13]:
class MOFAnet(nn.Module):
    # Maxim OFA Net
    def __init__(self, param_dict):
        super(MOFAnet, self).__init__()
        self.param_dict = param_dict
        self.in_ch = param_dict['in_ch']
        self.out_class = param_dict['out_class']
        self.n_units = param_dict['n_units']
        self.width_list = param_dict['width_list']
        self.kernel_list = param_dict['kernel_list']
        self.bias_list = param_dict['bias_list']
        self.bn = param_dict['bn']
        self.last_width = self.in_ch
        self.units = nn.ModuleList([])
        if 'depth_list' in param_dict:
            self.depth_list = param_dict['depth_list']
        else:
            self.depth_list = []
            for i in range(self.n_units):
                self.depth_list.append(len(self.kernel_list[i]))
        for i in range(self.n_units):
            self.units.append(Unit(self.depth_list[i], 
                                   self.kernel_list[i],
                                   self.width_list[i], 
                                   self.last_width, 
                                   self.bias_list[i],
                                   self.bn))
            self.last_width = self.width_list[i][-1]
        self.flatten = nn.Flatten()
        self.max_pool = nn.MaxPool2d(kernel_size=2)

        self.classifier = nn.Linear(512, self.out_class) 
    def forward(self, x):
        for i, unit in enumerate(self.units[:-1]):
            x = unit(x)
            x = self.max_pool(x)
        x = self.units[-1](x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [14]:
class Unit(nn.Module):
    def __init__(self, depth, kernel_list, 
                 width_list, init_width, bias_list, bn=True):
        super(Unit, self).__init__()
        self.depth = depth
        self.kernel_list = kernel_list
        self.width_list = width_list
        self.bias_list = bias_list
        self.bn = bn
        self._width_list = [init_width] + width_list
        self.layers = nn.ModuleList([])
        for i in range(depth):
            self.layers.append(
                FusedConv2dReLU(self._width_list[i],
                                self._width_list[i+1],
                                self.kernel_list[i],
                                self.bias_list[i],
                                self.bn))
    def forward(self, x):
        for i in range(self.depth):
            x = self.layers[i](x)
        return x

In [15]:
class FusedConv2dReLU(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            bias=True,
            bn=True):
        super(FusedConv2dReLU, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
            
        ktm_core = torch.zeros((9, 1))
        ktm_core[4] = 1
        self.ktm = nn.Parameter(data=ktm_core, requires_grad=True)
        
        if kernel_size == 1:
            self.pad = 0
        elif kernel_size == 3:
            self.pad = 1
        else:
            raise ValueError
        self.func = F.conv2d
        self.conv2d = nn.Conv2d(in_channels, out_channels,
                                kernel_size=3, stride=1,
                                padding=1, bias=bias)
        self.bn = bn
        if self.bn:
            self.batchnorm = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU()
        self.clamp = Clamp(min_val=-1, max_val=1)
    def forward(self, x):        
        weight = self.conv2d.weight[:self.out_channels, :self.in_channels, :, :]
        bias = self.conv2d.bias[:self.out_channels]
        if self.kernel_size == 1:
            flattened_weight = weight.view(weight.size(0), weight.size(1), -1, 9)
            weight = flattened_weight.to(device) @ self.ktm.to(device)
                    
        x = self.func(x, weight, bias, self.conv2d.stride, self.pad)
        if self.bn:
            x = self.batchnorm(x)
            x = x / 4
        x = self.activation(x)
        x = self.clamp(x)
        return x

# <ins>Functions</ins>

## Batchnorm Related 

In [16]:
def make_bn_stats_false(model):
    for u_ind, unit in enumerate(model.units):
        for l_ind, layer in enumerate(unit.layers):
            model.units[u_ind].layers[l_ind].batchnorm.track_running_stats = False

    return model

def make_bn_stats_true(model):
    for u_ind, unit in enumerate(model.units):
        for l_ind, layer in enumerate(unit.layers):
            model.units[u_ind].layers[l_ind].batchnorm.track_running_stats = True

    return model

def fuse_bn(conv, bn):
    w = conv.weight
    mean = bn.running_mean
    var_sqrt = torch.sqrt(bn.running_var + bn.eps)
    beta = bn.weight
    beta = beta * 0.25
    gamma = bn.bias
    gamma = gamma * 0.25
    if conv.bias is not None:
        b = conv.bias
    else:
        b = mean.new_zeros(mean.shape)
    w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
    b = (b - mean) / var_sqrt * beta + gamma
    fused_conv = nn.Conv2d(conv.in_channels,
                         conv.out_channels,
                         conv.kernel_size,
                         conv.stride,
                         conv.padding,
                         bias=True)
    fused_conv.weight = nn.Parameter(w)
    fused_conv.bias = nn.Parameter(b)
    return fused_conv


def fuse_bn_mofa(mofa_net):
    param_dict = copy.deepcopy(mofa_net.param_dict)
    if param_dict['bn'] == False:
        return mofa_net
    else:
        param_dict['bn'] = False
        fused_model = MOFAnet(param_dict)
        with torch.no_grad():
            fused_model.classifier.weight.copy_(mofa_net.classifier.weight)
            fused_model.classifier.bias.copy_(mofa_net.classifier.bias)
        for u_ind, unit in enumerate(mofa_net.units):
            for l_ind, layer in enumerate(unit.layers):
                fused_conv = fuse_bn(layer.conv2d, layer.batchnorm)
                fused_conv = fused_conv.to(device)
                with torch.no_grad():
                    fused_model.units[u_ind].layers[l_ind].conv2d.weight.copy_(fused_conv.weight)
                    fused_model.units[u_ind].layers[l_ind].conv2d.bias.copy_(fused_conv.bias)
        return fused_model

## Elastic Kernel - Depth -Width 

In [17]:
# def sample_subnet_kernel(mofa):
#     param_dict = copy.deepcopy(mofa.param_dict)
#     for u_ind, unit in enumerate(mofa.units):
#         for l_ind, layer in enumerate(unit.layers):
#             param_dict['kernel_list'][u_ind][l_ind] = random.choice([1, 3])
# #     param_dict['kernel_list'][0][0] = random.choice([1, 3])
# #     param_dict['kernel_list'][0][1] = random.choice([1, 3])
# #     param_dict['kernel_list'][0][2] = random.choice([1, 3])
#     param_dict['bn'] = False
#     subnet = MOFAnet(param_dict)
#     with torch.no_grad():
#         subnet.classifier.weight.copy_(mofa.classifier.weight)
#         subnet.classifier.bias.copy_(mofa.classifier.bias)
#         for u_ind, unit in enumerate(mofa.units):
#             for l_ind, layer in enumerate(unit.layers):
#                 subnet.units[u_ind].layers[l_ind].conv2d.weight.copy_(mofa.units[u_ind].layers[l_ind].conv2d.weight)
#                 subnet.units[u_ind].layers[l_ind].ktm.copy_(mofa.units[u_ind].layers[l_ind].ktm)
#                 if mofa.bias_list[u_ind][l_ind] is True:
#                     subnet.units[u_ind].layers[l_ind].conv2d.bias.copy_(mofa.units[u_ind].layers[l_ind].conv2d.bias)
#     return subnet


# def update_mofa_from_subnet_kernel(mofa, subnet):
#     with torch.no_grad():
#         mofa.classifier.weight.copy_(subnet.classifier.weight)
#         mofa.classifier.bias.copy_(subnet.classifier.bias)
#         for u_ind, unit in enumerate(mofa.units):
#             for l_ind, layer in enumerate(unit.layers):
#                 mofa.units[u_ind].layers[l_ind].conv2d.weight.copy_(subnet.units[u_ind].layers[l_ind].conv2d.weight)
#                 mofa.units[u_ind].layers[l_ind].ktm.copy_(subnet.units[u_ind].layers[l_ind].ktm)
#                 if mofa.bias_list[u_ind][l_ind] is True:
#                     mofa.units[u_ind].layers[l_ind].conv2d.bias.copy_(subnet.units[u_ind].layers[l_ind].conv2d.bias)
#     return mofa


# def sample_subnet_depth(mofa, sample_kernel=True):
#     param_dict = copy.deepcopy(mofa.param_dict)
#     depth_list = []
#     for u_ind in range(len(param_dict['width_list'])):
#         max_depth = len(param_dict['width_list'][u_ind])
#         min_depth = 1
#         depth_list.append(random.randint(min_depth, max_depth))
    
#     if sample_kernel:
#         subnet = sample_subnet_kernel(mofa) # This is confirmed by Ji
#     else:
#         subnet = copy.deepcopy(mofa)
    
#     param_dict = copy.deepcopy(subnet.param_dict)
#     param_dict['bn'] = False
#     param_dict['width_list'] = [lst[:depth_list[ind]] for ind, lst in enumerate(subnet.param_dict['width_list'])]
#     param_dict['kernel_list'] = [lst[:depth_list[ind]] for ind, lst in enumerate(subnet.param_dict['kernel_list'])]
#     param_dict['bias_list'] = [lst[:depth_list[ind]] for ind, lst in enumerate(subnet.param_dict['bias_list'])]
    
#     subnet2 = MOFAnet(param_dict)
    
#     with torch.no_grad():
#         subnet2.classifier.weight.copy_(subnet.classifier.weight)
#         subnet2.classifier.bias.copy_(subnet.classifier.bias)
#         for u_ind, unit in enumerate(subnet2.units):
#             for l_ind, layer in enumerate(unit.layers):
#                 subnet2.units[u_ind].layers[l_ind].conv2d.weight.copy_(subnet.units[u_ind].layers[l_ind].conv2d.weight)
#                 subnet2.units[u_ind].layers[l_ind].conv2d.bias.copy_(subnet.units[u_ind].layers[l_ind].conv2d.bias)
        
#     return subnet2, param_dict, depth_list

# def update_mofa_from_subnet_depth(mofa, subnet):
#     subnet_params = subnet.param_dict
#     mofa_params = mofa.param_dict
    
#     with torch.no_grad():
#         mofa.classifier.weight.copy_(subnet.classifier.weight)
#         mofa.classifier.bias.copy_(subnet.classifier.bias)
#         for u_ind, unit in enumerate(subnet.units):
#                 for l_ind, layer in enumerate(unit.layers):
#                     mofa.units[u_ind].layers[l_ind].conv2d.weight.copy_(subnet.units[u_ind].layers[l_ind].conv2d.weight)
#                     mofa.units[u_ind].layers[l_ind].ktm.copy_(subnet.units[u_ind].layers[l_ind].ktm)
#                     if mofa.bias_list[u_ind][l_ind] is True:
#                         mofa.units[u_ind].layers[l_ind].conv2d.bias.copy_(subnet.units[u_ind].layers[l_ind].conv2d.bias)
#     return mofa


# def sample_subnet_width(mofa, possible_width_list, sample_kernel_depth=True):
#     if sample_kernel_depth:
#         mofa = sample_subnet_depth(mofa, sample_kernel=True)
#     param_dict = mofa.param_dict
#     for u_ind, unit in enumerate(mofa.units):
#         for l_ind, layer in enumerate(unit.layers):
#             if u_ind == 0 and l_ind == 0:
#                     last_out_ch = layer.in_channels
#             layer.in_channels = last_out_ch        
#             if u_ind == (param_dict['n_units']-1) and l_ind == (len(param_dict['kernel_list'][u_ind])-1):
#                 continue
#             else:
#                 param_dict['width_list'][u_ind][l_ind] = random.choice(possible_width_list)
#                 layer.out_channels = param_dict['width_list'][u_ind][l_ind]
#                 last_out_ch = layer.out_channels
#                 # this comes from elastic depth
#                 if l_ind == param_dict['depth_list'][u_ind] - 1:
#                     break
#     return mofa
            
            
# def update_mofa_from_subnet_width(mofa, possible_width_list):
#     mofa = update_mofa_from_subnet_depth(_, mofa)
#     param_dict = mofa.param_dict
#     max_width = np.max(possible_width_list)
#     param_dict = mofa.param_dict
#     for u_ind, unit in enumerate(mofa.units):
#         for l_ind, layer in enumerate(unit.layers):
#             if u_ind == (param_dict['n_units']-1) and l_ind == (len(param_dict['kernel_list'][u_ind])-1):
#                 layer.in_channels = max_width
#             else:
#                 param_dict['width_list'][u_ind][l_ind] = max_width
#                 if u_ind == 0 and l_ind == 0:
#                     pass
#                 else:
#                     layer.in_channels = max_width
#                 layer.out_channels = max_width
#     return mofa
            


def sample_subnet_kernel(mofa):
    param_dict = mofa.param_dict
    for u_ind, unit in enumerate(mofa.units):
        for l_ind, layer in enumerate(unit.layers):
            param_dict['kernel_list'][u_ind][l_ind] = random.choice([1, 3])
            layer.kernel_size = param_dict['kernel_list'][u_ind][l_ind]
            if layer.kernel_size == 1:
                layer.pad = 0
    return mofa

def update_mofa_from_subnet_kernel(mofa):
    param_dict = mofa.param_dict
    param_dict['kernel_list'] = []
    for u_ind, unit in enumerate(mofa.units):
        param_dict['kernel_list'].append([])
        for l_ind, layer in enumerate(unit.layers):
            param_dict['kernel_list'][u_ind].append(layer.conv2d.kernel_size[0])
            layer.kernel_size = layer.conv2d.kernel_size
            layer.pad = layer.conv2d.padding
    return mofa

def sample_subnet_depth(mofa, sample_kernel=True):
    if sample_kernel:
        mofa = sample_subnet_kernel(mofa)
    param_dict = mofa.param_dict
    for u_ind, unit in enumerate(mofa.units):
        max_depth = param_dict['depth_list'][u_ind]
        min_depth = 1
        random_depth = random.randint(min_depth, max_depth)
        param_dict['depth_list'][u_ind] = random_depth
        param_dict['kernel_list'][u_ind] = param_dict['kernel_list'][u_ind][:random_depth]
        param_dict['width_list'][u_ind] = param_dict['width_list'][u_ind][:random_depth]
        param_dict['bias_list'][u_ind] = param_dict['bias_list'][u_ind][:random_depth]
        unit.depth = param_dict['depth_list'][u_ind]
    return mofa

def update_mofa_from_subnet_depth(mofa):
    mofa = update_mofa_from_subnet_kernel(mofa)
    param_dict = mofa.param_dict
    param_dict['width_list'] = []
    param_dict['bias_list'] = []
    for u_ind, unit in enumerate(mofa.units):
        max_depth = len(mofa.kernel_list[u_ind])
        param_dict['depth_list'][u_ind] = max_depth
        unit.depth = max_depth
        param_dict['width_list'].append([])
        param_dict['bias_list'].append([])
        for l_ind, layer in enumerate(unit.layers):
            param_dict['width_list'][u_ind].append(layer.conv2d.out_channels)
            param_dict['bias_list'][u_ind].append(layer.conv2d.bias is not None)
    return mofa


def sample_subnet_width(mofa, possible_width_list, sample_kernel_depth=True):
    if sample_kernel_depth:
        mofa = sample_subnet_depth(mofa, sample_kernel=True)
    param_dict = mofa.param_dict
    for u_ind, unit in enumerate(mofa.units):
        for l_ind in range(param_dict['depth_list'][u_ind]):
            layer = mofa.units[u_ind].layers[l_ind]
            if not(u_ind == 0 and l_ind == 0):
                    layer.in_channels = last_out_ch
            if u_ind == (param_dict['n_units'] - 1) and l_ind == (param_dict['depth_list'][u_ind] - 1):
                param_dict['width_list'][u_ind][l_ind] = mofa.units[-1].layers[-1].conv2d.out_channels
                layer.out_channels = mofa.units[-1].layers[-1].conv2d.out_channels
            else:
                possible_width_list = np.array(possible_width_list)
                valid_inds = possible_width_list <= mofa.units[u_ind].layers[l_ind].conv2d.out_channels
                possible_width_list = possible_width_list[valid_inds]
                random_width = random.choice(possible_width_list)
                param_dict['width_list'][u_ind][l_ind] = random_width
                layer.out_channels = random_width
                last_out_ch = layer.out_channels
    return mofa       


def update_mofa_from_subnet_width(mofa):
    mofa = update_mofa_from_subnet_depth(mofa)
    param_dict = mofa.param_dict
    for u_ind, unit in enumerate(mofa.units):
        for l_ind, layer in enumerate(unit.layers):
            param_dict['width_list'][u_ind][l_ind] = layer.conv2d.out_channels
            layer.out_channels = layer.conv2d.out_channels
            layer.in_channels = layer.conv2d.in_channels
    return mofa        
            
            
def sort_channels(mofa):
    for ind in range((mofa.n_units*mofa.param_dict['depth_list'][0])-1):
        u_ind = ind // n_layers
        l_ind = ind % n_layers
        layer = mofa.units[u_ind].layers[l_ind]
        
        importance = torch.sum(torch.abs(layer.conv2d.weight.data), dim=(1, 2, 3))
        _, inds = torch.sort(importance, descending=True)
        layer.conv2d.weight.data = layer.conv2d.weight.data[inds, :, :, :]
        layer.conv2d.bias.data = layer.conv2d.bias.data[inds]
        
        ind_new = ind + 1
        u_ind = ind_new // n_layers
        l_ind = ind_new % n_layers
        mofa.units[u_ind].layers[l_ind].conv2d.weight.data = mofa.units[u_ind].layers[l_ind].conv2d.weight.data[:, inds, :, :]
    return mofa


In [18]:
# import pprint
# pp = pprint.PrettyPrinter(indent=4)
# class bcolors:
#     HEADER = '\033[95m'
#     OKBLUE = '\033[94m'
#     OKCYAN = '\033[96m'
#     OKGREEN = '\033[92m'
#     WARNING = '\033[93m'
#     FAIL = '\033[91m'
#     ENDC = '\033[0m'
#     BOLD = '\033[1m'
#     UNDERLINE = '\033[4m'



# # model = torch.load('mofa_models/noclamp_mofa_acc71%.pth.tar')
# model = torch.load('mofa_models/EK_over4_clamp_mofa_acc64%_ep3000.pth.tar')
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_cuda else "cpu")

# fused_model = fuse_bn_mofa(model)
# # fused_model.param_dict = param_dict
# mofa = copy.deepcopy(fused_model)
# fused_model = fused_model.to(device)
# mofa = mofa.to(device)

# possible_width_list = [256, 128, 64]

# pp.pprint(fused_model.param_dict)
# print('\n############SAMPLE WIDTH############\n')

# mofa = sample_subnet_width(mofa, possible_width_list)
# pp.pprint(mofa.param_dict)
# for u_ind, unit in enumerate(mofa.units):
#         for l_ind, layer in enumerate(unit.layers):
#             if l_ind < mofa.param_dict['depth_list'][u_ind]:
#                 print(f'{bcolors.FAIL}Unit {u_ind} - Layer {l_ind} In channel: {layer.in_channels}'+ 
#                       f' - Out channel: {layer.out_channels}{bcolors.ENDC}')
#             else:
#                 print(f'Unit {u_ind} - Layer {l_ind} In channel: {layer.in_channels}'+ 
#                       f' - Out channel: {layer.out_channels}')
#         print('-------------------')

# print('\n############UPDATE WIDTH############\n')

# mofa = update_mofa_from_subnet_width(mofa)
# pp.pprint(mofa.param_dict)
# for u_ind, unit in enumerate(mofa.units):
#         for l_ind, layer in enumerate(unit.layers):
#             if l_ind < mofa.param_dict['depth_list'][u_ind]:
#                 print(f'{bcolors.FAIL}Unit {u_ind} - Layer {l_ind} In channel: {layer.in_channels}'+ 
#                       f' - Out channel: {layer.out_channels}{bcolors.ENDC}')
#             else:
#                 print(f'Unit {u_ind} - Layer {l_ind} In channel: {layer.in_channels}'+ 
#                       f' - Out channel: {layer.out_channels}')
#         print('-------------------')

        
# # print('\n############SAMPLE DEPTH############\n')


# # mofa = sample_subnet_depth(mofa)
# # pp.pprint(mofa.param_dict)
# # for u_ind, unit in enumerate(mofa.units):
# #         for l_ind, layer in enumerate(unit.layers):
# #             if l_ind < mofa.param_dict['depth_list'][u_ind]:
# #                 print(f'{bcolors.FAIL}Unit {u_ind} - Layer {l_ind} In channel: {layer.in_channels}'+ 
# #                       f' - Out channel: {layer.out_channels}{bcolors.ENDC}')
# #             else:
# #                 print(f'Unit {u_ind} - Layer {l_ind} In channel: {layer.in_channels}'+ 
# #                       f' - Out channel: {layer.out_channels}')
# #         print('-------------------')
        

# # print('\n############UPDATE DEPTH############\n')

# # mofa = update_mofa_from_subnet_depth(mofa)
# # pp.pprint(mofa.param_dict)
# # for u_ind, unit in enumerate(mofa.units):
# #         for l_ind, layer in enumerate(unit.layers):
# #             if l_ind < mofa.param_dict['depth_list'][u_ind]:
# #                 print(f'{bcolors.FAIL}Unit {u_ind} - Layer {l_ind} In channel: {layer.in_channels}'+ 
# #                       f' - Out channel: {layer.out_channels}{bcolors.ENDC}')
# #             else:
# #                 print(f'Unit {u_ind} - Layer {l_ind} In channel: {layer.in_channels}'+ 
# #                       f' - Out channel: {layer.out_channels}')
# #         print('-------------------')


## Others

In [19]:
def cross_entropy_loss_with_soft_target(pred, soft_target):
    logsoftmax = nn.LogSoftmax()
    return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))

def see_channel_importances(mofa):
    for u_ind, unit in enumerate(mofa.units):
                for l_ind, layer in enumerate(unit.layers):
                    importance = torch.sum(torch.abs(layer.conv2d.weight.data), dim=(1, 2, 3))
                    print(importance)


# <ins> MOFA Training </ins>

In [20]:
n_units = 5
# n_layers = 3

# param_dict = {}
# param_dict['n_units']     = n_units
# param_dict['in_ch']       = 3
# param_dict['out_class']   = 100
# param_dict['depth_list']  = [3, 3, 3, 3, 3]
# param_dict['width_list']  = [[256]*n_layers for _ in range(n_units)]
# param_dict['kernel_list'] = [[3]*n_layers for _ in range(n_units)]
# param_dict['bias_list']   = [[True]*n_layers for _ in range(n_units)]
# param_dict['bn']          = True


param_dict = {}
param_dict['n_units']     = n_units
param_dict['in_ch']       = 3
param_dict['out_class']   = 100
param_dict['depth_list']  = [4, 3, 3, 3, 2]
param_dict['width_list']  = [[64, 64, 64, 64], [64, 64, 64], [128, 128, 128], [128, 128, 128], [128, 128]]
param_dict['kernel_list'] = [[3, 3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3]]
param_dict['bias_list']   = [[True, True, True, True], [True, True, True], [True, True, True], [True, True, True], [True, True]]
param_dict['bn']          = True



In [21]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

class Args():
    def __init__(self):
        super(Args, self).__init__()
        self.truncate_testset = False
        self.act_mode_8bit = False
        
args = Args()
train_dataset, test_dataset = cifar100_get_datasets(('data', args))

trainset = DataLoader(dataset=train_dataset,
                      batch_size=100,
                      shuffle=True,
                      num_workers=0)

valset = DataLoader(dataset=test_dataset,
                      batch_size=1000,
                      shuffle=False,
                      num_workers=0)

mofa = MOFAnet(param_dict)
mofa = mofa.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mofa.parameters(), lr=1e-4)
scheduler = StepLR(optimizer, step_size=35, gamma=0.5)

best_val_accuracy = 0
max_epochs = 150
for epoch in range(max_epochs):
    t0 = time.time()
    mofa.train()
    for batch, labels in trainset:
        batch, labels = batch.to(device), labels.to(device)
        
        y_pred = mofa(batch)
        loss = criterion(y_pred, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()
    
    print(f'Epoch {epoch+1}')
    print(f'\tTraining loss:{loss.item()}')
    t1 = time.time()
    print(f'\tTraining time:{t1-t0:.2f} s - {(t1-t0)/60:.2f} mins ')
    # Validation
    correct = 0
    total = 0
    mofa.eval()
    with torch.no_grad():
        for data in valset:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = mofa(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_accuracy = correct / total
    if val_accuracy > best_val_accuracy:
        if epoch != 0:
            os.remove(f'mofa_models/arch_1/over4_clamp_mofa_acc{100*best_val_accuracy:.0f}%.pth.tar')
        torch.save(mofa, f'mofa_models/arch_1/over4_clamp_mofa_acc{100*val_accuracy:.0f}%.pth.tar')
        best_val_accuracy = val_accuracy
    print('\tAccuracy of the mofa on the test images: %d %%' % (
        100 * correct / total))
    

Files already downloaded and verified
Files already downloaded and verified
Epoch 1
	Training loss:3.943730115890503
	Training time:15.64 s - 0.26 mins 
	Accuracy of the mofa on the test images: 12 %
Epoch 2
	Training loss:3.50026798248291
	Training time:15.52 s - 0.26 mins 
	Accuracy of the mofa on the test images: 17 %
Epoch 3
	Training loss:3.1036670207977295
	Training time:15.56 s - 0.26 mins 
	Accuracy of the mofa on the test images: 23 %
Epoch 4
	Training loss:2.8937342166900635
	Training time:15.59 s - 0.26 mins 
	Accuracy of the mofa on the test images: 30 %
Epoch 5
	Training loss:2.6998748779296875
	Training time:15.61 s - 0.26 mins 
	Accuracy of the mofa on the test images: 32 %
Epoch 6
	Training loss:2.419478416442871
	Training time:15.62 s - 0.26 mins 
	Accuracy of the mofa on the test images: 37 %
Epoch 7
	Training loss:2.520132064819336
	Training time:15.63 s - 0.26 mins 
	Accuracy of the mofa on the test images: 38 %
Epoch 8
	Training loss:1.8047605752944946
	Training ti

# <ins> Elastic Kernel Training </ins>

In [61]:
# model = torch.load('mofa_models/noclamp_mofa_acc71%.pth.tar')
model = torch.load('mofa_models/over4_clamp_mofa_acc64%.pth.tar')
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [62]:
fused_model = fuse_bn_mofa(model)
# fused_model.param_dict = param_dict
mofa = copy.deepcopy(fused_model)
fused_model = fused_model.to(device)
mofa = mofa.to(device)
# mofa.param_dict = param_dict

In [None]:
kd_ratio = 0.5

class Args():
    def __init__(self):
        super(Args, self).__init__()
        self.truncate_testset = False
        self.act_mode_8bit = False
        
args = Args()
train_dataset, test_dataset = cifar100_get_datasets(('data', args))

trainset = DataLoader(dataset=train_dataset,
                      batch_size=100,
                      shuffle=True,
                      num_workers=0)

valset = DataLoader(dataset=test_dataset,
                      batch_size=1000,
                      shuffle=False,
                      num_workers=0)

criterion = torch.nn.CrossEntropyLoss()

best_val_accuracy = 0
max_epochs = 1401
# optimizer = torch.optim.SGD(mofa.parameters(), lr=1e-3, momentum=0.9)
for epoch in range(max_epochs):
    t0 = time.time()
    mofa.train()
    for batch, labels in trainset:
        batch, labels = batch.to(device), labels.to(device)
        
#         mofa = make_bn_stats_false(mofa)
        subnet = sample_subnet_kernel(mofa)
        subnet = subnet.to(device)
        optimizer = torch.optim.SGD(subnet.parameters(), lr=1e-3, momentum=0.9)
      
        y_pred = subnet(batch)
        
        if kd_ratio > 0:
            fused_model.train()
            with torch.no_grad():
                soft_logits = fused_model(batch).detach()
                soft_label = F.softmax(soft_logits, dim=1)
            kd_loss = cross_entropy_loss_with_soft_target(y_pred, soft_label)
            loss = kd_ratio * kd_loss + criterion(y_pred, labels)
        else:
            loss = criterion(y_pred, labels)     
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        mofa = update_mofa_from_subnet_kernel(mofa)
        
    print(f'Epoch {epoch+1}')
    print(f'\tTraining loss:{loss.item()}')
    t1 = time.time()
    print(f'\tTraining time:{t1-t0:.2f} s - {(t1-t0)/60:.2f} mins ')
    
    # Validation
    correct = 0
    total = 0
    mofa.eval()
    with torch.no_grad():
#         mofa = make_bn_stats_true(mofa)
#         mofa.train()
#         for data in valset:
#             images, labels = data
#             images, labels = images.to(device), labels.to(device)
#             outputs = mofa(images)
#         mofa.eval()
        for data in valset:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = mofa(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_accuracy = correct / total
#     if val_accuracy > best_val_accuracy:
#         if epoch is not 0:
#             os.remove(f'mofa_models/ofa_acc{100*best_val_accuracy:.0f}%.pth.tar')
#         torch.save(mofa, f'mofa_models/ofa_acc{100*val_accuracy:.0f}%.pth.tar')
#         best_val_accuracy = val_accuracy
    print('\tAccuracy of the mofa on the test images: %d %%' % (
        100 * correct / total))
    print(f'\tFirst ktm: {mofa.units[0].layers[0].ktm[4].item()}')
    print(f'\tLast ktm: {mofa.units[-1].layers[-1].ktm[4].item()}')
    if epoch % 200 == 0:
        torch.save(mofa, f'mofa_models/EK_noclamp_mofa_acc{100*val_accuracy:.0f}%_ep{epoch}.pth.tar')

Files already downloaded and verified
Files already downloaded and verified


  return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))


Epoch 1
	Training loss:7.023433685302734
	Training time:15.34 s - 0.26 mins 
	Accuracy of the mofa on the test images: 17 %
	First ktm: 0.9999492168426514
	Last ktm: 0.9742915630340576
Epoch 2
	Training loss:6.945072174072266
	Training time:15.32 s - 0.26 mins 
	Accuracy of the mofa on the test images: 14 %
	First ktm: 0.9999555349349976
	Last ktm: 0.9574437737464905
Epoch 3
	Training loss:6.931275844573975
	Training time:15.34 s - 0.26 mins 
	Accuracy of the mofa on the test images: 13 %
	First ktm: 0.9999498128890991
	Last ktm: 0.9497938752174377
Epoch 4
	Training loss:6.909799575805664
	Training time:15.39 s - 0.26 mins 
	Accuracy of the mofa on the test images: 12 %
	First ktm: 0.9999447464942932
	Last ktm: 0.944175124168396
Epoch 5
	Training loss:6.87863826751709
	Training time:15.39 s - 0.26 mins 
	Accuracy of the mofa on the test images: 12 %
	First ktm: 0.9999633431434631
	Last ktm: 0.9412333965301514
Epoch 6
	Training loss:6.9111552238464355
	Training time:15.36 s - 0.26 mins 

# <ins> Elastic Depth Training </ins>

In [17]:
# model = torch.load('mofa_models/EK_noclamp_mofa_acc71%_ep1400.pth.tar')
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_cuda else "cpu")

In [18]:
# fused_model = fuse_bn_mofa(model)
# fused_model.param_dict = param_dict
# mofa = copy.deepcopy(fused_model)
# fused_model = fused_model.to(device)
# mofa = mofa.to(device)

In [19]:
# kd_ratio = 0.5

# class Args():
#     def __init__(self):
#         super(Args, self).__init__()
#         self.truncate_testset = False
#         self.act_mode_8bit = False
        
# args = Args()
# train_dataset, test_dataset = cifar100_get_datasets(('data', args))

# trainset = DataLoader(dataset=train_dataset,
#                       batch_size=100,
#                       shuffle=True,
#                       num_workers=0)

# valset = DataLoader(dataset=test_dataset,
#                       batch_size=1000,
#                       shuffle=False,
#                       num_workers=0)

# criterion = torch.nn.CrossEntropyLoss()

# best_val_accuracy = 0
# max_epochs = 3001
# for epoch in range(max_epochs):
#     t0 = time.time()
#     mofa.train()
#     for batch, labels in trainset:
#         batch, labels = batch.to(device), labels.to(device)
        
# #         mofa = make_bn_stats_false(mofa)
#         subnet = sample_subnet_depth(mofa)
#         subnet = subnet.to(device)
#         optimizer = torch.optim.SGD(subnet.parameters(), lr=1e-3, momentum=0.9)
      
#         y_pred = subnet(batch)
        
#         if kd_ratio > 0:
#             fused_model.train()
#             with torch.no_grad():
#                 soft_logits = fused_model(batch).detach()
#                 soft_label = F.softmax(soft_logits, dim=1)
#             kd_loss = cross_entropy_loss_with_soft_target(y_pred, soft_label)
#             loss = kd_ratio * kd_loss + criterion(y_pred, labels)
#         else:
#             loss = criterion(y_pred, labels)     
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         mofa = update_mofa_from_subnet_depth(mofa)
        
#     print(f'Epoch {epoch+1}')
#     print(f'\tTraining loss:{loss.item()}')
#     t1 = time.time()
#     print(f'\tTraining time:{t1-t0:.2f} s - {(t1-t0)/60:.2f} mins ')
    
#     # Validation
#     correct = 0
#     total = 0
#     mofa.eval()
#     with torch.no_grad():
# #         mofa = make_bn_stats_true(mofa)
# #         mofa.train()
# #         for data in valset:
# #             images, labels = data
# #             images, labels = images.to(device), labels.to(device)
# #             outputs = mofa(images)
# #         mofa.eval()
#         for data in valset:
#             images, labels = data
#             images, labels = images.to(device), labels.to(device)
#             outputs = mofa(images)
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
#     val_accuracy = correct / total
# #     if val_accuracy > best_val_accuracy:
# #         if epoch is not 0:
# #             os.remove(f'mofa_models/ofa_acc{100*best_val_accuracy:.0f}%.pth.tar')
# #         torch.save(mofa, f'mofa_models/ofa_acc{100*val_accuracy:.0f}%.pth.tar')
# #         best_val_accuracy = val_accuracy
#     print('\tAccuracy of the mofa on the test images: %d %%' % (
#         100 * correct / total))
#     print(f'\tFirst ktm: {mofa.units[0].layers[0].ktm[4].item()}')
#     print(f'\tLast ktm: {mofa.units[4].layers[2].ktm[4].item()}')
#     if epoch % 200 == 0:
#         torch.save(mofa, f'mofa_models/ED_noclamp_mofa_acc{100*val_accuracy:.0f}%_ep{epoch}.pth.tar')

# <ins> Elastic Width Training </ins>

In [64]:
model = torch.load('mofa_models/ED_noclamp_mofa_acc64%_ep3000.pth.tar')
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [65]:
fused_model = model
fused_model.param_dict = param_dict
mofa = copy.deepcopy(fused_model)
fused_model = fused_model.to(device)
mofa = mofa.to(device)

In [None]:
kd_ratio = 0.5

class Args():
    def __init__(self):
        super(Args, self).__init__()
        self.truncate_testset = False
        self.act_mode_8bit = False
        
args = Args()
train_dataset, test_dataset = cifar100_get_datasets(('data', args))

trainset = DataLoader(dataset=train_dataset,
                      batch_size=100,
                      shuffle=True,
                      num_workers=0)

valset = DataLoader(dataset=test_dataset,
                      batch_size=1000,
                      shuffle=False,
                      num_workers=0)

criterion = torch.nn.CrossEntropyLoss()

best_val_accuracy = 0
max_epochs = 30000
for epoch in range(max_epochs):
    t0 = time.time()
    mofa.train()
    for batch, labels in trainset:
        batch, labels = batch.to(device), labels.to(device)
        
#         mofa = make_bn_stats_false(mofa)
        mofa = sort_channels(mofa)
        subnet = sample_subnet_width(mofa, [256, 128, 64], True)
        subnet = subnet.to(device)
        optimizer = torch.optim.SGD(subnet.parameters(), lr=1e-3, momentum=0.9)
      
        y_pred = subnet(batch)
        
        if kd_ratio > 0:
            fused_model.train()
            with torch.no_grad():
                soft_logits = fused_model(batch).detach()
                soft_label = F.softmax(soft_logits, dim=1)
            kd_loss = cross_entropy_loss_with_soft_target(y_pred, soft_label)
            loss = kd_ratio * kd_loss + criterion(y_pred, labels)
        else:
            loss = criterion(y_pred, labels)     
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        mofa = update_mofa_from_subnet_width(mofa)
        
    print(f'Epoch {epoch+1}')
    print(f'\tTraining loss:{loss.item()}')
    t1 = time.time()
    print(f'\tTraining time:{t1-t0:.2f} s - {(t1-t0)/60:.2f} mins ')
    
    # Validation
    correct = 0
    total = 0
    mofa.eval()
    with torch.no_grad():
#         mofa = make_bn_stats_true(mofa)
#         mofa.train()
#         for data in valset:
#             images, labels = data
#             images, labels = images.to(device), labels.to(device)
#             outputs = mofa(images)
#         mofa.eval()
        for data in valset:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = mofa(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_accuracy = correct / total
#     if val_accuracy > best_val_accuracy:
#         if epoch is not 0:
#             os.remove(f'mofa_models/ofa_acc{100*best_val_accuracy:.0f}%.pth.tar')
#         torch.save(mofa, f'mofa_models/ofa_acc{100*val_accuracy:.0f}%.pth.tar')
#         best_val_accuracy = val_accuracy
    print('\tAccuracy of the mofa on the test images: %d %%' % (
        100 * correct / total))
    print(f'\tFirst ktm: {mofa.units[0].layers[0].ktm[4].item()}')
    print(f'\tLast ktm: {mofa.units[4].layers[2].ktm[4].item()}')
    if epoch % 200 == 0:
        torch.save(mofa, f'mofa_models/EW_noclamp_mofa_acc{100*val_accuracy:.0f}%_ep{epoch}.pth.tar')

Files already downloaded and verified
Files already downloaded and verified


  return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))


Epoch 1
	Training loss:7.553444862365723
	Training time:18.08 s - 0.30 mins 
	Accuracy of the mofa on the test images: 52 %
	First ktm: 1.7284648418426514
	Last ktm: 1.0730682611465454
Epoch 2
	Training loss:7.455113410949707
	Training time:18.27 s - 0.30 mins 
	Accuracy of the mofa on the test images: 48 %
	First ktm: 1.7239234447479248
	Last ktm: 1.0207839012145996
Epoch 3
	Training loss:7.234837055206299
	Training time:18.44 s - 0.31 mins 
	Accuracy of the mofa on the test images: 47 %
	First ktm: 1.721926212310791
	Last ktm: 0.9984390735626221
Epoch 4
	Training loss:7.382931709289551
	Training time:18.19 s - 0.30 mins 
	Accuracy of the mofa on the test images: 46 %
	First ktm: 1.719703197479248
	Last ktm: 0.982803463935852
Epoch 5
	Training loss:7.08633279800415
	Training time:18.49 s - 0.31 mins 
	Accuracy of the mofa on the test images: 45 %
	First ktm: 1.717010736465454
	Last ktm: 0.9671750068664551
Epoch 6
	Training loss:7.237740516662598
	Training time:18.19 s - 0.30 mins 
	Ac

In [46]:
mofa.param_dict

{'n_units': 5,
 'in_ch': 3,
 'out_class': 100,
 'depth_list': [3, 3, 2, 1, 1],
 'width_list': [[256, 64, 256], [128, 64, 64], [256, 64], [128], [256]],
 'kernel_list': [[1, 1, 3], [3, 1, 3], [1, 1], [3], [1]],
 'bias_list': [[True, True, True],
  [True, True, True],
  [True, True],
  [True],
  [True]],
 'bn': False}

In [60]:
# for i in range(256):
#     print(str(i)+':')
#     print(fused_model.units[0].layers[1].conv2d.weight.data[i, 0, :, :] == mofa.units[0].layers[1].conv2d.weight.data[i, 0, :, :])

# <ins> Testing <ins/>

**MOFA Test**

In [39]:
class Args():
    def __init__(self):
        super(Args, self).__init__()
        self.truncate_testset = False
        self.act_mode_8bit = False
        
args = Args()
train_dataset, test_dataset = cifar100_get_datasets(('data', args))

trainset = DataLoader(dataset=train_dataset,
                      batch_size=100,
                      shuffle=True,
                      num_workers=0)

valset = DataLoader(dataset=test_dataset,
                      batch_size=100,
                      shuffle=False,
                      num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


In [54]:
correct = 0
total = 0
mofa.eval()
with torch.no_grad():
    for data in valset:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = mofa(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(correct / total)

0.6428


**Random Subnet Test**

In [44]:
subnet = sample_subnet_kernel(mofa)

In [45]:
print(subnet.param_dict)

{'n_units': 5,
 'in_ch': 3,
 'out_class': 100,
 'width_list': [[256, 256, 256],
  [256, 256, 256],
  [256, 256, 256],
  [256, 256, 256],
  [256, 256, 256]],
 'kernel_list': [[1, 3, 1], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]],
 'bias_list': [[True, True, True],
  [True, True, True],
  [True, True, True],
  [True, True, True],
  [True, True, True]],
 'bn': False}

In [46]:
correct = 0
total = 0
subnet = subnet.to(device)
subnet.eval()
with torch.no_grad():
    for data in valset:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = subnet(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        break
#     val_accuracy = correct / total
#     if val_accuracy > best_val_accuracy:
#         if epoch is not 0:
#             os.remove(f'mofa_models/ofa_acc{100*best_val_accuracy:.0f}%.pth.tar')
#         torch.save(mofa, f'mofa_models/ofa_acc{100*val_accuracy:.0f}%.pth.tar')
#         best_val_accuracy = val_accuracy
print(correct / total)

0.674


In [32]:
mofa.param_dict

{'n_units': 5,
 'in_ch': 3,
 'out_class': 100,
 'depth_list': [3, 3, 3, 3, 3],
 'width_list': [[256, 256, 256],
  [256, 256, 256],
  [256, 256, 256],
  [256, 256, 256],
  [256, 256, 256]],
 'kernel_list': [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]],
 'bias_list': [[True, True, True],
  [True, True, True],
  [True, True, True],
  [True, True, True],
  [True, True, True]],
 'bn': False}