# Distributions

Suppose we are trying to emulate a simulator's outputs $f(x) \in \mathbb{R}^d$. If we consider a batch of $n$ outputs then then we'd want to retrieve output means $\mu \in \mathbb{R}^{nd}$ and covariances $\Sigma \in \mathbb{R}^{nd \times nd}$.
$$
\Sigma =
\begin{bmatrix}
\Sigma_{11} & \Sigma_{12} & \cdots & \Sigma_{1n} \\[6pt]
\Sigma_{21} & \Sigma_{22} & \cdots & \Sigma_{2n} \\[6pt]
\vdots      & \vdots      & \ddots & \vdots      \\[6pt]
\Sigma_{n1} & \Sigma_{n2} & \cdots & \Sigma_{nn}
\end{bmatrix} \in \mathbb{R}^{nd \times nd}
\quad \text{s.t.} \quad 
\Sigma_{ij} \in \mathbb{R}^{d \times d}
$$
where $\Sigma_{ii}$ is the *marginal covariance* and $\Sigma_{ij}$ is the *cross-covariance*.

There are a couple of scenarios we want to consider:
1. **Full**: no simplification to the above covariance matrix.
2. **Block-diagonal**: $\Sigma_{ij} = 0$ for all $i\neq j$, i.e. no correlations between samples.
3. **Diagonal**: $\Sigma_{ij} = 0$ for all $i\neq j$ and $\Sigma_{ii}^{(a, b)} = 0$ for all $a \neq b$, i.e. no correlations between sample dimensions.
4. **Seperable**: $\Sigma = \Sigma_{N} \otimes \Sigma_{D}$ s.t. $\Sigma_{\text{N}} \in \mathbb{R}^{n \times n}$ and $\Sigma_{\text{D}} \in \mathbb{R}^{d \times d}$, i.e. correlations between samples and dimensions are modelled seperately. 

Note that:
$$
\begin{aligned}
\text{Full} &\supseteq \text{Block-Diagonal} \supseteq \text{Diagonal} \\
\text{Full} &\supseteq \text{Seperable}
\end{aligned}
$$

In [56]:
import torch, time
from autoemulate.experimental.data.gaussian import Dense, BlockDiagonal, Diagonal, Separable
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
n, d = 4, 3
mean = torch.randn(n, d)

## Full $(nd, nd)$ covariance

In [33]:
cov = torch.randn(n*d, n*d)
cov = cov @ cov.T
dist = Dense(mean, cov)
dist.logdet(), dist.trace(), dist.max_eig()

(tensor(19.2076), tensor(159.7938), tensor(37.4083))

## Block-diagonal $(n, d, d)$ covariance

In [34]:
cov = torch.randn(n, d, d)
cov = cov @ cov.transpose(-1, -2) + torch.eye(d) * 1e-4
dist = BlockDiagonal(mean, cov)
dist.logdet(), dist.trace(), dist.max_eig()

(tensor(1.3373), tensor(51.1198), tensor(13.1288))

## Diagonal $(n, d)$ covariance

In [35]:
cov = torch.abs(torch.randn(n, d))
dist = Diagonal(mean, cov)
dist.logdet(), dist.trace(), dist.max_eig()

(tensor(-14.7925), tensor(5.3312), tensor(1.3319))

## Seperable $(n, n)$ and $(d, d)$ covariance

In [36]:
cov_n = torch.randn(n, n)
cov_n = cov_n @ cov_n.T + torch.eye(n) * 1e-4
cov_d = torch.randn(d, d)
cov_d = cov_d @ cov_d.T + torch.eye(d) * 1e-4
dist = Separable(mean, cov_n, cov_d)
dist.logdet(), dist.trace(), dist.max_eig()

(tensor(-0.6665), tensor(98.4998), tensor(39.5591))

## Comparing dense to specializations

In [None]:
def time_ops(n: int, d: int, n_trials: int):

    # Covariances
    cov_block = torch.randn(n, d, d)
    cov_block = cov_block @ cov_block.transpose(-1, -2) + torch.eye(d) * 1e-4
    cov_diag = torch.abs(torch.randn(n, d))
    cov_n = torch.randn(n, n)
    cov_n = cov_n @ cov_n.T + torch.eye(n) * 1e-4
    cov_d = torch.randn(d, d)
    cov_d = cov_d @ cov_d.T + torch.eye(d) * 1e-4

    # Dense versions
    dense_block = torch.block_diag(*cov_block)
    dense_diag = torch.diag(cov_diag.flatten())
    dense_separable = torch.kron(cov_n, cov_d)

    # Distributions
    dist_dense_block = Dense(mean, dense_block)
    dist_block = BlockDiagonal(mean, cov_block)
    dist_dense_diag = Dense(mean, dense_diag)
    dist_diag = Diagonal(mean, cov_diag)
    dist_dense_separable = Dense(mean, dense_separable)
    dist_separable = Separable(mean, cov_n, cov_d)

In [55]:
torch.diag(torch.abs(torch.randn(n, d)).flatten())

tensor([[0.1070, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000],
        [0.0000, 0.3002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.1768, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.9874, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.6125, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5119, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.3682, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5751, 0.0000,
         0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7662,

In [45]:
torch.block_diag(*torch.randn(n, d, d)).shape

torch.Size([12, 12])

In [None]:
import timeit
import torch
import pandas as pd
from dataclasses import dataclass

# — your refactored classes (Gaussian, Dense, BlockDiagonal, Diagonal, Separable) —
# Paste them here before running the benchmark

# Helper to generate a random SPD matrix of size m×m
def random_spd(m: int) -> torch.Tensor:
    A = torch.randn(m, m)
    return A @ A.T + m * torch.eye(m)

def benchmark(fn, repeats=5):
    t = timeit.Timer(fn)
    # run `fn()` once per repeat, take the average
    times = t.repeat(repeat=repeats, number=1)
    return sum(times) / len(times)

def main():
    # Choose sizes small enough to fit in memory
    n, d = 10, 10
    nd = n * d

    mean = torch.randn(n, d)
    cov_dense   = random_spd(nd)
    cov_blocks  = torch.stack([random_spd(d) for _ in range(n)])
    cov_diag    = torch.rand(n, d) + 0.1
    cov_sep_n   = random_spd(n)
    cov_sep_d   = random_spd(d)

    # Instantiate
    dense = Dense(mean, cov_dense)
    block = BlockDiagonal(mean, cov_blocks)
    diag  = Diagonal(mean, cov_diag)
    sep   = Separable(mean, cov_sep_n, cov_sep_d)

    rows = []
    for name, obj in [
        ("Dense", dense),
        ("BlockDiagonal", block),
        ("Diagonal", diag),
        ("Separable", sep),
    ]:
        rows.append({
            "Method": name,
            "logdet (s)":  benchmark(obj.logdet),
            "trace   (s)":  benchmark(obj.trace),
            "max_eig (s)":  benchmark(obj.max_eig),
        })

    df = pd.DataFrame(rows)
    print(df)

In [38]:
main()

          Method  logdet (s)  trace   (s)  max_eig (s)
0          Dense    0.002031     0.000011     0.000645
1  BlockDiagonal    0.000418     0.000009     0.000124
2       Diagonal    0.000008     0.000003     0.000007
3      Separable    0.000098     0.000006     0.000088
