# Linear Inverse Problems via Flow Models

## Theoretical Foundation

### Theorem 1: Conditional Vector Fields under Gaussian Probability Paths

For a Gaussian probability path $q$, if we observe measurements $\mathbf{y} \sim q(\mathbf{y} \mid \mathbf{x}_1)$ and have an unconditional vector field $v(\mathbf{x}_t)$ enabling sampling $\mathbf{x}_t \sim q(\mathbf{x}_t)$, then the conditional vector field $v(\mathbf{x}_t, \mathbf{y})$ enabling sampling $\mathbf{x}_t \sim q(\mathbf{x}_t \mid \mathbf{y})$ is:

$$v(\mathbf{x}_t, \mathbf{y}) = v(\mathbf{x}_t) + \sigma_t^2 \frac{d \ln(\alpha_t / \sigma_t)}{dt} \nabla_{\mathbf{x}_t} \ln q(\mathbf{y} \mid \mathbf{x}_t)$$

### Training-Free Algorithm Adaptation

Using a pretrained unconditional denoiser/vector field $\widehat{v}(\mathbf{x}_t)$ or $\widehat{x}_1(\mathbf{x}_t)$, we approximate the conditional vector field as:

$$\widehat{v}(\mathbf{x}_t, \mathbf{y}) = \widehat{v}(\mathbf{x}_t) + \sigma_t^2 \frac{d \ln(\alpha_t / \sigma_t)}{dt} \gamma_t \nabla_{\mathbf{x}_t} \ln q^{app}(\mathbf{y} \mid \mathbf{x}_t)$$

where $\gamma_t$ is an adaptive weight (often set to 1 for unadaptive case) that accounts for approximation errors in the likelihood.

### Likelihood Approximation for Linear Measurements

For linear measurements with Gaussian approximation on $q(\mathbf{x}_1 \mid \mathbf{x}_t) \approx \mathcal{N}(\widehat{\mathbf{x}}_1(\mathbf{x}_t), r_t^2 \mathbf{I})$:

$$q^{app}(\mathbf{y} \mid \mathbf{x}_t) = \mathcal{N}(\mathbf{A}\widehat{\mathbf{x}}_1(\mathbf{x}_t), \sigma_y^2 \mathbf{I} + r_t^2 \mathbf{A}\mathbf{A}^\top)$$

where 

$$r_t^2 = \frac{\sigma_t^2}{\sigma_t^2 + \alpha_t^2}$$

The gradient of the log-likelihood is:

$$\nabla_{\mathbf{x}_t} \ln q^{app}(\mathbf{y} \mid \mathbf{x}_t) = (\mathbf{y} - \mathbf{A}\widehat{\mathbf{x}}_1(\mathbf{x}_t))^\top (\sigma_y^2 \mathbf{I} + r_t^2 \mathbf{A}\mathbf{A}^\top)^{-1} \mathbf{A} \frac{\partial \widehat{\mathbf{x}}_1}{\partial \mathbf{x}_t}$$


In [6]:
!uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

[2mAudited [1m3 packages[0m [2min 4ms[0m[0m


In [7]:
!uv pip install pillow opencv-python matplotlib numpy pandas scikit-learn

[2mAudited [1m6 packages[0m [2min 8ms[0m[0m


In [8]:
import torch
import torch.nn as nn
from typing import Callable, Tuple, Optional, Dict
from scipy.integrate import odeint
import numpy as np


class GaussianProbabilityPath:
    """
    Represents a Gaussian probability path with time-dependent noise schedule.
    
    Implements path: q(x_t | x_1) = N(α_t x_1, σ_t^2 I)
    """
    
    def __init__(self, alpha_fn: Callable, sigma_fn: Callable):
        """
        Args:
            alpha_fn: Function mapping t -> α_t (scaling factor)
            sigma_fn: Function mapping t -> σ_t (noise level)
        """
        self.alpha_fn = alpha_fn
        self.sigma_fn = sigma_fn
    
    def alpha(self, t: float) -> float:
        """Get α_t at time t"""
        return self.alpha_fn(t)
    
    def sigma(self, t: float) -> float:
        """Get σ_t at time t"""
        return self.sigma_fn(t)
    
    def r_t_squared(self, t: float) -> float:
        """
        Compute r_t^2 from Eq. (13) and paper derivation:
        r_t^2 = σ_t^2 / (σ_t^2 + α_t^2)
        """
        alpha_t = self.alpha(t)
        sigma_t = self.sigma(t)
        return (sigma_t ** 2) / (sigma_t ** 2 + alpha_t ** 2)
    
    def d_log_alpha_sigma_dt(self, t: float, eps: float = 1e-5) -> float:
        """
        Compute d ln(α_t / σ_t) / dt using finite differences
        
        This is needed for Theorem 1's correction term
        """
        log_ratio_plus = np.log(self.alpha(t + eps) / self.sigma(t + eps))
        log_ratio = np.log(self.alpha(t) / self.sigma(t))
        return (log_ratio_plus - log_ratio) / eps


class ConditionalOTFlowSolver(nn.Module):
    """
    Solves linear inverse problems via flows using conditional OT probability path.
    
    This implements Algorithm 1 based on Theorem 1 in the paper, using a pretrained 
    denoiser converted to a conditional OT probability path.
    
    Dimension agnostic: Works with signals of any shape.
    """
    
    def __init__(
        self,
        denoiser: Callable,
        measurement_matrix: torch.Tensor,
        probability_path: GaussianProbabilityPath,
        sigma_y: float = 0.0,
        gamma_t: float = 1.0,
        device: str = 'cpu'
    ):
        """
        Args:
            denoiser: Pretrained denoiser model x̂_1(z_t) that takes noisy input and returns denoised output
            measurement_matrix: Measurement matrix A (shape: [m, n] where n is the flattened signal dimension)
            probability_path: GaussianProbabilityPath defining the noise schedule
            sigma_y: Standard deviation of measurement noise
            gamma_t: Adaptive weight for likelihood correction (typically 1.0)
            device: Device to run computations on ('cpu' or 'cuda')
        """
        super().__init__()
        self.denoiser = denoiser
        self.register_buffer('A', measurement_matrix)
        self.probability_path = probability_path
        self.sigma_y = sigma_y
        self.gamma_t = gamma_t
        self.device = device
        
    def initialize_xt(
        self,
        y: torch.Tensor,
        t: float,
        shape: Tuple[int, ...]
    ) -> torch.Tensor:
        """
        Initialize x_t in signal space using Eq. (14): z_t = t*y_lifted + (1-t)*ε, where ε ~ N(0, I)
        
        Since y is in measurement space (shape [m]) but we need signal space (shape [...]),
        we use A^† y (pseudo-inverse) as a rough initialization in signal space.
        
        Args:
            y: Noisy measurements (shape: [m])
            t: Initial time step in [0, 1]
            shape: Shape of the full signal (e.g., (n,) for 1D, (h, w) for 2D, (d, h, w) for 3D, etc.)
            
        Returns:
            z_t: Initialized noisy signal in signal space with specified shape
        """
        y = y.to(self.device)
        
        # Sample random noise ε ~ N(0, I) in signal space
        epsilon = torch.randn(shape, device=self.device, dtype=self.A.dtype)
        
        # Lift y to signal space using pseudo-inverse
        # Flatten shape to get total signal dimension
        n_signal = int(np.prod(shape))
        A_pinv = torch.linalg.pinv(self.A)  # Shape: [n, m]
        y_lifted = A_pinv @ y  # Shape: [n]
        
        # Reshape to desired shape
        y_lifted = y_lifted.reshape(shape)
        
        # Initialize: blend between lifted measurement and noise
        z_t = self.probability_path.alpha(t) * y_lifted + self.probability_path.sigma(t) * epsilon
            
        return z_t
    
    def compute_r_t_squared(self, t_prime: float) -> float:
        r"""
        Compute r_t'^2 from the probability path.
        
        Eq. 13:
        r_t^2=\frac{\sigma_t^2}{\sigma_t^2+\alpha_t^2}
        
        Or should this be simplified to:
        r_{t^{\prime}}^2=\frac{\left(1-t^{\prime}\right)^2}{t^{\prime 2}+\left(1-t^{\prime}\right)^2}
        ??
        
        Args:
            t_prime: Current time step
            
        Returns:
            r_t'^2 value
        """
        return self.probability_path.r_t_squared(t_prime)
    
    def convert_to_vector_field(
        self,
        z_t: torch.Tensor,
        t_prime: float,
        x_pred: torch.Tensor
    ) -> torch.Tensor:
        r"""
        Convert denoiser prediction to vector field
        
        Eq 8: 
        \widehat{\boldsymbol{v}}=\left(\alpha_t \frac{d \ln \left(\alpha_t / \sigma_t\right)}{d t}\right) \widehat{\boldsymbol{x}_1}+\frac{d \ln \sigma_t}{d t} \boldsymbol{x}_t

        Simplified:
        \widehat{\boldsymbol{v}}=\frac{\widehat{\boldsymbol{x}}_1\left(\boldsymbol{z}_{t^{\prime}}\right)-\boldsymbol{z}_{t^{\prime}}}{1-t^{\prime}}
        
        Args:
            z_t: Current state (shape can be arbitrary)
            t_prime: Current time step
            x_pred: Denoiser prediction x̂_1(z_t') (same shape as z_t)
            
        Returns:
            Vector field v̂ (same shape as z_t)
        """
        if t_prime >= 1.0:
            return torch.zeros_like(z_t)
        
        v_hat = (x_pred - z_t) / (1 - t_prime)
        return v_hat
    
    def _flatten_and_unflatten(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, ...]]:
        """
        Flatten a tensor to 1D and return the original shape.
        
        Args:
            x: Tensor of arbitrary shape
            
        Returns:
            Flattened tensor, original shape
        """
        original_shape = x.shape
        flattened = x.reshape(-1)
        return flattened, original_shape
    
    def compute_likelihood_gradient(
        self,
        y: torch.Tensor,
        x_pred: torch.Tensor,
        z_t: torch.Tensor,
        t_prime: float,
        r_t_squared: float,
        jacobian: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Compute the likelihood gradient ∇_{x_t} ln q^app(y | x_t) from Theorem 1.
        
        This is the gradient of log-likelihood for the Gaussian measurement model:
        q^app(y | x_t) = N(A x̂_1(x_t), σ_y^2 I + r_t^2 A A^T)
        
        Works with signals of arbitrary shape (flattens internally).
        
        Args:
            y: Measurements (shape: [m])
            x_pred: Denoiser prediction x̂_1(z_t') (arbitrary shape)
            z_t: Current state (arbitrary shape)
            t_prime: Current time step
            r_t_squared: r_t'^2 value
            jacobian: Pre-computed Jacobian (∂x̂_1/∂z_t'). If None, uses finite differences.
            
        Returns:
            Likelihood gradient (same shape as z_t)
        """
        y = y.to(self.device)
        x_pred_flat, original_shape = self._flatten_and_unflatten(x_pred)
        z_t_flat, _ = self._flatten_and_unflatten(z_t)
        
        # Residual: y - A x̂_1
        residual = y - self.A @ x_pred_flat  # Shape: [m]
        
        # Covariance matrix: σ_y^2 I + r_t^2 A A^T
        m = self.A.shape[0]
        covariance = (
            r_t_squared * (self.A @ self.A.T) +
            (self.sigma_y ** 2) * torch.eye(m, device=self.device, dtype=self.A.dtype)
        )
        
        # Compute Jacobian if not provided
        if jacobian is None:
            jacobian = self._compute_jacobian_fd(z_t, x_pred)
        
        # Likelihood gradient: ∇_{z_t} ln q = jacobian^T @ A^T @ covariance^{-1} @ residual
        # where covariance = σ_y^2 I + r_t^2 A A^T
        inv_cov_residual = torch.linalg.solve(covariance, residual)  # Shape: [m]
        grad_measurement_space = self.A.T @ inv_cov_residual  # Shape: [n]
        grad_likelihood = jacobian.T @ grad_measurement_space  # Shape: [n]
        
        # Reshape back to original shape
        grad_likelihood = grad_likelihood.reshape(original_shape)
        
        return grad_likelihood
    
    def _compute_jacobian_fd(
        self,
        z_t: torch.Tensor,
        x_pred: torch.Tensor,
        eps: float = 1e-4
    ) -> torch.Tensor:
        """
        Compute Jacobian ∂x̂_1/∂z_t using finite differences.
        
        Dimension agnostic: Works with signals of arbitrary shape.
        
        Args:
            z_t: Input state (arbitrary shape)
            x_pred: Denoiser output at z_t (same shape as z_t)
            eps: Finite difference step size
            
        Returns:
            Jacobian matrix (shape: [n, n] where n is the total number of elements)
        """
        # Flatten to work with full Jacobian
        z_t_flat, original_shape = self._flatten_and_unflatten(z_t)
        n = z_t_flat.shape[0]
        
        jacobian = torch.zeros(n, n, device=self.device, dtype=z_t.dtype)
        
        for i in range(n):
            z_t_plus_flat = z_t_flat.clone()
            z_t_plus_flat[i] += eps
            
            # Reshape back to original shape for denoiser evaluation
            z_t_plus = z_t_plus_flat.reshape(original_shape)
            
            with torch.no_grad():
                x_pred_plus = self.denoiser(z_t_plus.unsqueeze(0)).squeeze(0)
            
            # Flatten the result to compute finite difference
            x_pred_plus_flat, _ = self._flatten_and_unflatten(x_pred_plus)
            x_pred_flat, _ = self._flatten_and_unflatten(x_pred)
            
            jacobian[:, i] = (x_pred_plus_flat - x_pred_flat) / eps
        
        return jacobian
    
    def correct_vector_field_theorem1(
        self,
        v_hat: torch.Tensor,
        likelihood_grad: torch.Tensor,
        t_prime: float
    ) -> torch.Tensor:
        r"""
        Apply Theorem 1 correction to the unconditional vector field.
        
        v̂_corrected = v̂(x_t) + σ_t^2 (d ln(α_t/σ_t)/dt) γ_t ∇_{x_t} ln q^app(y | x_t)
        
        Works with signals of arbitrary shape.
        
        Args:
            v_hat: Unconditional vector field (arbitrary shape)
            likelihood_grad: Gradient of log-likelihood (same shape as v_hat)
            t_prime: Current time step
            
        Returns:
            Corrected vector field (same shape as v_hat)
        """
        sigma_t = self.probability_path.sigma(t_prime)
        d_log_ratio_dt = self.probability_path.d_log_alpha_sigma_dt(t_prime)
        
        correction = (sigma_t ** 2) * d_log_ratio_dt * self.gamma_t * likelihood_grad
        v_corrected = v_hat + correction
        
        return v_corrected
    
    def forward(
        self,
        y: torch.Tensor,
        signal_shape: Tuple[int, ...],
        t_start: float = 0.5,
        n_steps: int = 100,
        use_likelihood_correction: bool = True
    ) -> torch.Tensor:
        """
        Solve the linear inverse problem by integrating the ODE.
        
        Integrates from t' = t_start to t' = 1 using the corrected vector field
        from Theorem 1.
        
        Dimension agnostic: Works with signals of arbitrary shape.
        
        Args:
            y: Noisy measurements (shape: [m])
            signal_shape: Shape of the signal (e.g., (n,) for 1D, (h, w) for 2D, (d, h, w) for 3D, etc.)
            t_start: Initial time step (typically 0.5-1.0)
            n_steps: Number of ODE integration steps
            use_likelihood_correction: Whether to apply Theorem 1 correction
            
        Returns:
            Reconstructed signal with the specified signal_shape
        """
        # Initialize x_t with specified shape
        z_t = self.initialize_xt(y, t_start, signal_shape)
        
        # Time steps for ODE integration from t_start to 1
        t_steps = torch.linspace(t_start, 1.0, n_steps, device=self.device)
        
        # ODE integration loop
        for i in range(len(t_steps) - 1):
            t_prime = float(t_steps[i])
            dt = float(t_steps[i + 1] - t_steps[i])
            
            # Compute r_t'^2 from probability path
            r_t_squared = self.compute_r_t_squared(t_prime)
            
            # Denoiser prediction
            with torch.no_grad():
                x_pred = self.denoiser(z_t.unsqueeze(0)).squeeze(0)
            
            # Convert to unconditional vector field
            v_hat = self.convert_to_vector_field(z_t, t_prime, x_pred)
            
            # Apply Theorem 1 correction
            if use_likelihood_correction and t_prime > 1e-6:
                likelihood_grad = self.compute_likelihood_gradient(
                    y, x_pred, z_t, t_prime, r_t_squared
                )
                v_corrected = self.correct_vector_field_theorem1(
                    v_hat, likelihood_grad, t_prime
                )
            else:
                v_corrected = v_hat
            
            # Euler step for ODE integration
            z_t = z_t + v_corrected * dt
        
        return z_t


In [9]:

#SCEGLI la PP che preferisci (non ho capito veramente la differenza)


# Probability path implementations
class VarianceExplodingPath(GaussianProbabilityPath):
    """
    Variance-Exploding (VE) probability path from DDPM/Song et al.
    
    α_t = sqrt(1 / (1 + t^2))
    σ_t = t / sqrt(1 + t^2)
    """
    
    def __init__(self):
        def alpha_fn(t):
            return np.sqrt(1.0 / (1.0 + t ** 2))
        
        def sigma_fn(t):
            return t / np.sqrt(1.0 + t ** 2)
        
        super().__init__(alpha_fn, sigma_fn)


class VariancePreservingPath(GaussianProbabilityPath):
    """
    Variance-Preserving (VP) probability path.
    
    α_t = exp(-0.5 * ∫_0^t β(s) ds)
    σ_t = sqrt(1 - α_t^2)
    """
    
    def __init__(self, beta_max: float = 20.0, beta_min: float = 0.1):
        def alpha_fn(t):
            # Linear schedule for beta
            beta_t = beta_min + t * (beta_max - beta_min)
            # Cumulative integral: ∫_0^t β(s) ds ≈ beta_min*t + 0.5*(beta_max-beta_min)*t^2
            integral = beta_min * t + 0.5 * (beta_max - beta_min) * t ** 2
            return np.exp(-0.5 * integral)
        
        def sigma_fn(t):
            alpha_t = alpha_fn(t)
            return np.sqrt(1.0 - alpha_t ** 2)
        
        super().__init__(alpha_fn, sigma_fn)

In [10]:

# Example usage
# QUESTO FA GIRARE IL CODICE CON UN DENOISER DUMMY
if __name__ == "__main__":
    # Dummy denoiser for demonstration - works with any flattened dimension
    class DummyDenoiser(nn.Module):
        def __init__(self, signal_dim: int):
            super().__init__()
            self.fc = nn.Linear(signal_dim, signal_dim)
        
        def forward(self, x):
            # Flatten, process, reshape back
            shape = x.shape[1:]
            x_flat = x.reshape(x.shape[0], -1)
            out = self.fc(x_flat)
            return out.reshape(x.shape)
    
    # Setup
    n_signal = 100  # Signal dimension
    m_measurement = 50  # Measurement dimension
    
    # Create measurement matrix A (works with flattened signal)
    A = torch.randn(m_measurement, n_signal)
    
    # Create denoiser
    denoiser = DummyDenoiser(n_signal)
    
    # Create probability path (VE path)
    prob_path = VarianceExplodingPath()
    
    # Create solver
    solver = ConditionalOTFlowSolver(
        denoiser=denoiser,
        measurement_matrix=A,
        probability_path=prob_path,
        sigma_y=0.01,
        gamma_t=1.0,
        device='cpu'
    )
    
    # Create measurements y
    y = torch.randn(m_measurement)
    
    # Example 1: 1D signal
    signal_shape_1d = (n_signal,)
    x_reconstructed_1d = solver(y, signal_shape=signal_shape_1d, t_start=0.5, n_steps=50)
    print(f"1D Reconstructed signal shape: {x_reconstructed_1d.shape}")
    print(f"1D Reconstructed signal (first 10 values): {x_reconstructed_1d.flatten()[:10]}")
    
    # Example 2: 2D signal (e.g., 10x10 image)
    signal_shape_2d = (10, 10)
    x_reconstructed_2d = solver(y, signal_shape=signal_shape_2d, t_start=0.5, n_steps=50)
    print(f"\n2D Reconstructed signal shape: {x_reconstructed_2d.shape}")
    print(f"2D Reconstructed signal (first row): {x_reconstructed_2d[0, :]}")
    
    # Example 3: 3D signal (e.g., 5x5x4 volume)
    signal_shape_3d = (5, 5, 4)
    x_reconstructed_3d = solver(y, signal_shape=signal_shape_3d, t_start=0.5, n_steps=50)
    print(f"\n3D Reconstructed signal shape: {x_reconstructed_3d.shape}")
    print(f"3D Reconstructed signal (first slice): {x_reconstructed_3d[0, :, :]}")


1D Reconstructed signal shape: torch.Size([100])
1D Reconstructed signal (first 10 values): tensor([ 0.0334, -0.0183, -0.0201,  0.0115,  0.0025, -0.0221,  0.0907, -0.0124,
        -0.0791, -0.0684])

2D Reconstructed signal shape: torch.Size([10, 10])
2D Reconstructed signal (first row): tensor([ 0.0143,  0.0631, -0.1100,  0.0601, -0.0096,  0.0039,  0.1075, -0.0042,
        -0.0511,  0.0011])

3D Reconstructed signal shape: torch.Size([5, 5, 4])
3D Reconstructed signal (first slice): tensor([[ 0.0084, -0.0916, -0.0161,  0.0692],
        [ 0.0059,  0.0239,  0.0988,  0.0091],
        [-0.1236, -0.1145,  0.0726, -0.0056],
        [-0.0155, -0.0116, -0.0711, -0.0462],
        [-0.0009,  0.0235, -0.1222,  0.0327]])
