In [2]:
import torch
import torch.nn as nn

In [3]:
# Define input image size
input_image_size = 512 * 512  # For ANN, flatten the image to a 1D vector
batch_size = 1  # Just a single batch for simplicity

# Sample image input for both ANN and CNN
image_tensor_ann = torch.randn(batch_size, input_image_size)  # For ANN
image_tensor_cnn = torch.randn(batch_size, 1, 512, 512)       # For CNN


### ANN Model
class SimpleANN(nn.Module):
    def __init__(self):
        super(SimpleANN, self).__init__()

        # Fully connected layer with 128 neurons
        self.fc = nn.Linear(input_image_size, 128)  

    def forward(self, x):
        return self.fc(x)


### CNN Model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 3x3 conv, 32 filters
        self.conv = nn.Conv2d(in_channels=1, 
                              out_channels=32, 
                              kernel_size=3, 
                              stride=1, 
                              padding=1) 

    def forward(self, x):
        return self.conv(x)

In [4]:
ann_model = SimpleANN()
ann_output = ann_model(image_tensor_ann)

cnn_model = SimpleCNN()
cnn_output = cnn_model(image_tensor_cnn)

# Count ANN parameters
ann_params = sum(p.numel() for p in ann_model.parameters())
print(f"ANN Parameter count: {ann_params}")

# Count CNN parameters
cnn_params = sum(p.numel() for p in cnn_model.parameters())
print(f"CNN Parameter count: {cnn_params}")

ANN Parameter count: 33554560
CNN Parameter count: 320
