In [12]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torchsummary


class FireStorm(nn.Module):
    """ Firestorm module, a modified version of the Fire module from SqueezeNet
    This model is based on Fire module from SqueezeNet with the addition of BatchNorm 
    and the change of ReLU to LeakyReLU

    Args:
        nn (_type_): Inherited from the nn module

    Returns:
        self: Return the class instance, can be used to call the class methods
    """

    def __init__(
        self,
        inplanes: int,
        squeeze_planes: int,
        expand1x1_planes: int,
        expand3x3_planes: int,
        
    ) -> None:
        """initialization of the FireStorm module,
        changing relatively 1x1 and 3x3 will will change the importance that the output will attribute to the two filters

        Args:
            inplanes (int): number of channels in the input tensor
            squeeze_planes (int): number of channels in the squeeze layer, is the lowest number of channels in the module,
            lowering this number will reduce the number of parameters in the model, but can reduce performance
            expand1x1_planes (int): number of channels in the 1x1 convolutional layer
            expand3x3_planes (int): number of channels in the 3x3 convolutional layer
        """
        super().__init__()
        self.inplanes = inplanes

        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.LeakyReLU(inplace=True)
        self.squeeze_bn = nn.BatchNorm2d(squeeze_planes)
       
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
        self.expand1x1_activation = nn.LeakyReLU(inplace=True)
        self.expand1x1_bn = nn.BatchNorm2d(expand1x1_planes)
        
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
        self.expand3x3_activation = nn.LeakyReLU(inplace=True)
        self.expand3x3_bn = nn.BatchNorm2d(expand3x3_planes)

    def forward(self, x):
        x = self.squeeze_bn(self.squeeze_activation(self.squeeze(x)))
        return torch.cat([
            self.expand1x1_bn(
                self.expand1x1_activation(self.expand1x1(x))
            ),
            self.expand3x3_bn(
                self.expand3x3_activation(self.expand3x3(x)) 
            ),
            ], 1)
            



class StormColorModel(nn.Module):
    """
    This net is a custom version of SqueezeNet: we replaced the Fire module with our custom FireStorm module,
    modified the final convolutional layer into a fully connected layer
    we used leaky ReLU instead of ReLU
    we added BatchNorm after each convolutional layer
    we modified the number of fire modules and the number of filters in each fire module
    """
    def __init__(self, num_classes: int = 251, dropout: float = 0.5) -> None:
        super().__init__()
        self.num_classes = num_classes

        self.features = nn.Sequential(
            
            nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(inplace=True),

            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),

            FireStorm(64, 16, 64, 64),
            FireStorm(128, 16, 64, 64),

            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),

            FireStorm(128, 32, 128, 128),
            FireStorm(256, 32, 128, 128),

            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),

            FireStorm(256, 48, 192, 192),
            FireStorm(384, 64, 256, 256),
            
        )

class StormModel2(nn.Module):
    """
    See Stormodel doc, version with more parameters
    """
    def __init__(self, num_classes: int = 251, dropout: float = 0.5) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.features = nn.Sequential(
            
            nn.Conv2d(3, 64, kernel_size=3, stride=2),
            nn.LeakyReLU(inplace=True),

            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),

            FireStorm(64, 16, 64, 64),
            FireStorm(128, 16, 64, 64),

            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),

            FireStorm(128, 32, 128, 128),
            FireStorm(256, 32, 128, 128),

            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),

            FireStorm(256, 48, 192, 192),
            FireStorm(384, 64, 192, 192),
            FireStorm(384, 64, 256, 256),  # added module
            
        )


        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, 512),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(512, self.num_classes),
        )
       
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.classifier(x)
        return torch.flatten(x, 1)




In [13]:
torchsummary.summary(StormModel2(), (3, 224, 224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 111, 111]           1,792
         LeakyReLU-2         [-1, 64, 111, 111]               0
         MaxPool2d-3           [-1, 64, 55, 55]               0
            Conv2d-4           [-1, 16, 55, 55]           1,040
         LeakyReLU-5           [-1, 16, 55, 55]               0
       BatchNorm2d-6           [-1, 16, 55, 55]              32
            Conv2d-7           [-1, 64, 55, 55]           1,088
         LeakyReLU-8           [-1, 64, 55, 55]               0
       BatchNorm2d-9           [-1, 64, 55, 55]             128
           Conv2d-10           [-1, 64, 55, 55]           9,280
        LeakyReLU-11           [-1, 64, 55, 55]               0
      BatchNorm2d-12           [-1, 64, 55, 55]             128
        FireStorm-13          [-1, 128, 55, 55]               0
           Conv2d-14           [-1, 16,