In [1]:
import torch
from torch import nn

In [2]:
class MLPCONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernal_size, stride, padding):
        super().__init__()

        self.mlpconv_block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernal_size, stride, padding), nn.ReLU(),
                                            nn.Conv2d(out_channels, out_channels, kernel_size=(1,1)), nn.ReLU(),
                                            nn.Conv2d(out_channels, out_channels, kernel_size=(1,1)), nn.ReLU())
    
    def forward(self, x):
        return self.mlpconv_block(x)


In [3]:
class NIN(nn.Module):
    def __init__(self, in_channels, out_classes):
        super().__init__()
        self.ninblock1 = MLPCONV(in_channels, 96, kernal_size=11, stride=4, padding=0)
        self.mp1 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        self.ninblock2 = MLPCONV(96, 256, kernal_size=5, stride=1, padding=2)
        self.mp2 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.ninblock3 = MLPCONV(256, 384, kernal_size=3, stride=1, padding=1)
        self.mp3 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.dp = nn.Dropout2d(0.5)

        self.ninblock4 = MLPCONV(384, out_classes, kernal_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.flat = nn.Flatten()

    def forward(self, x):
        x = self.ninblock1(x)
        x = self.mp1(x)
        
        x = self.ninblock2(x)
        x = self.mp2(x)

        x = self.ninblock3(x)
        x = self.mp3(x)

        x = self.dp(x)

        x = self.ninblock4(x)
        x = self.avgpool(x)
        x = self.flat(x)

        return x

In [4]:
nin_model = NIN(3, 100)
x_dumb = torch.randn(32, 3, 224, 224)
y_dumb = nin_model(x_dumb)
y_dumb.shape

torch.Size([32, 100])