In [2]:
from typeguard import CollectionCheckStrategy
import typeguard
from torch import Tensor
import torch
import torch.optim as optim
import torch.fft
import torch.nn as nn
from torchvision import datasets, transforms
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

typeguard.config.collection_check_strategy = CollectionCheckStrategy.ALL_ITEMS

In [None]:
# --- SIREN Layer using a 1x1 convolution ---
class SineLayerConv2D(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        omega_0: float = np.pi * 10,
        is_first: bool = False,
        bias: bool = True,
    ):
        """
        Args:
            in_channels: Number of input channels (e.g., 2 for x,y coordinates)
            out_channels: Number of output channels
            omega_0: Frequency scaling factor (use a higher value for the first layer)
            is_first: Whether this is the first layer in the network.
        """
        super().__init__()
        # For the first layer, we use the provided omega_0; for later layers, you don't have to add an omega so we set it equal to 1
        self.omega_0 = omega_0 if is_first else 1

        # 1x1 convolution applies the same linear transform at each pixel location
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)

        # Initialize weights according to SIREN paper recommendations
        with torch.no_grad():
            if is_first:
                self.conv.weight.uniform_(-1 / in_channels, 1 / in_channels)
            else:
                # If you want to allow different omega_0 for hidden layers, divide the bound by omega_0 by changing the self.omega_0 parameter initialization.
                bound = np.sqrt(6 / in_channels) / self.omega_0
                self.conv.weight.uniform_(-bound, bound)
            if self.conv.bias is not None:
                self.conv.bias.fill_(0)

    def forward(self, x):
        return torch.sin(self.omega_0 * self.conv(x))


In [None]:
# --- Complete SIREN Network ---
class SirenConvNet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        num_hidden_layers: int,
        first_omega_0: float = np.pi * 10,
    ):
        """
        A SIREN network built with pointwise (1x1) convolutions.

        Args:
            in_channels: Number of input channels (for coordinates, usually 2 for a (x,y) grid)
            hidden_channels: Number of channels in hidden layers
            out_channels: Number of output channels (e.g., 3 for an RGB image)
            num_hidden_layers: Number of hidden Sine layers
            first_omega_0: The omega_0 for the first layer.
        """
        super().__init__()
        layers: list[nn.Module] = []
        # First layer with a higher omega_0
        layers.append(
            SineLayerConv2D(
                in_channels, hidden_channels, omega_0=first_omega_0, is_first=True
            )
        )

        # Hidden layers
        for _ in range(num_hidden_layers):
            layers.append(
                SineLayerConv2D(
                    hidden_channels, hidden_channels, omega_0=1, is_first=False
                )
            )

        # Final layer: use a plain 1x1 conv without sine activation.
        final_layer = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=True)
        layers.append(final_layer)
        self.net = nn.Sequential(*layers)

        # Since the pixel range is from -1 to 1, data_range should be 2.0, or the tuple (-1,1).
        self.ssim_metric = StructuralSimilarityIndexMeasure(data_range=(-1.0, 1.0))

    def loss(self, pred: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """
        Computes the combined loss as the sum of MSE loss and (1 - SSIM).

        Args:
            pred (torch.Tensor): The predicted image tensor.
            x (torch.Tensor): The ground truth image tensor.

        Returns:
            torch.Tensor: A scalar tensor representing the combined loss.
        """
        # Compute MSE loss
        mse_loss = F.mse_loss(pred, x)
        # Compute SSIM value, which lies between 0 and 1; lower SSIM is worse, so we use (1 - SSIM) as the loss component. We need to ensure the range of the prediction is on [-1,1] so we clamp the tensor
        pred_clamped = torch.clamp(pred, -1, 1)
        ssim_loss = 1 - self.ssim_metric(pred_clamped, x)

        return mse_loss + ssim_loss

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


In [None]:
class DNArchConvNet(nn.Module):
    """
    A DNArch convolutional network built by stacking DNArchConvLayer layers.

    This network is defined with a maximum number of layers, maximum number of channels,
    and maximum kernel size. Both the convolution weights and the effective architecture
    (i.e. which parts of the kernel are used) are learned via backpropagation.
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        max_num_hidden_channels: int,
        max_num_hidden_layers: int,
        max_kernel_size: int,
    ):
        """
        Args:
            in_channels: Number of input channels.
            max_channels: Maximum number of channels in the hidden layers.
            out_channels: Number of output channels.
            num_layers: Maximum number of layers in the network.
            max_kernel_size: Maximum kernel size for each DNArchConvLayer.
        """
        super().__init__()

        self.L_max : int = max_num_hidden_channels + 2

        # First layer: from input channels to max_channels.
        layers.append(DNArchConvLayer(in_channels, max_channels, max_kernel_size))

        # Hidden layers: from max_channels to max_channels.
        for _ in range(max_num_hidden_layers):
            layers.append(DNArchConvLayer(max_channels, max_channels, max_kernel_size))

        # Final layer: from max_channels to desired output channels.
        layers.append(DNArchConvLayer(max_channels, out_channels, max_kernel_size))

        self.net = nn.Sequential(*layers)

        # Here we define a simple MSE loss, similar to the SIREN model.
        self.loss_fn = nn.MSELoss()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

    def loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Computes the MSE loss between the prediction and the target.

        Args:
            pred: The predicted output tensor.
            target: The ground truth tensor.

        Returns:
            The computed loss value.
        """
        return self.loss_fn(pred, target)


In [None]:
# --- DNArch Layer Definition ---
class DNArchLayer(nn.Module):
    """
    A single DNArch layer that uses two shared inside networks to generate its
    pointwise and depthwise convolution kernels. Each layer has four learnable masks:
      - A sigmoid mask for the pointwise part.
      - A sigmoid mask for the depthwise part.
      - A 2D Gaussian mask (learnable, but initialized from a Gaussian).
      - An additional sigmoid mask.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        max_kernel_size: int,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.max_kernel_size = max_kernel_size

        # Define the four masks (each of shape (max_kernel_size, max_kernel_size)):
        self.pointwise_mask = nn.Parameter(
            torch.zeros(max_kernel_size, max_kernel_size)
        )
        self.depthwise_mask = nn.Parameter(
            torch.zeros(max_kernel_size, max_kernel_size)
        )
        self.gaussian_mask = nn.Parameter(
            self.create_gaussian_mask(max_kernel_size), requires_grad=True
        )
        self.combined_mask = nn.Parameter(torch.zeros(max_kernel_size, max_kernel_size))

        self.sigmoid = nn.Sigmoid()

    def create_gaussian_mask(self, kernel_size: int) -> torch.Tensor:
        """Creates an initial 2D Gaussian mask over the kernel grid."""
        center = kernel_size // 2
        grid_y, grid_x = torch.meshgrid(
            torch.arange(kernel_size), torch.arange(kernel_size), indexing="ij"
        )
        grid_x = grid_x.float() - center
        grid_y = grid_y.float() - center
        sigma = kernel_size / 2.0
        gaussian = torch.exp(-(grid_x**2 + grid_y**2) / (2 * sigma**2))
        return gaussian

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Create coordinate grid for the kernel.
        grid = create_kernel_grid(self.max_kernel_size).to(
            x.device
        )  # (K, 2) with K = max_kernel_size^2
        K = grid.shape[0]

        # Generate convolution weights via the two inside networks.
        # Pointwise weights: (K, out_channels * in_channels) then reshaped.
        pw_weights = self.pointwise_generator(grid)
        pw_weights = pw_weights.view(K, self.out_channels, self.in_channels)

        # Depthwise weights: (K, in_channels)
        dw_weights = self.depthwise_generator(grid)
        dw_weights = dw_weights.view(K, self.in_channels)

        # Compute the masks.
        mask_pw = self.sigmoid(self.pointwise_mask).view(-1)  # (K,)
        mask_dw = self.sigmoid(self.depthwise_mask).view(-1)  # (K,)
        mask_gauss = self.gaussian_mask.view(-1)  # (K,)
        mask_comb = self.sigmoid(self.combined_mask).view(-1)  # (K,)
        total_mask = mask_pw * mask_dw * mask_gauss * mask_comb  # (K,)

        # Apply the mask to the generated weights.
        pw_weights = pw_weights * total_mask.view(K, 1, 1)
        dw_weights = dw_weights * total_mask.view(K, 1)

        # Reshape weights into kernels:
        # Pointwise kernel: (out_channels, in_channels, k, k)
        pw_kernel = pw_weights.permute(1, 2, 0).view(
            self.out_channels,
            self.in_channels,
            self.max_kernel_size,
            self.max_kernel_size,
        )
        # Depthwise kernel: (in_channels, 1, k, k)
        dw_kernel = dw_weights.permute(1, 0).view(
            self.in_channels, 1, self.max_kernel_size, self.max_kernel_size
        )

        # Apply depthwise convolution first.
        padding = self.max_kernel_size // 2
        x_depth = F.conv2d(
            x, dw_kernel, bias=None, groups=self.in_channels, padding=padding
        )
        # Then apply pointwise convolution.
        out = F.conv2d(x_depth, pw_kernel, bias=None, padding=padding)
        return out


# --- DNArch Network Definition ---


class DNArchConvNet(nn.Module):
    """
    A DNArch-style convolutional network that builds its architecture using DNArchLayer.
    The only parameters that scale with the number of layers are the masks (4 per layer)
    plus one additional global mask for the whole network.

    The two inside networks (for pointwise and depthwise generation) are shared across layers.
    """

    def __init__(
        self,
        in_channels: int,
        max_channels: int,
        out_channels: int,
        num_layers: int,
        max_kernel_size: int,
    ):
        super().__init__()
        # Create the two inside networks (shared across layers).
        self.pointwise_generator = PointwiseGenerator(
            in_channels=max_channels,
            out_channels=max_channels,
            max_kernel_size=max_kernel_size,
        )
        self.depthwise_generator = DepthwiseGenerator(
            in_channels=max_channels, max_kernel_size=max_kernel_size
        )

        layers = []
        # First layer: maps from input channels to max_channels.
        layers.append(
            DNArchLayer(
                in_channels,
                max_channels,
                max_kernel_size,
                self.pointwise_generator,
                self.depthwise_generator,
            )
        )
        # Hidden layers.
        for _ in range(num_layers - 2):
            layers.append(
                DNArchLayer(
                    max_channels,
                    max_channels,
                    max_kernel_size,
                    self.pointwise_generator,
                    self.depthwise_generator,
                )
            )
        # Final layer: maps from max_channels to output channels.
        layers.append(
            DNArchLayer(
                max_channels,
                out_channels,
                max_kernel_size,
                self.pointwise_generator,
                self.depthwise_generator,
            )
        )

        self.net = nn.Sequential(*layers)
        # Global mask for the whole network (a single learnable sigmoid parameter).
        self.global_mask = nn.Parameter(torch.zeros(1))
        self.sigmoid = nn.Sigmoid()

        self.loss_fn = nn.MSELoss()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.net(x)
        # Apply the global mask.
        global_mask = self.sigmoid(self.global_mask)
        out = out * global_mask
        return out

    def loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return self.loss_fn(pred, target)
