In [None]:
import torch
import torch.nn as nn

In [None]:
class block(nn.Module):
  def __init__(self, in_channels, out_channels, stride=1, expansion=6, downsample=None):
    super(block, self).__init__()
    middle_channels = in_channels*expansion

    self.pass1 = nn.Sequential(
        nn.Conv2d(in_channels, middle_channels, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(middle_channels),
        nn.Mish(),
        nn.Conv2d(middle_channels, middle_channels, 3, stride=stride, padding=1, bias=False),
        nn.BatchNorm2d(middle_channels),
        nn.Mish(),
        nn.Conv2d(middle_channels, out_channels, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.Mish(),
    )

    self.pass2 = nn.Sequential(
        nn.Conv2d(in_channels, middle_channels, kernel_size=1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(middle_channels),
        nn.Mish(),
        nn.Conv2d(middle_channels, middle_channels, kernel_size=3, stride=stride, padding=1,groups=in_channels, bias=False),
        nn.BatchNorm2d(middle_channels),
        nn.Mish(),
        nn.Conv2d(middle_channels, out_channels, kernel_size=1, stride=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.Mish(),
    )

    self.downsample = downsample
  
  def forward(self, x):
    identity = x.clone()
    x1 = self.pass1(x)
    x2 = self.pass2(x)
    x = x1 + x2
    if self.downsample is not None:
      identity = self.downsample(identity)
    x += identity # always adds identity
    return x



class Network(nn.Module):
  def __init__(self, in_channels=3, num_classes=10):
    super(Network, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=7, stride=2, padding=3, bias=False)
    self.bn1 = nn.BatchNorm2d(32)
    self.relu6 = nn.Mish()
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    self.block1 = self._make_layer(block, 1, 32, 16, 1, 1)
    self.block2 = self._make_layer(block, 6, 16, 24, 2, 2)
    self.block3 = self._make_layer(block, 6, 24, 32, 3, 2)
    self.block4 = self._make_layer(block, 6, 32, 64, 4, 2)
    self.block5 = self._make_layer(block, 6, 64, 96, 3, 1)
    self.block6 = self._make_layer(block, 6, 96, 160, 3, 2)
    self.block7 = self._make_layer(block, 6, 160, 320, 1, 1)
    
    self.conv2 = nn.Conv2d(320,1280, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(1280)
    self.avgPool = nn.AvgPool2d(7)
    self.flatten = nn.Flatten()
    self.fc = nn.Linear(1280, num_classes)

    self.dropout = nn.Dropout(0.1)

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu6(x)
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)
    x = self.block5(x)
    x = self.block6(x)
    x = self.block7(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu6(x)
    x = self.avgPool(x)
    x = self.flatten(x)
    x = self.fc(x)
    x = F.log_softmax(x, dim=1)
    
    return x
  
  def _make_layer(self, block, expansion, in_channels, out_channels, repeats, stride):
    layers = []
    downsample = None

    if stride != 1 or in_channels != out_channels: # identity convolution now occurs at least once in every function call
      downsample = nn.Sequential(
          nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride, bias=False),
          nn.BatchNorm2d(out_channels),
      )


    layers.append(
        block(in_channels, out_channels, stride, expansion, downsample)
    )

    for _ in range(repeats-1):
      layers.append(
          block(out_channels, out_channels, 1, expansion)
      )
    
    return nn.Sequential(*layers)


model = Network()