# Detailed walk through for PtyRAD

- Created with PtyRAD 0.1.0b10
- Requires PtyRAD >= 0.1.0b10
- Latest demo params files / scripts: https://github.com/chiahao3/ptyrad/tree/main/demo
- Documentation: https://ptyrad.readthedocs.io/en/latest/
- PtyRAD paper: https://doi.org/10.1093/mam/ozaf070
- PtyRAD arXiv: https://arxiv.org/abs/2505.07814
- Zenodo record: https://doi.org/10.5281/zenodo.15273176
- Box folder: https://cornell.box.com/s/n5balzf88jixescp9l15ojx7di4xn1uo
- Youtube channel: https://www.youtube.com/@ptyrad_official

**Before running this notebook, you must first follow the instruction in `README.md` to:**
1. Create the Python environment with all dependant Python packages like PyTorch
2. Activate that python environment
3. Install `ptyrad` package into your activated Python environement (only need to install once)
4. Download the demo data into `demo/data/` from the `demo/data/data_url.txt`

> Note: This notebook is designed for showcasing only the "reconstruction" mode, most of the wrapper class / functions are exposed so that you can see how different components work together.

Author: Chia-Hao Lee, cl2696@cornell.edu

# 01. Imports

In [None]:
import os
from random import shuffle

import numpy as np
import torch

# Change this to the ABSOLUTE PATH to the demo/ folder so you can correctly access data/ and params/
work_dir = "../" # Leave this as-is if you're running the notebook from the `ptyrad/demo/scripts/` folder, this will change it back to demo/

os.chdir(work_dir)
print("Current working dir: ", os.getcwd())
# The printed working dir should be ".../ptyrad/demo" to locate the demo params files easily
# Note that the output/ directory will be automatically generated under your working directory

In [None]:
from ptyrad.load import load_params
from ptyrad.initialization import Initializer
from ptyrad.models import PtychoAD
from ptyrad.losses import CombinedLoss
from ptyrad.constraints import CombinedConstraint
from ptyrad.reconstruction import (
    recon_step,
    create_optimizer,
    make_batches,
    select_scan_indices,
    parse_torch_compile_configs,
    toggle_grad_requires,
)
from ptyrad.save import save_results, copy_params_to_dir, make_output_folder
from ptyrad.utils import (
    CustomLogger,
    get_blob_size,
    parse_sec_to_time_str,
    print_system_info,
    set_gpu_device,
    time_sync,
    vprint,
)
from ptyrad.visualization import (
    plot_forward_pass,
    plot_pos_grouping,
    plot_scan_positions,
    plot_summary,
)

In [None]:
logger = CustomLogger(log_file='ptyrad_log.txt', log_dir='auto', prefix_time='datetime', show_timestamp=True)

print_system_info()
device = set_gpu_device(gpuid=0)

# 02. Initialize optimization

In [None]:
params_path = "params/tBL_WSe2_reconstruct.yml"

# We enable validation to auto-fill defaults and check parameter consistency since PtyRAD 0.1.0b8
# If you run into issues with validation (e.g., false positives or unexpected errors),
# you can temporarily disable it by setting `validate=False` and prepare a fully complete params file yourself.
# If this happens, please report the bug so we can improve the validation logic.
params              = load_params(params_path, validate=True)
init_params         = params.get('init_params')
hypertune_params    = params.get('hypertune_params') # It's parsed but not needed in this demo notebook
model_params        = params.get('model_params')
loss_params         = params.get('loss_params')
constraint_params   = params.get('constraint_params')
recon_params        = params.get('recon_params')

In [None]:
init = Initializer(init_params).init_all()

In [None]:
pos = init.init_variables["crop_pos"] + init.init_variables["probe_pos_shifts"]
plot_scan_positions(pos, figsize=(8, 8))

In [None]:
model = PtychoAD(init.init_variables, model_params, device=device)
optimizer = create_optimizer(model.optimizer_params, model.optimizable_params)

## Check the forward pass

In [None]:
indices = np.random.randint(0, init.init_variables["N_scans"], 2)
dp_power = 0.5
plot_forward_pass(model, indices, dp_power)

# Setup the loss and constraint function

In [None]:
loss_fn = CombinedLoss(loss_params, device=device)

constraint_fn = CombinedConstraint(constraint_params, device=device)

# 03. Main optimization loop

In [None]:
NITER             = recon_params.get('NITER')
INDICES_MODE      = recon_params.get('INDICES_MODE')
batch_config      = recon_params.get('BATCH_SIZE', {})
grad_accumulation = batch_config.get("grad_accumulation", 1)
batch_size        = batch_config.get('size') * grad_accumulation
GROUP_MODE        = recon_params.get('GROUP_MODE')
SAVE_ITERS        = recon_params.get('SAVE_ITERS')
output_dir        = recon_params.get('output_dir')
recon_dir_affixes = recon_params.get('recon_dir_affixes')
selected_figs     = recon_params.get('selected_figs')
compiler_configs  = parse_torch_compile_configs(recon_params.get('compiler_configs'))
copy_params       = recon_params.get('copy_params')

pos = (model.crop_pos + model.opt_probe_pos_shifts).detach().cpu().numpy()
probe_int = model.get_complex_probe_view().abs().pow(2).sum(0).detach().cpu().numpy()
dx = init.init_variables["dx"]
d_out = get_blob_size(dx, probe_int, output="d90")  # d_out unit is in Ang

indices = select_scan_indices(
    init.init_variables['N_scan_slow'],
    init.init_variables['N_scan_fast'],
    subscan_slow=INDICES_MODE.get('subscan_slow'),
    subscan_fast=INDICES_MODE.get('subscan_fast'),
    mode=INDICES_MODE.get('mode', 'random'),
)

batches = make_batches(indices, pos, batch_size, mode=GROUP_MODE)
vprint(f"The effective batch size (i.e., how many probe positions are simultaneously used for 1 update of ptychographic parameters) is batch_size * grad_accumulation = {batch_size} * {grad_accumulation} = {batch_size*grad_accumulation}")

fig_grouping = plot_pos_grouping(
    pos,
    batches,
    circle_diameter=d_out / dx,
    diameter_type="90%",
    dot_scale=1,
    show_fig=True,
    pass_fig=True,
)

if SAVE_ITERS is not None:
    output_path = make_output_folder(
        output_dir,
        indices,
        init_params,
        recon_params,
        model,
        constraint_params,
        loss_params,
        recon_dir_affixes
    )
    
    fig_grouping.savefig(output_path + "/summary_pos_grouping.png")

    if copy_params:
        copy_params_to_dir(params_path, output_path, params)

# Flush to file after the output_path is created
if logger is not None and logger.flush_file:
    logger.flush_to_file(log_dir = output_path)

In [None]:
start_t = time_sync()
vprint("### Starting the PtyRADSolver in reconstruction mode ###")
vprint(" ")
    
# torch.compile options
if compiler_configs is None:
    compiler_configs = {'disable': True} # Default to not use the compiler for maximal support for different machines

for niter in range(1, NITER + 1):

    # Toggle the grad calculation to enable or disable AD update on tensors at certain iterations
    toggle_grad_requires(model, niter, verbose=True)
    
    # Apply torch.compile to `recon_step``
    if niter in model.compilation_iters: # compilation_iters always contain niter=1
        vprint(f"Setting up PyTorch compiler with {compiler_configs}")
        torch._dynamo.reset()
        recon_step_compiled = torch.compile(recon_step, **compiler_configs)

    shuffle(batches)
    batch_losses = recon_step_compiled(
        batches, grad_accumulation, model, optimizer, loss_fn, constraint_fn, niter
    )

    ## Saving intermediate results
    if SAVE_ITERS is not None and niter % SAVE_ITERS == 0:
        with torch.no_grad():
        # Note that `params` stores the original params from the configuration file, 
        # while `model` contains the actual params that could be updated by meas_crop, meas_pad, or meas_resample
            save_results(
                output_path,
                model,
                params,
                optimizer,
                niter,
                indices,
                batch_losses,
            )

            ## Saving summary
            plot_summary(
                output_path,
                model,
                niter,
                indices,
                init.init_variables,
                selected_figs=selected_figs,
                show_fig=False,
                save_fig=True,
            )
vprint(f"### Finish {NITER} iterations, averaged iter_t = {np.mean(model.iter_times):.5g} sec ###")
vprint(" ")
end_t = time_sync()
solver_t = end_t - start_t
time_str = f", or {parse_sec_to_time_str(solver_t)}" if solver_t > 60 else ""
vprint(f"### The PtyRADSolver is finished in {solver_t:.3f} sec {time_str} ###")

if logger is not None and logger.flush_file:
    logger.close()