In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd
from omni import *
from mixup import * 

In [2]:
"""
Creates a MobileNetV3 Model as defined in:
Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019).
Searching for MobileNetV3
arXiv preprint arXiv:1905.02244.
"""

import torch.nn as nn
import math


__all__ = ['mobilenetv3_large', 'mobilenetv3_small']


def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class SELayer(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, _make_divisible(channel // reduction, 8)),
                nn.ReLU(inplace=True),
                nn.Linear(_make_divisible(channel // reduction, 8), channel),
                h_sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


def conv_3x3_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size = 3, stride = stride,
                            padding = 1, bias=False),
        nn.BatchNorm2d(oup),
        h_swish()
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size = 1, stride = 1,
                            padding = 0, bias=False),
        nn.BatchNorm2d(oup),
        h_swish()
    )


def od_conv_1x1(inp, oup, kernel_num = 1):
    return nn.Sequential(
        ODConvBN(inp, oup, kernel_size = 1, 
                    stride = 1, kernel_num = kernel_num),
        h_swish()
    )

class InvertedResidual(nn.Module):
    def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs, kernel_num = 4):
        super(InvertedResidual, self).__init__()
        assert stride in [1, 2]

        self.identity = stride == 1 and inp == oup

        if inp == hidden_dim:
            self.conv = nn.Sequential(
                # dw
                ODConvBN(hidden_dim, hidden_dim, kernel_size, stride, groups=hidden_dim, kernel_num = 1),
                h_swish() if use_hs else nn.ReLU(inplace=True),
                # Squeeze-and-Excite
                SELayer(hidden_dim) if use_se else nn.Identity(),
                # pw-linear
                ODConv2d(hidden_dim, oup, 1, 1, kernel_num = kernel_num),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                ODConvBN(inp, hidden_dim, kernel_size = 1, stride = 1, kernel_num = 1),
                h_swish() if use_hs else nn.ReLU(inplace=True),
                # dw
                ODConvBN(hidden_dim, hidden_dim, kernel_size, stride, groups=hidden_dim, kernel_num = 1),
                # Squeeze-and-Excite
                SELayer(hidden_dim) if use_se else nn.Identity(),
                h_swish() if use_hs else nn.ReLU(inplace=True),
                # pw-linear
                ODConv2d(hidden_dim, oup, 1, 1, kernel_num = kernel_num),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.identity:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV3(nn.Module):
    def __init__(self, cfgs, mode, num_classes=1000, width_mult=1.):
        super(MobileNetV3, self).__init__()
        # setting of inverted residual blocks
        self.cfgs = cfgs
        assert mode in ['large', 'small']

        # building first layer
        input_channel = _make_divisible(16 * width_mult, 8)
        layers = [conv_3x3_bn(3, input_channel, 2)]
        # building inverted residual blocks
        block = InvertedResidual
        for k, t, c, use_se, use_hs, s in self.cfgs:
            output_channel = _make_divisible(c * width_mult, 8)
            exp_size = _make_divisible(input_channel * t, 8)
            layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs, kernel_num = 4))
            input_channel = output_channel
        self.features = nn.Sequential(*layers)

        # building last several layers
        self.conv = od_conv_1x1(input_channel, exp_size, kernel_num = 4)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        output_channel = {'large': 1280, 'small': 1024}

        output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[mode]
        self.classifier = nn.Sequential(
            nn.Linear(exp_size, output_channel),
            h_swish(),
            nn.Dropout(0.2),
            nn.Linear(output_channel, num_classes),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.conv(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def net_update_temperature(self, temperature):
        for modules in self.modules():
            if hasattr(modules, "update_temperature"):
                modules.update_temperature(temperature)

    def display_temperature(self):
        for modules in self.modules():
            if hasattr(modules, "get_temperature"):
                return modules.get_temperature()

def mobilenetv3_large(**kwargs):
    """
    Constructs a MobileNetV3-Large model
    """
    cfgs = [
        # k, t,   c,  SE, HS, s
        [3,   1,  16, 0, 0, 1],
        [3,   4,  24, 0, 0, 2],
        [3,   3,  24, 0, 0, 1],
        [5,   3,  40, 1, 0, 2],
        [5,   3,  40, 1, 0, 1],
        [5,   3,  40, 1, 0, 1],
        [3,   6,  80, 0, 1, 2],
        [3, 2.5,  80, 0, 1, 1],
        [3, 2.3,  80, 0, 1, 1],
        [3, 2.3,  80, 0, 1, 1],
        [3,   6, 112, 1, 1, 1],
        [3,   6, 112, 1, 1, 1],
        [5,   6, 160, 1, 1, 2],
        [5,   6, 160, 1, 1, 1],
        [5,   6, 160, 1, 1, 1]
    ]
    return MobileNetV3(cfgs, mode='large', **kwargs)


def mobilenetv3_small(**kwargs):
    """
    Constructs a MobileNetV3-Small model
    """
    cfgs = [
        # k,   t,  c, SE, HS, s
        [3,    1,  16, 1, 0, 2],
        [3,  4.5,  24, 0, 0, 2],
        [3, 3.67,  24, 0, 0, 1],
        [5,    4,  40, 1, 1, 2],
        [5,    6,  40, 1, 1, 1],
        [5,    6,  40, 1, 1, 1],
        [5,    3,  48, 1, 1, 1],
        [5,    3,  48, 1, 1, 1],
        [5,    6,  96, 1, 1, 2],
        [5,    6,  96, 1, 1, 1],
        [5,    6,  96, 1, 1, 1],
    ]

    return MobileNetV3(cfgs, mode='small', **kwargs)


In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Training

In [4]:
import torch
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import numpy as np

In [5]:
def load_data(data_dir, download = True):

  transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ])

  train_data = datasets.CIFAR10(
      root = data_dir, train = True,
      download = download, transform = transform
  )

  test_data = datasets.CIFAR10(
      root = data_dir, train = False,
      download = download, transform = transform
  )

  return (train_data, test_data)

train_data, test_data = load_data('./data/cifar10')

Files already downloaded and verified
Files already downloaded and verified


In [6]:
batch_size = 64
num_workers = 4

train_loader = DataLoader(train_data, batch_size = batch_size,
                          shuffle = True, num_workers = num_workers)
test_loader = DataLoader(test_data, batch_size = batch_size,
                         shuffle = True, num_workers = num_workers)

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
import logging
import os
from tqdm.notebook import tqdm

def check_logging_directory(path):
  parent_directory = os.path.dirname(path)
  if not os.path.exists(parent_directory):
    os.makedirs(parent_directory)
    print("Create new directory")

logging_path = './logging/mixup_omni_mobilenetv3_normallarge_cifar10.log'
check_logging_directory(logging_path)

logging.basicConfig(filename=logging_path, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


In [40]:
from gradient_descent_the_ultimate_optimizer import gdtuo

criterion = nn.CrossEntropyLoss().to(device)

num_epochs = 5

mobile_v3 = mobilenetv3_large(num_classes = 10).to(device)
optim = gdtuo.Adam(optimizer=gdtuo.SGD(1e-5))
mw = gdtuo.ModuleWrapper(mobile_v3, optimizer=optim)
mw.initialize()

print(f"The number of parameters: {count_parameters(mobile_v3)}")


The number of parameters: 7579763


In [41]:
from thop import profile

input_size = (1, 3, 224, 224)

flops, params = profile(mobile_v3, inputs=(torch.randn(*input_size).to(device),))
print(f"FLOPs: {flops / 1e9} billion")
print(f"Parameters: {params / 1e6} million")

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU6'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
FLOPs: 0.031794834 billion
Parameters: 3.810971 million


In [42]:
def get_temperature(iteration, epoch, iter_per_epoch, temp_epoch=10, temp_init=30.0):
    total_temp_iter = iter_per_epoch * temp_epoch
    current_iter = iteration + epoch * iter_per_epoch
    # print(current_iter)
    temperature = 1.0 + max(0, (temp_init - 1.0) * ((total_temp_iter - current_iter) / total_temp_iter))
    return temperature

In [43]:
# Huấn luyện mô hình
train_loss, val_loss = [], []
train_acc, val_acc = [], []

epoch_bar = tqdm(desc = 'Epoch',
                 total = num_epochs, position = 1)
train_bar = tqdm(desc = 'Training', total = len(train_loader),
                 position = 1, leave = True)
val_bar = tqdm(desc = 'Validation', total = len(test_loader),
               position = 1, leave = True)

print("🚀 Training MobileNetV3 - Omni Dimensional Dynamic Convolution 🚀")
logging.info("🚀 Training MobileNetV3 - Omni Dimensional Dynamic Convolution 🚀")

num_epochs = 10
for epoch in range(num_epochs):

    epoch_bar.set_description(f'Epoch {epoch + 1}/{num_epochs}')

    running_loss = 0.0
    running_acc = 0.0
    total_loss = 0.0
    total_acc = 0.0

    total = 0
    for i, (X, y) in enumerate(train_loader):

        if epoch < 50:
            temp = get_temperature(i + 1, epoch, len(train_loader), temp_epoch = 50, temp_init = 48.79)
            mw.module.net_update_temperature(temp)
            
        mw.begin()
        mw.zero_grad()
        X, y = X.to(device), y.to(device)
        X, y_origin, y_sampled, lam = mixup_data(X, y, device, alpha = 0.4)
        
        # Forward pass
        output = mw.forward(X)
        loss = mixup_criterion(criterion, output, y_origin, y_sampled, lam)
        
        # Backward pass
        loss.backward(create_graph = True)
        mw.step()
    
        loss_t = loss.item()
        running_loss += (loss_t - running_loss) / (i + 1)
        total_loss += loss_t
        
        # Calculating the accuracy
        _, predicted = torch.max(output.data, 1)
        n_correct = (lam * predicted.eq(y_origin.data).cpu().sum().float()
                    + (1 - lam) * predicted.eq(y_sampled.data).cpu().sum().float())

        acc_t = n_correct / len(predicted) * 100
        running_acc += (acc_t - running_acc) / (i + 1)

        total_acc += n_correct
        total += y.shape[0]


        train_bar.set_postfix(loss = running_loss,
                              acc = f"{running_acc:.2f}%",
                              epoch = epoch + 1)
        train_bar.update()

    # mw.begin()
    current_loss = total_loss / len(train_loader)
    current_acc = total_acc / total * 100
    train_loss.append(current_loss)
    train_acc.append(current_acc)

    print("========================================")
    print("\033[1;34m" + f"Epoch {epoch + 1}/{num_epochs}" + "\033[0m")
    print(f"Train Loss: {current_loss:.2f}\t|\tTrain Acc: {current_acc:.2f}%")

    logging.info("========================================")
    logging.info("\033[1;34m" + f"Epoch {epoch + 1}/{num_epochs}" + "\033[0m")
    logging.info(f"Train Loss: {current_loss:.2f}  -   Train Acc: {current_acc:.2f}%")


    # Eval trên valid set
    running_loss = 0.0
    running_acc = 0.0
    total_loss = 0.0
    total_acc = 0.0

    total = 0
    # mw.end()
    # mw.eval()
    with torch.no_grad():
        for i, (X, y) in enumerate(test_loader):

            X, y = X.to(device), y.to(device)
            # Forward pass
            output = mw.forward(X)

            # Calculate Loss
            loss = criterion(output, y)
            loss_t = loss.item()
            running_loss += (loss_t - running_loss) / (i + 1)
            total_loss += loss_t

            # Calculate Accuracies
            _, predicted = torch.max(output.data, 1)
            n_correct = (predicted == y).sum().item()
            acc_t = n_correct / len(predicted) * 100
            running_acc += (acc_t - running_acc) / (i + 1)
            total_acc += n_correct

            total += y.shape[0]

            val_bar.set_postfix(loss = running_loss,
                                acc = f"{running_acc:.2f}%",
                                epoch = epoch + 1)
            val_bar.update()

    current_loss = total_loss / len(test_loader)
    current_acc = total_acc / total * 100

    val_loss.append(current_loss)
    val_acc.append(current_acc)

    print(f"Val Loss: {current_loss:.2f}\t|\tVal Acc: {current_acc:.2f}%")
    logging.info(f"Val Loss: {current_loss:.2f}  -  Val Acc: {current_acc:.2f}%")

    train_bar.n = 0
    val_bar.n = 0
    epoch_bar.update()

    if epoch < 50:
        temperature = mw.module.display_temperature()
        print(f"The current temperature is: {temperature}")

print("========================================")
print("Training Completed! 😀")
logging.info("========================================")
logging.info("Training Completed! 😀")

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/782 [00:00<?, ?it/s]

Validation:   0%|          | 0/157 [00:00<?, ?it/s]

🚀 Training MobileNetV3 - Omni Dimensional Dynamic Convolution 🚀


OutOfMemoryError: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 23.65 GiB total capacity; 21.99 GiB already allocated; 13.44 MiB free; 22.10 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
print(mw.module.display_temperature())

In [32]:
# save_model = './saved_model/'

# if not os.path.exists(save_model):
#     os.makedirs(save_model, exist_ok = True)
#     print("Creating new directory")

# model_path = os.path.join(save_model, 'omnimobilenetV3_CIFAR10kokLarge.pth')
# torch.save(mw.module.state_dict(), model_path)


In [37]:

model_path = './saved_model/omnimobilenetV3_CIFAR10Large.pth'

abc_model = mobilenetv3_large(num_classes = 10).to(device)

optim = gdtuo.Adam(optimizer=gdtuo.SGD(1e-5))
bs = gdtuo.ModuleWrapper(abc_model, optimizer=optim)
bs.initialize()


bs.module.load_state_dict(torch.load(model_path))
bs.module

RuntimeError: Error(s) in loading state_dict for MobileNetV3:
	Unexpected key(s) in state_dict: "total_ops", "total_params", "features.0.2.total_ops", "features.0.2.total_params", "features.0.2.sigmoid.total_ops", "features.0.2.sigmoid.total_params", "features.1.total_ops", "features.1.total_params", "features.1.conv.0.total_ops", "features.1.conv.0.total_params", "features.1.conv.0.0.total_ops", "features.1.conv.0.0.total_params", "features.1.conv.0.0.attention.total_ops", "features.1.conv.0.0.attention.total_params", "features.1.conv.2.total_ops", "features.1.conv.2.total_params", "features.1.conv.3.total_ops", "features.1.conv.3.total_params", "features.1.conv.3.attention.total_ops", "features.1.conv.3.attention.total_params", "features.2.total_ops", "features.2.total_params", "features.2.conv.0.total_ops", "features.2.conv.0.total_params", "features.2.conv.0.0.total_ops", "features.2.conv.0.0.total_params", "features.2.conv.0.0.attention.total_ops", "features.2.conv.0.0.attention.total_params", "features.2.conv.2.total_ops", "features.2.conv.2.total_params", "features.2.conv.2.0.total_ops", "features.2.conv.2.0.total_params", "features.2.conv.2.0.attention.total_ops", "features.2.conv.2.0.attention.total_params", "features.2.conv.3.total_ops", "features.2.conv.3.total_params", "features.2.conv.5.total_ops", "features.2.conv.5.total_params", "features.2.conv.5.attention.total_ops", "features.2.conv.5.attention.total_params", "features.3.total_ops", "features.3.total_params", "features.3.conv.0.total_ops", "features.3.conv.0.total_params", "features.3.conv.0.0.total_ops", "features.3.conv.0.0.total_params", "features.3.conv.0.0.attention.total_ops", "features.3.conv.0.0.attention.total_params", "features.3.conv.2.total_ops", "features.3.conv.2.total_params", "features.3.conv.2.0.total_ops", "features.3.conv.2.0.total_params", "features.3.conv.2.0.attention.total_ops", "features.3.conv.2.0.attention.total_params", "features.3.conv.3.total_ops", "features.3.conv.3.total_params", "features.3.conv.5.total_ops", "features.3.conv.5.total_params", "features.3.conv.5.attention.total_ops", "features.3.conv.5.attention.total_params", "features.4.total_ops", "features.4.total_params", "features.4.conv.0.total_ops", "features.4.conv.0.total_params", "features.4.conv.0.0.total_ops", "features.4.conv.0.0.total_params", "features.4.conv.0.0.attention.total_ops", "features.4.conv.0.0.attention.total_params", "features.4.conv.2.total_ops", "features.4.conv.2.total_params", "features.4.conv.2.0.total_ops", "features.4.conv.2.0.total_params", "features.4.conv.2.0.attention.total_ops", "features.4.conv.2.0.attention.total_params", "features.4.conv.3.total_ops", "features.4.conv.3.total_params", "features.4.conv.3.fc.3.total_ops", "features.4.conv.3.fc.3.total_params", "features.4.conv.5.total_ops", "features.4.conv.5.total_params", "features.4.conv.5.attention.total_ops", "features.4.conv.5.attention.total_params", "features.5.total_ops", "features.5.total_params", "features.5.conv.0.total_ops", "features.5.conv.0.total_params", "features.5.conv.0.0.total_ops", "features.5.conv.0.0.total_params", "features.5.conv.0.0.attention.total_ops", "features.5.conv.0.0.attention.total_params", "features.5.conv.2.total_ops", "features.5.conv.2.total_params", "features.5.conv.2.0.total_ops", "features.5.conv.2.0.total_params", "features.5.conv.2.0.attention.total_ops", "features.5.conv.2.0.attention.total_params", "features.5.conv.3.total_ops", "features.5.conv.3.total_params", "features.5.conv.3.fc.3.total_ops", "features.5.conv.3.fc.3.total_params", "features.5.conv.5.total_ops", "features.5.conv.5.total_params", "features.5.conv.5.attention.total_ops", "features.5.conv.5.attention.total_params", "features.6.total_ops", "features.6.total_params", "features.6.conv.0.total_ops", "features.6.conv.0.total_params", "features.6.conv.0.0.total_ops", "features.6.conv.0.0.total_params", "features.6.conv.0.0.attention.total_ops", "features.6.conv.0.0.attention.total_params", "features.6.conv.2.total_ops", "features.6.conv.2.total_params", "features.6.conv.2.0.total_ops", "features.6.conv.2.0.total_params", "features.6.conv.2.0.attention.total_ops", "features.6.conv.2.0.attention.total_params", "features.6.conv.3.total_ops", "features.6.conv.3.total_params", "features.6.conv.3.fc.3.total_ops", "features.6.conv.3.fc.3.total_params", "features.6.conv.5.total_ops", "features.6.conv.5.total_params", "features.6.conv.5.attention.total_ops", "features.6.conv.5.attention.total_params", "features.7.total_ops", "features.7.total_params", "features.7.conv.0.total_ops", "features.7.conv.0.total_params", "features.7.conv.0.0.total_ops", "features.7.conv.0.0.total_params", "features.7.conv.0.0.attention.total_ops", "features.7.conv.0.0.attention.total_params", "features.7.conv.1.total_ops", "features.7.conv.1.total_params", "features.7.conv.1.sigmoid.total_ops", "features.7.conv.1.sigmoid.total_params", "features.7.conv.2.total_ops", "features.7.conv.2.total_params", "features.7.conv.2.0.total_ops", "features.7.conv.2.0.total_params", "features.7.conv.2.0.attention.total_ops", "features.7.conv.2.0.attention.total_params", "features.7.conv.3.total_ops", "features.7.conv.3.total_params", "features.7.conv.4.total_ops", "features.7.conv.4.total_params", "features.7.conv.4.sigmoid.total_ops", "features.7.conv.4.sigmoid.total_params", "features.7.conv.5.total_ops", "features.7.conv.5.total_params", "features.7.conv.5.attention.total_ops", "features.7.conv.5.attention.total_params", "features.8.total_ops", "features.8.total_params", "features.8.conv.0.total_ops", "features.8.conv.0.total_params", "features.8.conv.0.0.total_ops", "features.8.conv.0.0.total_params", "features.8.conv.0.0.attention.total_ops", "features.8.conv.0.0.attention.total_params", "features.8.conv.1.total_ops", "features.8.conv.1.total_params", "features.8.conv.1.sigmoid.total_ops", "features.8.conv.1.sigmoid.total_params", "features.8.conv.2.total_ops", "features.8.conv.2.total_params", "features.8.conv.2.0.total_ops", "features.8.conv.2.0.total_params", "features.8.conv.2.0.attention.total_ops", "features.8.conv.2.0.attention.total_params", "features.8.conv.3.total_ops", "features.8.conv.3.total_params", "features.8.conv.4.total_ops", "features.8.conv.4.total_params", "features.8.conv.4.sigmoid.total_ops", "features.8.conv.4.sigmoid.total_params", "features.8.conv.5.total_ops", "features.8.conv.5.total_params", "features.8.conv.5.attention.total_ops", "features.8.conv.5.attention.total_params", "features.9.total_ops", "features.9.total_params", "features.9.conv.0.total_ops", "features.9.conv.0.total_params", "features.9.conv.0.0.total_ops", "features.9.conv.0.0.total_params", "features.9.conv.0.0.attention.total_ops", "features.9.conv.0.0.attention.total_params", "features.9.conv.1.total_ops", "features.9.conv.1.total_params", "features.9.conv.1.sigmoid.total_ops", "features.9.conv.1.sigmoid.total_params", "features.9.conv.2.total_ops", "features.9.conv.2.total_params", "features.9.conv.2.0.total_ops", "features.9.conv.2.0.total_params", "features.9.conv.2.0.attention.total_ops", "features.9.conv.2.0.attention.total_params", "features.9.conv.3.total_ops", "features.9.conv.3.total_params", "features.9.conv.4.total_ops", "features.9.conv.4.total_params", "features.9.conv.4.sigmoid.total_ops", "features.9.conv.4.sigmoid.total_params", "features.9.conv.5.total_ops", "features.9.conv.5.total_params", "features.9.conv.5.attention.total_ops", "features.9.conv.5.attention.total_params", "features.10.total_ops", "features.10.total_params", "features.10.conv.0.total_ops", "features.10.conv.0.total_params", "features.10.conv.0.0.total_ops", "features.10.conv.0.0.total_params", "features.10.conv.0.0.attention.total_ops", "features.10.conv.0.0.attention.total_params", "features.10.conv.1.total_ops", "features.10.conv.1.total_params", "features.10.conv.1.sigmoid.total_ops", "features.10.conv.1.sigmoid.total_params", "features.10.conv.2.total_ops", "features.10.conv.2.total_params", "features.10.conv.2.0.total_ops", "features.10.conv.2.0.total_params", "features.10.conv.2.0.attention.total_ops", "features.10.conv.2.0.attention.total_params", "features.10.conv.3.total_ops", "features.10.conv.3.total_params", "features.10.conv.4.total_ops", "features.10.conv.4.total_params", "features.10.conv.4.sigmoid.total_ops", "features.10.conv.4.sigmoid.total_params", "features.10.conv.5.total_ops", "features.10.conv.5.total_params", "features.10.conv.5.attention.total_ops", "features.10.conv.5.attention.total_params", "features.11.total_ops", "features.11.total_params", "features.11.conv.0.total_ops", "features.11.conv.0.total_params", "features.11.conv.0.0.total_ops", "features.11.conv.0.0.total_params", "features.11.conv.0.0.attention.total_ops", "features.11.conv.0.0.attention.total_params", "features.11.conv.1.total_ops", "features.11.conv.1.total_params", "features.11.conv.1.sigmoid.total_ops", "features.11.conv.1.sigmoid.total_params", "features.11.conv.2.total_ops", "features.11.conv.2.total_params", "features.11.conv.2.0.total_ops", "features.11.conv.2.0.total_params", "features.11.conv.2.0.attention.total_ops", "features.11.conv.2.0.attention.total_params", "features.11.conv.3.total_ops", "features.11.conv.3.total_params", "features.11.conv.3.fc.3.total_ops", "features.11.conv.3.fc.3.total_params", "features.11.conv.4.total_ops", "features.11.conv.4.total_params", "features.11.conv.4.sigmoid.total_ops", "features.11.conv.4.sigmoid.total_params", "features.11.conv.5.total_ops", "features.11.conv.5.total_params", "features.11.conv.5.attention.total_ops", "features.11.conv.5.attention.total_params", "features.12.total_ops", "features.12.total_params", "features.12.conv.0.total_ops", "features.12.conv.0.total_params", "features.12.conv.0.0.total_ops", "features.12.conv.0.0.total_params", "features.12.conv.0.0.attention.total_ops", "features.12.conv.0.0.attention.total_params", "features.12.conv.1.total_ops", "features.12.conv.1.total_params", "features.12.conv.1.sigmoid.total_ops", "features.12.conv.1.sigmoid.total_params", "features.12.conv.2.total_ops", "features.12.conv.2.total_params", "features.12.conv.2.0.total_ops", "features.12.conv.2.0.total_params", "features.12.conv.2.0.attention.total_ops", "features.12.conv.2.0.attention.total_params", "features.12.conv.3.total_ops", "features.12.conv.3.total_params", "features.12.conv.3.fc.3.total_ops", "features.12.conv.3.fc.3.total_params", "features.12.conv.4.total_ops", "features.12.conv.4.total_params", "features.12.conv.4.sigmoid.total_ops", "features.12.conv.4.sigmoid.total_params", "features.12.conv.5.total_ops", "features.12.conv.5.total_params", "features.12.conv.5.attention.total_ops", "features.12.conv.5.attention.total_params", "features.13.total_ops", "features.13.total_params", "features.13.conv.0.total_ops", "features.13.conv.0.total_params", "features.13.conv.0.0.total_ops", "features.13.conv.0.0.total_params", "features.13.conv.0.0.attention.total_ops", "features.13.conv.0.0.attention.total_params", "features.13.conv.1.total_ops", "features.13.conv.1.total_params", "features.13.conv.1.sigmoid.total_ops", "features.13.conv.1.sigmoid.total_params", "features.13.conv.2.total_ops", "features.13.conv.2.total_params", "features.13.conv.2.0.total_ops", "features.13.conv.2.0.total_params", "features.13.conv.2.0.attention.total_ops", "features.13.conv.2.0.attention.total_params", "features.13.conv.3.total_ops", "features.13.conv.3.total_params", "features.13.conv.3.fc.3.total_ops", "features.13.conv.3.fc.3.total_params", "features.13.conv.4.total_ops", "features.13.conv.4.total_params", "features.13.conv.4.sigmoid.total_ops", "features.13.conv.4.sigmoid.total_params", "features.13.conv.5.total_ops", "features.13.conv.5.total_params", "features.13.conv.5.attention.total_ops", "features.13.conv.5.attention.total_params", "features.14.total_ops", "features.14.total_params", "features.14.conv.0.total_ops", "features.14.conv.0.total_params", "features.14.conv.0.0.total_ops", "features.14.conv.0.0.total_params", "features.14.conv.0.0.attention.total_ops", "features.14.conv.0.0.attention.total_params", "features.14.conv.1.total_ops", "features.14.conv.1.total_params", "features.14.conv.1.sigmoid.total_ops", "features.14.conv.1.sigmoid.total_params", "features.14.conv.2.total_ops", "features.14.conv.2.total_params", "features.14.conv.2.0.total_ops", "features.14.conv.2.0.total_params", "features.14.conv.2.0.attention.total_ops", "features.14.conv.2.0.attention.total_params", "features.14.conv.3.total_ops", "features.14.conv.3.total_params", "features.14.conv.3.fc.3.total_ops", "features.14.conv.3.fc.3.total_params", "features.14.conv.4.total_ops", "features.14.conv.4.total_params", "features.14.conv.4.sigmoid.total_ops", "features.14.conv.4.sigmoid.total_params", "features.14.conv.5.total_ops", "features.14.conv.5.total_params", "features.14.conv.5.attention.total_ops", "features.14.conv.5.attention.total_params", "features.15.total_ops", "features.15.total_params", "features.15.conv.0.total_ops", "features.15.conv.0.total_params", "features.15.conv.0.0.total_ops", "features.15.conv.0.0.total_params", "features.15.conv.0.0.attention.total_ops", "features.15.conv.0.0.attention.total_params", "features.15.conv.1.total_ops", "features.15.conv.1.total_params", "features.15.conv.1.sigmoid.total_ops", "features.15.conv.1.sigmoid.total_params", "features.15.conv.2.total_ops", "features.15.conv.2.total_params", "features.15.conv.2.0.total_ops", "features.15.conv.2.0.total_params", "features.15.conv.2.0.attention.total_ops", "features.15.conv.2.0.attention.total_params", "features.15.conv.3.total_ops", "features.15.conv.3.total_params", "features.15.conv.3.fc.3.total_ops", "features.15.conv.3.fc.3.total_params", "features.15.conv.4.total_ops", "features.15.conv.4.total_params", "features.15.conv.4.sigmoid.total_ops", "features.15.conv.4.sigmoid.total_params", "features.15.conv.5.total_ops", "features.15.conv.5.total_params", "features.15.conv.5.attention.total_ops", "features.15.conv.5.attention.total_params", "conv.0.total_ops", "conv.0.total_params", "conv.0.0.total_ops", "conv.0.0.total_params", "conv.0.0.attention.total_ops", "conv.0.0.attention.total_params", "conv.1.total_ops", "conv.1.total_params", "conv.1.sigmoid.total_ops", "conv.1.sigmoid.total_params", "classifier.1.total_ops", "classifier.1.total_params", "classifier.1.sigmoid.total_ops", "classifier.1.sigmoid.total_params". 

In [38]:
a = torch.load(model_path)
a

OrderedDict([('total_ops', tensor([0.], dtype=torch.float64)),
             ('total_params', tensor([0.], dtype=torch.float64)),
             ('features.0.0.weight',
              tensor([[[[ 0.1486, -0.0617, -0.0471],
                        [ 0.1117, -0.0409, -0.0465],
                        [-0.0394, -0.3110,  0.0519]],
              
                       [[ 0.0780, -0.0069,  0.0244],
                        [ 0.0869, -0.0207, -0.0104],
                        [ 0.0322,  0.1527, -0.0520]],
              
                       [[-0.1337, -0.0260,  0.0186],
                        [-0.0505, -0.2336, -0.3106],
                        [-0.1419, -0.2132, -0.1861]]],
              
              
                      [[[-0.0981, -0.0997, -0.0966],
                        [ 0.0808, -0.0327,  0.0070],
                        [ 0.1549, -0.1021,  0.2263]],
              
                       [[-0.0759,  0.1428,  0.1049],
                        [ 0.1698,  0.2342,  0.1261],
            

In [46]:
mw.module.state_dict()

OrderedDict([('total_ops', tensor([0.], dtype=torch.float64)),
             ('total_params', tensor([0.], dtype=torch.float64)),
             ('features.0.0.weight',
              tensor([[[[ 4.2397e-02,  4.7373e-02, -1.2767e-01],
                        [ 2.7492e-01,  1.5265e-01, -6.4039e-02],
                        [-2.2401e-01,  2.8021e-02,  1.5997e-01]],
              
                       [[ 2.1404e-03,  7.2182e-03,  3.6672e-02],
                        [-1.7405e-01,  5.0692e-02,  1.2708e-01],
                        [ 2.6666e-02,  5.2613e-02,  1.4837e-01]],
              
                       [[-1.6538e-01,  1.7506e-02,  3.4247e-03],
                        [-6.6497e-02,  1.7464e-01, -7.0528e-02],
                        [ 1.0093e-01,  3.6057e-02, -1.2070e-01]]],
              
              
                      [[[ 7.5169e-03, -2.5916e-01, -2.0091e-01],
                        [ 7.8611e-02,  1.0876e-02,  6.5900e-02],
                        [ 1.6697e-02,  1.1010e-01, -1.2

In [47]:
mw.module.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [48]:
mw.module.state_dict()

OrderedDict([('total_ops', tensor([0.], dtype=torch.float64)),
             ('total_params', tensor([0.], dtype=torch.float64)),
             ('features.0.0.weight',
              tensor([[[[ 0.1486, -0.0617, -0.0471],
                        [ 0.1117, -0.0409, -0.0465],
                        [-0.0394, -0.3110,  0.0519]],
              
                       [[ 0.0780, -0.0069,  0.0244],
                        [ 0.0869, -0.0207, -0.0104],
                        [ 0.0322,  0.1527, -0.0520]],
              
                       [[-0.1337, -0.0260,  0.0186],
                        [-0.0505, -0.2336, -0.3106],
                        [-0.1419, -0.2132, -0.1861]]],
              
              
                      [[[-0.0981, -0.0997, -0.0966],
                        [ 0.0808, -0.0327,  0.0070],
                        [ 0.1549, -0.1021,  0.2263]],
              
                       [[-0.0759,  0.1428,  0.1049],
                        [ 0.1698,  0.2342,  0.1261],
            