---
draft: true
---

In [None]:
import copy
import math
import matplotlib.pyplot as plt
import numpy as np
import torch
from typing import Callable


def coords_to_edges(coords: torch.Tensor) -> torch.Tensor:
    delta = np.diff(coords)[0]
    return np.hstack([coords - 0.5 * delta, [coords[-1] + 0.5 * delta]])


def edges_to_coords(edges: torch.Tensor) -> torch.Tensor:
    return 0.5 * (edges[:-1] + edges[1:])


def rotation_matrix(angle: torch.Tensor) -> torch.Tensor:
    matrix = torch.zeros((2, 2))
    matrix[0, 0] = +torch.cos(angle)
    matrix[1, 1] = +torch.cos(angle)
    matrix[0, 1] = +torch.sin(angle)
    matrix[1, 0] = -torch.sin(angle)
    return matrix


def marginal_pdf(
    values: torch.Tensor, 
    coords: torch.Tensor, 
    sigma: float = 1.0, 
    epsilon: float = 1.00e-12,
) -> tuple[torch.Tensor, torch.Tensor]:
    residuals = values - coords.repeat(*values.shape)
    kernel_values = torch.exp(-0.5 * (residuals / sigma).pow(2))
    prob = torch.mean(kernel_values, dim=-2)
    delta = coords[1] - coords[0]
    scale = torch.sum(prob * delta)
    scale = scale + epsilon
    prob = prob / scale
    return (prob, kernel_values)


def kde_histogram(x: torch.Tensor, edges: torch.Tensor, bandwidth: float) -> torch.Tensor:
    coords = edges_to_coords(edges)
    prob, _ = marginal_pdf(x.unsqueeze(-1), coords, bandwidth)
    return prob


class Histogram(torch.nn.Module):
    def __init__(self, edges: torch.Tensor, bandwidth: float = 0.5, axis: int = 0, kde: bool = True) -> None:
        super().__init__()
        self.register_buffer("edges", edges)
        self.register_buffer("coords", edges_to_coords(self.edges))
        self.register_buffer("resolution", edges[1] - edges[0])           
        self.register_buffer("bandwidth", bandwidth * self.resolution)
        self.axis = axis
        self.kde = kde

    def project(self, x: torch.Tensor) -> torch.Tensor:
        return x[:, self.axis]

    def bin(self, x_proj: torch.Tensor) -> torch.Tensor:
        if self.kde:
            return kde_histogram(x_proj, self.edges, bandwidth=self.bandwidth)
        else:
            hist = torch.histogram(x_proj, self.edges, density=True)
            return hist.hist

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.bin(self.project(x))            
    
        
class Distribution(torch.nn.Module):
    def __init__(self, locs: torch.Tensor, stds: torch.Tensor) -> None:
        super().__init__()
        self.nmodes = len(locs)
        self.register_parameter("locs", torch.nn.Parameter(locs))
        self.register_parameter("stds", torch.nn.Parameter(stds))

    def sample(self, size: int) -> torch.Tensor:
        sizes = torch.ones(self.nmodes) * (size // self.nmodes)
        
        indices = torch.arange(self.nmodes)
        indices = indices[sizes > 0]

        x = torch.empty(0, device=sizes.device)        
        for i in indices:
            loc = self.locs[i]
            std = torch.abs(self.stds[i])
            size = int(sizes[i])
            x_k = torch.randn((size, 2))
            x_k = x_k * std + loc
            x = torch.cat((x, x_k), dim=0)
        return x


class LinearTransform(torch.nn.Module):
    def __init__(self, matrix: torch.Tensor) -> None:
        super().__init__()
        self.matrix = matrix

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.matmul(x, self.matrix.T)


class ReconstructionModel:
    def __init__(
        self, 
        distribution: Distribution, 
        transforms: list[torch.nn.Module],
        diagnostics: list[torch.nn.Module],
        projections: list[torch.Tensor], 
    ) -> None:
        self.distribution = distribution
        self.transforms = transforms
        self.projections = projections
        self.diagnostics = diagnostics
        self.n = len(transforms)

    def sample(self, size: int) -> torch.Tensor:
        return self.distribution.sample(size)

    def simulate(self, size: int) -> list[torch.Tensor]:
        x = self.sample(size)
        
        projections = []
        for transform, diagnostic in zip(self.transforms, self.diagnostics):
            projection = diagnostic(transform(x))
            projections.append(projection)
        return projections

    def loss(self, size: int) -> torch.Tensor:    
        projections_pred = self.simulate(size)
        projections_meas = self.projections
        
        loss = 0.0
        for index in range(self.n):
            y_pred = projections_pred[index]
            y_meas = projections_meas[index]
            loss = loss + torch.mean(torch.abs(y_pred - y_meas))
        loss = loss / self.n
        return loss

In [None]:
x_true = torch.randn((100_000, 2))
x_true = x_true / torch.norm(x_true, dim=1)[:, None]
x_true = x_true + torch.randn(x_true.shape) * 0.15
x_true = x_true / torch.std(x_true, axis=0)

xmax = 3.0
limits = 2 * [(-xmax, xmax)]

fig, ax = plt.subplots(figsize=(2.5, 2.5))
ax.hist2d(x_true[:, 0], x_true[:, 1], bins=75, range=(2 * [(-xmax, xmax)]))
ax.set_xticks([])
ax.set_yticks([]);

In [None]:
nmeas = 5
angles = torch.linspace(0.0, math.pi, nmeas + 1)[:-1]

transforms = []
for angle in angles:
    transform = LinearTransform(rotation_matrix(angle))
    transforms.append(transform)

diagnostics = []
for transform in transforms:
    diagnostic = Histogram(edges=torch.linspace(-xmax, xmax, 33))
    diagnostics.append(diagnostic)

In [None]:
# Generate data
projections = []
with torch.no_grad():
    for transform, diagnostic in zip(transforms, diagnostics):
        x_out = transform(x_true)
        hist = torch.histogram(x_out[:, 0], bins=diagnostic.edges, density=True)  
        projections.append(hist.hist)

## Training

In [None]:
def plot_dist(distribution: Distribution, nsamp: int = 100_000) -> tuple:
    x_pred = distribution.sample(nsamp)

    fig, axs = plt.subplots(ncols=2, figsize=(3.5, 1.75), constrained_layout=True)
    axs[0].scatter(distribution.locs[:, 0], distribution.locs[:, 1], c="red", s=2,)
    for ax, x in zip(axs, [x_pred, x_true[:nsamp]]):
        bins = 75
        limits = 2 * [(-xmax, xmax)]
        ax.hist2d(x[:, 0], x[:, 1], bins=bins, range=limits)    
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(limits[0])
        ax.set_ylim(limits[1])
    return fig, axs

In [None]:
def plot_proj(distribution: Distribution, nsamp: int = 100_000) -> tuple:
    projections_pred = model.simulate(nsamp)
    projections_meas = model.projections         

    nproj = len(projections_pred)
    ncols = min(nproj, 8)
    nrows = int(math.ceil(nproj / ncols))

    fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(2.5 * ncols, 1.0 * nrows), constrained_layout=True)
    for index in range(nmeas):
        proj_meas = projections_meas[index]
        proj_pred = projections_pred[index]
        edges = diagnostics[index].edges
        scale = max(proj.max() for proj in [proj_pred, proj_meas])
        ax = axs.flat[index]
        ax.stairs(proj_meas / scale, edges, lw=1.5, color="black")
        ax.stairs(proj_pred / scale, edges, lw=1.5, color="red")
    for ax in axs.flat:
        xlim = np.array(ax.get_xlim())
        ylim = np.array(ax.get_ylim())
        ax.set_xlim(xlim * 1.10)
        ax.set_ylim(ylim * 1.20)
    return fig, axs

In [None]:
def train_model(
    model: ReconstructionModel,
    iters: int = 1000,
    batch_size: int = 10_000,
    lr: float = 0.001,
    reg: float = 0.0,
) -> dict:
    history = {}
    history["loss"] = []

    optimizer = torch.optim.Adam(distribution.parameters(), lr=0.005)
    for iteration in range(iters):
        # Compute loss
        loss = model.loss(batch_size)

        # Compute loss (reg)
        loss_reg = 0.0
        for i in range(distribution.nmodes):
            loss_reg += torch.mean(distribution.stds[i])
        loss_reg /= distribution.nmodes
        loss_reg *= reg
        loss += loss_reg

        # Backprop
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        history["loss"].append(loss.detach())

    return model, history

In [None]:
# Reconstruction model
modes = 10

locs = torch.zeros((modes, 2))
locs[:, 0] = torch.linspace(-2.0, 2.0, modes)

stds = torch.ones(modes) * 0.25

distribution = Distribution(locs=locs, stds=stds)

model = ReconstructionModel(
    distribution=distribution,
    transforms=transforms,
    projections=projections,
    diagnostics=diagnostics,
)

model, history = train_model(model, lr=0.001)
history["loss"][-1]

In [None]:
fig, ax = plt.subplots()
ax.plot(history["loss"])

In [None]:
for i in range(4):
    modes = 5
    stds = torch.ones(modes) * 0.02
    locs = torch.zeros((modes, 2))

    if i == 1:
        locs[:, 0] = torch.linspace(-2.0, 2.0, modes)
    if i == 2:
        locs[:, 1] = torch.linspace(-2.0, 2.0, modes)
    if i == 3:
        locs[:, 0] = torch.linspace(-2.0, 2.0, modes)
        locs[:, 1] = torch.linspace(-2.0, 2.0, modes)
    
    distribution = Distribution(locs=locs, stds=stds)
    model = ReconstructionModel(
        distribution=distribution, 
        transforms=transforms, 
        projections=projections, 
        diagnostics=diagnostics
    )

    with torch.no_grad():
        plot_dist(distribution)
        plt.show()
    
    model, history = train_model(model, iters=1200, lr=0.001)
    
    with torch.no_grad():
        print(history["loss"][-1])
        
        fig, ax = plt.subplots(figsize=(3, 2))
        ax.plot(history["loss"])
        plt.show()
        
        plot_dist(distribution)
        plot_proj(distribution)
        plt.show()