To hammer dimension stuff in

In [3]:
import torch
import torch.nn as nn

# Define network 1: No dependency on the initial input
class ConvDeconvNet(nn.Module):
    def __init__(self):
        super(ConvDeconvNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Define network 2: Dependency on the initial input
class ConvFCNet(nn.Module):
    def __init__(self):
        super(ConvFCNet, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 256),  # 8 if input image size is 32x32, here it matters!
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        conv_out = self.conv_layers(x)
        fc_out = self.fc_layers(conv_out)
        return fc_out

In [4]:
input_image = torch.randn(1, 3, 32, 32)  # Batch size 1, 3 channels, 32x32 image

In [5]:
net1 = ConvDeconvNet()
output1 = net1(input_image)
print("Output shape of ConvDeconvNet:", output1.shape)

Output shape of ConvDeconvNet: torch.Size([1, 3, 32, 32])


In [6]:
net2 = ConvFCNet()
output2 = net2(input_image)
print("Output shape of ConvFCNet:", output2.shape)

Output shape of ConvFCNet: torch.Size([1, 10])


In [7]:
input_image_2 = torch.randn(1, 3, 64, 64)  # Batch size 1, 3 channels, 64x64 image

In [8]:
output1_2 = net1(input_image_2)
print("Output shape of ConvDeconvNet with 64x64 input:", output1_2.shape)

Output shape of ConvDeconvNet with 64x64 input: torch.Size([1, 3, 64, 64])


In [9]:
output2_2 = net2(input_image_2)  # We expect an error here
print("Output shape of ConvFCNet with 64x64 input:", output2_2.shape)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x32768 and 8192x256)