-
Notifications
You must be signed in to change notification settings - Fork 10
Description
Course
machine-learning-zoomcamp
Question
When building a Convolutional Neural Network (CNN) with PyTorch, such as in the homework for Module 8, one adds one or more convolution, pooling, etc layers, then flattens the output to a vector, and finally adds one or more linear layers to output the final prediction(s). The question is, how do you know how many features are output from the convolution layer(s), after flattening, and passed in as input features to the first linear layer?
Answer
Let's start by defining a flattened, CNN model that we'd like to add a linear layer to, but we're not sure how many input features it should have.
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3,
out_channels=16,
kernel_size=(2, 2),
stride=2,
padding=2
)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d((3, 3))
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = torch.flatten(x, 1)
return xThe number of features output from the CNN layers of a model, after flattening, can be obtained in many ways:
- Attempt to calculate the number of features that are generated through each convolution layer, factoring in the number of input features, kernel size(s), stride(s), padding, pooling, etc, etc; however, this is error prone and I don't recommend this approach, so I won't include the equations here.
- Use the
summaryfunction from thetorchinfopackage, which provides the number of parameters output from the CNN layers
!pip install torchinfo
from torchinfo import summary
input_size = (1, 3, 150, 150) # Need to include batch size
model = CNN()
summary(model, input_size=input_size)
# Output
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
CNN [1, 10000] --
├─Conv2d: 1-1 [1, 16, 77, 77] 208
├─ReLU: 1-2 [1, 16, 77, 77] --
├─MaxPool2d: 1-3 [1, 16, 25, 25] --
==========================================================================================
Total params: 208
Trainable params: 208
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 1.23
==========================================================================================
Input size (MB): 0.27
Forward/backward pass size (MB): 0.76
Params size (MB): 0.00
Estimated Total Size (MB): 1.03
==========================================================================================The summary function from the torchsummary package can also be used but it requires a little more care if using a gpu and the batch size should not be included in the input_size argument (e.g. (3, 150, 150)).
- Pass a dummy variable into the model to measure how many features are output. Construct the model with only the convolution layer(s) like the CNN model above, then pass a dummy tensor in that matches the size of the image input into the network:
model = CNN()
# Create a dummy input tensor (e.g., batch_size=1, image size (3, 150, 150))
dummy_input = torch.randn(1, 3, 150, 150)
# Pass the dummy input through the convolutional layers
output_features = model(dummy_input)
# Get the calculated in_features for nn.Linear
in_features_calc = output_features.size(1)
print(in_features_calc)The in_features_calc is the number of features that will be passed into the first linear layer of the model. For example, in the CNN class above, the next layer in the forward function would be x = nn.Linear(10000, output_features).inner(x). The length of the tensor dimension to print out depends on the model's last layer. In the case, above the flatten(x, 1) call at the end of the model definition means we want the size along the second dimension, size(1).
- It's also worth mentioning that a lazy fully connected layer can be used for the first layer after the CNN layer(s) so that the number of input features doesn't have to be calculated, but instead will be inferred automatically. See
nn.LazyLineardocumentation for details.
Checklist
- I have searched existing FAQs and this question is not already answered
- The answer provides accurate, helpful information
- I have included any relevant code examples or links