<a href="https://colab.research.google.com/github/ZhengyiGuo2002/CSDM-code/blob/main/CSDM_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import InterpolationMode
from tqdm import tqdm
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

# ==========================================================
# Global config you can set
# ==========================================================
DATASET_NAME = "MNIST"   # e.g., "MNIST"; if your torchvision has it, "QMNIST", etc.
DATA_ROOT    = "mnist_data"
SIDE_X       = 48        # target side length for uncompressed images (e.g., 40x40)
SIDE_Y       = 28        # compressed side length m (i.e., y has M = m*m measurements)
SEED         = 1234      # for reproducible A
device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch = 64


# ==========================================================
# 0.  Base dataset + uncompressed dataloaders
# ==========================================================
def build_base_datasets(dataset_name=DATASET_NAME, img_side=SIDE_X, root=DATA_ROOT, download=True):
    """
    Returns torchvision-style train/test datasets that yield tensors in [0,1]
    with shape (C, img_side, img_side).
    """
    tfm = transforms.Compose([
        transforms.Resize((img_side, img_side), interpolation=InterpolationMode.BILINEAR),
        transforms.ToTensor(),  # -> float32 in [0,1], shape (C,H,W)
    ])
    DatasetClass = getattr(datasets, dataset_name)  # assumes MNIST-like ctor
    train_ds = DatasetClass(root=root, train=True,  transform=tfm, download=download)
    test_ds  = DatasetClass(root=root, train=False, transform=tfm, download=download)
    return train_ds, test_ds

def build_uncompressed_loaders(batch_size, dataset_name=DATASET_NAME, img_side=SIDE_X, root=DATA_ROOT, download=True):
    train_ds, test_ds = build_base_datasets(dataset_name, img_side, root, download)
    tr_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    te_loader = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)
    return (train_ds, test_ds), (tr_loader, te_loader)

# ==========================================================
# 1.  COMPRESS – Gaussian sketch & dataset wrapper
# ==========================================================
def _infer_dims(ds):
    x0, _ = ds[0]
    assert x0.ndim == 3 and x0.shape[1] == x0.shape[2], "Dataset must yield (C,H,W) with square H==W."
    C, H, W = x0.shape
    D = C * H * W
    return C, H, W, D

def _make_A(D, M, seed=SEED):
    """
    Gaussian sketch A ~ N(0, 1/M) with shape (M, D).
    """
    g = torch.Generator(device="cpu").manual_seed(seed) if seed is not None else None
    A = torch.normal(mean=0.0, std=1.0 / math.sqrt(M), size=(M, D), generator=g)
    return A  # keep on CPU; move to GPU only if you need to multiply there

class CompressedDataset(Dataset):
    """
    Wraps a base image dataset (yielding (C,H,W), label) into compressed measurements.

    Returns per item:
        y_vec : (M,)            # flat measurement vector, M = m_side * m_side
        y_img : (1, m_side, m_side)  # square view of y_vec for image diffusers
        label : original label
    """
    def __init__(self, base_ds, A, m_side: int):
        super().__init__()
        self.ds = base_ds
        self.A = A  # (M, D) on CPU
        self.m_side = m_side
        self.M = m_side * m_side

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        x_img, lbl = self.ds[idx]            # x_img: (C,S,S) in [0,1]
        x = x_img.reshape(-1)                # (D,)
        # y = A x
        # (both A and x are CPU tensors here; if you need GPU matmul later, move both to the same device)
        y = torch.mv(self.A, x)              # (M,)
        y_img = y.view(1, self.m_side, self.m_side)
        return y, y_img, lbl

def build_compressed_loaders(train_ds, test_ds, batch_size, m_side=SIDE_Y, seed=SEED):
    """
    Given base (uncompressed) datasets, build a single fixed Gaussian A and
    dataloaders for:
      - vector measurements (for 1D diffusion)
      - image-shaped measurements (for 2D diffusion)
    """
    C, H, W, D = _infer_dims(train_ds)
    M = m_side * m_side
    assert M <= D, f"Compressed dimension M={M} should be <= D={D} for a true compression."

    A = _make_A(D, M, seed=seed)  # CPU
    comp_train = CompressedDataset(train_ds, A, m_side)
    comp_test  = CompressedDataset(test_ds,  A, m_side)

    # Collate for vector (B, M) + labels
    def collate_vec(batch):
        y_vec = torch.stack([b[0] for b in batch], dim=0)                 # (B, M)
        lbl   = torch.tensor([b[2] for b in batch], dtype=torch.long)
        return y_vec, lbl

    # Collate for image view (B, 1, m, m) + labels
    def collate_img(batch):
        y_img = torch.stack([b[1] for b in batch], dim=0)                 # (B, 1, m, m)
        lbl   = torch.tensor([b[2] for b in batch], dtype=torch.long)
        return y_img, lbl

    # Vector dataloaders (for 1D diffusion)
    tr_loader_vec = DataLoader(comp_train, batch_size=batch_size, shuffle=True,  collate_fn=collate_vec)
    te_loader_vec = DataLoader(comp_test,  batch_size=batch_size, shuffle=False, collate_fn=collate_vec)

    # Image dataloaders (for 2D diffusion)
    tr_loader_img = DataLoader(comp_train, batch_size=batch_size, shuffle=True,  collate_fn=collate_img)
    te_loader_img = DataLoader(comp_test,  batch_size=batch_size, shuffle=False, collate_fn=collate_img)

    return A, (tr_loader_vec, te_loader_vec), (tr_loader_img, te_loader_img)

# ==========================================================
# Example usage (keeps your original variable names where sensible)
# ==========================================================
# 0) build uncompressed loaders (no compression)
# Note: use your existing `batch` variable
(train_ds_raw, test_ds_raw), (tr_loader, te_loader) = build_uncompressed_loaders(
    batch_size=batch,
    dataset_name=DATASET_NAME,
    img_side=SIDE_X,
    root=DATA_ROOT,
    download=True
)

# 1) build compressed loaders (vector + image views) using one fixed A
A_cpu, (tr_loader_c_vec, te_loader_c_vec), (tr_loader_c_img, te_loader_c_img) = build_compressed_loaders(
    train_ds=train_ds_raw,
    test_ds=test_ds_raw,
    batch_size=batch,
    m_side=SIDE_Y,
    seed=SEED
)

# (Optional) have a GPU copy for later ops if needed
A_gpu = A_cpu.to(device)

# Diffusion

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional

class RunningNormalizer:
    """
    Simple scalar normalizer: x_norm = (x - mean) / (std + eps).
    Use one global mean/std for stability in compressed domain.
    """
    def __init__(self, eps: float = 1e-6):
        self.mean = None
        self.std = None
        self.eps = eps

    @torch.no_grad()
    def fit_from_loader(self, loader, is_vector: bool, device: torch.device, max_batches: Optional[int] = 50):
        """
        Estimate global mean/std from (up to) max_batches of the loader.
        - vector loader: yields (y_vec, lbl)
        - image loader : yields (y_img, lbl); will flatten to vectors for scale stats
        """
        cnt = 0
        s1, s2, n = 0.0, 0.0, 0
        for b, (y, _) in enumerate(loader):
            y = y.to(device=device, dtype=torch.float32)
            if not is_vector:  # image: (B,1,m,m)
                y = y.view(y.size(0), -1)
            s1 += y.sum().item()
            s2 += (y ** 2).sum().item()
            n  += y.numel()
            cnt += 1
            if max_batches is not None and cnt >= max_batches:
                break
        mean = s1 / max(1, n)
        var  = s2 / max(1, n) - mean * mean
        std  = math.sqrt(max(var, 1e-8))
        self.mean = mean
        self.std  = std

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        return (x - self.mean) / (self.std + self.eps)

    def denormalize(self, x: torch.Tensor) -> torch.Tensor:
        return x * (self.std + self.eps) + self.mean

# ==========================================================
# SDE definitions (VE, VP) & marginals
# ==========================================================
class VPSDE:
    """
    Variance-Preserving SDE (DDPM continuous limit):
        dX = -0.5 β(t) X dt + sqrt(β(t)) dW
    Marginal: X_t = α(t) X_0 + σ(t) ε
      with α(t) = exp(-0.5 ∫ β),  σ^2(t) = 1 - α(t)^2
    """
    def __init__(self, beta_min=0.1, beta_max=20.0):
        self.beta_min = float(beta_min)
        self.beta_max = float(beta_max)
        self.delta = self.beta_max - self.beta_min

    def beta(self, t: torch.Tensor) -> torch.Tensor:
        # linear from beta_min to beta_max over t in [0,1]
        return self.beta_min + t * self.delta

    def a(self, t: torch.Tensor) -> torch.Tensor:
        # α(t) = exp(-0.5 * (beta_min t + 0.5 delta t^2))
        return torch.exp(-0.5 * (self.beta_min * t + 0.5 * self.delta * t * t))

    def sigma(self, t: torch.Tensor) -> torch.Tensor:
        a = self.a(t)
        return torch.sqrt(torch.clamp(1.0 - a * a, min=1e-10))

    def g2(self, t: torch.Tensor) -> torch.Tensor:
        # g^2(t) = β(t)
        return self.beta(t)

    def f(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        # f(t,x) = -0.5 β(t) x
        beta_t = self.beta(t).view(-1, *([1] * (x.dim() - 1)))
        return -0.5 * beta_t * x

    def prior_sample(self, shape, device):
        # X_T ~ N(0, I)
        return torch.randn(shape, device=device)


# ==========================================================
# Losses (DSM): target score = -epsilon / sigma(t)
# ==========================================================
def dsm_loss(model, x0, sde, t, weight_type="sigma2", is_image: bool = False):
    """
    x0: (B, M) for vectors OR (B,1,m,m) for images
    sde: VESDE or VPSDE
    t:  (B,) uniform in (eps,1]
    weight_type: "sigma2" (λ(t)=σ^2(t)) or "none"
    returns: scalar loss
    """
    eps = torch.randn_like(x0)

    # VP
    alpha_t = sde.a(t).view(-1, *([1] * (x0.dim()-1)))
    sigma_t = sde.sigma(t).view(-1, *([1] * (x0.dim()-1)))
    xt = alpha_t * x0 + sigma_t * eps
    target = - eps / sigma_t

    pred = model(xt, t)
    if weight_type == "sigma2":
        w = (sigma_t ** 2)
        loss = (((pred - target) ** 2) * w).mean()
    else:
        loss = ((pred - target) ** 2).mean()
    return loss

# ==========================================================
# Training loops
# ==========================================================
def train_score_img(
    tr_loader_c_img, te_loader_c_img,
    device: torch.device,
    sde_type: str = "VE",
    sigma_min: float = 0.01, sigma_max: float = 50.0,  # VE
    beta_min: float = 0.1,  beta_max: float = 20.0,    # VP
    lr: float = 2e-4, epochs: int = 10,
    weight_type: str = "sigma2",
    normalize: bool = True,
):
    """
    Train a score model on compressed images (B, 1, m, m).
    """
    normalizer = RunningNormalizer()
    if normalize:
        normalizer.fit_from_loader(tr_loader_c_img, is_vector=False, device=device, max_batches=300)

    if sde_type.upper() == "VE":
        sde = VESDE(sigma_min=sigma_min, sigma_max=sigma_max)
    else:
        sde = VPSDE(beta_min=beta_min, beta_max=beta_max)

    model = UNetScoreNetImg(base_ch=128, tdim=256).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)

    # NEW: ReduceLROnPlateau scheduler (reacts to validation loss)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode='min', factor=0.5, patience=2, threshold=1e-4,
        cooldown=1, min_lr=2e-6
    )

    for ep in tqdm(range(epochs)):
        model.train()
        run_loss, nsteps = 0.0, 0
        for y_img, _ in tr_loader_c_img:
            y_img = y_img.to(device=device, dtype=torch.float32)
            if normalize:
                y_img = normalizer.normalize(y_img)

            t = torch.rand(y_img.size(0), device=device) * (1.0 - 1e-5) + 1e-5
            loss = dsm_loss(model, y_img, sde, t, weight_type=weight_type, is_image=True)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            run_loss += float(loss.item()); nsteps += 1

        model.eval()
        with torch.no_grad():
            val_loss, vsteps = 0.0, 0
            for y_img_val, _ in te_loader_c_img:
                y_img_val = y_img_val.to(device=device, dtype=torch.float32)
                if normalize:
                    y_img_val = normalizer.normalize(y_img_val)
                t_val = torch.rand(y_img_val.size(0), device=device) * (1.0 - 1e-5) + 1e-5
                loss_val = dsm_loss(model, y_img_val, sde, t_val,
                                    weight_type=weight_type, is_image=True)
                val_loss += float(loss_val.item()); vsteps += 1
            val_loss /= max(1, vsteps)

        # NEW: step the scheduler with validation loss
        scheduler.step(val_loss)
        current_lr = opt.param_groups[0]['lr']

        print(f"[Img][{sde_type}] epoch {ep+1}/{epochs}  "
              f"train_loss={run_loss/max(1,nsteps):.4f}  val_loss={val_loss:.4f}  lr={current_lr:.2e}")

    return model, sde, normalizer

# ==========================================================
# Sampling (Probability-flow ODE; deterministic)
# ==========================================================
@torch.no_grad()
def sample_compressed_img(
    model, sde, normalizer, num_samples, m_side, device,
    steps=1000, snr=0.15, use_pc=True
):
    """Return (N,1,side,side) in compressed or image space (取决于训练域)。"""
    model.eval()
    x = sde.prior_sample((num_samples, 1, m_side, m_side), device=device)  # t=1

    t_grid = torch.linspace(1.0, 0.0, steps+1, device=device)
    for i in range(steps):
        t, t_next = t_grid[i], t_grid[i+1]
        dt = (t_next - t)  # 负数
        tb = torch.full((num_samples,), float(t.item()), device=device)

        s = model(x, tb)  # score(x,t)

        beta = sde.g2(tb).view(-1,1,1,1)          # = β(t)
        drift = -0.5 * beta * x - beta * s
        diff  = torch.sqrt(torch.clamp(beta * (-dt), min=1e-12))

        # Predictor: Euler–Maruyama
        z = torch.randn_like(x)
        x = x + drift * dt + diff * z

        if use_pc:
            # Corrector: Langevin steps（1~2 步通常够用）
            with torch.no_grad():
                # 自适应步长：基于目标 SNR
                noise_std = diff.mean()
                score_norm = torch.sqrt((s**2).mean(dim=(1,2,3), keepdim=True) + 1e-12)
                step_size = (snr * noise_std / (score_norm + 1e-12))**2
                for _ in range(1):  # 1次校正
                    zc = torch.randn_like(x)
                    x = x + step_size * s + torch.sqrt(2.0 * step_size) * zc

    if normalizer.mean is not None:
        x = normalizer.denormalize(x)
    return x.clamp(-1, 1)

# ==========================================================
# Decoding sampled compressed data with your FISTA
# ==========================================================
@torch.no_grad()
def reconstruct_from_compressed_vectors(
    y_vec: torch.Tensor,        # (N, M)
    A: torch.Tensor,            # (M, D)  on SAME device as y_vec or will be moved
    img_shape: Tuple[int,int,int],  # (C,H,W)
    lam: float,
    L: Optional[float] = None,
    max_iter: int = 500,
    tol: float = 1e-5,
) -> torch.Tensor:
    """
    Uses your fista_lasso_batch to decode y -> x in image space.
    """
    device = A.device
    y_vec = y_vec.to(device=device, dtype=A.dtype)
    X_hat = fista_lasso_batch(A, y_vec, lam=lam, L=L, max_iter=max_iter, tol=tol)  # (N, D)
    with torch.no_grad():
        frac_clipped = ((X_hat < 0) | (X_hat > 1)).float().mean().item()
        print(f"clipped fraction: {frac_clipped:.3f}")
    C,H,W = img_shape
    X_hat = X_hat.view(-1, C, H, W).clamp(0,1).cpu()
    return X_hat

@torch.no_grad()
def reconstruct_from_compressed_images(
    y_img: torch.Tensor,        # (N, 1, m, m)
    A: torch.Tensor,            # (M, D)
    img_shape: Tuple[int,int,int],
    lam: float,
    L: Optional[float] = None,
    max_iter: int = 500,
    tol: float = 1e-5,
) -> torch.Tensor:
    """
    Flatten y_img and decode with FISTA.
    """
    y_vec = y_img.view(y_img.size(0), -1)
    return reconstruct_from_compressed_vectors(y_vec, A, img_shape, lam, L, max_iter, tol)

In [None]:
# ===== Minimal, Drop-in Better UNet for compressed "images" =====
# Implements:
# (1) Upsample+Conv (no transposed conv -> fewer checkerboard artifacts)
# (2) 1x1 conv for channel alignment (no F.pad)
# (3) FiLM time conditioning (scale-shift after norm)
# (4) Two down/up scales + lightweight 2D attention at bottleneck
# (5) Wider base channels + zero-inited output head

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# --- (Aux) Sinusoidal time embedding ---
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """
        t: (B,) in [0, 1] or [0, T]
        returns: (B, dim)
        """
        half = self.dim // 2
        # Use log frequencies
        emb_scale = torch.exp(
            torch.linspace(math.log(1.0), math.log(10000.0), half, device=t.device)
        )
        # shape: (B, half)
        args = t[:, None] * emb_scale[None, :]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return emb


# --- (Aux) helpers ---
def zero_module(m: nn.Module) -> nn.Module:
    """Zero-initialize a module's parameters (useful for output/layer ends)."""
    for p in m.parameters():
        nn.init.zeros_(p)
    return m

def _gn_groups(c: int, max_groups: int = 32) -> int:
    """Pick a GroupNorm groups count that divides c (<= max_groups)."""
    for g in range(min(max_groups, c), 0, -1):
        if c % g == 0: return g
    return 1


class Skip1x1(nn.Module):
    """Channel alignment 1x1 conv (replaces F.pad)."""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)


class Down(nn.Module):
    """Simple stride-2 downsample."""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.op = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.op(x)


class Up(nn.Module):
    """Nearest-neighbor upsample + 3x3 conv (replaces ConvTranspose2d)."""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.ups = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(self.ups(x))


class SelfAttention2d(nn.Module):
    """Lightweight MHSA on 2D feature maps. Run it only at small resolution."""
    def __init__(self, channels: int, num_heads: int = 4):
        super().__init__()
        assert channels % num_heads == 0, "channels must be divisible by num_heads"
        self.channels = channels
        self.num_heads = num_heads
        self.head_dim = channels // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Conv2d(channels, 3 * channels, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

        # Optional norm before attention (kept simple & robust)
        self.norm = nn.GroupNorm(_gn_groups(channels), channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)

        qkv = self.qkv(x)  # (B, 3C, H, W)
        q, k, v = torch.chunk(qkv, 3, dim=1)  # each (B, C, H, W)

        # reshape to (B, nH, HW, head_dim)
        def _reshape(t):
            return t.view(b, self.num_heads, self.head_dim, h * w).permute(0, 1, 3, 2).contiguous()
        q, k, v = map(_reshape, (q, k, v))

        attn = torch.softmax(torch.matmul(q, k.transpose(-2, -1)) * self.scale, dim=-1)  # (B, nH, HW, HW)
        out = torch.matmul(attn, v)  # (B, nH, HW, head_dim)

        out = out.permute(0, 1, 3, 2).contiguous().view(b, c, h, w)  # (B, C, H, W)
        out = self.proj(out)
        return x_in + out


class ResBlockFiLM(nn.Module):
    """
    ResNet block with FiLM time conditioning: scale-shift after norm
    Replaces your original ResBlock (which added a bias from temb)
    """
    def __init__(self, in_ch: int, out_ch: int, tdim: int, dropout: float = 0.0):
        super().__init__()
        g1 = _gn_groups(in_ch)
        g2 = _gn_groups(out_ch)

        self.norm1 = nn.GroupNorm(g1, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)

        self.norm2 = nn.GroupNorm(g2, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)

        # FiLM gamma/beta from time embedding
        self.emb = nn.Sequential(
            nn.SiLU(),
            nn.Linear(tdim, 2 * out_ch)
        )

        self.skip = Skip1x1(in_ch, out_ch)
        self.dropout = nn.Dropout(dropout)

        # Zero-init the last conv for stability
        self.conv2 = zero_module(self.conv2)

    def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
        # First conv
        h = self.conv1(F.silu(self.norm1(x)))

        # FiLM conditioning on the second normed activations
        gamma, beta = self.emb(temb).chunk(2, dim=1)  # (B, C), (B, C)
        h = self.norm2(h)
        h = (1 + gamma[:, :, None, None]) * h + beta[:, :, None, None]
        h = self.dropout(F.silu(h))

        h = self.conv2(h)
        return self.skip(x) + h


class UNetScoreNetImg(nn.Module):
    """
    Small-but-strong UNet for 1x(HxW) compressed 'images':
    - two downs/ups
    - FiLM time conditioning
    - bottleneck attention
    - zero-inited output head
    NOTE: Expect H and W divisible by 4 (e.g., 28 -> 14 -> 7).
    """
    def __init__(self, base_ch: int = 64, tdim: int = 256, dropout: float = 0.1):
        super().__init__()
        self.time_emb = SinusoidalTimeEmbedding(tdim)
        self.time_mlp = nn.Sequential(
            nn.Linear(tdim, 4 * tdim),
            nn.SiLU(),
            nn.Linear(4 * tdim, tdim),
        )

        # Encoder blocks (每尺度两个 block)
        self.in_conv = nn.Conv2d(1, base_ch, 3, padding=1)
        self.e1  = ResBlockFiLM(base_ch, base_ch, tdim, dropout=dropout)
        self.e1b = ResBlockFiLM(base_ch, base_ch, tdim, dropout=dropout)

        self.down1 = Down(base_ch, base_ch)
        self.e2  = ResBlockFiLM(base_ch, 2 * base_ch, tdim, dropout=dropout)
        self.e2b = ResBlockFiLM(2 * base_ch, 2 * base_ch, tdim, dropout=dropout)
        self.attn_mid = SelfAttention2d(2 * base_ch, num_heads=max(4, (2*base_ch)//64))

        self.down2 = Down(2 * base_ch, 2 * base_ch)
        self.e3  = ResBlockFiLM(2 * base_ch, 4 * base_ch, tdim, dropout=dropout)
        self.e3b = ResBlockFiLM(4 * base_ch, 4 * base_ch, tdim, dropout=dropout)

        # Bottleneck
        self.mid1 = ResBlockFiLM(4 * base_ch, 4 * base_ch, tdim, dropout=dropout)
        self.attn = SelfAttention2d(4 * base_ch, num_heads=4)
        self.mid2 = ResBlockFiLM(4 * base_ch, 4 * base_ch, tdim, dropout=dropout)

        # Decoder blocks（每尺度两个 block）
        self.up1 = Up(4 * base_ch, 2 * base_ch)
        self.d2  = ResBlockFiLM(4 * base_ch, 2 * base_ch, tdim, dropout=dropout)
        self.d2b = ResBlockFiLM(2 * base_ch, 2 * base_ch, tdim, dropout=dropout)

        self.up2 = Up(2 * base_ch, base_ch)
        self.d1  = ResBlockFiLM(2 * base_ch, base_ch, tdim, dropout=dropout)
        self.d1b = ResBlockFiLM(base_ch, base_ch, tdim, dropout=dropout)

        self.out = zero_module(nn.Conv2d(base_ch, 1, 3, padding=1))


    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        x: (B, 1, H, W), t: (B,)
        """
        # 2) forward 里把每尺度再过一次 block
        temb = self.time_mlp(self.time_emb(t))

        h0 = self.in_conv(x)
        h1 = self.e1(h0, temb)
        h1 = self.e1b(h1, temb)

        h2_in = self.down1(h1)
        h2 = self.e2(h2_in, temb)
        h2 = self.e2b(h2, temb)
        h2 = self.attn_mid(h2)

        h3_in = self.down2(h2)
        h3 = self.e3(h3_in, temb)
        h3 = self.e3b(h3, temb)

        h = self.mid1(h3, temb)
        h = self.attn(h)
        h = self.mid2(h, temb)

        h = self.up1(h)
        h = torch.cat([h, h2], dim=1)
        h = self.d2(h, temb)
        h = self.d2b(h, temb)

        h = self.up2(h)
        h = torch.cat([h, h1], dim=1)
        h = self.d1(h, temb)
        h = self.d1b(h, temb)

        return self.out(F.silu(h))

# Decompress

In [None]:
# ==========================================================
# FISTA / LASSO utilities
# ==========================================================

def soft_threshold(x: torch.Tensor, tau: float) -> torch.Tensor:
    """Elementwise soft-thresholding: prox_{tau ||.||_1}(x)."""
    return torch.sign(x) * torch.clamp(x.abs() - tau, min=0.0)

@torch.no_grad()
def estimate_lipschitz_squared(A: torch.Tensor, iters: int = 50) -> float:
    """
    Estimate L = ||A||_2^2 via power iteration on A^T A.
    Safe to run on CPU/GPU. Returns a Python float.
    """
    assert A.ndim == 2, "A must be 2D (M x D)"
    M, D = A.shape
    device = A.device
    dtype  = A.dtype

    v = torch.randn(D, device=device, dtype=dtype)
    v = v / (v.norm() + 1e-12)

    for _ in range(iters):
        # v <- (A^T A) v / ||(A^T A) v||
        Av = A @ v               # (M,)
        v  = A.t() @ Av          # (D,)
        n  = v.norm()
        if n <= 1e-20:
            v = torch.randn(D, device=device, dtype=dtype)
            v = v / (v.norm() + 1e-12)
            continue
        v = v / n

    # Rayleigh quotient to get sigma_max^2
    Av = A @ v
    sigma = Av.norm()           # ||A v||_2
    L = float(sigma.item() ** 2)
    # safety margin to ensure step = 1/L is conservative
    return 1.1 * L

@torch.no_grad()
def fista_lasso_batch(
    A: torch.Tensor,      # (M, D)
    Y: torch.Tensor,      # (B, M) batched measurements
    lam: float,           # lambda in objective
    L: float = None,      # Lipschitz constant (||A||_2^2). If None, estimated.
    max_iter: int = 500,
    tol: float = 1e-5,
    verbose: bool = False,
    x_init: torch.Tensor = None,   # optional warm start (B, D)
):
    """
    Batched FISTA for min_X 0.5||A X^T - Y^T||_F^2 + lam ||X||_1, with rows of X independent.
    A: (M, D), Y: (B, M) -> returns X: (B, D)
    """
    assert A.ndim == 2 and Y.ndim == 2, "A is (M,D), Y is (B,M)"
    M, D = A.shape
    B, M2 = Y.shape
    assert M == M2, "Measurement dimension mismatch."

    device = A.device
    dtype  = A.dtype
    Y = Y.to(device=device, dtype=dtype)

    # Lipschitz constant and step size
    if L is None:
        L = estimate_lipschitz_squared(A)
    step = 1.0 / L

    # Initialize
    if x_init is None:
        X  = torch.zeros(B, D, device=device, dtype=dtype)
    else:
        X  = x_init.to(device=device, dtype=dtype).clone()
    Yk = X.clone()
    t  = 1.0

    prev_obj = torch.tensor(float("inf"), device=device, dtype=dtype)

    for k in range(max_iter):
        # Gradient at Yk: grad = (Yk A^T - Y) A  -> shapes: (B,M) @ (M,D) = (B,D)
        R = Yk @ A.t() - Y
        G = R @ A

        # Prox step
        X_next = soft_threshold(Yk - step * G, lam * step)
        X_next = X_next.clamp(0, 1)   # projection onto [0,1] each iteration

        # Nesterov momentum
        t_next = 0.5 * (1.0 + torch.sqrt(torch.tensor(1.0 + 4.0 * t * t, device=device, dtype=dtype)))
        Yk = X_next + ((t - 1.0) / t_next) * (X_next - X)

        # Periodic convergence check (objective)
        if (k % 10 == 0) or (k == max_iter - 1):
            AX = X_next @ A.t()
            res = AX - Y
            obj = 0.5 * (res.pow(2).sum(dim=1)).mean() + lam * (X_next.abs().sum(dim=1)).mean()
            if verbose:
                print(f"[FISTA] iter={k:04d}  obj={obj.item():.6f}")
            # relative improvement small?
            denom = torch.maximum(prev_obj, torch.tensor(1.0, device=device, dtype=dtype))
            if torch.isfinite(prev_obj) and torch.abs(prev_obj - obj) <= tol * denom:
                X = X_next
                break
            prev_obj = obj

        X, t = X_next, t_next

    return X  # (B, D)

# ==========================================================
# Reconstruction helpers for your compressed loaders
# ==========================================================
@torch.no_grad()
def reconstruct_from_image_loader(
    loader_img,           # yields (y_img, label) with y_img shape (B, 1, m, m)
    A: torch.Tensor,      # (M, D)
    img_shape,            # (C, H, W) of original images
    lam: float = 0.01,
    L: float = None,
    max_iter: int = 500,
    tol: float = 1e-5,
    clip01: bool = True,
    verbose: bool = False
):
    """
    Reconstruct from image-form compressed dataloader (tr/te_loader_c_img).
    We just flatten y_img -> y_vec and call the same FISTA.
    """
    C, H, W = img_shape
    A_dev = A
    device = A_dev.device
    dtype  = A_dev.dtype

    recons, labels = [], []
    for (y_img, lbl) in tqdm(loader_img):
        # y_img: (B, 1, m, m) -> (B, M)
        y_vec = y_img.view(y_img.size(0), -1).to(device=device, dtype=dtype)
        X_hat = fista_lasso_batch(A_dev, y_vec, lam=lam, L=L, max_iter=max_iter, tol=tol, verbose=verbose)
        X_hat = X_hat.view(-1, C, H, W)
        if clip01:
            X_hat = X_hat.clamp(0.0, 1.0)
        recons.append(X_hat.cpu())
        labels.append(lbl.detach().cpu())
    X_imgs = torch.cat(recons, dim=0)
    labels = torch.cat(labels, dim=0)
    return X_imgs, labels

In [None]:
# ==========================================================
# Example: end-to-end usage
# ==========================================================

# --- Assume you already have these from your compress & FISTA parts:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_shape = tuple(train_ds_raw[0][0].shape)         # (C, H, W)
M = SIDE_Y * SIDE_Y
A_dev = A_cpu.to(device)
L_hat = estimate_lipschitz_squared(A_dev)           # from your FISTA module

# ---------------------------
# 3) Train on compressed images (UNet)
# ---------------------------
m = SIDE_Y
img_model, img_sde, img_norm = train_score_img(
    tr_loader_c_img, te_loader_c_img,
    device=device, sde_type="VP",
    sigma_min=0.01, sigma_max=50.0,
    lr=4e-4, epochs=50, weight_type="sigma2", normalize=True
)

# 4) Sample compressed images and reconstruct
y_samp_img = sample_compressed_img(img_model, img_sde, img_norm, num_samples=64, m_side=m, device=device, steps=1000)
X_recon_img = reconstruct_from_compressed_images(y_samp_img, A_dev, img_shape, lam=0.01, L=L_hat, max_iter=700)

# Compare original and reconstructed images using MSE

In [None]:
import torch
import math
from typing import Dict, List, Tuple

# -----------------------------
# Metrics
# -----------------------------
@torch.no_grad()
def compute_metrics(gt: torch.Tensor, pred: torch.Tensor) -> Dict[str, float]:
    """
    gt, pred: (N, C, H, W) in [0,1]
    Returns mean metrics across N (also returns per-sample tensors for inspection).
    """
    assert gt.shape == pred.shape, f"Shape mismatch: {gt.shape} vs {pred.shape}"
    diff = pred - gt
    mse_per = (diff ** 2).mean(dim=(1,2,3))  # per-sample MSE
    mae_per = diff.abs().mean(dim=(1,2,3))
    # PSNR with peak=1.0 for [0,1] images
    psnr_per = 10.0 * torch.log10(torch.ones_like(mse_per) / (mse_per + 1e-12))

    return {
        "MSE_mean": float(mse_per.mean().item()),
        "MAE_mean": float(mae_per.mean().item()),
        "PSNR_mean": float(psnr_per.mean().item()),
        # Optionally expose quantiles for robustness
        "MSE_p95": float(mse_per.quantile(0.95).item()),
        "PSNR_p05": float(psnr_per.quantile(0.05).item()),
    }

@torch.no_grad()
def collect_ground_truth(loader) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    loader yields (x_img, label). Returns:
      X_gt: (N, C, H, W), labels: (N,)
    Assumes shuffle=False to preserve order.
    """
    xs, ys = [], []
    for x, y in loader:
        xs.append(x)
        ys.append(y)
    X_gt = torch.cat(xs, dim=0).contiguous()
    labels = torch.cat(ys, dim=0).contiguous()
    return X_gt, labels

# -----------------------------
# End-to-end evaluation helpers
# -----------------------------

@torch.no_grad()
def evaluate_reconstruction_image(
    te_loader_uncompressed,       # yields (img, label)
    te_loader_c_img,              # yields (y_img, label)
    A_dev,
    img_shape: Tuple[int,int,int],
    lam: float,
    L: float,
    max_iter: int = 500,
    tol: float = 1e-5,
    verbose: bool = False,
) -> Dict[str, float]:
    gt_imgs, gt_lbls = collect_ground_truth(te_loader_uncompressed)
    Xhat, y_lbls = reconstruct_from_image_loader(
        te_loader_c_img, A_dev, img_shape, lam=lam, L=L, max_iter=max_iter, tol=tol, verbose=verbose
    )
    if not torch.equal(gt_lbls, y_lbls):
        print("[warn] label misalignment detected (batch sizes or ordering differ). Metrics may be off.")
    return compute_metrics(gt_imgs, Xhat)

# -----------------------------
# Optional: quick λ sweep on a subset to pick a good starting point
# -----------------------------
@torch.no_grad()
def lambda_sweep(
    te_loader_uncompressed,
    te_loader_c_vec,     # or pass te_loader_c_img and swap the recon fn
    A_dev,
    img_shape,
    L: float,
    lam_list: List[float] = (0.001, 0.003, 0.01, 0.03, 0.1),
    max_iter: int = 500,
    max_batches: int = 5,   # keep it small for speed; set None for full eval
    use_image_loader: bool = False,
) -> List[Tuple[float, Dict[str, float]]]:
    # collect a small GT subset
    xs, ys = [], []
    for b, (x, y) in enumerate(te_loader_uncompressed):
        xs.append(x)
        ys.append(y)
        if (max_batches is not None) and (b + 1 >= max_batches):
            break
    X_gt = torch.cat(xs, dim=0)
    Y_gt = torch.cat(ys, dim=0)

    # collect a matching subset of compressed batches and reconstruct per λ
    results = []
    for lam in lam_list:
        recons = []
        seen = 0
        if use_image_loader:
            for b, (y_img, lbl) in enumerate(te_loader_c_img):
                y_vec = y_img.view(y_img.size(0), -1)
                X_hat = fista_lasso_batch(A_dev, y_vec.to(A_dev.device, A_dev.dtype),
                                          lam=lam, L=L, max_iter=max_iter)
                recons.append(X_hat.cpu())
                seen += y_vec.size(0)
                if (max_batches is not None) and (b + 1 >= max_batches):
                    break
        else:
            for b, (y_vec, lbl) in enumerate(te_loader_c_vec):
                X_hat = fista_lasso_batch(A_dev, y_vec.to(A_dev.device, A_dev.dtype),
                                          lam=lam, L=L, max_iter=max_iter)
                recons.append(X_hat.cpu())
                seen += y_vec.size(0)
                if (max_batches is not None) and (b + 1 >= max_batches):
                    break

        X_hat = torch.cat(recons, dim=0).view(seen, *img_shape).clamp(0,1)
        metrics = compute_metrics(X_gt[:seen], X_hat)
        results.append((lam, metrics))
        print(f"[λ={lam:.4g}] MSE={metrics['MSE_mean']:.5f}  PSNR={metrics['PSNR_mean']:.3f} dB")
    return results

In [None]:
img_shape = tuple(train_ds_raw[0][0].shape)  # (C,H,W)
A_dev = A_cpu.to(device)
L_hat = estimate_lipschitz_squared(A_dev)    # from your FISTA module

# Image-form compressed test set
metrics_img = evaluate_reconstruction_image(
    te_loader, te_loader_c_img, A_dev, img_shape, lam=0.003, L=L_hat, max_iter=500
)
print("Image recon:", metrics_img)

# Optional: pick a λ quickly
_ = lambda_sweep(
    te_loader, te_loader_c_img, A_dev, img_shape, L=L_hat,
    lam_list=[0.001, 0.003, 0.01, 0.03, 0.1], max_iter=500, max_batches=100, use_image_loader=True
)

In [None]:
import torch
import matplotlib.pyplot as plt

# ---------- helpers to collect the first N items from loaders ----------
@torch.no_grad()
def _collect_first_n_uncompressed(loader, N):
    xs, ys = [], []
    for x, y in loader:
        xs.append(x)
        ys.append(y)
        if sum(t.size(0) for t in xs) >= N:
            break
    X = torch.cat(xs, dim=0)[:N].contiguous()   # (N, C, H, W)
    Y = torch.cat(ys, dim=0)[:N].contiguous()   # (N,)
    return X, Y

@torch.no_grad()
def _collect_first_n_img(loader_img, N):
    ys, lbls = [], []
    for y_img, lbl in loader_img:
        ys.append(y_img)
        lbls.append(lbl)
        if sum(t.size(0) for t in ys) >= N:
            break
    Y = torch.cat(ys, dim=0)[:N].contiguous()       # (N, 1, m, m)
    L = torch.cat(lbls, dim=0)[:N].contiguous()     # (N,)
    return Y, L

# ---------- FISTA decode wrappers for N samples ----------
@torch.no_grad()
def _reconstruct_from_vectors(Y_vec, A, img_shape, lam=0.01, L=None, max_iter=500, tol=1e-5, clamp01=True):
    # Y_vec: (N, M); A: (M, D)
    device = A.device
    dtype  = A.dtype
    C, H, W = img_shape
    X_hat = fista_lasso_batch(A, Y_vec.to(device=device, dtype=dtype),
                              lam=lam, L=L, max_iter=max_iter, tol=tol)  # (N, D)
    X_hat = X_hat.view(-1, C, H, W)
    if clamp01:
        X_hat = X_hat.clamp(0, 1)
    return X_hat.cpu()

@torch.no_grad()
def _reconstruct_from_images(Y_img, A, img_shape, lam=0.01, L=None, max_iter=500, tol=1e-5, clamp01=True):
    # Y_img: (N, 1, m, m) -> flatten to (N, M) then decode
    Y_vec = Y_img.view(Y_img.size(0), -1)
    return _reconstruct_from_vectors(Y_vec, A, img_shape, lam, L, max_iter, tol, clamp01)

# ---------- plotting helper ----------
def _show_grid(images, title, nrows=4, ncols=4):
    """
    images: (N, C, H, W) in [0,1]
    """
    N, C, H, W = images.shape
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*2.2, nrows*2.2))
    fig.suptitle(title, fontsize=14)
    for i in range(nrows * ncols):
        ax = axes[i // ncols, i % ncols]
        ax.axis('off')
        if i >= N:
            continue
        im = images[i]
        if C == 1:
            ax.imshow(im[0].cpu().numpy(), cmap='gray', vmin=0.0, vmax=1.0)
        else:
            ax.imshow(im.permute(1, 2, 0).cpu().numpy())
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# ---------- main: collect, reconstruct, visualize ----------
@torch.no_grad()
def visualize_original_vs_reconstructed(
    te_loader, te_loader_c_img,
    A_dev, L_hat, img_shape,
    lam=0.01, max_iter=500, N=16
):
    # 1) originals
    X_gt, lbl_gt = _collect_first_n_uncompressed(te_loader, N)  # (N,C,H,W)

    # 3) recon from image-form compressed
    Y_img, lbl_img = _collect_first_n_img(te_loader_c_img, N)   # (N,1,m,m)
    X_img = _reconstruct_from_images(Y_img, A_dev, img_shape, lam=lam, L=L_hat, max_iter=max_iter)

    # --- quick alignment sanity check ---
    if not (torch.equal(lbl_gt, lbl_img)):
        print("[warn] labels not perfectly aligned across loaders; "
              "ensure all test loaders are shuffle=False and built from the same base dataset.")

    # 4) visualize
    _show_grid(X_gt,  title="Original test images (first 16)")
    _show_grid(X_img, title="Reconstructed from IMAGE compressed measurements (FISTA)")

# ----------------- run it -----------------
# You can tweak lam/max_iter/N as needed
visualize_original_vs_reconstructed(
    te_loader=te_loader,
    te_loader_c_img=te_loader_c_img,
    A_dev=A_dev,
    L_hat=L_hat,
    img_shape=img_shape,
    lam=0.003,
    max_iter=500,
    N=16
)

In [None]:
def _to_images(x: torch.Tensor, img_shape):
    """Ensure x is (N,C,H,W). If x is (N,D), reshape using img_shape."""
    if x.dim() == 2:
        C, H, W = img_shape
        x = x.view(x.size(0), C, H, W)
    return x

def _show_grid(images: torch.Tensor, title: str, n: int = 16, ncols: int = 4):
    images = images.cpu()
    n = min(n, images.size(0))
    nrows = (n + ncols - 1) // ncols

    fig, axes = plt.subplots(nrows, ncols, figsize=(2.2*ncols, 2.2*nrows))
    axes = axes.flatten() if isinstance(axes, (list, tuple)) else axes.ravel()

    for i in range(n):
        ax = axes[i]; ax.axis('off')
        img = images[i]
        if img.size(0) == 1:   # grayscale
            ax.imshow(img[0].numpy(), cmap='gray', vmin=0.0, vmax=1.0)
        else:                  # RGB
            ax.imshow(img.permute(1,2,0).numpy())
    for j in range(n, len(axes)):
        axes[j].axis('off')

    fig.suptitle(title, fontsize=12)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# ---- Use it ----
# img_shape must be your original image shape, e.g. (1, 32, 32) or (1, 28, 28)
# 4) Sample compressed images and reconstruct
X_recon_img_imgs = _to_images(X_recon_img, img_shape)  # should already be (N,C,H,W)

_show_grid(X_recon_img_imgs, "Reconstruction from IMAGE measurements (FISTA)",  n=36, ncols=6)

In [None]:
@torch.no_grad()
def sample_compressed_img(
    model, sde, normalizer, num_samples, m_side, device,
    steps=1000, snr=0.15, use_pc=True
):
    """Return (N,1,side,side) in compressed or image space """
    model.eval()
    x = sde.prior_sample((num_samples, 1, m_side, m_side), device=device)  # t=1

    t_grid = torch.linspace(1.0, 0.0, steps+1, device=device)
    for i in tqdm(range(steps)):
        t, t_next = t_grid[i], t_grid[i+1]
        dt = (t_next - t)  # 负数
        tb = torch.full((num_samples,), float(t.item()), device=device)

        s = model(x, tb)  # score(x,t)

        beta = sde.g2(tb).view(-1,1,1,1)          # = β(t)
        drift = -0.5 * beta * x - beta * s
        diff  = torch.sqrt(torch.clamp(beta * (-dt), min=1e-12))

        # Predictor: Euler–Maruyama
        z = torch.randn_like(x)
        x = x + drift * dt + diff * z

        if use_pc:
            # Corrector: Langevin steps（
            with torch.no_grad():
                noise_std = diff.mean()
                score_norm = torch.sqrt((s**2).mean(dim=(1,2,3), keepdim=True) + 1e-12)
                step_size = (snr * noise_std / (score_norm + 1e-12))**2
                for _ in range(1):
                    zc = torch.randn_like(x)
                    x = x + step_size * s + torch.sqrt(2.0 * step_size) * zc

    if normalizer.mean is not None:
        x = normalizer.denormalize(x)
    return x.clamp(-1, 1)

y_samp_img = sample_compressed_img(img_model, img_sde, img_norm, num_samples=64, m_side=m, device=device, steps=1000)
X_recon_img = reconstruct_from_compressed_images(y_samp_img, A_dev, img_shape, lam=0.05, L=L_hat, max_iter=1000)
X_recon_img_imgs = _to_images(X_recon_img, img_shape)  # should already be (N,C,H,W)
_show_grid(X_recon_img_imgs, "Reconstruction from IMAGE measurements (FISTA)",  n=49, ncols=7)

In [None]:
@torch.no_grad()
def sample_compressed_img_pf_heun(
    model, sde, normalizer, num_samples, m_side, device, steps=1200
):
    model.eval()
    x = sde.prior_sample((num_samples,1,m_side,m_side), device=device)
    s = torch.linspace(0, 1, steps+1, device=device)
    t_grid = (s**2).flip(0) * (1 - 1e-5) + 1e-5

    for i in tqdm(range(steps)):
        t1, t2 = t_grid[i], t_grid[i+1]
        dt = t2 - t1

        tb1  = torch.full((num_samples,), float(t1.item()), device=device)
        beta1 = sde.g2(tb1).view(-1,1,1,1)
        s1    = model(x, tb1)

        sigma1 = sde.sigma(tb1).view(-1,1,1,1)
        s1n = torch.sqrt((s1*s1).mean(dim=(1,2,3), keepdim=True) + 1e-12)
        maxn = 5.0 / (sigma1 + 1e-12)
        s1   = s1 * torch.minimum(torch.ones_like(s1n), maxn/s1n)

        f1 = -0.5 * beta1 * (x + s1)
        x_e = x + f1 * dt

        tb2   = torch.full((num_samples,), float(t2.item()), device=device)
        beta2 = sde.g2(tb2).view(-1,1,1,1)
        s2    = model(x_e, tb2)
        f2 = -0.5 * beta2 * (x_e + s2)

        x = x + 0.5 * (f1 + f2) * dt

    if normalizer.mean is not None:
        x = normalizer.denormalize(x)
    return x.clamp(-1, 1)

y_samp_img = sample_compressed_img_pf_heun(img_model, img_sde, img_norm, num_samples=64, m_side=m, device=device, steps=1000)
X_recon_img = reconstruct_from_compressed_images(y_samp_img, A_dev, img_shape, lam=0.05, L=L_hat, max_iter=700)
X_recon_img_imgs = _to_images(X_recon_img, img_shape)  # should already be (N,C,H,W)
_show_grid(X_recon_img_imgs, "Reconstruction from IMAGE measurements (FISTA)",  n=49, ncols=7)