In [1]:
from torchvision.models import resnet50, resnet34

In [20]:
import torch
from torchvision.models import resnet50, ResNet34_Weights
from torchvision.models.feature_extraction import create_feature_extractor

# 1) Build backbone
m = resnet34(weights=ResNet34_Weights.DEFAULT).eval()


In [38]:
x = torch.rand( 1,64,64)
x = x.repeat(1, 3, 1, 1) if len(x.shape) == 4 else x.repeat(3, 1, 1)
print(x.shape)
x = m.conv1(x)    # 64, 32, 32
print(x.shape)
x = m.bn1(x)      # 64, 32, 32
print(x.shape)
x = m.relu(x)     # 64, 32, 32
print(x.shape)
x = m.maxpool(x)  # 64, 16, 16
print(x.shape)
x = m.layer1(x)   # 64, 16, 16
print(x.shape)
x = m.layer2(x)   # 128, 8, 8
print(x.shape)
x = m.layer3(x)   # 256, 4, 4
print(x.shape)
x = m.layer4(x)   # 512, 2, 2
print(x.shape)

torch.Size([3, 64, 64])
torch.Size([64, 32, 32])


ValueError: expected 4D input (got 3D input)

In [56]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision.models import resnet50, resnet34


class UpSampleBlock(nn.Module):
    """
    Upsampling block for the U-Net generator using interpolation followed by convolution.

    Consists of Upsample -> Conv2d -> BatchNorm -> Activation -> Conv2d -> BatchNorm -> Activation

    Args:
        in_channel (int): Number of input channels.
        out_channel (int): Number of output channels.
        use_batchnorm (bool): Whether to use batch normalization.
        interpolation_mode (str): Interpolation mode for upsampling.
                                 Options: 'nearest', 'bilinear', 'bicubic'.
        activation (callable): Activation function for upsampling.
    """
    def __init__(self, in_channel, skip_in_channel, out_channel, use_batchnorm=True, interpolation_mode='bilinear', activation=nn.LeakyReLU()):
        super(UpSampleBlock, self).__init__()

        self.interpolation_mode = interpolation_mode

        # Convolution after upsampling
        self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channel) if use_batchnorm else nn.Identity()
        self.activation1 = activation

        # Second convolution (after concatenation with skip connection)
        self.conv2 = nn.Conv2d(in_channel + skip_in_channel, out_channel, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel) if use_batchnorm else nn.Identity()
        self.activation2 = activation

    def forward(self, x, skip_features):
        x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=False if self.interpolation_mode != 'nearest' else None)
        # apply convolution after upsampling
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation1(x)

        # Concatenate with skip connection
        x = torch.cat([x, skip_features], dim=1)

        # Apply second convolution
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation2(x)

        return x


class UNetGenerator(nn.Module):
    """
    U-Net Generator for image-to-image translation tasks.

    Features:
    - Flexible input/output channels for various tasks (e.g., 1->3 for colorization)
    - Skip connections to preserve spatial information
    - Interpolation-based upsampling to avoid checkerboard artifacts
    - Configurable depth to handle different image resolutions

    Args:
        in_channels (int): Number of input image channels.
        out_channels (int): Number of output image channels.
        init_features (int): Number of features in the first layer, doubles with each downsampling.
        depth (int): Depth of the U-Net, number of downsampling operations.
        use_batchnorm (bool): Whether to use batch normalization.
        use_maxpool (bool): Whether to use max pooling for downsampling.
        interpolation_mode (str): Mode for upsampling interpolation ('nearest', 'bilinear', 'bicubic').
        activation (callable): Activation function for downsampling and upsampling.
        final_activation (callable): Activation function for the final layer.
    """
    def __init__(self, in_channels=1, out_channels=3, init_features=64, use_batchnorm=True, use_maxpool=True, interpolation_mode='bilinear', activation=nn.LeakyReLU(), final_activation=nn.Tanh()):
        super(UNetGenerator, self).__init__()

        self.interpolation_mode = interpolation_mode
        # Encoder (downsampling) path
        self.encoder = resnet34(weights=ResNet34_Weights.DEFAULT)
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512) if use_batchnorm else nn.Identity(),
            activation,
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512) if use_batchnorm else nn.Identity(),
            activation,
        )
        self.d_layer1 = UpSampleBlock(512, 256, 256, use_batchnorm, interpolation_mode, activation)
        self.d_layer2 = UpSampleBlock(256, 128, 128, use_batchnorm, interpolation_mode, activation)
        self.d_layer3 = UpSampleBlock(128, 64, 64, use_batchnorm, interpolation_mode, activation)
        self.d_layer4 = UpSampleBlock(64, 64, 64, use_batchnorm, interpolation_mode, activation)

        self.d_conv5 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(64) if use_batchnorm else nn.Identity()
        self.activation5 = activation
        # Final layer
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.final_activation = final_activation

    def forward(self, x):
        # Store skip connections
        skip_connections = []
        x = x.repeat(1, 3, 1, 1)
        # Encoder path
        x = self.encoder.conv1(x)    # 64, 32, 32
        x = self.encoder.bn1(x)      # 64, 32, 32
        x = self.encoder.relu(x)     # 64, 32, 32
        skip_connections.append(x)
        x = self.encoder.maxpool(x)  # 64, 16, 16
        x = self.encoder.layer1(x)   # 64, 16, 16
        skip_connections.append(x)
        x = self.encoder.layer2(x)   # 128, 8, 8
        skip_connections.append(x)
        x = self.encoder.layer3(x)   # 256, 4, 4
        skip_connections.append(x)
        x = self.encoder.layer4(x)   # 512, 2, 2

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path
        x = self.d_layer1(x, skip_connections[3])
        x = self.d_layer2(x, skip_connections[2])
        x = self.d_layer3(x, skip_connections[1])
        x = self.d_layer4(x, skip_connections[0])
        
        # Last upsampling
        x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=False if self.interpolation_mode != 'nearest' else None)
        x = self.d_conv5(x)
        x = self.bn5(x)
        x = self.activation5(x)

        # Final convolution and activation
        x = self.final_conv(x)
        x = self.final_activation(x)

        return x


def print_unet_dimensions(input_size=(64,64), in_channels=1, out_channels=3):
    """
    Helper function to print the dimensions of each layer in the U-Net.

    Args:
        input_size (tuple): Input image dimensions (height, width).
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
    """
    print(f"U-Net Dimensions for {input_size[0]}x{input_size[1]} image:")

    # Calculate appropriate depth
    min_dim = min(input_size)
    max_depth = 0
    while min_dim >= 8:  # Ensure smallest feature map is at least 8x8
        min_dim = min_dim // 2
        max_depth += 1

    print(f"Recommended max depth: {max_depth}")

    # Print layer dimensions
    current_h, current_w = input_size
    features = 64  # starting features

    print(f"Input: {in_channels}x{current_h}x{current_w}")
    print(f"Initial: {features}x{current_h}x{current_w}")

    # Encoder
    for i in range(max_depth):
        current_h, current_w = current_h // 2, current_w // 2
        features *= 2
        print(f"Encoder {i + 1}: {features}x{current_h}x{current_w}")

    # Decoder
    for i in range(max_depth):
        features //= 2
        current_h, current_w = current_h * 2, current_w * 2
        print(f"Decoder {i + 1}: {features}x{current_h}x{current_w}")

    print(f"Output: {out_channels}x{current_h}x{current_w}")


if __name__ == "__main__":
    # Example usage
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # For colorization task
    model = UNetGenerator(in_channels=1, out_channels=3, final_activation=nn.Tanh(), activation=nn.LeakyReLU())

    # Print model structure
    print(model)
    print(f"Number of trainable parameters {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # Check dimensions
    print_unet_dimensions()

    # Test with sample input
    x = torch.randn(1, 1, 64, 64)  # Batch size 1, 1 channel (grayscale), 64x64 image
    y = model(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {y.shape}")

    # Test backward pass
    print("\nTesting backward pass...")

    # Move to GPU if available
    if torch.cuda.is_available():
        model = model.to(device)
        x = x.to(device)

    # Create a fake target (simulating a real colorized image)
    target = torch.randn(1, 3, 64, 64)
    if torch.cuda.is_available():
        target = target.to(device)

    # Define a loss function
    criterion = nn.MSELoss()

    # Forward pass
    output = model(x)

    # Calculate loss
    loss = criterion(output, target)
    print(f"Loss value: {loss.item()}")

    # Backward pass
    loss.backward()

    # Check that gradients have been calculated
    has_gradients = all(param.grad is not None for param in model.parameters() if param.requires_grad)
    print(f"All parameters have gradients: {has_gradients}")

    # Print gradient statistics for verification
    total_norm = 0
    param_count = 0
    for name, param in model.named_parameters():
        if param.grad is not None:
            param_norm = param.grad.data.norm(2)
            total_norm += param_norm.item() ** 2

            # Print first few parameter gradients as examples
            if param_count < 3:
                print(f"Gradient norm for {name}: {param_norm}")
            param_count += 1

    total_norm = total_norm ** 0.5
    print(f"Total gradient norm: {total_norm}")

    # Reset gradients
    model.zero_grad()
    print("Backward pass completed successfully")


UNetGenerator(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tra

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [6]:
m

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  