# PyTorch tutorial 3 — Atmospheric data generation for AO (with 3 animations)

In this notebook we will:
1. **Generate random atmospheric phase screens** (Von Kármán).
2. **Animate** three situations:
   - i.i.d. **random** phase screens,
   - a **dynamic** (frozen‑flow) single layer,
   - a **multi‑layer** atmosphere.
3. **Decompose** one phase into **Zernike modes** and **reconstruct** it to compare.
4. Wrap everything into a **PyTorch `Dataset`** that returns `(input, zernike_coeffs)` and is ready for training.

You said you will ship a **folder with all the functions**, so here we **keep the animation cells** and **keep the original generators**, only ordering them so that everything is explained before it is used.


# Google Colab

In [None]:
!git clone https://github.com/PUCVOptolab/MACHINE_LEARNING_FOR_ADAPTIVE_OPTICS_WORKSHOP.git
%cd MACHINE_LEARNING_FOR_ADAPTIVE_OPTICS_WORKSHOP

## Learning objectives

By the end of this notebook you should be able to:

1. Explain what the Von Kármán model is and what `r0`, `L0`, `l0` do.
2. Generate and **visualize** atmospheric phase screens.
3. Animate **random**, **dynamic**, and **multi‑layer** atmospheres.
4. Build a **Zernike basis** on a pupil and project a phase onto it.
5. Compare a **true** phase to its **Zernike reconstruction**.
6. Build an **AO dataset** `(input, zernikes)` for training a neural network.


In [None]:
import torch, math, numpy as np
from IPython.display import HTML
import matplotlib.pyplot as plt
from matplotlib import animation


device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)

## 1. From toy inputs to **Von Kármán** turbulence

Real AO systems must cope with **2D phase screens** produced by the atmosphere. A common model is **Von Kármán**, where the phase $(\phi(x,y))$ has a power spectral density (PSD) controlled by:

- **`r0`** (Fried parameter): turbulence strength (↑ `r0` ⇒ better seeing).
- **`L0`** (outer scale): controls energy at **large** spatial scales.
- **`l0`** (inner scale): damps the **high‑frequency** end.

We will use three generators:
1) a **random / i.i.d.** screen source (each frame is independent),  
2) a **dynamic** infinite generator (frozen‑flow) with **wind**,  
3) a **multi‑layer** model that **sums** several dynamic layers.

In [None]:
from Functions.pupils import CreateTelescopePupil
from Functions.vonkarman_model_newv3 import VonKarmanPhaseScreenGenerator  # adjust path as needed
from Functions.vonkarman_model_newv3 import InfiniteVonKarmanPhaseScreenGenerator
from Functions.vonkarman_model_newv3 import MultiLayerPhaseScreen

N = 128  # Grid size
D = 3.0  # Telescope diameter in meters
pupil = CreateTelescopePupil(N, "disc", device='cuda')


## 2.1 Random / i.i.d. atmospheric screens

**What it shows:** every frame is a **new** independent Von Kármán realization.  
**Take‑home:** there is **no temporal correlation** → the pattern **jumps** between frames.  
**Use:** quick statistical demos, data augmentation. Not realistic as a time series.

In [None]:
batch_size = 10

vk_rand = VonKarmanPhaseScreenGenerator(
    N=N, D_tel=D, r0=0.15, L0=25.0, l0=0.01,
    pupil_mask=pupil, device=device, batch_size=batch_size
)

phase = vk_rand.generate_total_phase().detach().cpu()

frames = []         # Save frames here
labels = []         # Optional: text for each frame

for i in range(batch_size):
    phase_ran = phase[i,:,:,:]           # (B, 1, N, N)
    phase_t = phase_ran.squeeze()     # -> (N, N) si B=1
    arr = phase_t.detach().cpu().numpy()

    # Normalization for better visualization
    arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)

    # Apply pupil mask if needed
    arr = arr * pupil.cpu().squeeze().numpy()

    frames.append(arr)
    labels.append(f"Frame {i+1}")

fig, ax = plt.subplots(figsize=(5,5), dpi=100)
ax.set_axis_off()
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)

# imagen inicial
im = ax.imshow(frames[0], cmap='viridis', animated=True)
txt = ax.text(5, 10, labels[0], color='white', fontsize=10)

fps = 5  # 10 frames per second

def update(k):
    im.set_array(frames[k])
    txt.set_text(labels[k])
    return im, txt

movie = animation.FuncAnimation(
    fig,
    update,
    frames=len(frames),
    interval=1000/fps,   # ms between frames
    blit=True
)
plt.close(fig)
HTML(movie.to_jshtml())


## 2.2 Dynamic (frozen‑flow) single layer

**What it shows:** a **time‑evolving** layer that **moves** with a wind vector.  
**Take‑home:** consecutive frames are **correlated** → the pattern **drifts**.  
**Parameters:** `wind_speed` (m/s), `wind_dir_deg`, `fps`.

In [None]:
dyn = InfiniteVonKarmanPhaseScreenGenerator(
    N=N, D_tel=D,
    r0=0.15, L0=25.0,
    pupil_mask=pupil,
    device=device,
    wind_speed=12,
    wind_dir_deg=30.0,
    fps=500
)

steps = 100          # Frames in the animation
frames = []         # Save frames here
labels = []         # Optional: text for each frame

for i in range(steps):
    phase = dyn.evolve()           # (B, 1, N, N)
    phase_t = phase.squeeze()     # -> (N, N) si B=1
    arr = phase_t.detach().cpu().numpy()

    # Normalization for better visualization
    arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)

    # Apply pupil mask if needed
    arr = arr * pupil.cpu().squeeze().numpy()

    frames.append(arr)
    labels.append(f"Frame {i+1}")

fig, ax = plt.subplots(figsize=(5,5), dpi=100)
ax.set_axis_off()
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)

# imagen inicial
im = ax.imshow(frames[0], cmap='viridis', animated=True)
txt = ax.text(5, 10, labels[0], color='white', fontsize=10)

fps = 10  # 10 frames per second

def update(k):
    im.set_array(frames[k])
    txt.set_text(labels[k])
    return im, txt

movie = animation.FuncAnimation(
    fig,
    update,
    frames=len(frames),
    interval=1000/fps,   # ms between frames
    blit=True
)
plt.close(fig)
HTML(movie.to_jshtml())


## 2.3 Multi‑layer atmosphere

**What it shows:** several layers, each with its own **strength**, **wind speed/direction**, and optional **weight** (`Cn2`).  
**Take‑home:** the phase has features at multiple scales drifting in different directions → **more realistic** input for the network.

In [None]:
layers = [
    dict(r0=0.2, L0=25.0, l0=0.01, wind_speed=15, wind_dir_deg=135, Cn2=0.7),
    dict(r0=0.2, L0=25.0, l0=0.01, wind_speed=18, wind_dir_deg=45.0, Cn2=0.2),
    dict(r0=0.2, L0=25.0, l0=0.01, wind_speed=20, wind_dir_deg=15.0, Cn2=0.1)
]

ml = MultiLayerPhaseScreen(
    layers=layers,
    N=N,
    D_tel=D,
    device=device,
    pupil_mask=pupil,
)

steps = 100          # Frames in the animation
frames = []         # Save frames here
labels = []         # Optional: text for each frame

for i in range(steps):
    phase = ml.evolve()           # (B, 1, N, N)
    phase_t = phase.squeeze()     # -> (N, N) si B=1
    arr = phase_t.detach().cpu().numpy()

    # Normalization for better visualization
    arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)

    # Apply pupil mask if needed
    arr = arr * pupil.cpu().squeeze().numpy()

    frames.append(arr)
    labels.append(f"Frame {i+1}")

fig, ax = plt.subplots(figsize=(5,5), dpi=100)
ax.set_axis_off()
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)

# imagen inicial
im = ax.imshow(frames[0], cmap='viridis', animated=True)
txt = ax.text(5, 10, labels[0], color='white', fontsize=10)

fps = 10  # 10 frames per second

def update(k):
    im.set_array(frames[k])
    txt.set_text(labels[k])
    return im, txt

movie = animation.FuncAnimation(
    fig,
    update,
    frames=len(frames),
    interval=1000/fps,   # ms between frames
    blit=True
)
plt.close(fig)
HTML(movie.to_jshtml())


## 3. Zernike decomposition and reconstruction

**Why Zernike?** We want to convert a 2D phase map (e.g., 128×128) into a **compact target vector** for learning. Zernikes are standard in optics/AO and allow reconstruction and **residual visualization**.

**Pipeline:**  
1) **Build** the Zernike basis on the same grid/pupil.  
2) **Project**: $( c_k = \langle \phi, Z_k \rangle )$.  
3) **Reconstruct**: $( \hat{\phi} = \sum_k c_k Z_k )$.  
4) **Compare** original vs reconstructed vs residual.


In [None]:
from Functions.mvm import MVM

def genOTF_PWFS4(N_fourier_points=128, N_points_aperture=64, separation=2.0, device='cpu'):
    beta = separation * math.pi / (N_fourier_points / N_points_aperture)
    x = torch.arange(N_fourier_points, device=device, dtype=torch.float32)
    center = N_fourier_points // 2
    fx = (x - center) * (2.0 * center / N_fourier_points)
    FX, FY = torch.meshgrid(fx, fx, indexing='ij')
    H = lambda t: torch.heaviside(t, torch.tensor(0.5, device=device))
    pym = 0j * FX
    for sx, sy in ((1,1),(1,-1),(-1,-1),(-1,1)):
        mask  = H(sx*FX) * H(sy*FY)
        phase = -beta * (sx*FX + sy*FY)
        pym   = pym + mask * torch.exp(1j * phase)
    return torch.fft.fftshift(pym).unsqueeze(0).unsqueeze(0)


N = 128; aper = 64
pwfs = genOTF_PWFS4(N_fourier_points=N, N_points_aperture=aper, separation=1.0, device=device)

Zj = [2,200] # max 400
jmodes = np.arange(Zj[0], Zj[1]-1) 
mvm = MVM(aperture=pupil,FourierFilter=pwfs,jModes=jmodes,device=device)

In [None]:
vk_rand = VonKarmanPhaseScreenGenerator(
    N=N, D_tel=D, r0=0.15, L0=25.0, l0=0.01,
    pupil_mask=pupil, device=device, batch_size=1
)

phase = vk_rand.generate_total_phase()
z_coeffs = mvm.phase2zern(phase)[0]
pha_zern = mvm.zern2phase(z_coeffs)

plt.figure(figsize=(8,4))
plt.subplot(1,2,1)
plt.imshow(phase[0,0,:,:].cpu(), cmap='viridis')
plt.colorbar(label="phase [rad or a.u.]")
plt.title("One Von Kármán phase screen")
plt.axis("off")

plt.subplot(1,2,2)
plt.imshow(pha_zern[0,0,:,:].cpu(), cmap='viridis')
plt.colorbar(label="phase [rad or a.u.]")
plt.title("Reconstructed Zernike phase")
plt.axis("off")

plt.show()

In [None]:
steps = 50
frames = []          # fase "real"
frames_zernike = []  # fase reconstruida
labels = []

for i in range(steps):
    phase = dyn.evolve()                  # (1,1,N,N)
    zcoeffs = mvm.phase2zern(phase.to(device))[0]          # (nmodes,)
    phi_zernike = mvm.zern2phase(zcoeffs).unsqueeze(0)     # (1,1,N,N)

    arr = phase.squeeze().detach().cpu().numpy()
    arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
    arr = arr * pupil.cpu().squeeze().numpy()

    arr_z = phi_zernike.squeeze().detach().cpu().numpy()
    arr_z = (arr_z - arr_z.min()) / (arr_z.max() - arr_z.min() + 1e-8)
    arr_z = arr_z * pupil.cpu().squeeze().numpy()

    frames.append(arr)
    frames_zernike.append(arr_z)
    labels.append(f"Frame {i+1}")

fig, (ax1, ax2, ax3) = plt.subplots(
    1, 3, figsize=(12,4), dpi=100,
    gridspec_kw={'width_ratios': [1,1,1]}
)

for ax in (ax1, ax2, ax3):
    ax.set_axis_off()

fig.subplots_adjust(left=0.02, right=0.98, top=0.95, bottom=0.05, wspace=0.05)

im1 = ax1.imshow(frames[0], cmap='viridis', animated=True)
ax1.set_title("Phase (atm)")

im2 = ax2.imshow(frames_zernike[0], cmap='viridis', animated=True)
ax2.set_title("Phase (Zernike)")

im3 = ax3.imshow(frames[0] - frames_zernike[0], cmap='viridis', animated=True)
ax3.set_title("Difference")

txt = ax1.text(5, 10, labels[0], color='white', fontsize=10)

fps = 10

def update(k):
    im1.set_array(frames[k])
    im2.set_array(frames_zernike[k])
    im3.set_array(frames[k] - frames_zernike[k])
    txt.set_text(labels[k])


    return im1, im2, im3, txt

movie = animation.FuncAnimation(
    fig,
    update,
    frames=len(frames),
    interval=1000/fps,
    blit=False
)

plt.close(fig)
HTML(movie.to_jshtml())


## 4. Dataset and DataLoader (for training)

We now wrap the **random** generator and the **Zernike projection** into a PyTorch `Dataset`.

- **Input (`x`)**: by default the **phase** (shape `(1, H, W)`).  
  Later, you can switch to a **PWFS image** or **PSF**, keeping the same API.  
- **Target (`y`)**: Zernike coefficient vector (shape `(n_modes,)`).

With a `DataLoader` we can shuffle and choose a batch size to feed models.


In [None]:
import math
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

from Functions.pupils import CreateTelescopePupil
from Functions.vonkarman_model_newv3 import VonKarmanPhaseScreenGenerator  # adjust path as needed
from Functions.mvm import MVM


class VKZernikeDataset(Dataset):
    def __init__(self, vk_rand, mvm, length=1000, r0_range=None):

        self.vk = vk_rand
        self.mvm = mvm
        self.length = length
        self.r0_range = r0_range

    def __len__(self): 
        return self.length

    def __getitem__(self, idx):
        # update r0 if needed
        if self.r0_range is not None:
            r0 = float(torch.empty(1).uniform_(*self.r0_range))
            try:
                self.vk.update_parameters(r0=r0)
            except Exception:
                pass

        phase = self.vk.generate_total_phase()
        # zernike coefficients (n_modes,) with your MVM
        z = self.mvm.phase2zern(phase)[0].float()
        # input = phase image (1,N,N)
        x = phase[0].float()
        return x, z


In [None]:
def genOTF_PWFS4(N_fourier_points=128, N_points_aperture=64, separation=2.0, device='cpu'):
    beta = separation * math.pi / (N_fourier_points / N_points_aperture)
    x = torch.arange(N_fourier_points, device=device, dtype=torch.float32)
    center = N_fourier_points // 2
    fx = (x - center) * (2.0 * center / N_fourier_points)
    FX, FY = torch.meshgrid(fx, fx, indexing='ij')
    H = lambda t: torch.heaviside(t, torch.tensor(0.5, device=device))
    pym = 0j * FX
    for sx, sy in ((1,1),(1,-1),(-1,-1),(-1,1)):
        mask  = H(sx*FX) * H(sy*FY)
        phase = -beta * (sx*FX + sy*FY)
        pym   = pym + mask * torch.exp(1j * phase)
    return torch.fft.fftshift(pym).unsqueeze(0).unsqueeze(0)


N = 128; aper = 64; D = 3.0
device = "cuda" if torch.cuda.is_available() else "cpu"

Zj = [2,200] # max 400
jmodes = np.arange(Zj[0], Zj[1]-1) 
pupil = CreateTelescopePupil(N, "disc", device=device)
pwfs = genOTF_PWFS4(N_fourier_points=N, N_points_aperture=aper, separation=1.0, device=device)
mvm = MVM(aperture=pupil,FourierFilter=pwfs,jModes=jmodes,device=device)

vk_rand = VonKarmanPhaseScreenGenerator(
    N=N, D_tel=D, r0=0.15, L0=25.0, l0=0.01,
    pupil_mask=pupil, device=device, batch_size=1
)

ds = VKZernikeDataset(vk_rand, mvm, length=500, r0_range=(0.10, 0.30))
dl = DataLoader(ds, batch_size=10, shuffle=True)

xb, yb = next(iter(dl))
print("x batch:", xb.shape)
print("y batch:", yb.shape)



In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

data = 1

plt.subplot(1, 2, 1)
phase_img = xb[data, 0].detach().cpu().numpy()      
plt.imshow(phase_img, cmap='viridis')
plt.colorbar(label="phase [rad or a.u.]")
plt.title("Input phase from Von Kármán model")
plt.axis("off")

plt.subplot(1, 2, 2)
z = yb[data, 0].detach().cpu().numpy()               
x = np.arange(len(z))                             

plt.bar(x, z, width=0.8)
plt.axhline(0, color='k', linewidth=1)            
plt.ylim(-1.0, 1.0)                               
plt.title("Target Zernike coefficients")
plt.ylabel("Coefficient [a.u.]")

step = max(1, len(z)//10)
plt.xticks(x[::step], x[::step])

plt.tight_layout()
plt.show()