In [1]:
import torch
import torch.nn as nn

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),   # (1, 28, 28) → (32, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(2),                  # → (32, 14, 14)
            nn.Conv2d(32, 64, 3, padding=1),  # → (64, 14, 14)
            nn.ReLU(),
            nn.MaxPool2d(2)                   # → (64, 7, 7)
        )
        
        self.flatten = nn.Flatten()
        self.classifier = nn.Sequential(
            nn.LazyLinear(128),  # 자동으로 in_features 계산
            nn.ReLU(),
            nn.Linear(128, 10)   # MNIST: 10 classes
        )

    def forward(self, x):
        x = self.features(x)
        print("Before Flatten:", x.shape)  # ← 여기서 feature map shape 출력
        
        x = self.flatten(x)
        print("After Flatten:", x.shape)   # ← Flatten 후 shape 출력
        
        x = self.classifier(x)
        return x

In [5]:
# 테스트
model = SimpleCNN()
dummy_input = torch.randn(1, 1, 28, 28)
output = model(dummy_input)
print("Output shape:", output.shape)  # torch.Size([1, 10])

Before Flatten: torch.Size([1, 64, 7, 7])
After Flatten: torch.Size([1, 3136])
Output shape: torch.Size([1, 10])
