In [None]:
import torch
import torch.nn as nn

class InceptionModule(nn.Module):
    def __init__(self, in_channels, out_1x1, out_3x3, out_5x5, out_maxpool):
        super(InceptionModule, self).__init__()

        # 1x1 convolution branch
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_1x1, kernel_size=1),
            nn.ReLU(inplace=True)
        )

        # 3x3 convolution branch
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, out_3x3, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_3x3, out_3x3, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # 5x5 convolution branch
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, out_5x5, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_5x5, out_5x5, kernel_size=5, padding=2),
            nn.ReLU(inplace=True)
        )

        # Max-pooling branch
        self.branch_maxpool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, out_maxpool, kernel_size=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        out1x1 = self.branch1x1(x)
        out3x3 = self.branch3x3(x)
        out5x5 = self.branch5x5(x)
        out_maxpool = self.branch_maxpool(x)

        # Concatenate the outputs from all branches along the channel dimension
        out = torch.cat([out1x1, out3x3, out5x5, out_maxpool], dim=1)

        return out

# Example usage:
in_channels = 64
out_1x1 = 16
out_3x3 = 24
out_5x5 = 8
out_maxpool = 16

inception = InceptionModule(in_channels, out_1x1, out_3x3, out_5x5, out_maxpool)