In [1]:
from src.models.jelly import ShallowCSNN
from src.utils.dataloaders import create_dataloaders
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchmetrics import F1Score, AUROC, Recall, Specificity, Accuracy
import torch
from src.utils.parameters import instantiate_cls
from torchvision import transforms
from spikingjelly.activation_based import neuron, functional
from src.datasets.custom import CustomImageFolder
from torch.utils.data import DataLoader

In [4]:
from torch import nn

class CNN_F(nn.Module):
    def __init__(self, n_input=3, n_output=4, in_size=224):
        super(CNN_F, self).__init__()

        self.n_input = n_input
        self.n_output = n_output
        self.in_size = in_size
        k_size = 3

        self.block1 = nn.Sequential(
            nn.Conv2d(n_input, 64, 11, stride = 4, padding = 0),
            nn.LeakyReLU(),
            nn.LocalResponseNorm(5),
            nn.MaxPool2d(2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(64, 256, 5, stride = 1, padding = 2),
            nn.LeakyReLU(),
            nn.LocalResponseNorm(5),
            nn.MaxPool2d(2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(256, 256, 3, stride = 1, padding = 1),
            nn.LeakyReLU(),
            nn.LocalResponseNorm(5),
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(256, 256, 3, stride = 1, padding = 1),
            nn.LeakyReLU(),
            nn.LocalResponseNorm(5),
        )
        
        self.block5 = nn.Sequential(
            nn.Conv2d(256, 256, 3, stride = 1, padding = 1),
            nn.LeakyReLU(),
            nn.LocalResponseNorm(5),
            nn.MaxPool2d(2)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(13 * 13 * 256, 4096),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, n_output)
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.classifier(x)
        return x


In [None]:
from torchsummary import summary

model = CNN_F(1, 4, 224).to('cuda')
summary(model, (1, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 54, 54]           7,808
         LeakyReLU-2           [-1, 64, 54, 54]               0
 LocalResponseNorm-3           [-1, 64, 54, 54]               0
         MaxPool2d-4           [-1, 64, 27, 27]               0
            Conv2d-5          [-1, 256, 27, 27]         409,856
         LeakyReLU-6          [-1, 256, 27, 27]               0
 LocalResponseNorm-7          [-1, 256, 27, 27]               0
         MaxPool2d-8          [-1, 256, 13, 13]               0
            Conv2d-9          [-1, 256, 13, 13]         590,080
        LeakyReLU-10          [-1, 256, 13, 13]               0
LocalResponseNorm-11          [-1, 256, 13, 13]               0
           Conv2d-12          [-1, 256, 13, 13]         590,080
        LeakyReLU-13          [-1, 256, 13, 13]               0
LocalResponseNorm-14          [-1, 256,

: 