# **Project Goals**

The goal of the project would be as follows:

1: Write your own code (not an existing repo on the internet) that performs single shot ptychography using structured illumination as demonstrated in the papers by Levitan et al.

2: answer the following question:

a. What happens if the illumination (=probe) field has zero crossings? Is this a problem for the method? Do zero crossing result in non-measured points?

b. suppose the method is simulated using 2x2 upsampling (that is, 2x2 probe pixels determine the average phase in a 1x1 super-pixel inside the object). Is it required that there is a relative phase shift between the 2x2 pixels inside the probe? Or can the modulation be amplitude-only (that is no phase shift, but an intensity variation)

c. What is the minimum amount of upsampling needed?

d. Is there a class of illumination functions for which the method fails?

e. Try implementing regularization to the method. Please include L1, L2, and Total variation regularization.

# **Project Ptychography**

Illumination: A grid of partially overlapping beams, each beam approaches the sample from different angles. 

Reconstruction: Randomized probe imaging algorithm. 

General strategy: Divide the intensity pattern into a collection of individual, smaller diffraction patterns. The smaller patterns are centered on an individual beam and tagged with a corresponding translation at the same plane.

X-ray problem: Difficult to generate a grid of identical beams with sufficient angular separation and uniform intensities. Grating equation: $d(
\sin\theta_m) = m\lambda$. Small wavelength leads to small grating period.

Feature: It recovers both the probe and the object, so the probe is not needed to be known. But the division of the detector limits resolution by limiting the NA

Method used in the single-shot article:

A purely iterative algorithm for single-shot ptychography which uses a pre-calibrated probe and operates on full diffraction patterns without partitioning them into a ptychography dataset.

This method overcomes limitations of the general strategy that in cases where the object contains high frequency components, the scattering from neighboring beams does overlap and interfere, causing reconstructions to perform poorly.

The forward model for this algorithm:

$$I = \sum_{n=1}^{N} \left| \mathcal{F} \left\{ P_n \cdot \mathcal{F}^{-1} \left\{ \mathcal{U} \cdot \exp(i T) \right\} \right\} \right|^2$$

$P_n$ is the discrete representation= of the nth mode of the pre-calibrated probe. $\mathcal{U} is a zero padding operator, and $T$ is a low-resolution representation of the object's transmission function

Zero-padding $\mathcal{U}$ is a band-limiting constraint that stablizes the inverse problem.

$T$ is constraint to be purely real to apply an additional phase-only constraint on the object. Allowing $T$ to be complex-valued can remove the phase-only constraint. The final object function is defined as:

$$ O = \exp(iT)$$

The first step of the model is to upsample the low-resolution object $O$ by padding it with zeros in Fourier space

In [10]:
def forward_model(obj):
    return obj

## **Reconstruction**

To perform a reconstruction, we start with an initial guess of the project function and use a forward model to simulate the corresponding diffraction pattern. 

Next, we calculate the normalized mean squared error between the measured diffraction amplitudes and a simulated diffraction pattern including a known detector background:

$$ L = \frac{1}{\sum_{ij}I_{ij}} \sum_{ij}\left( \sqrt{|\tilde{E_{ij}}|^2 + B_{ij}} - \sqrt{I_{ij}}\right)^2$$

From the equation above, we can write it in Python:

In [None]:
import jax.numpy as jnp
from chex import assert_equal_shape

def loss_function(simulated: jnp.ndarray, background: jnp.ndarray, measured: jnp.ndarray) -> float:
    '''
    Calculates the normalized mean squared error (L) for the given simulated diffraction amplitudes E, detector background B, and the measured diffraction pattern I.

    Args: 
        simulated (jnp.ndarray): The simulated diffraction pattern (Intensity) calculated from a forward model
        backgroud (jnp.ndarray): The detector background (Intensity)
        I (jnp.ndarray): The measured diffraction pattern (Intensity)
    
    Return:
        float: The normalized mean squared error (L) 
    '''
    # Assert that the three arrays have the same shape
    assert_equal_shape([simulated, background, measured])

    # Calculate the factor to normalize the error
    factor = 1.0 / (jnp.sum(measured) + 1e-10)

    # Calculate the simulated diffraction amplitudes
    simulated_amp = jnp.sqrt(simulated + background)

    # Calculate the measured diffraction amplitudes
    measured_amp = jnp.sqrt(measured)

    # Calculate the normalized mean squared error
    loss = factor * jnp.sum((simulated_amp - measured_amp) ** 2)

    return loss

Example usage:

In [None]:
simulated = jnp.array([[1.0, 2.0], [3.0, 4.0]])  # Simulated amplitudes (complex or real)
background = jnp.array([[0.1, 0.1], [0.1, 0.1]])   # Background
measured = jnp.array([[1.5, 4.5], [9.5, 16.5]])   # Measured intensities

print("loss =", loss_function(simulated, background, measured))

loss = 0.0014692659


We then use automatic differentiation to calculate the Wirtinger derivative of the loss function with respect to the object guess $T$:

$$ \frac{\delta L}{\delta T} = \frac{\delta L}{\delta |E_{ij}|^2} \frac{\delta |E_{ij}|^2}{\delta T}$$

In [11]:
from jax import grad

def derivative_loss_function_wrt_obj(obj: jnp.ndarray):
    '''
    Calculates the Wirtinger derivative of the loss function with respect to the object guess obj.

    Args:
        obj (jnp.ndarray): The object guess (complex or real)

    Return:
        jnp.ndarray: The derivative of the loss function with respect to the object guess obj
    '''
    # Calculate the derivative of the foward model wrt the object
    d_forward_model = grad(forward_model)(obj)

    # Calculate the derivative of the loss function wrt the simulated pattern
    simulated = forward_model(obj)
    d_loss_function_wrt_simulated = grad(loss_function)(simulated)

    # Calculate the derivative of the loss function wrt the object
    d_loss_function_wrt_obj = d_loss_function_wrt_simulated * d_forward_model

    return d_loss_function_wrt_obj
