In [2]:
import torch
import torch.nn as nn

n_classes = 24
kernel_size = 3
IMG_CHS = 1  # grayscale

class CNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(IMG_CHS, 25, kernel_size, stride=1, padding=1),
            nn.BatchNorm2d(25),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(25, 50, kernel_size, stride=1, padding=1),
            nn.BatchNorm2d(50),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d(2, stride=2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(50, 75, kernel_size, stride=1, padding=1),
            nn.BatchNorm2d(75),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )
        self.flatten = nn.Flatten()
        self.fc1 = nn.LazyLinear(512)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(512, n_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        print("Before Flatten:", x.shape)  # <-- Flatten 직전 shape 출력
        x = self.flatten(x)
        print("Flatten shape:", x.shape)  # <-- Flatten 직후 shape 출력
        x = self.fc1(x)
        x = self.dropout(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

In [3]:
# 테스트
# Before Flatten: torch.Size([1, 75, 3, 3])
# Output shape: torch.Size([1, 24])

model = CNNModel()
x = torch.randn(1, 1, 28, 28)
out = model(x)
print("Output shape:", out.shape)

Before Flatten: torch.Size([1, 75, 3, 3])
Flatten shape: torch.Size([1, 675])
Output shape: torch.Size([1, 24])
