In [1]:
import os
import time
import argparse
import shutil
import math

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from thop import profile
import matplotlib.pyplot as plt

from Utils import *
from Models import *

In [2]:
class HSwish_dev(nn.Module):
    """
    H-Swish activation function from 'Searching for MobileNetV3,' https://arxiv.org/abs/1905.02244.
    Parameters:
    ----------
    inplace : bool
        Whether to use inplace version of the module.
    """
    def __init__(self, inplace=True):
        super(HSwish_dev, self).__init__()
        self.inplace = inplace
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(x) + x * torch.exp(-x) * (self.sigmoid(x) * self.sigmoid(x))

In [2]:
net = MicroNet(num_classes = 100, add_se = True, Activation = 'HSwish')

In [11]:
for name, module in net.named_children():
    if 'last' in name:
        print(module)

Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [3]:
train_loader, test_loader, num_classes = transform_data_set('CIFAR100', batch_size = 128, augmentation = 'FastAuto')

|| Prepare dataset with FastAutoAugmentation ||
Files already downloaded and verified


In [4]:
checkpoint = torch.load('./Checkpoint/micronet_ver6.t7')

In [5]:
checkpoint.keys()

dict_keys(['net_init', 'net1', 'net2', 'net3', 'net4', 'train_losses', 'train_accuracy', 'test_losses', 'test_accuracy', 'flops_params', 'Score'])

In [19]:
checkpoint['test_accuracy'][350:400]

[7338.0,
 7395.999999999999,
 7320.0,
 7402.0,
 7348.0,
 7504.000000000001,
 7498.999999999999,
 7570.0,
 7517.0,
 7528.0,
 7506.0,
 7633.0,
 7580.0,
 7575.0,
 7589.0,
 7703.0,
 7653.0,
 7711.0,
 7600.0,
 7628.0,
 7653.0,
 7718.000000000001,
 7736.0,
 7698.999999999999,
 7761.0,
 7723.0,
 7889.0,
 7798.999999999999,
 7828.0,
 7834.0,
 7886.0,
 7869.0,
 7865.000000000001,
 7886.0,
 7895.999999999999,
 7900.0,
 7948.999999999999,
 7870.0,
 7941.0,
 7938.0,
 7929.000000000001,
 7945.0,
 7979.000000000001,
 7969.0,
 7950.0,
 7950.0,
 7969.0,
 7973.999999999999,
 7951.000000000001,
 7968.000000000001]

In [15]:
net.load_state_dict(checkpoint['net4'], strict = False)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [16]:
net.to(net.device)
net.half()  
for layer in net.modules():
    if isinstance(layer, nn.BatchNorm2d):
        layer.float()

In [17]:
eval_16bit(net, test_loader)

100%|██████████| 200/200 [00:06<00:00, 28.57it/s]

Loss: 0.758 | Acc1: 79.220% | Acc5: 95.660%





(7922.0, 0.757601318359375)

In [6]:
class HSwish(nn.Module):
    """
    H-Swish activation function from 'Searching for MobileNetV3,' https://arxiv.org/abs/1905.02244.
    Parameters:
    ----------
    inplace : bool
        Whether to use inplace version of the module.
    """
    def __init__(self, inplace=True):
        super(HSwish, self).__init__()
        self.inplace = inplace
        self.relu = nn.ReLU6(inplace = self.inplace)

    def forward(self, x):
        return x * self.relu(x + 3.0) / 6.0

In [9]:
def hswish_ops(m, x, y):
    x = x[0]

    nelements = x.numel()

    m.total_ops += torch.Tensor([int(nelements)]) * 4

In [24]:
def sigmoid_ops(m, x, y):
    x = x[0]

    nelements = x.numel()

    m.total_ops += torch.Tensor([int(nelements)]) * 3

In [9]:
def count_nonzero(net):
    num = 0 
    for module in net.parameters():
        num += torch.sum(torch.abs(module.flatten()) != 0.)
    return num.item()

In [10]:
num = 0 
for module in net.parameters():
     num += len(torch.abs(module.flatten()))
num

342300

In [12]:
def micro_score(net, precision = 'FP16'):
    input = torch.randn(1, 3, 32, 32).type(torch.HalfTensor).to(net.device)
    flops, params = profile(net, inputs=(input, ))
    non_zero_ratio = count_nonzero(net) / params
    
    
    flops = flops * non_zero_ratio
    params = params * non_zero_ratio
    #use fp-16bit
    if precision == 'FP16':
        flops = flops / 2
        params = params / 2
    
    score = params/36500000 + flops/10490000000
    print('Non zero ratio: {}'.format(non_zero_ratio))
    print('Score: {}, flops: {}, params: {}'.format(score, flops, params))
    return score

In [13]:
net.to(net.device)
net.half()  
for layer in net.modules():
    if isinstance(layer, nn.BatchNorm2d):
        layer.float()
score = micro_score(net)

THOP has not implemented counting method for  CrossEntropyLoss()
Register FLOP counter for module Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Register FLOP counter for module BatchNorm2d(32, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
Register FLOP counter for module Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
Register FLOP counter for module BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
Register FLOP counter for module Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
Register FLOP counter for module BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
Register FLOP counter for module Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
Register FLOP counter for module BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
Register FLOP counter for module ReLU6(inplace)
Register 

In [14]:
score

0.006158700702320247

In [7]:
np.exp(-9/2) / np.sqrt(2*np.pi)

0.0044318484119380075

3.141592653589793