In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import random_split
import torch.nn as nn
import torchvision
import torch.optim as optim


import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.metrics import top_k_accuracy_score

import timm
import json, os

In [None]:
# HYPERPARAMETERS

train_size = 0.8
test_size=0.2
cuda = 'cuda'
cpu = 'cpu' 
device = cuda if torch.cuda.is_available() else cpu
xception_weight_path = r"C:\Users\SUDARSHAN\.cache\torch\hub\checkpoints\xception-43020ad28.pth"
xception_tensorflow_weights = r"C:\Users\SUDARSHAN\.keras\models\xception_weights_tf_dim_ordering_tf_kernels.h5"


In [None]:
with open("mapping.json", 'r') as f:
    mapping = json.load(f)
    

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=15, translate= (.15, .15), scale = (0.85, 1.15)),

    transforms.Resize(324), 
    transforms.RandomCrop(299),

    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.Resize((324, 324)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.TenCrop((299, 299)),
    transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),

    ])

In [None]:

class myDataset(torch.utils.data.Dataset):

    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x, y = self.data[index]
        if self.transform:
            x = self.transform(x)
        return x, y
   


In [None]:

full_dataset = torchvision.datasets.ImageFolder(
    root='./dataset',
    transform=None
)

train, test = random_split(full_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(0))

train_dataset = myDataset(train, transform=train_transform)
test_dataset = myDataset(test, transform=test_transform)

batch_size=32

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True, 
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=1,
    shuffle=False   
)   

In [None]:
#visualize an image from train_dataset

i = 1135

image = train_dataset[i][0].permute(1, 2, 0).numpy()
label = mapping[str(train_dataset[i][1]+1)] #+1 is required, because the "class 1" in folders, is class 0 when torch loads it

print(f"label: {label} (index - {train_dataset[i][1]+1})")
plt.imshow(image)

In [None]:
class BNConv2D(nn.Module):

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride = 1,
            padding= 0,
            dilation=1,
            groups=1,
    ):
        super().__init__()

        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)


    def forward(self, X):
        
        out = self.conv(X)
        out = self.bn(out)

        return out
    
class BN_DepthwiseSeparableConv2D(nn.Module):

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride = 1,
            padding= 0,
            dilation=1,
    ):
        super().__init__()

        self.depthwise_conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels, 
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            bias=False
        )
        self.pointwise_conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            bias=False
        )
        self.bn = nn.BatchNorm2d(out_channels)
    
    def forward(self, X):

        out = self.depthwise_conv(X)
        out = self.pointwise_conv(out)
        out = self.bn(out)

        return out


In [None]:
class MiddleBlock(nn.Module):

    def __init__(self, channels:int=728):
        super().__init__()

        self.main = nn.Sequential(
            nn.ReLU(), BN_DepthwiseSeparableConv2D(channels, channels, 3, padding=1),
            nn.ReLU(), BN_DepthwiseSeparableConv2D(channels, channels, 3, padding=1),
            nn.ReLU(), BN_DepthwiseSeparableConv2D(channels, channels, 3, padding=1),
        )

    def forward(self, X):

        out = self.main(X)
        out += X
        
        return out

class Xception(nn.Module):

    def __init__(self, num_classes:int=1000, pretrained=False, dropout=0.0):
        super().__init__()
        self.dropout = dropout


        
        ## ENTRY FLOW
        self.block1_conv1 = BNConv2D(in_channels=3, out_channels=32, stride=2, kernel_size=3, padding=0)
        self.block1_conv2 = BNConv2D(in_channels=32, out_channels=64, kernel_size=3, padding=0)
        



        self.residual1 = BNConv2D(in_channels=64, out_channels=128, kernel_size=1, padding=0, stride=2)
        self.block2_conv1 = BN_DepthwiseSeparableConv2D(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.block2_conv2 = BN_DepthwiseSeparableConv2D(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        

        self.residual2 = BNConv2D(in_channels=128, out_channels=256, kernel_size=1, padding=0, stride=2)
        self.block3_conv1 = BN_DepthwiseSeparableConv2D(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.block3_conv2 = BN_DepthwiseSeparableConv2D(in_channels=256, out_channels=256, kernel_size=3, padding=1)

        self.residual3 = BNConv2D(in_channels=256, out_channels=728, kernel_size=1, padding=0, stride=2)
        self.block4_conv1 = BN_DepthwiseSeparableConv2D(in_channels=256, out_channels=728, kernel_size=3, padding=1)
        self.block4_conv2 = BN_DepthwiseSeparableConv2D(in_channels=728, out_channels=728, kernel_size=3, padding=1)

        ## MIDDLE FLOW
        middle_layers = []

        for _ in range(8):
            middle_layers.append(MiddleBlock())

        self.middle_layers = nn.Sequential(*middle_layers)

        
        ## EXIT FLOW

        self.exit_residual = BNConv2D(728, 1024, 1, 2)

        self.exitblock1_conv1 = BN_DepthwiseSeparableConv2D(728, 728, 3, padding=1)
        self.exitblock1_conv2 = BN_DepthwiseSeparableConv2D(728, 1024, 3, padding=1)

        self.exitblock2_conv1 = BN_DepthwiseSeparableConv2D(1024, 1536, 3, padding=1)
        self.exitblock2_conv2 = BN_DepthwiseSeparableConv2D(1536, 2048, 3, padding=1)


        self.fc = nn.Linear(2048, 1000)

        self.relu = nn.ReLU()
        self.maxpooling = nn.MaxPool2d(3, 2, 1)
        self.global_avgpool = nn.AdaptiveAvgPool2d(1)


        if pretrained:
            self.load_state_dict(torch.load("xception_base.pt"))
        else:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

        if num_classes!=1000:
            self.fc = nn.Linear(2048, num_classes)
            
            

    def forward(self, X):
        

        ## ENTRY FLow
        out = self.block1_conv1(X)
        out = self.relu(out)
        out = self.block1_conv2(out)
        out = self.relu(out)

        residue = self.residual1(out)
        out = self.block2_conv1(out)
        out = self.relu(out)
        out = self.block2_conv2(out)
        out = self.maxpooling(out)
        out += residue

        residue = self.residual2(out)
        out = self.block3_conv1(out)
        out = self.relu(out)
        out = self.block3_conv2(out)
        out = self.maxpooling(out)
        out += residue

        residue = self.residual3(out)
        out = self.block4_conv1(out)
        out = self.relu(out)
        out = self.block4_conv2(out)
        out = self.maxpooling(out)
        out += residue

        ## MIDDLE FLOW
        out = self.middle_layers(out)

        ## EXIT FLOW

        residue = self.exit_residual(out)
        out = self.relu(out)
        out = self.exitblock1_conv1(out)
        out = self.relu(out)
        out = self.exitblock1_conv2(out)
        out = self.maxpooling(out)
        out += residue

        out = self.exitblock2_conv1(out)
        out = self.exitblock2_conv2(out)
        out = self.global_avgpool(out)

        out = torch.flatten(out, 1)
        out = self.fc(out)

        return out

 

In [None]:
model = Xception()




model.fc = nn.Linear(2048, 256)
nn.init.kaiming_normal_(model.fc.weight, mode='fan_out', nonlinearity='relu')
nn.init.zeros_(model.fc.bias)

model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), )



In [None]:
for p in model.parameters():
    p.requires_grad = False

for p in model.fc.parameters():
    p.requires_grad = True


assert sum(p.numel() for p in model.parameters() if p.requires_grad)==524544


In [None]:
for epoch in range(5):

    start = time.time()
    total_training_loss = 0
    total_training_samples = 0
    total_training_correct_classfied = 0

    model.eval()

    model.fc.train()

    for iteration, (images, labels) in enumerate(train_loader):

        images = images.to(device)
        labels = labels.to(device)

        output_predictions = model(images)

        loss_value = criterion(output_predictions, labels)

        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()
        optimizer.zero_grad()

        

        
        total_training_loss += loss_value.item()*labels.shape[0]
        total_training_samples += labels.shape[0]
        _, predicted = torch.max(output_predictions, dim=1)
        total_training_correct_classfied += (labels == predicted).sum().item()

        if iteration%250==0:
            print(f"iteration {iteration+1} done.")

    print(f"epoch: {epoch+1+0}, train loss: {total_training_loss/total_training_samples}, train accuracy: {total_training_correct_classfied/total_training_samples}, time taken: {time.time()-start}")

    # if (epoch-1)%3==0:
    #     print_test_details(model, criterion, test_loader)   
    print()





In [None]:
old_new_mapping = {
    "conv1.weight": "block1_conv1.conv.weight",
"bn1.weight": "block1_conv1.bn.weight",
"bn1.bias": "block1_conv1.bn.bias",
"bn1.running_mean": "block1_conv1.bn.running_mean",
"bn1.running_var": "block1_conv1.bn.running_var",
"bn1.num_batches_tracked": "block1_conv1.bn.num_batches_tracked",
"conv2.weight": "block1_conv2.conv.weight",
"bn2.weight": "block1_conv2.bn.weight",
"bn2.bias": "block1_conv2.bn.bias",
"bn2.running_mean": "block1_conv2.bn.running_mean",
"bn2.running_var": "block1_conv2.bn.running_var",
"bn2.num_batches_tracked": "block1_conv2.bn.num_batches_tracked",
"block1.skip.weight": "residual1.conv.weight",
"block1.skipbn.weight": "residual1.bn.weight",
"block1.skipbn.bias": "residual1.bn.bias",
"block1.skipbn.running_mean": "residual1.bn.running_mean",
"block1.skipbn.running_var": "residual1.bn.running_var",
"block1.skipbn.num_batches_tracked": "residual1.bn.num_batches_tracked",
"block1.rep.0.conv1.weight": "block2_conv1.depthwise_conv.weight",
"block1.rep.0.pointwise.weight": "block2_conv1.pointwise_conv.weight",
"block1.rep.1.weight": "block2_conv1.bn.weight",
"block1.rep.1.bias": "block2_conv1.bn.bias",
"block1.rep.1.running_mean": "block2_conv1.bn.running_mean",
"block1.rep.1.running_var": "block2_conv1.bn.running_var",
"block1.rep.1.num_batches_tracked": "block2_conv1.bn.num_batches_tracked",
"block1.rep.3.conv1.weight": "block2_conv2.depthwise_conv.weight",
"block1.rep.3.pointwise.weight": "block2_conv2.pointwise_conv.weight",
"block1.rep.4.weight": "block2_conv2.bn.weight",
"block1.rep.4.bias": "block2_conv2.bn.bias",
"block1.rep.4.running_mean": "block2_conv2.bn.running_mean",
"block1.rep.4.running_var": "block2_conv2.bn.running_var",
"block1.rep.4.num_batches_tracked": "block2_conv2.bn.num_batches_tracked",
"block2.skip.weight": "residual2.conv.weight",
"block2.skipbn.weight": "residual2.bn.weight",
"block2.skipbn.bias": "residual2.bn.bias",
"block2.skipbn.running_mean": "residual2.bn.running_mean",
"block2.skipbn.running_var": "residual2.bn.running_var",
"block2.skipbn.num_batches_tracked": "residual2.bn.num_batches_tracked",
"block2.rep.1.conv1.weight": "block3_conv1.depthwise_conv.weight",
"block2.rep.1.pointwise.weight": "block3_conv1.pointwise_conv.weight",
"block2.rep.2.weight": "block3_conv1.bn.weight",
"block2.rep.2.bias": "block3_conv1.bn.bias",
"block2.rep.2.running_mean": "block3_conv1.bn.running_mean",
"block2.rep.2.running_var": "block3_conv1.bn.running_var",
"block2.rep.2.num_batches_tracked": "block3_conv1.bn.num_batches_tracked",
"block2.rep.4.conv1.weight": "block3_conv2.depthwise_conv.weight",
"block2.rep.4.pointwise.weight": "block3_conv2.pointwise_conv.weight",
"block2.rep.5.weight": "block3_conv2.bn.weight",
"block2.rep.5.bias": "block3_conv2.bn.bias",
"block2.rep.5.running_mean": "block3_conv2.bn.running_mean",
"block2.rep.5.running_var": "block3_conv2.bn.running_var",
"block2.rep.5.num_batches_tracked": "block3_conv2.bn.num_batches_tracked",
"block3.skip.weight": "residual3.conv.weight",
"block3.skipbn.weight": "residual3.bn.weight",
"block3.skipbn.bias": "residual3.bn.bias",
"block3.skipbn.running_mean": "residual3.bn.running_mean",
"block3.skipbn.running_var": "residual3.bn.running_var",
"block3.skipbn.num_batches_tracked": "residual3.bn.num_batches_tracked",
"block3.rep.1.conv1.weight": "block4_conv1.depthwise_conv.weight",
"block3.rep.1.pointwise.weight": "block4_conv1.pointwise_conv.weight",
"block3.rep.2.weight": "block4_conv1.bn.weight",
"block3.rep.2.bias": "block4_conv1.bn.bias",
"block3.rep.2.running_mean": "block4_conv1.bn.running_mean",
"block3.rep.2.running_var": "block4_conv1.bn.running_var",
"block3.rep.2.num_batches_tracked": "block4_conv1.bn.num_batches_tracked",
"block3.rep.4.conv1.weight": "block4_conv2.depthwise_conv.weight",
"block3.rep.4.pointwise.weight": "block4_conv2.pointwise_conv.weight",
"block3.rep.5.weight": "block4_conv2.bn.weight",
"block3.rep.5.bias": "block4_conv2.bn.bias",
"block3.rep.5.running_mean": "block4_conv2.bn.running_mean",
"block3.rep.5.running_var": "block4_conv2.bn.running_var",
"block3.rep.5.num_batches_tracked": "block4_conv2.bn.num_batches_tracked",
"block4.rep.1.conv1.weight": "middle_layers.0.main.1.depthwise_conv.weight",
"block4.rep.1.pointwise.weight": "middle_layers.0.main.1.pointwise_conv.weight",
"block4.rep.2.weight": "middle_layers.0.main.1.bn.weight",
"block4.rep.2.bias": "middle_layers.0.main.1.bn.bias",
"block4.rep.2.running_mean": "middle_layers.0.main.1.bn.running_mean",
"block4.rep.2.running_var": "middle_layers.0.main.1.bn.running_var",
"block4.rep.2.num_batches_tracked": "middle_layers.0.main.1.bn.num_batches_tracked",
"block4.rep.4.conv1.weight": "middle_layers.0.main.3.depthwise_conv.weight",
"block4.rep.4.pointwise.weight": "middle_layers.0.main.3.pointwise_conv.weight",
"block4.rep.5.weight": "middle_layers.0.main.3.bn.weight",
"block4.rep.5.bias": "middle_layers.0.main.3.bn.bias",
"block4.rep.5.running_mean": "middle_layers.0.main.3.bn.running_mean",
"block4.rep.5.running_var": "middle_layers.0.main.3.bn.running_var",
"block4.rep.5.num_batches_tracked": "middle_layers.0.main.3.bn.num_batches_tracked",
"block4.rep.7.conv1.weight": "middle_layers.0.main.5.depthwise_conv.weight",
"block4.rep.7.pointwise.weight": "middle_layers.0.main.5.pointwise_conv.weight",
"block4.rep.8.weight": "middle_layers.0.main.5.bn.weight",
"block4.rep.8.bias": "middle_layers.0.main.5.bn.bias",
"block4.rep.8.running_mean": "middle_layers.0.main.5.bn.running_mean",
"block4.rep.8.running_var": "middle_layers.0.main.5.bn.running_var",
"block4.rep.8.num_batches_tracked": "middle_layers.0.main.5.bn.num_batches_tracked",
"block5.rep.1.conv1.weight": "middle_layers.1.main.1.depthwise_conv.weight",
"block5.rep.1.pointwise.weight": "middle_layers.1.main.1.pointwise_conv.weight",
"block5.rep.2.weight": "middle_layers.1.main.1.bn.weight",
"block5.rep.2.bias": "middle_layers.1.main.1.bn.bias",
"block5.rep.2.running_mean": "middle_layers.1.main.1.bn.running_mean",
"block5.rep.2.running_var": "middle_layers.1.main.1.bn.running_var",
"block5.rep.2.num_batches_tracked": "middle_layers.1.main.1.bn.num_batches_tracked",
"block5.rep.4.conv1.weight": "middle_layers.1.main.3.depthwise_conv.weight",
"block5.rep.4.pointwise.weight": "middle_layers.1.main.3.pointwise_conv.weight",
"block5.rep.5.weight": "middle_layers.1.main.3.bn.weight",
"block5.rep.5.bias": "middle_layers.1.main.3.bn.bias",
"block5.rep.5.running_mean": "middle_layers.1.main.3.bn.running_mean",
"block5.rep.5.running_var": "middle_layers.1.main.3.bn.running_var",
"block5.rep.5.num_batches_tracked": "middle_layers.1.main.3.bn.num_batches_tracked",
"block5.rep.7.conv1.weight": "middle_layers.1.main.5.depthwise_conv.weight",
"block5.rep.7.pointwise.weight": "middle_layers.1.main.5.pointwise_conv.weight",
"block5.rep.8.weight": "middle_layers.1.main.5.bn.weight",
"block5.rep.8.bias": "middle_layers.1.main.5.bn.bias",
"block5.rep.8.running_mean": "middle_layers.1.main.5.bn.running_mean",
"block5.rep.8.running_var": "middle_layers.1.main.5.bn.running_var",
"block5.rep.8.num_batches_tracked": "middle_layers.1.main.5.bn.num_batches_tracked",
"block6.rep.1.conv1.weight": "middle_layers.2.main.1.depthwise_conv.weight",
"block6.rep.1.pointwise.weight": "middle_layers.2.main.1.pointwise_conv.weight",
"block6.rep.2.weight": "middle_layers.2.main.1.bn.weight",
"block6.rep.2.bias": "middle_layers.2.main.1.bn.bias",
"block6.rep.2.running_mean": "middle_layers.2.main.1.bn.running_mean",
"block6.rep.2.running_var": "middle_layers.2.main.1.bn.running_var",
"block6.rep.2.num_batches_tracked": "middle_layers.2.main.1.bn.num_batches_tracked",
"block6.rep.4.conv1.weight": "middle_layers.2.main.3.depthwise_conv.weight",
"block6.rep.4.pointwise.weight": "middle_layers.2.main.3.pointwise_conv.weight",
"block6.rep.5.weight": "middle_layers.2.main.3.bn.weight",
"block6.rep.5.bias": "middle_layers.2.main.3.bn.bias",
"block6.rep.5.running_mean": "middle_layers.2.main.3.bn.running_mean",
"block6.rep.5.running_var": "middle_layers.2.main.3.bn.running_var",
"block6.rep.5.num_batches_tracked": "middle_layers.2.main.3.bn.num_batches_tracked",
"block6.rep.7.conv1.weight": "middle_layers.2.main.5.depthwise_conv.weight",
"block6.rep.7.pointwise.weight": "middle_layers.2.main.5.pointwise_conv.weight",
"block6.rep.8.weight": "middle_layers.2.main.5.bn.weight",
"block6.rep.8.bias": "middle_layers.2.main.5.bn.bias",
"block6.rep.8.running_mean": "middle_layers.2.main.5.bn.running_mean",
"block6.rep.8.running_var": "middle_layers.2.main.5.bn.running_var",
"block6.rep.8.num_batches_tracked": "middle_layers.2.main.5.bn.num_batches_tracked",
"block7.rep.1.conv1.weight": "middle_layers.3.main.1.depthwise_conv.weight",
"block7.rep.1.pointwise.weight": "middle_layers.3.main.1.pointwise_conv.weight",
"block7.rep.2.weight": "middle_layers.3.main.1.bn.weight",
"block7.rep.2.bias": "middle_layers.3.main.1.bn.bias",
"block7.rep.2.running_mean": "middle_layers.3.main.1.bn.running_mean",
"block7.rep.2.running_var": "middle_layers.3.main.1.bn.running_var",
"block7.rep.2.num_batches_tracked": "middle_layers.3.main.1.bn.num_batches_tracked",
"block7.rep.4.conv1.weight": "middle_layers.3.main.3.depthwise_conv.weight",
"block7.rep.4.pointwise.weight": "middle_layers.3.main.3.pointwise_conv.weight",
"block7.rep.5.weight": "middle_layers.3.main.3.bn.weight",
"block7.rep.5.bias": "middle_layers.3.main.3.bn.bias",
"block7.rep.5.running_mean": "middle_layers.3.main.3.bn.running_mean",
"block7.rep.5.running_var": "middle_layers.3.main.3.bn.running_var",
"block7.rep.5.num_batches_tracked": "middle_layers.3.main.3.bn.num_batches_tracked",
"block7.rep.7.conv1.weight": "middle_layers.3.main.5.depthwise_conv.weight",
"block7.rep.7.pointwise.weight": "middle_layers.3.main.5.pointwise_conv.weight",
"block7.rep.8.weight": "middle_layers.3.main.5.bn.weight",
"block7.rep.8.bias": "middle_layers.3.main.5.bn.bias",
"block7.rep.8.running_mean": "middle_layers.3.main.5.bn.running_mean",
"block7.rep.8.running_var": "middle_layers.3.main.5.bn.running_var",
"block7.rep.8.num_batches_tracked": "middle_layers.3.main.5.bn.num_batches_tracked",
"block8.rep.1.conv1.weight": "middle_layers.4.main.1.depthwise_conv.weight",
"block8.rep.1.pointwise.weight": "middle_layers.4.main.1.pointwise_conv.weight",
"block8.rep.2.weight": "middle_layers.4.main.1.bn.weight",
"block8.rep.2.bias": "middle_layers.4.main.1.bn.bias",
"block8.rep.2.running_mean": "middle_layers.4.main.1.bn.running_mean",
"block8.rep.2.running_var": "middle_layers.4.main.1.bn.running_var",
"block8.rep.2.num_batches_tracked": "middle_layers.4.main.1.bn.num_batches_tracked",
"block8.rep.4.conv1.weight": "middle_layers.4.main.3.depthwise_conv.weight",
"block8.rep.4.pointwise.weight": "middle_layers.4.main.3.pointwise_conv.weight",
"block8.rep.5.weight": "middle_layers.4.main.3.bn.weight",
"block8.rep.5.bias": "middle_layers.4.main.3.bn.bias",
"block8.rep.5.running_mean": "middle_layers.4.main.3.bn.running_mean",
"block8.rep.5.running_var": "middle_layers.4.main.3.bn.running_var",
"block8.rep.5.num_batches_tracked": "middle_layers.4.main.3.bn.num_batches_tracked",
"block8.rep.7.conv1.weight": "middle_layers.4.main.5.depthwise_conv.weight",
"block8.rep.7.pointwise.weight": "middle_layers.4.main.5.pointwise_conv.weight",
"block8.rep.8.weight": "middle_layers.4.main.5.bn.weight",
"block8.rep.8.bias": "middle_layers.4.main.5.bn.bias",
"block8.rep.8.running_mean": "middle_layers.4.main.5.bn.running_mean",
"block8.rep.8.running_var": "middle_layers.4.main.5.bn.running_var",
"block8.rep.8.num_batches_tracked": "middle_layers.4.main.5.bn.num_batches_tracked",
"block9.rep.1.conv1.weight": "middle_layers.5.main.1.depthwise_conv.weight",
"block9.rep.1.pointwise.weight": "middle_layers.5.main.1.pointwise_conv.weight",
"block9.rep.2.weight": "middle_layers.5.main.1.bn.weight",
"block9.rep.2.bias": "middle_layers.5.main.1.bn.bias",
"block9.rep.2.running_mean": "middle_layers.5.main.1.bn.running_mean",
"block9.rep.2.running_var": "middle_layers.5.main.1.bn.running_var",
"block9.rep.2.num_batches_tracked": "middle_layers.5.main.1.bn.num_batches_tracked",
"block9.rep.4.conv1.weight": "middle_layers.5.main.3.depthwise_conv.weight",
"block9.rep.4.pointwise.weight": "middle_layers.5.main.3.pointwise_conv.weight",
"block9.rep.5.weight": "middle_layers.5.main.3.bn.weight",
"block9.rep.5.bias": "middle_layers.5.main.3.bn.bias",
"block9.rep.5.running_mean": "middle_layers.5.main.3.bn.running_mean",
"block9.rep.5.running_var": "middle_layers.5.main.3.bn.running_var",
"block9.rep.5.num_batches_tracked": "middle_layers.5.main.3.bn.num_batches_tracked",
"block9.rep.7.conv1.weight": "middle_layers.5.main.5.depthwise_conv.weight",
"block9.rep.7.pointwise.weight": "middle_layers.5.main.5.pointwise_conv.weight",
"block9.rep.8.weight": "middle_layers.5.main.5.bn.weight",
"block9.rep.8.bias": "middle_layers.5.main.5.bn.bias",
"block9.rep.8.running_mean": "middle_layers.5.main.5.bn.running_mean",
"block9.rep.8.running_var": "middle_layers.5.main.5.bn.running_var",
"block9.rep.8.num_batches_tracked": "middle_layers.5.main.5.bn.num_batches_tracked",
"block10.rep.1.conv1.weight": "middle_layers.6.main.1.depthwise_conv.weight",
"block10.rep.1.pointwise.weight": "middle_layers.6.main.1.pointwise_conv.weight",
"block10.rep.2.weight": "middle_layers.6.main.1.bn.weight",
"block10.rep.2.bias": "middle_layers.6.main.1.bn.bias",
"block10.rep.2.running_mean": "middle_layers.6.main.1.bn.running_mean",
"block10.rep.2.running_var": "middle_layers.6.main.1.bn.running_var",
"block10.rep.2.num_batches_tracked": "middle_layers.6.main.1.bn.num_batches_tracked",
"block10.rep.4.conv1.weight": "middle_layers.6.main.3.depthwise_conv.weight",
"block10.rep.4.pointwise.weight": "middle_layers.6.main.3.pointwise_conv.weight",
"block10.rep.5.weight": "middle_layers.6.main.3.bn.weight",
"block10.rep.5.bias": "middle_layers.6.main.3.bn.bias",
"block10.rep.5.running_mean": "middle_layers.6.main.3.bn.running_mean",
"block10.rep.5.running_var": "middle_layers.6.main.3.bn.running_var",
"block10.rep.5.num_batches_tracked": "middle_layers.6.main.3.bn.num_batches_tracked",
"block10.rep.7.conv1.weight": "middle_layers.6.main.5.depthwise_conv.weight",
"block10.rep.7.pointwise.weight": "middle_layers.6.main.5.pointwise_conv.weight",
"block10.rep.8.weight": "middle_layers.6.main.5.bn.weight",
"block10.rep.8.bias": "middle_layers.6.main.5.bn.bias",
"block10.rep.8.running_mean": "middle_layers.6.main.5.bn.running_mean",
"block10.rep.8.running_var": "middle_layers.6.main.5.bn.running_var",
"block10.rep.8.num_batches_tracked": "middle_layers.6.main.5.bn.num_batches_tracked",
"block11.rep.1.conv1.weight": "middle_layers.7.main.1.depthwise_conv.weight",
"block11.rep.1.pointwise.weight": "middle_layers.7.main.1.pointwise_conv.weight",
"block11.rep.2.weight": "middle_layers.7.main.1.bn.weight",
"block11.rep.2.bias": "middle_layers.7.main.1.bn.bias",
"block11.rep.2.running_mean": "middle_layers.7.main.1.bn.running_mean",
"block11.rep.2.running_var": "middle_layers.7.main.1.bn.running_var",
"block11.rep.2.num_batches_tracked": "middle_layers.7.main.1.bn.num_batches_tracked",
"block11.rep.4.conv1.weight": "middle_layers.7.main.3.depthwise_conv.weight",
"block11.rep.4.pointwise.weight": "middle_layers.7.main.3.pointwise_conv.weight",
"block11.rep.5.weight": "middle_layers.7.main.3.bn.weight",
"block11.rep.5.bias": "middle_layers.7.main.3.bn.bias",
"block11.rep.5.running_mean": "middle_layers.7.main.3.bn.running_mean",
"block11.rep.5.running_var": "middle_layers.7.main.3.bn.running_var",
"block11.rep.5.num_batches_tracked": "middle_layers.7.main.3.bn.num_batches_tracked",
"block11.rep.7.conv1.weight": "middle_layers.7.main.5.depthwise_conv.weight",
"block11.rep.7.pointwise.weight": "middle_layers.7.main.5.pointwise_conv.weight",
"block11.rep.8.weight": "middle_layers.7.main.5.bn.weight",
"block11.rep.8.bias": "middle_layers.7.main.5.bn.bias",
"block11.rep.8.running_mean": "middle_layers.7.main.5.bn.running_mean",
"block11.rep.8.running_var": "middle_layers.7.main.5.bn.running_var",
"block11.rep.8.num_batches_tracked": "middle_layers.7.main.5.bn.num_batches_tracked",
"block12.skip.weight": "exit_residual.conv.weight",
"block12.skipbn.weight": "exit_residual.bn.weight",
"block12.skipbn.bias": "exit_residual.bn.bias",
"block12.skipbn.running_mean": "exit_residual.bn.running_mean",
"block12.skipbn.running_var": "exit_residual.bn.running_var",
"block12.skipbn.num_batches_tracked": "exit_residual.bn.num_batches_tracked",
"block12.rep.1.conv1.weight": "exitblock1_conv1.depthwise_conv.weight",
"block12.rep.1.pointwise.weight": "exitblock1_conv1.pointwise_conv.weight",
"block12.rep.2.weight": "exitblock1_conv1.bn.weight",
"block12.rep.2.bias": "exitblock1_conv1.bn.bias",
"block12.rep.2.running_mean": "exitblock1_conv1.bn.running_mean",
"block12.rep.2.running_var": "exitblock1_conv1.bn.running_var",
"block12.rep.2.num_batches_tracked": "exitblock1_conv1.bn.num_batches_tracked",
"block12.rep.4.conv1.weight": "exitblock1_conv2.depthwise_conv.weight",
"block12.rep.4.pointwise.weight": "exitblock1_conv2.pointwise_conv.weight",
"block12.rep.5.weight": "exitblock1_conv2.bn.weight",
"block12.rep.5.bias": "exitblock1_conv2.bn.bias",
"block12.rep.5.running_mean": "exitblock1_conv2.bn.running_mean",
"block12.rep.5.running_var": "exitblock1_conv2.bn.running_var",
"block12.rep.5.num_batches_tracked": "exitblock1_conv2.bn.num_batches_tracked",
"conv3.conv1.weight": "exitblock2_conv1.depthwise_conv.weight",
"conv3.pointwise.weight": "exitblock2_conv1.pointwise_conv.weight",
"bn3.weight": "exitblock2_conv1.bn.weight",
"bn3.bias": "exitblock2_conv1.bn.bias",
"bn3.running_mean": "exitblock2_conv1.bn.running_mean",
"bn3.running_var": "exitblock2_conv1.bn.running_var",
"bn3.num_batches_tracked": "exitblock2_conv1.bn.num_batches_tracked",
"conv4.conv1.weight": "exitblock2_conv2.depthwise_conv.weight",
"conv4.pointwise.weight": "exitblock2_conv2.pointwise_conv.weight",
"bn4.weight": "exitblock2_conv2.bn.weight",
"bn4.bias": "exitblock2_conv2.bn.bias",
"bn4.running_mean": "exitblock2_conv2.bn.running_mean",
"bn4.running_var": "exitblock2_conv2.bn.running_var",
"bn4.num_batches_tracked": "exitblock2_conv2.bn.num_batches_tracked",
"fc.weight": "fc.weight",
"fc.bias": "fc.bias",
}

In [None]:
waw = BNConv2D(3, 32)