In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import Tensor

In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, inplanes: int, planes: int, stride: int) -> None:
        super(ResidualBlock, self).__init__()
        self.in_channels = inplanes
        self.out_channels = planes
        self.stride = stride
        self.conv1 = nn.Conv2d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = 3,
            stride = self.stride, 
            padding = 1
        )
        self.bn1 = nn.BatchNorm2d(self.out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(
            in_channels = self.out_channels,
            out_channels = self.out_channels,
            kernel_size = 3,
            padding = 1
        )
        self.bn2 = nn.BatchNorm2d(self.out_channels)
        self.downsample = nn.Conv2d(
                in_channels = self.in_channels,
                out_channels = self.out_channels,
                kernel_size = 1,
                stride = self.stride
        )
        self.sh_conv = nn.BatchNorm2d(self.out_channels)
        
    def forward(self, X: Tensor) -> Tensor:
        if self.in_channels != self.out_channels or self.stride > 1:
            residual = self.batchNorm2d(self.sh_conv(X))
        else:
            residual = X
            
        out = nn.ReLU(self.batchNorm2d(self.conv1(X)))
        out = nn.batchNorm2d(self.conv2(X))
        
        out += residual
        out = nn.ReLU(out)
        
        return out

In [26]:
batch_size = 16
channels = 3  # Numero di canali
height, width = 32, 32  # Dimensioni dell'immagine

# Crea un input di esempio
input_tensor = torch.randn(batch_size, channels, height, width)

# Definisci i parametri del blocco residuo
inplanes = channels
planes = 64
stride = 1

# Crea un'istanza del blocco residuo
residual_block = ResidualBlock(inplanes, planes, stride)

# Passa l'input attraverso il blocco residuo
output_tensor = residual_block(input_tensor)

# Verifica le dimensioni dell'output
print("Dimensioni dell'input:", input_tensor.size())
print("Dimensioni dell'output:", output_tensor.size())

Dimensioni dell'input: torch.Size([16, 3, 32, 32])
Dimensioni dell'output: torch.Size([16, 64, 32, 32])
