# PyTorch tutorial 4 — training a Zernike estimator from AO data

Now that we can generate AO data, we can train a **larger model** to predict **many Zernike modes** (e.g. 50–200) from a WFS- or PSF-like input.

We will:
1. Define the **model family** (CNN encoder, shallow UNet, or Conv+MLP) for modal regression.
2. Choose a **target representation** (Zernike ordered vector, NCPA modes, or DM modal basis).
3. Train with a **multi-output regression loss** (typically `MSELoss`) and explain why it is still **a scalar**.
4. Add **metrics** beyond the loss: RMS phase error, per-mode MSE, and % energy on the first modes.
5. Compare **optimizers** (Adam vs AdamW) and **schedulers** (StepLR / Cosine) for faster convergence.
6. Save the **trained checkpoint** and run **inference** on new turbulence realizations.

By the end of this notebook you will have a **trained Zernike regressor** that takes AO-like data and outputs a vector of modal coefficients.


In [None]:
import math
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt

from Functions.mvm import MVM
from Functions.fourier_masks import genOTF_PWFS4
from Functions.pupils import CreateTelescopePupil
from Functions.vonkarman_model_newv3 import VonKarmanPhaseScreenGenerator 

In [None]:
N = 128  # Grid size
D = 3.0  # Telescope diameter in meters
size_pupils_pixels = 64  # Pupil size in pixels
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def pwfs_propagation(pupil_amp, pupil_phase, pyr_mask, N):

    dh, dw = N*4 - pupil_amp.shape[-2], N*4 - pupil_amp.shape[-1]
    pupil_padded = F.pad(pupil_amp, (dw//2, dw - dw//2, dh//2, dh - dh//2))
    phase_padded = F.pad(pupil_phase, (dw//2, dw - dw//2, dh//2, dh - dh//2))
    
    u_pupil = pupil_padded * torch.exp(1j * phase_padded)
    U_foc   = torch.fft.fft2(u_pupil)
    U_foc   = torch.fft.fftshift(U_foc, dim=(-2,-1))
    H_pyr   = pyr_mask[0,0]
    U_pyr   = U_foc * H_pyr
    U_pyr   = torch.fft.ifftshift(U_pyr, dim=(-2,-1))
    u_det   = torch.fft.ifft2(U_pyr)
    I_det   = u_det.real**2 + u_det.imag**2
    return I_det

def make_circular_pupil(N, D, device):
    y = torch.arange(N, device=device) - N//2
    x = torch.arange(N, device=device) - N//2
    Y, X = torch.meshgrid(y, x, indexing='ij')
    R = torch.sqrt(X**2 + Y**2)
    return (R <= (D/2)).float()

In [None]:
N = 256  # Grid size
D = 3.0  # Telescope diameter in meters
size_pupils_pixels = 256  # Pupil size in pixels

# pupil = make_circular_pupil(N, size_pupils_pixels, device=device)
pupil = CreateTelescopePupil(N, "disc", device=device)
pyr_mask = genOTF_PWFS4(N_fourier_points=N*4, N_points_aperture=size_pupils_pixels, separation=1.1, device=device)
phi = torch.zeros((1,1,N,N), device=device)

piston_pwfs = pwfs_propagation(pupil, phi, pyr_mask, N).clamp_min(0.0)

npix = piston_pwfs.shape[-1] // 2
k = int(1.15 * size_pupils_pixels)  
k = min(k, npix)                              
crop_pwfs = piston_pwfs[..., npix-k:npix+k, npix-k:npix+k]
crop_pwfs = F.interpolate(crop_pwfs, size=(N, N), mode='bilinear', align_corners=False)

plt.figure(figsize=(16,4))
plt.subplot(1,4,1)
plt.imshow(pupil.squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("Pupil amplitude")
plt.subplot(1,4,2)
plt.imshow(torch.angle(pyr_mask).squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("PWFS Pyramid phase mask")
plt.subplot(1,4,3)
plt.imshow(piston_pwfs.squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("PWFS response to piston aberration")
plt.subplot(1,4,4)
plt.imshow(crop_pwfs.squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("Cropped PWFS response")
plt.show()

In [None]:
N = 256  # Grid size
D = 3.0  # Telescope diameter in meters
size_pupils_pixels = 256  # Pupil size in pixels

# pupil = make_circular_pupil(N, size_pupils_pixels, device=device)
pupil = CreateTelescopePupil(N, "disc", device=device)
pyr_mask = genOTF_PWFS4(N_fourier_points=N*4, N_points_aperture=size_pupils_pixels, separation=1.1, device=device)

batch_size = 1
r0 = 0.1   # Fried parameter in meters
L0 = 20.0   # Outer scale in meters
l0 = 0.001  # Inner scale in meters

vk_rand = VonKarmanPhaseScreenGenerator(
    N=N, D_tel=D, r0=r0, L0=L0, l0=l0,
    pupil_mask=pupil, device=device, batch_size=batch_size
)

phase = vk_rand.generate_total_phase().detach()

phase_pwfs = pwfs_propagation(pupil, phase, pyr_mask,N).clamp_min(0.0)

npix = phase_pwfs.shape[-1] // 2
k = int(1.15 * size_pupils_pixels)  
k = min(k, npix)                              
crop_pwfs = phase_pwfs[..., npix-k:npix+k, npix-k:npix+k]
crop_pwfs = F.interpolate(crop_pwfs, size=(N, N), mode='bilinear', align_corners=False)

plt.figure(figsize=(16,4))
plt.subplot(1,4,1)
plt.imshow(phase.squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("Pupil amplitude")
plt.subplot(1,4,2)
plt.imshow(torch.angle(pyr_mask).squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("PWFS Pyramid phase mask")
plt.subplot(1,4,3)
plt.imshow(phase_pwfs.squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("PWFS response to piston aberration")
plt.subplot(1,4,4)
plt.imshow(crop_pwfs.squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("Cropped PWFS response")
plt.show()

In [None]:
def add_noise(x, rn=2.0, photons=1e4, scale=1.0, use_poisson=True):
    y = x.real.float().clamp(min=0)*scale
    if use_poisson: y = torch.poisson((y*photons).clamp(min=0))/photons
    return (y + torch.randn_like(y)*rn).clamp_(min=0)

def normalize(x, norm='minmax'):
    if norm == 'minmax':
        x_min = x.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]
        x_max = x.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]
        return (x - x_min) / (x_max - x_min + 1e-10)
    elif norm == 'zscore':
        x_mean = x.mean(dim=(-1,-2), keepdim=True)
        x_std  = x.std(dim=(-1,-2), keepdim=True)
        return (x - x_mean) / (x_std + 1e-10)
    else:
        raise ValueError("Unsupported normalization method")

I_noisy = add_noise(crop_pwfs, rn=0.2, photons=10e3, scale=2.0, use_poisson=True)
I_norm = normalize(I_noisy, norm='zscore')

plt.figure(figsize=(16,4))
plt.subplot(1,3,1)
plt.imshow(crop_pwfs.squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("PWFS intensity (clean)")
plt.subplot(1,3,2)
plt.imshow(I_noisy.squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("PWFS intensity (noisy)")
plt.subplot(1,3,3)
plt.imshow(I_norm.squeeze().detach().cpu().numpy(), cmap='viridis')
plt.colorbar()
plt.title("PWFS intensity (normalized)")

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
import random

def genOTF_PWFS4(N_fourier_points=128, N_points_aperture=64, separation=1.0, device='cpu'):
    """OTF/transferencia de un PWFS 4 caras no modulado (modelo simple en frecuencia)."""
    beta = separation * math.pi / (N_fourier_points / N_points_aperture)
    x = torch.arange(N_fourier_points, device=device, dtype=torch.float32)
    center = N_fourier_points // 2
    fx = (x - center) * (2.0 * center / N_fourier_points)
    FX, FY = torch.meshgrid(fx, fx, indexing='ij')
    H = lambda t: torch.heaviside(t, torch.tensor(0.5, device=device))
    pym = torch.zeros_like(FX, dtype=torch.complex64)
    for sx, sy in ((1,1),(1,-1),(-1,-1),(-1,1)):
        mask  = H(sx*FX) * H(sy*FY)
        phase = -beta * (sx*FX + sy*FY)
        pym   = pym + mask * torch.exp(1j * phase)
    return torch.fft.fftshift(pym).unsqueeze(0).unsqueeze(0)

def pwfs_forward(pupil_amp, pupil_phase, pyr_mask):
    """Propagación PWFS simple: Pupil -> Foco -> Prisma -> Detector (intensidad)."""
    u_pupil = pupil_amp * torch.exp(1j * pupil_phase)
    U_foc   = torch.fft.fft2(u_pupil)
    U_foc   = torch.fft.fftshift(U_foc, dim=(-2,-1))
    H_pyr   = pyr_mask[0,0]
    U_pyr   = U_foc * H_pyr
    U_pyr   = torch.fft.ifftshift(U_pyr, dim=(-2,-1))
    u_det   = torch.fft.ifft2(U_pyr)
    I_det   = u_det.real**2 + u_det.imag**2
    return I_det

def make_circular_pupil(N, D, device):
    y = torch.arange(N, device=device) - N//2
    x = torch.arange(N, device=device) - N//2
    Y, X = torch.meshgrid(y, x, indexing='ij')
    R = torch.sqrt(X**2 + Y**2)
    return (R <= (D/2)).float()

def add_noise(x, rn=2.0, photons=1e4, scale=1.0, use_poisson=True):
    y = x.real.float().clamp(min=0)*scale
    if use_poisson: y = torch.poisson((y*photons).clamp(min=0))/photons
    return (y + torch.randn_like(y)*rn).clamp_(min=0)

In [None]:
def noll_to_nm(j):
    # Referencia estándar: Noll 1976
    n = 0
    j1 = j - 1
    while j1 >= n + 1:
        n += 1
        j1 -= n
    m = (-n + 2*j1)
    return n, m

def zernike_radial(n, m, r):
    m = abs(m)
    R = torch.zeros_like(r)
    for k in range((n - m)//2 + 1):
        c = ((-1)**k) * math.comb(n - k, k) * math.comb(n - 2*k, (n - m)//2 - k)
        R = R + c * r**(n - 2*k)
    return R

def zernike(n, m, rho, theta):
    R = zernike_radial(n, m, rho)
    if m > 0:
        Z = R * torch.cos(m * theta)
    elif m < 0:
        Z = R * torch.sin(-m * theta)
    else:
        Z = R
    return Z

def build_zernike_stack(N, D, device, num_modes, exclude_piston=True):
    yy = torch.arange(N, device=device) - N//2
    xx = torch.arange(N, device=device) - N//2
    Y, X = torch.meshgrid(yy, xx, indexing='ij')
    r = torch.sqrt(X**2 + Y**2) / (D/2)          # radio normalizado al disco unidad
    th = torch.atan2(Y, X)
    pupil = (r <= 1.0).float()

    # Elegimos índices de Noll empezando en j=1, pero quizá saltamos j=1 (pistón)
    modes = []
    j = 1
    while len(modes) < (num_modes + (1 if exclude_piston else 0)):
        n, m = noll_to_nm(j)
        modes.append((j, n, m))
        j += 1

    if exclude_piston:
        modes = [t for t in modes if t[0] != 1]

    modes = modes[:num_modes]

    Zs = []
    for j, n, m in modes:
        Z = zernike(n, m, torch.clamp(r, 0, 1), th) * pupil
        # Normaliza a var ~1 dentro de pupila para que las amplitudes sean comparables
        Z_in = Z[pupil.bool()]
        std = Z_in.std() if Z_in.numel() > 0 else torch.tensor(1.0, device=device)
        Zs.append(Z / (std + 1e-8))
    Z_stack = torch.stack(Zs, dim=0)   # [M, N, N]
    return Z_stack, pupil, modes

In [None]:
class PWFS_Zernike_Dataset(Dataset):
    def __init__(self,
                 n=2000,
                 img_size=128,
                 pupil_diam=64,
                 separation=1.0,
                 num_modes=10,
                 coeff_scale=0.2,
                 device='cpu',
                 seed=0,
                 exclude_piston=True):
        self.n = n
        self.N = img_size
        self.D = pupil_diam
        self.device = torch.device(device)
        self.num_modes = num_modes
        self.coeff_scale = coeff_scale
        random.seed(seed); torch.manual_seed(seed)

        self.otf_pyr = genOTF_PWFS4(self.N, self.D, separation, device=self.device)
        self.Z_stack, self.pupil, self.modes = build_zernike_stack(
            self.N, self.D, self.device, num_modes=num_modes, exclude_piston=exclude_piston
        )

    def __len__(self): return self.n

    def __getitem__(self, idx):
        c = (torch.rand(self.num_modes, device=self.device)*2 - 1) * self.coeff_scale  # [M]
        phi = (c.view(-1,1,1) * self.Z_stack).sum(dim=0)   # [N,N]
        I   = pwfs_forward(self.pupil, phi, self.otf_pyr)
        m, s = I.mean(), I.std()
        I = (I - m) / (s + 1e-6)
        x = I.unsqueeze(0).float().cpu()          # [1,N,N]
        y = c.float().cpu()                       # [M]
        return x, y

In [None]:
class TinyPWFSNet(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(1, 16, 5, padding=2), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64, 64), nn.ReLU(),
            nn.Linear(64, out_dim)
        )
    def forward(self, x): 
        return self.head(self.backbone(x))

In [None]:
def rmse(a,b): 
    return torch.sqrt(((a-b)**2).mean())

@torch.no_grad()
def eval_epoch(model, loader, device, loss_fn):
    model.eval(); total=0.0; total_rmse=0.0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        pred = model(x)
        loss = loss_fn(pred, y)
        total += loss.item() * y.size(0)
        total_rmse += rmse(pred, y).item() * y.size(0)
    n = len(loader.dataset)
    return total/n, total_rmse/n

def train_epoch(model, loader, device, opt, loss_fn):
    model.train(); total=0.0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        pred = model(x)
        loss = loss_fn(pred, y)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        total += loss.item() * y.size(0)
    return total / len(loader.dataset)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

NUM_MODES    = 20 #Maximun
COEFF_SCALE  = 0.20      
IMG          = 128
PUPIL_PIX    = 60
BATCH_TR     = 32
BATCH_VAL    = 64
EPOCHS       = 50
LR           = 1e-3

In [None]:
train_ds = PWFS_Zernike_Dataset(n=4000, img_size=IMG, pupil_diam=PUPIL_PIX,
                                separation=1.1, num_modes=NUM_MODES,
                                coeff_scale=COEFF_SCALE, device=device, seed=0)

val_ds   = PWFS_Zernike_Dataset(n=800,  img_size=IMG, pupil_diam=PUPIL_PIX,
                                separation=1.1, num_modes=NUM_MODES,
                                coeff_scale=COEFF_SCALE, device=device, seed=1)

In [None]:
train_dl = DataLoader(train_ds, batch_size=BATCH_TR, shuffle=True, num_workers=0)
val_dl   = DataLoader(val_ds,   batch_size=BATCH_VAL, shuffle=False, num_workers=0)

model = TinyPWFSNet(out_dim=NUM_MODES).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
loss_fn = nn.MSELoss()

In [None]:
train_curve, val_curve, val_rmse = [], [], []

for ep in range(1, EPOCHS + 1):
    tr = train_epoch(model, train_dl, device, opt, loss_fn)
    va, va_r = eval_epoch(model, val_dl, device, loss_fn)
    train_curve.append(tr); val_curve.append(va); val_rmse.append(va_r)
    if ep % 5 == 0 or ep == 1:
        print(f'Epoch {ep:02d} | train={tr:.4f} | val={va:.4f} | val RMSE={va_r:.4f}')

plt.figure(); plt.plot(train_curve, label='train'); plt.plot(val_curve, label='val')
plt.legend(); plt.title(f'PWFS — loss (M={NUM_MODES})'); plt.grid(True); plt.show()

plt.figure(); plt.plot(val_rmse, marker='o')
plt.title(f'PWFS — val RMSE (M={NUM_MODES})'); plt.grid(True); plt.show()


@torch.no_grad()
def show_preds(model, loader, k=3):
    model.eval(); shown=0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        x = add_noise(x, rn=0.2, photons=10e3, scale=2.0, use_poisson=True)
        p = model(x)
        for i in range(min(k-shown, x.size(0))):
            img = x[i,0].cpu().numpy()
            plt.figure(); plt.imshow(img, cmap='viridis'); plt.axis('off')
            Mshow = min(4, y.shape[-1])
            gt_str = ", ".join([f"{y[i,j].item():.2f}" for j in range(Mshow)])
            pr_str = ", ".join([f"{p[i,j].item():.2f}" for j in range(Mshow)])
            plt.title(f'GT[0:{Mshow}]={gt_str} | Pred[0:{Mshow}]={pr_str}')
            plt.show()
        shown += min(k-shown, x.size(0))
        if shown >= k: break

show_preds(model, val_dl, k=3)
