In [2]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class CNN(nn.Module):
    def __init__(self, 
                 in_channels=3, 
                 num_classes=10, 
                 filter_sizes=[64, 64, 64, 64, 64], 
                 kernel_sizes=[3, 3, 3, 3, 3],
                 activation_fn=nn.ReLU(),
                 dense_neurons=256,
                 input_size=(224, 224),
                 use_batchnorm=False,
                 dropout_rate=0.0):
        super(SimpleCNN, self).__init__()
        
        self.conv_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.use_batchnorm = use_batchnorm
        self.activation_fn = activation_fn
        
        # First layer takes the input channels
        channels = [in_channels] + filter_sizes
        
        # Create the convolutional blocks
        for i in range(len(filter_sizes)):
            self.conv_layers.append(
                nn.Conv2d(channels[i], channels[i+1], kernel_size=kernel_sizes[i], padding=kernel_sizes[i]//2)
            )
            if self.use_batchnorm:
                self.batch_norms.append(nn.BatchNorm2d(channels[i+1]))
        
        # Calculate the output dimensions after convolutions and pooling
        # Start with the input dimensions
        height, width = input_size
        
        # After each conv layer, we have a max pooling with kernel size 2 and stride 2
        for _ in range(len(filter_sizes)):
            # Max pooling reduces dimensions by half
            height //= 2
            width //= 2
        
        # Calculate flattened feature size
        flattened_size = filter_sizes[-1] * height * width
        
        # Define fully connected layers
        self.fc1 = nn.Linear(flattened_size, dense_neurons)
        self.fc2 = nn.Linear(dense_neurons, num_classes)
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, x):
        # Apply convolution blocks
        for i, conv in enumerate(self.conv_layers):
            x = conv(x)
            if self.use_batchnorm:
                x = self.batch_norms[i](x)
            x = self.activation_fn(x)
            x = F.max_pool2d(x, 2, 2)  # Max pooling with 2x2 kernel and stride 2
        
        # Flatten the output
        x = x.view(x.size(0), -1)
        
        # Apply fully connected layers
        x = self.activation_fn(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

In [None]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
import os
import numpy as np
from PIL import Image
import wandb
import matplotlib.pyplot as plt