In [26]:
import torch
from torch import Tensor

In [27]:
def lens_pupil_torch(x_mesh: Tensor, y_mesh: Tensor, D: float):
    """ 
    Args:
        x_mesh:(H,W)
        y_mesh:(H,W)
        D: float,the diameter of the lens aperture
    Returns:
        pupil: [1,H,W,1]
    """
    device = x_mesh.device
    dtype = torch.float32
    
    R = D / 2
    pupil = ((x_mesh**2 + y_mesh**2) <= R**2).to(dtype).to(device)  # (H,W)
    pupil = pupil.unsqueeze(0).unsqueeze(-1)  # [1,H,W,1]
    return pupil

In [28]:
def defocus_select_torch(scene_distances: Tensor, defocus_amount: Tensor,
                   total_depths: int, n_sample_depths: int, step: int):
    """ 
    Args:
        scene_distances: [total_depths], the distances of all scene depths
        defocus_amount: [n_sample_depths], the defocus amounts for selected depths
        total_depths: int, total number of depths in the scene
        n_sample_depths: int, number of selected depths for defocus calculation
    Returns:
        selected_defocus_amount: [total_depths], the defocus amounts for all depths
    """
    device = scene_distances.device
    dtype = torch.float32

    num_blocks = total_depths // n_sample_depths
    i = torch.mod(step, max(1, num_blocks))
    start = i * n_sample_depths
    stop = start + n_sample_depths
    defocus_amount = defocus_amount[start:stop].to(dtype).to(device)
    return defocus_amount

In [29]:
def zernike_phase_torch(phase_var: Tensor, zernike_volume: Tensor,
                  wave_lengths: Tensor, bound_val: int):
    """
    Args:
        phase_var:length-T tensor,the learnable zernike coefficients [T]
        zernike_volume:length-T x num_zernike x H x W,the zernike basis functions for each wavelength [T,H,W]
        wave_lengths:length-T tensor,the wavelengths [C]
        bound_val:int 
    Returns:
        phase_map:length-T x H x W,the phase delay introduced by the phase mask at each wavelength
        phi:[1,H,W,C]
    """
    device = phase_var.device
    dtype = torch.float32

    phase_var = phase_var[:, None, None] * bound_val
    phase_map = torch.sum(phase_var * zernike_volume.to(dtype).to(device),
                          dim=0)
    phase_map = phase_map[None, :, :, None]
    wave_lengths = wave_lengths[None, None, None, :]
    C = wave_lengths.shape[-1]
    phi = phase_map.repeat(1, 1, 1, C)
    return phase_map, phi

In [30]:
def defocus_pupil_torch(defocus_amount: Tensor, x_mesh: Tensor,
                   y_mesh: Tensor, wave_lengths: Tensor):
    """ 
    Args:
        defocus_amount: [n_sample_depths], the defocus amounts for selected depths
        x_mesh:(H,W)
        y_mesh:(H,W)
        wave_lengths: the wavelengths
    Returns:
    

SyntaxError: incomplete input (2146867805.py, line 3)