In [3]:
import torch
import torch.nn as nn

def get_activation(activation_name):
    activation_name = activation_name.lower()
    if activation_name == "relu":
        return nn.ReLU()
    elif activation_name == "gelu":
        return nn.GELU()
    elif activation_name == "silu":
        return nn.SiLU()
    elif activation_name == "mish":
        return nn.Mish()
    else:
        raise ValueError(f"Unsupported activation: {activation_name}")

class CustomCNN(nn.Module):
    def __init__(self, 
                 input_channels=3,
                 num_classes=10,
                 num_filters=[32, 64, 128, 128, 256],
                 kernel_size=3,
                 activation="relu",
                 dense_units=128):
        super(CustomCNN, self).__init__()

        self.features = nn.Sequential()
        in_channels = input_channels
        act_fn = get_activation(activation)

        for i, out_channels in enumerate(num_filters):
            self.features.add_module(f'conv{i+1}', nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=1))
            self.features.add_module(f'activation{i+1}', act_fn)
            self.features.add_module(f'pool{i+1}', nn.MaxPool2d(kernel_size=2))
            in_channels = out_channels

        # Assuming input image is 224x224 (change accordingly)
        self.feature_map_size = 224 // (2**len(num_filters))  # MaxPooling halves the size each time

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(num_filters[-1] * self.feature_map_size * self.feature_map_size, dense_units),
            act_fn,
            nn.Linear(dense_units, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x



  0%|          | 314M/240G [01:26<18:19:53, 3.63MB/s] 


KeyboardInterrupt: 