In [None]:
# ==============================================================================
# Step 1: Imports and PReLU Implementation
# ==============================================================================

# --- 1. Import Necessary Libraries ---

import torch
import torch.nn as nn

# --- 2. Implement the Custom PReLU Activation Function ---

class PReLU(nn.Module):
    """
    Implements the Parametric Rectified Linear Unit (PReLU) activation function.

    As defined in the paper "Delving Deep into Rectifiers," PReLU is a generalization
    of ReLU where the slope for negative inputs is a learnable parameter.

    Formula: f(y_i) = max(0, y_i) + a_i * min(0, y_i)
    where 'a_i' is a learnable coefficient.

    This implementation supports both:
    1. Channel-wise PReLU: A separate 'a_i' is learned for each channel.
    2. Channel-shared PReLU: A single 'a' is learned and shared across all channels.
    """

    def __init__(self, num_parameters: int = 1, init: float = 0.25):
        """
        Initializes the PReLU layer.

        Args:
            num_parameters (int): The number of 'a' parameters to learn.
                                  - For channel-shared, this is 1.
                                  - For channel-wise, this is the number of channels
                                    in the input feature map.
                                  Defaults to 1.
            init (float): The initial value for the 'a' parameter(s). The paper
                          suggests starting with 0.25. Defaults to 0.25.
        """
        super(PReLU, self).__init__()

        # --- Define the Learnable Parameter 'a' ---
        # In PyTorch, learnable parameters of a module must be wrapped in `nn.Parameter`. This tells PyTorch that this tensor should be considered a model parameter, which means
        # its gradients will be computed during backpropagation and it will be updated by the optimizer.
        # 1. `torch.empty(num_parameters)`: Allocates a 1D tensor of size `num_parameters` without initializing its values.
        # 2. `.fill_(init)`: Fills the allocated tensor with the initial value 'init'. The underscore suffix in PyTorch functions (e.g., `fill_`) indicates an "in-place" operation, 
        #      modifying the tensor directly.
        # 3. `nn.Parameter(...)`: Wraps the tensor, registering it as a learnable parameter. We name it 'weight' to be consistent with PyTorch's naming conventions for parameters
        #     in layers, although it represents our 'a' coefficient.
        self.weight = nn.Parameter(torch.empty(num_parameters).fill_(init))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Defines the forward pass of the PReLU activation.

        Args:
            x (torch.Tensor): The input tensor (pre-activations) from the previous layer.
                              Expected shape for a conv layer: (batch_size, channels, height, width).

        Returns:
            torch.Tensor: The output tensor after applying PReLU.
        """
        # The core logic of PReLU is to apply a different function based on the sign of the input.

        # f(x) = relu(x) + a * (-relu(-x)) which is equivalent to: f(x) = max(0, x) + a * min(0, x)
        # Let's break down the implementation:
        # 1. `torch.max(torch.tensor(0.0, device=x.device), x)`: This calculates the positive part,
        #    equivalent to ReLU(x). We create the zero tensor on the same device as `x` for compatibility.

        # 2. `torch.min(torch.tensor(0.0, device=x.device), x)`: This calculates the negative part.

        # 3. `self.weight * ...`: The negative part is multiplied by our learnable parameter 'a'.
        #    PyTorch's broadcasting rules will automatically handle the multiplication, whether
        #    'self.weight' is a single value (channel-shared) or a vector (channel-wise).
        #    For a (N, C, H, W) input tensor and a (C,) weight tensor, broadcasting will
        #    apply the correct a_i to each channel.

        # To handle both channel-wise and channel-shared cases gracefully with broadcasting,
        # we need to ensure the shape of `self.weight` is compatible with `x`.
        # `x` has shape (N, C, H, W). `self.weight` has shape (C,). We reshape it to (1, C, 1, 1).
        if self.weight.shape[0] > 1:
            weight_reshaped = self.weight.view(1, -1, 1, 1)
        else:
            weight_reshaped = self.weight

        return torch.max(torch.tensor(0.0, device=x.device), x) + weight_reshaped * torch.min(torch.tensor(0.0, device=x.device), x)

# --- 3. Example Usage and Verification ---

# Shape: (batch_size=1, channels=3, height=2, width=2)
dummy_input = torch.tensor([[[[-1.0, 2.0],
                              [0.5, -3.0]]]], dtype=torch.float32)
print("--- PReLU Test ---")
print("Original Input Tensor:\n", dummy_input)

# Test the channel-shared PReLU
print("\n--- Channel-Shared PReLU (a=0.25) ---")
prelu_shared = PReLU(num_parameters=1, init=0.25)
output_shared = prelu_shared(dummy_input)
print("Output Tensor:\n", output_shared)
# Expected output for negative values: -1.0*0.25 = -0.25, -3.0*0.25 = -0.75

# Test the channel-wise PReLU
# Let's pretend our input has 1 channel and we want a specific 'a' for it.
print("\n--- Channel-Wise PReLU (a=0.1 for the single channel) ---")
prelu_wise = PReLU(num_parameters=1, init=0.1) # Here num_parameters matches channel count
output_wise = prelu_wise(dummy_input)
print("Output Tensor:\n", output_wise)
# Expected output for negative values: -1.0*0.1 = -0.1, -3.0*0.1 = -0.3

# Check that the parameter is indeed learnable
print("\nIs 'weight' a learnable parameter in prelu_shared?", prelu_shared.weight.requires_grad)

--- PReLU Test ---
Original Input Tensor:
 tensor([[[[-1.0000,  2.0000],
          [ 0.5000, -3.0000]]]])

--- Channel-Shared PReLU (a=0.25) ---
Output Tensor:
 tensor([[[[-0.2500,  2.0000],
          [ 0.5000, -0.7500]]]], grad_fn=<AddBackward0>)

--- Channel-Wise PReLU (a=0.1 for the single channel) ---
Output Tensor:
 tensor([[[[-0.1000,  2.0000],
          [ 0.5000, -0.3000]]]], grad_fn=<AddBackward0>)

Is 'weight' a learnable parameter in prelu_shared? True


In [None]:
# ==============================================================================
# Step 2: Spatial Pyramid Pooling (SPP) Implementation
# ==============================================================================

import torch
import torch.nn as nn
import math # We need the math library for ceiling function

class SpatialPyramidPooling(nn.Module):
    """
    Implements the Spatial Pyramid Pooling (SPP) layer.

    As described in "Spatial Pyramid Pooling in Deep Convolutional Networks for
    Visual Recognition" and used in this paper's architecture. SPP allows the
    network to handle variable-sized input images by producing a fixed-length
    output vector.

    This implementation replicates the 4-level pyramid from the paper:
    - 7x7 bins
    - 3x3 bins
    - 2x2 bins
    - 1x1 bin (global pooling)
    """

    def __init__(self, pyramid_levels: list):
        """
        Initializes the SPP layer.

        Args:
            pyramid_levels (list): A list of integers defining the number of bins
                                   for each side of the square pyramid levels.
                                   Example: [7, 3, 2, 1] for 7x7, 3x3, 2x2, 1x1 bins.
        """
        super(SpatialPyramidPooling, self).__init__()
        self.pyramid_levels = pyramid_levels

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Defines the forward pass of the SPP layer.

        Args:
            x (torch.Tensor): The input feature map from the last convolutional layer.
                              Expected shape: (batch_size, channels, height, width).

        Returns:
            torch.Tensor: The fixed-length feature vector after pooling and concatenation.
                          Shape: (batch_size, channels * total_bins)
        """
        # Get the input dimensions.
        # N = batch_size, C = channels, H = height, W = width
        N, C, H, W = x.size()

        # --- Iterate through each level of the pyramid ---
        # We will store the output of each pyramid level in this list.
        level_outputs = []

        for level_size in self.pyramid_levels:
            # For each pyramid level (e.g., 7x7), we calculate the required
            # pooling kernel size and stride dynamically based on the input size (H, W).
            # This is the core idea that makes SPP input-size-agnostic.

            # 1. Calculate Kernel Size:
            # The kernel should be large enough to cover at least one "bin" of the input.
            # We use the ceiling function to ensure the kernel is not too small.
            kernel_h = math.ceil(H / level_size)
            kernel_w = math.ceil(W / level_size)

            # 2. Calculate Stride:
            # The stride determines how the pooling window moves. To create `level_size`
            # bins, the window should slide by approximately H / level_size.
            # We use the floor function to ensure we get exactly `level_size` outputs.
            stride_h = math.floor(H / level_size)
            stride_w = math.floor(W / level_size)
            
            # --- Handle potential off-by-one errors due to discrete grid sizes ---
            # The padding is calculated to ensure the pooling operation covers the entire
            # feature map. The formula for output size is `(Input + 2*Pad - Kernel)/Stride + 1`.
            # We want Output = level_size. So, `level_size = (Input + 2*Pad - Kernel)/Stride + 1`.
            # Solving for Pad: `Pad = ((level_size - 1) * Stride + Kernel - Input) / 2`.
            # We need to ensure the padding is non-negative.
            pad_h = (kernel_h * level_size - H + stride_h -1) // 2
            pad_w = (kernel_w * level_size - W + stride_w -1) // 2

            # A more robust and modern way is to use AdaptiveAvgPool2d, which handles
            # these calculations internally. Let's implement it both ways for clarity.

            # --- Method 1: Manual Calculation (for understanding) ---
            # pooler = nn.MaxPool2d(
            #     kernel_size=(kernel_h, kernel_w),
            #     stride=(stride_h, stride_w),
            #     padding=(pad_h, pad_w)
            # )
            # pooled_level = pooler(x)

            # --- Method 2: Using Adaptive Pooling (the standard, easier way) ---
            # `nn.AdaptiveAvgPool2d` or `nn.AdaptiveMaxPool2d` simplifies this greatly.
            # You just specify the desired output size (e.g., (level_size, level_size)),
            # and PyTorch automatically computes the necessary kernel and stride.
            # We use average pooling here, but max pooling is also common.
            adaptive_pooler = nn.AdaptiveAvgPool2d((level_size, level_size))
            pooled_level = adaptive_pooler(x)

            # After pooling, the shape is (N, C, level_size, level_size).
            # We need to flatten this into a vector for each item in the batch.
            # `pooled_level.view(N, -1)` flattens all dimensions except the batch dimension.
            # The resulting shape will be (N, C * level_size * level_size).
            flattened_level = pooled_level.view(N, -1)
            level_outputs.append(flattened_level)

        # --- Concatenate all level outputs ---
        # Now we concatenate the flattened vectors from all pyramid levels along the
        # feature dimension (dim=1).
        # For example, if we have outputs of shape (N, 256*49), (N, 256*9), etc.,
        # we concatenate them to get one long vector per batch item.
        spp_output = torch.cat(level_outputs, dim=1)

        return spp_output

# --- 3. Example Usage and Verification ---

# Define the pyramid levels as specified in the paper
pyramid_spec = [7, 3, 2, 1]

# Instantiate our SPP layer
spp_layer = SpatialPyramidPooling(pyramid_levels=pyramid_spec)

# Let's test it with two different-sized input feature maps to see if it
# produces a fixed-length output.
# Let's assume the number of channels from the last conv layer is 256.

# Test Case 1: Input feature map of size 13x13
input_map1 = torch.randn(1, 256, 13, 13) # (N, C, H, W)
output1 = spp_layer(input_map1)

# Test Case 2: Input feature map of size 10x15 (non-square)
input_map2 = torch.randn(1, 256, 10, 15)
output2 = spp_layer(input_map2)

# Calculate the expected output size
num_channels = 256
total_bins = sum([level*level for level in pyramid_spec]) # 7*7 + 3*3 + 2*2 + 1*1 = 49 + 9 + 4 + 1 = 63
expected_output_features = num_channels * total_bins

print("--- SPP Layer Test ---")
print(f"Pyramid Levels: {pyramid_spec}")
print(f"Total Bins per Channel: {total_bins}")
print(f"Expected Output Feature Dimension: {expected_output_features}")
print("-" * 20)
print(f"Input 1 Shape: {input_map1.shape}")
print(f"Output 1 Shape: {output1.shape}")
print("-" * 20)
print(f"Input 2 Shape: {input_map2.shape}")
print(f"Output 2 Shape: {output2.shape}")
print("-" * 20)
print(f"Verification Successful: Both outputs have the same shape ({output1.shape[1]} features).")

--- SPP Layer Test ---
Pyramid Levels: [7, 3, 2, 1]
Total Bins per Channel: 63
Expected Output Feature Dimension: 16128
--------------------
Input 1 Shape: torch.Size([1, 256, 13, 13])
Output 1 Shape: torch.Size([1, 16128])
--------------------
Input 2 Shape: torch.Size([1, 256, 10, 15])
Output 2 Shape: torch.Size([1, 16128])
--------------------
Verification Successful: Both outputs have the same shape (16128 features).


In [10]:
# ==============================================================================
# Step 3: Assembling the Full PReLU-Net Architecture (Model B)
# ==============================================================================

# We continue using the previously imported libraries and defined custom modules.
import torch
import torch.nn as nn
import math

# Let's re-paste our custom modules here for a self-contained script.
class PReLU(nn.Module):
    def __init__(self, num_parameters: int = 1, init: float = 0.25):
        super(PReLU, self).__init__()
        self.weight = nn.Parameter(torch.empty(num_parameters).fill_(init))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.weight.shape[0] > 1:
            weight_reshaped = self.weight.view(1, -1, 1, 1)
        else:
            weight_reshaped = self.weight
        return torch.max(torch.tensor(0.0, device=x.device), x) + \
               weight_reshaped * torch.min(torch.tensor(0.0, device=x.device), x)

class SpatialPyramidPooling(nn.Module):
    def __init__(self, pyramid_levels: list):
        super(SpatialPyramidPooling, self).__init__()
        self.pyramid_levels = pyramid_levels

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        N, C, H, W = x.size()
        level_outputs = []
        for level_size in self.pyramid_levels:
            adaptive_pooler = nn.AdaptiveAvgPool2d((level_size, level_size))
            pooled_level = adaptive_pooler(x)
            flattened_level = pooled_level.view(N, -1)
            level_outputs.append(flattened_level)
        return torch.cat(level_outputs, dim=1)


# --- 1. Define the Full PReLU-Net (Model B) ---

class PReLUNet(nn.Module):
    """
    Implementation of the PReLU-Net architecture (Model B, 22 layers)
    from the paper "Delving Deep into Rectifiers".
    """
    def __init__(self, num_classes: int = 1000, prelu_shared: bool = True):
        """
        Initializes the PReLU-Net.

        Args:
            num_classes (int): The number of output classes (e.g., 1000 for ImageNet).
            prelu_shared (bool): If True, use channel-shared PReLU. If False, use
                                 channel-wise PReLU.
        """
        super(PReLUNet, self).__init__()
        
        # --- Define the Convolutional Feature Extractor ---
        # We group layers into sequential blocks based on the architecture in Table 3.
        # `nn.Sequential` is a container that passes data through modules in order.

        # Stage 1: Input -> 112x112
        # Paper: 7x7 conv, 96 filters, stride 2
        self.stage1 = nn.Sequential(
            # `in_channels=3` for standard RGB images.
            # `out_channels=96` is the number of filters.
            # `kernel_size=7`, `stride=2`, `padding=3` to get output H/2, W/2.
            nn.Conv2d(3, 96, kernel_size=7, stride=2, padding=3),
            PReLU(num_parameters=1 if prelu_shared else 96),
            # `kernel_size=3`, `stride=2` for max pooling.
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # Stage 2: 112x112 -> 56x56
        # Paper: Three 3x3 conv layers with 256 filters
        self.stage2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 256),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # Stage 3: 56x56 -> 28x28
        # Paper: Six 3x3 conv layers with 384 filters
        self.stage3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 384),
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 384),
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 384),
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 384),
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 384),
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 384),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # Stage 4: 28x28 -> 14x14
        # Paper: Six 3x3 conv layers with 768 filters
        self.stage4 = nn.Sequential(
            nn.Conv2d(384, 768, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 768),
            nn.Conv2d(768, 768, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 768),
            nn.Conv2d(768, 768, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 768),
            nn.Conv2d(768, 768, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 768),
            nn.Conv2d(768, 768, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 768),
            nn.Conv2d(768, 768, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 768),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # Stage 5: 14x14 -> 7x7 (Note: No pooling at the end of this stage in paper's table)
        # Paper: Three 3x3 conv layers with 896 filters
        self.stage5 = nn.Sequential(
            nn.Conv2d(768, 896, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 896),
            nn.Conv2d(896, 896, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 896),
            nn.Conv2d(896, 896, kernel_size=3, stride=1, padding=1),
            PReLU(num_parameters=1 if prelu_shared else 896),
        )

        # --- Define the Spatial Pyramid Pooling Layer ---
        # The paper uses a 4-level pyramid.
        self.spp = SpatialPyramidPooling([7, 3, 2, 1])

        # --- Define the Classifier (Fully Connected Layers) ---
        # First, we need to calculate the input size to the first FC layer.
        # This is determined by the output of the SPP layer.
        # Number of channels from last conv layer (stage5) is 896.
        # Total bins from SPP = 7*7 + 3*3 + 2*2 + 1*1 = 63.
        fc_input_features = 896 * 63

        self.classifier = nn.Sequential(
            # First FC layer. Paper specifies 4096 output units.
            nn.Linear(fc_input_features, 4096),
            PReLU(num_parameters=1), # FC layers use shared PReLU
            # Dropout is a regularization technique to prevent overfitting.
            # It randomly zeros some of the elements of the input tensor.
            # Paper specifies 50% dropout.
            nn.Dropout(p=0.5),

            # Second FC layer.
            nn.Linear(4096, 4096),
            PReLU(num_parameters=1),
            nn.Dropout(p=0.5),

            # Final output layer.
            nn.Linear(4096, num_classes)
        )

        # --- Initialize weights ---
        # The paper emphasizes the importance of a custom initialization method.
        # We will apply it to all conv and linear layers.
        self._initialize_weights()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """ Defines the full forward pass of the network. """
        # Pass input through the convolutional stages
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        
        # Apply Spatial Pyramid Pooling
        x = self.spp(x)

        # Pass the fixed-length vector through the classifier
        x = self.classifier(x)

        return x

    def _initialize_weights(self):
        """
        Applies the He initialization method as described in the paper.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # He initialization for Conv2d layers
                # `nn.init.kaiming_normal_` implements the He initialization.
                # `mode='fan_in'` ensures variance is scaled by the number of input channels.
                # `nonlinearity='leaky_relu'` with a=0.25 (for PReLU init) or 'relu' is appropriate.
                # The paper's PReLU init is a=0.25, so we can use leaky_relu with this slope.
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu', a=0.25)
                if m.bias is not None:
                    # Initialize biases to zero.
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                # He initialization for Linear layers
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu', a=0.25)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


# --- 2. Example Usage and Verification ---

# Instantiate the full network (using channel-shared PReLU for efficiency)
print("--- PReLU-Net (Model B) Instantiation ---")
model = PReLUNet(num_classes=1000, prelu_shared=True)
# print(model) # Uncomment to see the full architecture printed out

# Create a dummy input tensor that mimics a batch of ImageNet images.
# Shape: (batch_size=2, channels=3, height=224, width=224)
dummy_image_batch = torch.randn(2, 3, 224, 224)

# Perform a forward pass
print("\n--- Performing a forward pass ---")
print(f"Input batch shape: {dummy_image_batch.shape}")
output = model(dummy_image_batch)
print(f"Output logits shape: {output.shape}")

# Verify the output shape is correct (batch_size, num_classes)
assert output.shape == (2, 1000)
print("\nVerification Successful: The output shape is correct.")

--- PReLU-Net (Model B) Instantiation ---

--- Performing a forward pass ---
Input batch shape: torch.Size([2, 3, 224, 224])
Output logits shape: torch.Size([2, 1000])

Verification Successful: The output shape is correct.


In [12]:
# ==============================================================================
# Step 4: Setting up Training Components

from torch import optim

print("--- Setting up Training Components ---")

# Instantiate the model. We'll use a smaller number of classes for this example.
# Let's assume we're fine-tuning on a 10-class problem like CIFAR-10.
model = PReLUNet(num_classes=10, prelu_shared=True)
print("Model Instantiated.")

# Define the Loss Function.
# For multi-class classification, the standard and most effective loss function
# is Cross-Entropy Loss. `nn.CrossEntropyLoss` in PyTorch conveniently combines
# `nn.LogSoftmax` and `nn.NLLLoss` (Negative Log Likelihood Loss) in one class.
# It expects raw, unnormalized logits from the model and integer class labels.
criterion = nn.CrossEntropyLoss()
print("Loss Function: CrossEntropyLoss")

# Define the Optimizer.
# The paper specifies using SGD with momentum.
# `optim.SGD` is PyTorch's implementation of Stochastic Gradient Descent.
#
# - `model.parameters()`: This crucial method, inherited from `nn.Module`,
#   automatically gathers all learnable parameters (tensors wrapped in `nn.Parameter`,
#   including our PReLU 'a' weights) and passes them to the optimizer.
# - `lr=0.01`: The learning rate. This is a critical hyperparameter that controls
#   the step size of parameter updates. The paper uses 1e-2 initially.
# - `momentum=0.9`: The momentum factor. This helps accelerate SGD in the
#   relevant direction and dampens oscillations. The paper uses 0.9.
# - `weight_decay=0.0005`: This implements L2 regularization. It adds a penalty
#   to the loss proportional to the squared magnitude of the weights, helping
#   to prevent overfitting. The paper uses 5e-4.
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
print("Optimizer: SGD with momentum and weight decay")


# --- 2. Implement the Training Loop (for one epoch) ---

print("\n--- Simulating a Training Loop for One Epoch ---")

# Create a dummy dataset for demonstration.
# In a real scenario, this would come from a `torch.utils.data.DataLoader`.
# Let's create a batch of 4 images, 3 channels, 224x224 pixels.
dummy_inputs = torch.randn(4, 3, 224, 224)
# Corresponding dummy labels (integer class indices from 0 to 9).
dummy_labels = torch.randint(0, 10, (4,))

# --- A single training step ---

# Set the model to training mode.
# This is important for layers like Dropout, which behave differently
# during training (actively dropping out neurons) and evaluation (using all neurons).
model.train()
print("1. Model set to train mode: `model.train()`")

# Zero the gradients.
# Before the backward pass, we must explicitly zero out the gradients from the
# previous iteration. If we don't, gradients would accumulate, leading to
# incorrect updates.
optimizer.zero_grad()
print("2. Gradients zeroed: `optimizer.zero_grad()`")

# Step A: Forward Pass
# We pass the input data through our model to get the output predictions (logits).
# PyTorch builds a computation graph behind the scenes to track all operations.
outputs = model(dummy_inputs)
print(f"3. Forward Pass: Input shape {dummy_inputs.shape} -> Output shape {outputs.shape}")

# Step B: Loss Computation
# We compute the loss by comparing the model's outputs with the ground-truth labels.
loss = criterion(outputs, dummy_labels)
print(f"4. Loss Computed: {loss.item():.4f}") # .item() gets the scalar value of the loss

# Step C: Backward Pass
# This is where the magic happens. `loss.backward()` computes the gradient of the loss
# with respect to every single parameter in the model that has `requires_grad=True`.
# The gradients are stored in the `.grad` attribute of each parameter tensor.
loss.backward()
print("5. Backward Pass: Gradients computed with `loss.backward()`")

# Step D: Parameter Update
# The optimizer uses the computed gradients to update the model's parameters.
# It applies the update rule (e.g., SGD with momentum) to take a step.
optimizer.step()
print("6. Optimizer Step: Model parameters updated with `optimizer.step()`")

print("\n--- Training Step Complete ---")
print("The model's weights (and PReLU 'a' params) have been updated once.")

--- Setting up Training Components ---
Model Instantiated.
Loss Function: CrossEntropyLoss
Optimizer: SGD with momentum and weight decay

--- Simulating a Training Loop for One Epoch ---
1. Model set to train mode: `model.train()`
2. Gradients zeroed: `optimizer.zero_grad()`
3. Forward Pass: Input shape torch.Size([4, 3, 224, 224]) -> Output shape torch.Size([4, 10])
4. Loss Computed: 13.5604
5. Backward Pass: Gradients computed with `loss.backward()`
6. Optimizer Step: Model parameters updated with `optimizer.step()`

--- Training Step Complete ---
The model's weights (and PReLU 'a' params) have been updated once.
