# Reconstruction of matrix

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
from typing import Literal, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

ROOT_FOLDER = os.path.join(".", "..")
if ROOT_FOLDER not in sys.path:
    sys.path.insert(0, ROOT_FOLDER)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cpu')

## Utils

In [2]:
def generate_tensor_in_range(
    size: tuple,
    value_range: tuple = (0, 1),
    seed: int = 420,
    dtype: torch.dtype = torch.float32,
    device: str | torch.device | None = DEVICE,
) -> torch.Tensor:
    generator = torch.Generator(device=DEVICE).manual_seed(seed)

    min_val, max_val = value_range
    if not min_val < max_val:
        raise RuntimeError(
            f"The minimum value of the range ({min_val}) must be strictly "
            f"less than the maximum value ({max_val})."
        )

    if dtype.is_floating_point:
        tensor = torch.rand(size, generator=generator, dtype=dtype, device=device)
        tensor = tensor * (max_val - min_val) + min_val
    else:
        tensor = torch.randint(
            low=int(min_val),
            high=int(max_val),
            size=size,
            dtype=dtype,
            generator=generator,
            device=device,
        )

    return tensor

In [3]:
def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

## Models

### Basic

In [4]:
class Basic(nn.Module):
    def __init__(
        self, matrix_shape: tuple[int, int] | torch.Size, device=DEVICE, **kwargs
    ):
        super().__init__()
        self.m, self.n = matrix_shape[0], matrix_shape[1]

        generator = torch.Generator(device=DEVICE).manual_seed(420)
        self.w = nn.Parameter(
            torch.randn(
                self.m, self.n, generator=generator, dtype=torch.float32, device=device
            )
        )

    def forward(self):
        return self.w

### Original

In [5]:
def fn(x, y, a, b, c, d, u, v):
    return c * torch.exp(-((a * x - u) ** 2) - (b * y - v) ** 2) + d


def generate_grid_torch(
    rows: int,
    cols: int,
    values_range: tuple[float, float] = (-1, 1),
    device: torch.device = DEVICE,
):
    """Generates a 2D grid of coordinates as a PyTorch tensor."""
    x = torch.linspace(*values_range, cols, device=device)
    y = torch.linspace(*values_range, rows, device=device)
    yy, xx = torch.meshgrid(y, x, indexing="ij")
    return torch.stack((xx, yy), dim=-1)


class FunctionLayer(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        k: int,
        f,
        bias: Optional[Literal["single", "multiple"]] = None,
        activation=nn.LeakyReLU(),
        generator=None,
        device: torch.device = DEVICE,
    ):
        super(FunctionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.k = k
        self.f = f
        self.activation = activation

        points = generate_grid_torch(in_features, out_features, device=device)
        self.register_buffer("points", points)

        self.param_a = nn.Parameter(
            torch.randn(k, 1, 1, dtype=torch.float32, device=device, generator=generator)
        )
        self.param_b = nn.Parameter(
            torch.randn(k, 1, 1, dtype=torch.float32, device=device, generator=generator)
        )
        self.param_c = nn.Parameter(
            torch.randn(k, 1, 1, dtype=torch.float32, device=device, generator=generator)
        )
        self.param_d = nn.Parameter(
            torch.randn(k, 1, 1, dtype=torch.float32, device=device, generator=generator)
        )
        self.param_u = nn.Parameter(
            torch.randn(k, 1, 1, dtype=torch.float32, device=device, generator=generator)
        )
        self.param_v = nn.Parameter(
            torch.randn(k, 1, 1, dtype=torch.float32, device=device, generator=generator)
        )

        if bias == "single":
            self.b = nn.Parameter(
                torch.randn(1, dtype=torch.float32, device=device, generator=generator)
            )
        elif bias == "multiple":
            self.b = nn.Parameter(
                torch.randn(
                    out_features, dtype=torch.float32, device=device, generator=generator
                )
            )
        else:
            self.b = 0

    def forward(self):
        # Extract the base x and y coordinates from the grid.
        # Shape: [in_features, out_features]
        x_coords = self.points[..., 0]
        y_coords = self.points[..., 1]

        f_values = self.f(
            x_coords,
            y_coords,
            self.param_a,
            self.param_b,
            self.param_c,
            self.param_d,
            self.param_u,
            self.param_v,
        )  # [k, in_features, out_features]
        # f_values shape: [k, in_features, out_features]

        # Sum the k basis functions to create the final weight matrix.
        ws = f_values.sum(dim=0)  # ws shape: [in_features, out_features]

        return ws + self.b

    def extra_repr(self):
        return f"in_features={self.in_features}, out_features={self.out_features}, k={self.k}"


class Reconstructor(nn.Module):
    def __init__(
        self,
        matrix_shape: tuple[int, int] | torch.Size,
        f=fn,
        k: int = 64,
        bias: Optional[Literal["single", "multiple"]] = None,
    ):
        super(Reconstructor, self).__init__()

        generator = torch.Generator(device=DEVICE).manual_seed(420)

        self.layer = FunctionLayer(
            in_features=matrix_shape[0],
            out_features=matrix_shape[1],
            k=k,
            f=f,
            activation=nn.LeakyReLU(),
            bias=bias,
            generator=generator,
        )

    def forward(self):
        return self.layer.forward()

### Gaussian Mixture

In [6]:
class GaussianMixture2D(nn.Module):
    """
    Approximates an m-by-n matrix as a sum of k separable 2D Gaussians.
    Each Gaussian i has amplitude A_i, centers (mu_x_i, mu_y_i) and widths (sigma_x_i, sigma_y_i).
    """

    def __init__(self, matrix_shape: tuple[int, int] | torch.Size, k: int, device=DEVICE):
        super().__init__()
        self.m, self.n = matrix_shape[0], matrix_shape[1]
        self.k = k
        generator = torch.Generator(device=DEVICE).manual_seed(420)

        # Precompute grid of x,y coordinates in [-1,1] for each dimension
        xs = torch.linspace(-1, 1, self.n, device=device)
        ys = torch.linspace(-1, 1, self.m, device=device)
        yy, xx = torch.meshgrid(ys, xs, indexing="ij")  # shape [m,n]
        self.register_buffer("X", xx)
        self.register_buffer("Y", yy)

        # Learnable parameters for each Gaussian
        # Amplitude and bias:
        self.amplitude = nn.Parameter(
            torch.randn(k, 1, self.n, generator=generator, device=device)
        )
        self.center_x = nn.Parameter(
            torch.randn(k, self.m, 1, generator=generator, device=device)
        )
        self.center_y = nn.Parameter(
            torch.randn(k, self.m, 1, generator=generator, device=device)
        )
        self.log_sigma_x = nn.Parameter(
            torch.randn(k, 1, self.n, generator=generator, device=device)
        )
        self.log_sigma_y = nn.Parameter(
            torch.randn(k, 1, self.n, generator=generator, device=device)
        )
        self.bias = nn.Parameter(
            torch.randn(1, self.n, generator=generator, device=device)
        )
        # self.bias = nn.Parameter(torch.randn(1, generator=generator, device=device))

    def forward(self):
        # Apply softplus to sigmas to ensure positive widths
        sigma_x = F.softplus(self.log_sigma_x) + 1e-3  # shape [k,1,1]
        sigma_y = F.softplus(self.log_sigma_y) + 1e-3

        # sigma_x = self.log_sigma_x**2
        # sigma_y = self.log_sigma_y**2

        # Broadcast coordinate grid to [k,m,n] by adding leading dim
        X = self.X.unsqueeze(0)  # [1,m,n] -> [k,m,n] by broadcast
        Y = self.Y.unsqueeze(0)
        cx = self.center_x  # [k,1,1]
        cy = self.center_y

        # Compute squared distances normalized by sigma^2
        dx2 = ((X - cx) ** 2) / (2 * sigma_x**2)
        dy2 = ((Y - cy) ** 2) / (2 * sigma_y**2)
        gaussians = self.amplitude * torch.exp(-(dx2 + dy2))  # [k,m,n]
        out = gaussians.sum(dim=0) + self.bias  # sum over k -> [m,n] + bias
        return out

## Reconstruction

In [7]:
def reconstruct(
    y: torch.Tensor,
    model_class: type[nn.Module] = Reconstructor,
    epochs: int = 100,
    log_interval: int = 10,
    loss_fn: Literal["mse", "mae"] = "mse",
    lr: float = 1e-2,
    **kwargs,
) -> torch.Tensor:
    if loss_fn == "mse":
        _loss_fn = torch.nn.MSELoss()
    else:
        _loss_fn = torch.nn.L1Loss()

    model = model_class(matrix_shape=y.shape, **kwargs).to(DEVICE)

    print(f"Model with {count_parameters(model)} parameters\n")

    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    losses = []

    for epoch in range(1, epochs + 1):
        optimizer.zero_grad()

        y_hat = model.forward()

        loss = _loss_fn(y_hat, y)
        loss.backward()

        optimizer.step()

        losses.append(loss.item())
        if (epoch - 1) % log_interval == 0 or epoch == epochs:
            print(f"Epoch {epoch:0>3}/{epochs}: Loss: {loss.item():.3f}")

    return model.forward().detach().numpy()

## Experiments

In [8]:
y = generate_tensor_in_range((100, 100), value_range=(-10, 10))

_ = reconstruct(
    y,
    model_class=Basic,
    loss_fn="mae",
    epochs=1000,
    log_interval=100,
    lr=1e-2,
    k=512,
)

Model with 10000 parameters

Epoch 001/1000: Loss: 5.370
Epoch 101/1000: Loss: 4.413
Epoch 201/1000: Loss: 3.543
Epoch 301/1000: Loss: 2.761
Epoch 401/1000: Loss: 2.071
Epoch 501/1000: Loss: 1.477
Epoch 601/1000: Loss: 0.982
Epoch 701/1000: Loss: 0.587
Epoch 801/1000: Loss: 0.293
Epoch 901/1000: Loss: 0.101
Epoch 1000/1000: Loss: 0.021


In [11]:
y = generate_tensor_in_range((100, 100), value_range=(-10, 10))

_ = reconstruct(
    y,
    model_class=Reconstructor,
    loss_fn="mae",
    epochs=1000,
    log_interval=100,
    lr=1e-2,
    bias="multiple",
    k=512,
)

Model with 3172 parameters

Epoch 001/1000: Loss: 7.688
Epoch 101/1000: Loss: 4.933
Epoch 201/1000: Loss: 4.908
Epoch 301/1000: Loss: 4.898
Epoch 401/1000: Loss: 4.894
Epoch 501/1000: Loss: 4.892
Epoch 601/1000: Loss: 4.892
Epoch 701/1000: Loss: 4.890
Epoch 801/1000: Loss: 4.891
Epoch 901/1000: Loss: 4.889
Epoch 1000/1000: Loss: 4.892


In [10]:
y = generate_tensor_in_range((100, 100), value_range=(-10, 10))

_ = reconstruct(
    y,
    model_class=GaussianMixture2D,
    loss_fn="mae",
    epochs=1000,
    log_interval=100,
    lr=5e-3,
    k=16,
)

Model with 8100 parameters

Epoch 001/1000: Loss: 5.137
Epoch 101/1000: Loss: 4.397
Epoch 201/1000: Loss: 3.843
Epoch 301/1000: Loss: 3.391
Epoch 401/1000: Loss: 3.048
Epoch 501/1000: Loss: 2.794
Epoch 601/1000: Loss: 2.603
Epoch 701/1000: Loss: 2.456
Epoch 801/1000: Loss: 2.337
Epoch 901/1000: Loss: 2.242
Epoch 1000/1000: Loss: 2.165
