### SegNet-Basic.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EncoderBlock(nn.Module):
    """
    A single encoder block for 
    Consists of two convolutional layers, each followed by Batch Normalization and ReLU,
    and concludes with a Max-Pooling operation that returns indices.
    """
    def __init__(self, in_channels, out_channels):
        """
        Constructor for the EncoderBlock.

        Args:
            in_channels (int): Number of input feature map channels.
            out_channels (int): Number of output feature map channels (after the convolutions).
        """
        super(EncoderBlock, self).__init__() # Calls the constructor of the parent class (nn.Module).
                                            # This is mandatory for all nn.Module subclasses.

        # First Convolutional Layer
        # nn.Conv2d: Applies a 2D convolution over an input signal composed of several input planes.
        # in_channels: Number of channels in the input image/feature map.
        # out_channels: Number of channels produced by the convolution (i.e., number of filters).
        # kernel_size=3: The size of the convolutional kernel (filter) is 3x3.
        #                This is a common choice for capturing local patterns efficiently.
        # padding=1: Adds 1 pixel of zero-padding to all four sides of the input.
        #            This is crucial to ensure that the output feature map has the same
        #            spatial dimensions (height and width) as the input feature map *before* pooling.
        # bias=False: Whether to add a learnable bias to the output.
        #             It's common to set bias=False when followed by nn.BatchNorm2d,
        #             as BatchNorm's beta parameter effectively handles the bias.
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)

        # First Batch Normalization Layer
        # nn.BatchNorm2d: Applies Batch Normalization over a 4D input (N, C, H, W).
        # out_channels: Number of channels in the input to BatchNorm, which is the output of conv1.
        #               BatchNorm normalizes independently for each channel.
        self.bn1 = nn.BatchNorm2d(out_channels)

        # Second Convolutional Layer
        # Similar to conv1, but input channels are now 'out_channels' (from the previous conv/bn/relu).
        # It maintains the 'out_channels' count, effectively adding more feature extraction capacity.
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)

        # Second Batch Normalization Layer
        # Again, normalizing the output of conv2.
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Max-Pooling Layer
        # nn.MaxPool2d: Applies a 2D max-pooling operation.
        # kernel_size=2: The pooling window size is 2x2.
        # stride=2: The step size for the pooling window is 2.
        #           This means the window moves 2 pixels at a time, resulting in a 2x spatial downsampling.
        #           E.g., a 10x10 input becomes a 5x5 output.
        # return_indices=True: THIS IS CRUCIAL FOR SEGNET.
        #                      It makes the MaxPool2d layer return not only the pooled output tensor
        #                      but also the indices of the maximum values in each pooling region.
        #                      These indices are passed to the corresponding decoder's MaxUnpool2d layer.
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)

    def forward(self, x):
        """
        Defines the forward pass for the EncoderBlock.

        Args:
            x (torch.Tensor): Input tensor for the block (N, C, H, W).

        Returns:
            tuple: A tuple containing:
                - pooled_x (torch.Tensor): The output tensor after max-pooling.
                - indices (torch.Tensor): The indices of max values from the pooling operation.
        """
        # First Conv -> BN -> ReLU
        # F.relu: The Rectified Linear Unit activation function (max(0, x)).
        #         It introduces non-linearity, allowing the network to learn complex patterns.
        x = F.relu(self.bn1(self.conv1(x))) # Apply conv1, then bn1, then relu. The output is fed to conv2.

        # Second Conv -> BN -> ReLU
        x = F.relu(self.bn2(self.conv2(x))) # Apply conv2, then bn2, then relu. The output is fed to pooling.

        # Max-Pooling
        # The pool layer returns two tensors: the pooled output and the indices.
        pooled_x, indices = self.pool(x)

        return pooled_x, indices # Return both for the main SegNet model.

In [3]:
class DecoderBlock(nn.Module):
    """
    A single decoder block for SegNet-Basic.
    Consists of Max-Unpooling (using indices), followed by Batch Normalization and
    two convolutional layers (each followed by BatchNorm).
    Note: Decoder convolutions in SegNet-Basic have no bias and no ReLU, and use a 7x7 kernel.
    """
    def __init__(self, in_channels, out_channels):
        """
        Constructor for the DecoderBlock.

        Args:
            in_channels (int): Number of input feature map channels (from the previous decoder layer).
            out_channels (int): Number of output feature map channels (for the reconstructed features).
        """
        super(DecoderBlock, self).__init__()

        # Max-Unpooling Layer
        # nn.MaxUnpool2d: Performs unpooling. It uses the output and indices from nn.MaxPool2d
        #                 to reconstruct the input tensor.
        # kernel_size=2: Must match the kernel_size used in the corresponding MaxPool2d layer.
        # stride=2: Must match the stride used in the corresponding MaxPool2d layer.
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)

        # First Convolutional Layer
        # in_channels: The input to this conv layer comes from the unpooled map.
        #              This unpooled map retains the channel depth from the previous decoder layer.
        # out_channels: The number of channels for the reconstructed features at this stage.
        # kernel_size=7: THIS IS A SPECIFIC CONSTRAINT OF SEGNET-BASIC DECODER.
        #                A larger kernel allows the convolution to have a wider receptive field,
        #                which helps in "filling in" the sparse unpooled map and providing
        #                more context for smooth labeling.
        # padding=3: For a 7x7 kernel, padding=3 ensures the spatial dimensions (H, W)
        #            remain the same as the input to this convolution (the unpooled map).
        # bias=False: THIS IS A SPECIFIC CONSTRAINT OF SEGNET-BASIC DECODER.
        #             As discussed previously, when BatchNorm follows, biases are often redundant.
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=7, padding=3, bias=False)

        # First Batch Normalization Layer
        self.bn1 = nn.BatchNorm2d(out_channels)

        # Second Convolutional Layer
        # Input channels are now 'out_channels' (from the previous conv/bn).
        # It helps further refine the reconstructed features.
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=7, padding=3, bias=False)

        # Second Batch Normalization Layer
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x, indices, output_size):
        """
        Defines the forward pass for the DecoderBlock.

        Args:
            x (torch.Tensor): Input tensor for the block (N, C, H_low, W_low)
                              from the previous decoder layer.
            indices (torch.Tensor): Max-pooling indices obtained from the
                                    corresponding encoder block.
            output_size (tuple): The desired spatial size (H, W) of the output tensor
                                 after unpooling. This corresponds to the size of the
                                 feature map *before* pooling in the corresponding encoder block.
                                 It is crucial for MaxUnpool2d to correctly infer the output shape.

        Returns:
            torch.Tensor: The reconstructed dense feature map.
        """
        # Unpooling using stored indices
        # The unpool layer takes the low-res tensor 'x', the 'indices', and the 'output_size'.
        # It places values from 'x' at the positions indicated by 'indices' in the larger 'output_size' grid,
        # filling other positions with zeros. This creates a sparse feature map.
        x = self.unpool(x, indices, output_size=output_size)

        # First Conv -> BN (No ReLU here per SegNet-Basic specification)
        x = self.bn1(self.conv1(x)) # Apply conv1, then bn1.

        # Second Conv -> BN (No ReLU here per SegNet-Basic specification)
        x = self.bn2(self.conv2(x)) # Apply conv2, then bn2.

        return x

In [4]:
class SegNetBasic(nn.Module):
    """
    The full SegNet-Basic architecture for semantic segmentation.
    Comprises 4 encoder blocks and 4 decoder blocks.
    """
    def __init__(self, num_classes=11, in_channels=3):
        """
        Constructor for SegNetBasic.

        Args:
            num_classes (int): Number of semantic classes for segmentation (e.g., 11 for CamVid).
            in_channels (int): Number of input channels for the image (e.g., 3 for RGB).
        """
        super(SegNetBasic, self).__init__()

        # Define channel progression for Encoder and Decoder
        # Encoder channels: 3 (input) -> 64 -> 128 -> 256 -> 512
        # Decoder channels: 512 -> 256 -> 128 -> 64 -> num_classes (output)
        self.encoder_channels = [in_channels, 64, 128, 256, 512]
        self.decoder_channels = [512, 256, 128, 64] # Output of decoder blocks

        # --- Encoder Pathway ---
        # Each encoder block outputs the pooled tensor and its indices.
        # We need to store these indices to pass to the corresponding decoder block.
        self.enc1 = EncoderBlock(self.encoder_channels[0], self.encoder_channels[1]) # 3 -> 64 channels
        self.enc2 = EncoderBlock(self.encoder_channels[1], self.encoder_channels[2]) # 64 -> 128 channels
        self.enc3 = EncoderBlock(self.encoder_channels[2], self.encoder_channels[3]) # 128 -> 256 channels
        self.enc4 = EncoderBlock(self.encoder_channels[3], self.encoder_channels[4]) # 256 -> 512 channels

        # --- Decoder Pathway ---
        # Each decoder block takes the upsampled tensor from the previous decoder,
        # the indices from its corresponding encoder, and the original size before pooling.
        self.dec1 = DecoderBlock(self.decoder_channels[0], self.decoder_channels[1]) # 512 -> 256 channels
        self.dec2 = DecoderBlock(self.decoder_channels[1], self.decoder_channels[2]) # 256 -> 128 channels
        self.dec3 = DecoderBlock(self.decoder_channels[2], self.decoder_channels[3]) # 128 -> 64 channels
        # The last decoder block reduces to the number of classes.
        self.dec4 = DecoderBlock(self.decoder_channels[3], num_classes) # 64 -> num_classes channels

        # --- Final Classifier ---
        # No additional classification layer is typically needed if the last decoder block
        # outputs directly to num_classes. However, a 1x1 convolution can be used as a
        # pixel-wise classifier to refine the final feature map into logit scores.
        # In SegNet's original description, the output of the final decoder is fed to a softmax classifier.
        # Here, the last DecoderBlock already outputs num_classes.
        # We will apply log_softmax in the forward pass for numerical stability with NLLLoss.

    def forward(self, x):
        """
        Defines the forward pass for SegNetBasic.

        Args:
            x (torch.Tensor): Input image tensor (N, in_channels, H, W).

        Returns:
            torch.Tensor: Log-probabilities for each pixel belonging to a class (N, num_classes, H, W).
        """
        # Store original input size for the final unpooling step in the decoder
        # This is needed because MaxUnpool2d sometimes requires the exact output size
        # if the input dimensions are not perfectly divisible by stride.
        input_size = x.size() # Example: (N, 3, 360, 480)

        # --- Encoder Forward Pass ---
        # Each step stores the output of the block *before* pooling (for decoder's output_size)
        # and the indices from the pooling layer.
        enc1, indices1 = self.enc1(x) # x: (N, 3, H, W) -> enc1: (N, 64, H/2, W/2)
        size1 = x.size() # Store size *before* pooling for decoder unpooling later (N, 3, H, W)

        enc2, indices2 = self.enc2(enc1) # enc1: (N, 64, H/2, W/2) -> enc2: (N, 128, H/4, W/4)
        size2 = enc1.size() # (N, 64, H/2, W/2)

        enc3, indices3 = self.enc3(enc2) # enc2: (N, 128, H/4, W/4) -> enc3: (N, 256, H/8, W/8)
        size3 = enc2.size() # (N, 128, H/4, W/4)

        # The deepest encoder output, which will be the input to the deepest decoder.
        # This is the feature map with the lowest spatial resolution and highest semantic abstraction.
        enc4, indices4 = self.enc4(enc3) # enc3: (N, 256, H/8, W/8) -> enc4: (N, 512, H/16, W/16)
        size4 = enc3.size() # (N, 256, H/8, W/8)

        # --- Decoder Forward Pass ---
        # The deepest encoder output (enc4) is the input to the first (deepest) decoder block.
        # Each decoder block takes:
        # 1. The input tensor from the previous decoder block (or enc4 for the first one).
        # 2. The max-pooling indices from the *corresponding* encoder block.
        # 3. The 'output_size' which is the size of the feature map *before* pooling in the corresponding encoder.

        # dec1 takes enc4 (512 channels) and indices4. Output is 256 channels.
        dec1 = self.dec1(enc4, indices4, output_size=size4) # enc4: (N, 512, H/16, W/16) -> dec1: (N, 256, H/8, W/8)

        # dec2 takes dec1 (256 channels) and indices3. Output is 128 channels.
        dec2 = self.dec2(dec1, indices3, output_size=size3) # dec1: (N, 256, H/8, W/8) -> dec2: (N, 128, H/4, W/4)

        # dec3 takes dec2 (128 channels) and indices2. Output is 64 channels.
        dec3 = self.dec3(dec2, indices2, output_size=size2) # dec2: (N, 128, H/4, W/4) -> dec3: (N, 64, H/2, W/2)

        # dec4 takes dec3 (64 channels) and indices1. Output is num_classes channels.
        # Its output_size is the original input image size (before first encoder pooling).
        output = self.dec4(dec3, indices1, output_size=size1) # dec3: (N, 64, H/2, W/2) -> output: (N, num_classes, H, W)

        # For pixel-wise classification, we typically apply a Softmax layer
        # to get probabilities for each class at each pixel.
        # F.log_softmax is often used with nn.NLLLoss for numerical stability,
        # or F.softmax if using a custom CrossEntropyLoss that expects probabilities.
        # nn.CrossEntropyLoss in PyTorch already combines log_softmax and NLLLoss,
        # so for training, you'd usually pass the raw 'output' directly to CrossEntropyLoss.
        # For inference or visualizing probabilities, you might explicitly apply softmax.
        return F.log_softmax(output, dim=1) # Apply log_softmax across the channel dimension (dim=1)
                                            # to get log-probabilities for each class per pixel.

In [None]:
# Conceptual Dataset Class (you'd replace this with a real dataset like CamVidDataset)
class DummySegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples=100, img_height=360, img_width=480, num_classes=11):
        self.num_samples = num_samples
        self.img_height = img_height
        self.img_width = img_width
        self.num_classes = num_classes

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate random dummy image (RGB)
        image = torch.randn(3, self.img_height, self.img_width)
        # Generate random dummy segmentation mask (class labels per pixel)
        # Masks should be LongTensor and values from 0 to num_classes-1
        mask = torch.randint(0, self.num_classes, (self.img_height, self.img_width), dtype=torch.long)
        return image, mask

# Instantiate the dataset and DataLoader
# You would replace this with actual data loading logic
# For example, using torchvision.datasets for standard datasets, or a custom dataset for CamVid.
# E.g., for CamVid: image and mask transformations (resize, normalize, ToTensor) would be applied.
train_dataset = DummySegmentationDataset()
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)

# For a realistic example, consider an image preprocessing step:
# from torchvision import transforms
# preprocess = transforms.Compose([
#     transforms.Resize((360, 480)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])
# # For masks: only ToTensor and then convert to Long (no normalization)
# mask_preprocess = transforms.Compose([
#     transforms.Resize((360, 480)),
#     transforms.ToTensor(),
#     lambda x: x.squeeze().long() # Masks are typically 1-channel, ensure long type
# ])

In [None]:
# Instantiate the model
num_classes = 11 # For example, CamVid dataset
segnet_model = SegNetBasic(num_classes=num_classes, in_channels=3)

# Define the loss function
# nn.CrossEntropyLoss expects raw scores (logits) from the network,
# and integer labels for the ground truth.
# The network's final output is F.log_softmax(output, dim=1), so it gives log-probabilities.
# Therefore, we should use nn.NLLLoss (Negative Log Likelihood Loss) if we apply log_softmax manually.
# If we omit F.log_softmax in the model's forward and return 'output' directly, then we use nn.CrossEntropyLoss.
# Let's adjust the SegNetBasic forward to return raw logits for direct use with nn.CrossEntropyLoss:
# (Inside SegNetBasic.forward, change: `return F.log_softmax(output, dim=1)` to `return output`)

# For this example, let's keep F.log_softmax and use NLLLoss.
criterion = nn.NLLLoss() # This is suitable because our model outputs log-probabilities.
                         # If the model outputted raw logits, nn.CrossEntropyLoss would be used.
                         # nn.CrossEntropyLoss is more common in practice as it's an all-in-one.
                         # For a practical SegNet, you'd likely remove the F.log_softmax from forward
                         # and use nn.CrossEntropyLoss for the loss.

In [None]:
# Define the optimizer
learning_rate = 0.01 # A common starting learning rate
momentum = 0.9       # As used in the paper for SGD
# optimizer = torch.optim.SGD(segnet_model.parameters(), lr=learning_rate, momentum=momentum)
optimizer = torch.optim.Adam(segnet_model.parameters(), lr=learning_rate) # Adam is often easier to tune initially

In [None]:
num_epochs = 5 # For demonstration, typically hundreds of epochs for real training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use GPU if available
segnet_model.to(device) # Move model to the selected device

print(f"Training SegNet-Basic on {device}")

for epoch in range(num_epochs):
    segnet_model.train() # Set model to training mode.
                         # This enables Dropout and Batch Normalization's training behavior.

    running_loss = 0.0
    for batch_idx, (images, masks) in enumerate(train_loader):
        # 1. Move data to device
        images = images.to(device) # Transfer input images to GPU (if cuda)
        masks = masks.to(device)   # Transfer ground truth masks to GPU

        # 2. Zero the gradients
        # Before each new batch, gradients from the previous batch must be cleared.
        # Otherwise, gradients would accumulate across batches.
        optimizer.zero_grad() # Sets gradients of all optimized torch.Tensor to zero.

        # 3. Forward pass
        # model(images) calls the forward method of SegNetBasic.
        outputs = segnet_model(images) # Shape: (N, num_classes, H, W). Log-probabilities.

        # 4. Calculate loss
        # The loss function compares the model's predictions (outputs) with the true labels (masks).
        loss = criterion(outputs, masks) # NLLLoss expects log-probabilities and class indices.

        # 5. Backward pass
        # Computes gradients of the loss with respect to all model parameters that require gradients.
        # These gradients are then stored in the .grad attribute of each parameter tensor.
        loss.backward()

        # 6. Optimizer step
        # Updates the model's weights using the computed gradients and the optimizer's algorithm (e.g., Adam).
        optimizer.step()

        running_loss += loss.item() # .item() retrieves the Python number from a single-element tensor.

        # Optional: Print loss periodically
        if (batch_idx + 1) % 50 == 0: # Print every 50 batches
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], "
                  f"Loss: {running_loss / (batch_idx + 1):.4f}")

    print(f"Epoch {epoch+1} finished. Average Loss: {running_loss / len(train_loader):.4f}")

    # Optional: Evaluation on a validation set after each epoch (highly recommended for real training)
    # segnet_model.eval() # Set model to evaluation mode (disables Dropout, uses population stats for BatchNorm)
    # with torch.no_grad(): # Disable gradient calculations for evaluation (saves memory and speeds up)
    #     val_loss = 0.0
    #     for images, masks in val_loader:
    #         images = images.to(device)
    #         masks = masks.to(device)
    #         outputs = segnet_model(images)
    #         loss = criterion(outputs, masks)
    #         val_loss += loss.item()
    #     print(f"Validation Loss: {val_loss / len(val_loader):.4f}")

print("Training complete!")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Re-define our EncoderBlock and DecoderBlock for clarity within this example context
class EncoderBlockTrace(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlockTrace, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
    def forward(self, x):
        print(f"  Encoder Input shape: {x.shape}")
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        pre_pool_size = x.size() # Store size *before* pooling
        pooled_x, indices = self.pool(x)
        print(f"  Encoder Post-Pool shape: {pooled_x.shape}, Pre-Pool size for decoder: {pre_pool_size}")
        return pooled_x, indices, pre_pool_size

class DecoderBlockTrace(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlockTrace, self).__init__()
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=7, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=7, padding=3, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
    def forward(self, x, indices, output_size):
        print(f"  Decoder Input shape: {x.shape}, Target Unpool Size: {output_size}, Indices shape: {indices.shape}")
        x = self.unpool(x, indices, output_size=output_size)
        print(f"  Decoder Post-Unpool (Sparse) shape: {x.shape}")
        x = self.bn1(self.conv1(x))
        x = self.bn2(self.conv2(x))
        print(f"  Decoder Post-Conv (Dense) shape: {x.shape}")
        return x

class SegNetBasicTrace(nn.Module):
    def __init__(self, num_classes=2, in_channels=3):
        super(SegNetBasicTrace, self).__init__()
        self.encoder_channels = [in_channels, 64, 128, 256, 512]
        self.decoder_channels = [512, 256, 128, 64]

        self.enc1 = EncoderBlockTrace(self.encoder_channels[0], self.encoder_channels[1])
        self.enc2 = EncoderBlockTrace(self.encoder_channels[1], self.encoder_channels[2])
        self.enc3 = EncoderBlockTrace(self.encoder_channels[2], self.encoder_channels[3])
        self.enc4 = EncoderBlockTrace(self.encoder_channels[3], self.encoder_channels[4])

        self.dec1 = DecoderBlockTrace(self.decoder_channels[0], self.decoder_channels[1])
        self.dec2 = DecoderBlockTrace(self.decoder_channels[1], self.decoder_channels[2])
        self.dec3 = DecoderBlockTrace(self.decoder_channels[2], self.decoder_channels[3])
        self.dec4 = DecoderBlockTrace(self.decoder_channels[3], num_classes)

    def forward(self, x):
        print("\n--- Encoder Pathway ---")
        enc1, indices1, size1 = self.enc1(x)
        enc2, indices2, size2 = self.enc2(enc1)
        enc3, indices3, size3 = self.enc3(enc2)
        enc4, indices4, size4 = self.enc4(enc3)

        print("\n--- Decoder Pathway ---")
        dec1 = self.dec1(enc4, indices4, output_size=size4)
        dec2 = self.dec2(dec1, indices3, output_size=size3)
        dec3 = self.dec3(dec2, indices2, output_size=size2)
        output = self.dec4(dec3, indices1, output_size=size1)

        print(f"\n--- Final Output ---")
        print(f"Final output shape (logits): {output.shape}")
        return F.log_softmax(output, dim=1) # Applying log_softmax for the final output

# Instantiate and run a dummy forward pass
dummy_input = torch.randn(1, 3, 16, 16) # N=1, C=3, H=16, W=16
segnet_trace_model = SegNetBasicTrace(num_classes=2, in_channels=3)
segnet_trace_model.eval() # Set to eval mode for consistent BN behavior (though not trained)

print(f"Starting SegNet-Basic Trace with input shape: {dummy_input.shape}")
traced_output = segnet_trace_model(dummy_input)
print(f"Log-softmax output shape: {traced_output.shape}")

# Expected output shapes:
# Encoder:
# (1, 3, 16, 16) -> (1, 64, 8, 8)
# (1, 64, 8, 8) -> (1, 128, 4, 4)
# (1, 128, 4, 4) -> (1, 256, 2, 2)
# (1, 256, 2, 2) -> (1, 512, 1, 1) # Deepest encoder output

# Decoder (Input to decoder, indices from corresponding encoder, output_size for unpooling)
# (1, 512, 1, 1) + indices4 + size4 (1,256,2,2) -> (1, 256, 2, 2)
# (1, 256, 2, 2) + indices3 + size3 (1,128,4,4) -> (1, 128, 4, 4)
# (1, 128, 4, 4) + indices2 + size2 (1,64,8,8) -> (1, 64, 8, 8)
# (1, 64, 8, 8) + indices1 + size1 (1,3,16,16) -> (1, 2, 16, 16) # Final output