In [1]:
# temporary solution to crashing
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
#import torchvision.transforms as transforms
from torch import Tensor
from torchvision.ops import Conv2dNormActivation

import matplotlib.pyplot as plt
import numpy as np
#import mysql.connector as connector

from pathlib import Path

In [None]:
# issues to debug: since groups == hidden_dim and hidden_dim == int(round(in_channels * expand_ratio))
# the in_channels are not divisible by groups

class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio, norm_layer=None):
        super().__init__()

        assert stride == 1 or stride == 2, 'stride must be 1 or 2'

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        
        hidden_dim = int(round(in_channels * expand_ratio))
        # print(hidden_dim)

        layers = [] 

        if expand_ratio != 1:
            # add expansion 
            layers.append(Conv2dNormActivation(in_channels, hidden_dim, kernel_size=3, stride=stride, groups=hidden_dim, 
                                               norm_layer=norm_layer, activation_layer=nn.ReLU6))
        # depthwise conv => pointwise => norm_layer
        layers.extend([Conv2dNormActivation(hidden_dim, hidden_dim, kernel_size=3, stride=stride, groups=hidden_dim, norm_layer=norm_layer, activation_layer=nn.ReLU6), 
                                                nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias=False), 
                                                norm_layer(out_channels)])
        
        self.conv = nn.Sequential(*layers)
        print(self.conv)
        self.stride = stride
        self.out_channels = out_channels
        self.in_channels = in_channels
        self._is_cn = stride > 1 # downsample indicator - what does this mean?
        self.use_res_conn = stride == 1 and in_channels == out_channels

    def forward(self, x):
        if self.use_res_conn:
            return x + self.conv(x)
        else:                
            return self.conv(x)
        

In [None]:
# testing the inverted residual layers
# everything seems to work except there is an issue with "groups"

test_inv = InvertedResidual(3, 30, 1, 1)
test_input = torch.randn(1, 3, 32, 32)

test_inv(test_input).shape

Sequential(
  (0): Conv2dNormActivation(
    (0): Conv2d(3, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
  )
  (1): Conv2dNormActivation(
    (0): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
  )
  (2): Conv2d(1, 30, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (3): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


torch.Size([1, 30, 32, 32])