Skip to content
21 changes: 21 additions & 0 deletions src/ptychi/api/options/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,27 @@ def get_non_data_fields(self) -> dict:
del d["initial_guess"]
return d

@dataclasses.dataclass
class SynthesisDictLearnProbeOptions(Options):

d_mat: Union[ndarray, Tensor] = None
"""The synthesis sparse dictionary matrix; contains the basis functions
that will be used to represent the probe via the sparse code weights."""

d_mat_conj_transpose: Union[ndarray, Tensor] = None
"""Conjugate transpose of the synthesis sparse dictionary matrix."""

d_mat_pinv: Union[ndarray, Tensor] = None
"""Moore-Penrose pseudoinverse of the synthesis sparse dictionary matrix."""

probe_sparse_code: Union[ndarray, Tensor] = None
"""Sparse code weights vector."""

probe_sparse_code_nnz: float = None
"""Number of non-zeros we will keep when enforcing sparsity constraint on
the sparse code weights vector probe_sparse_code."""

enabled: bool = False

@dataclasses.dataclass
class PositionCorrectionOptions(Options):
Expand Down
7 changes: 7 additions & 0 deletions src/ptychi/api/options/pie.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@ class PIEObjectOptions(base.ObjectOptions):
"""


@dataclasses.dataclass
class PIEProbeExperimentalOptions(base.Options):
sdl_probe_options: base.SynthesisDictLearnProbeOptions = dataclasses.field(default_factory=base.SynthesisDictLearnProbeOptions)


@dataclasses.dataclass
class PIEProbeOptions(base.ProbeOptions):

alpha: float = 0.1
"""
Multiplier for the update to the probe, as defined in table 1 of Maiden (2017).
"""

experimental: PIEProbeExperimentalOptions = dataclasses.field(default_factory=PIEProbeExperimentalOptions)


@dataclasses.dataclass
Expand Down
6 changes: 6 additions & 0 deletions src/ptychi/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ def build_probe(self):
self.probe_options.experimental.deep_image_prior_options.enabled
):
self.probe = probe.DIPProbe(**kwargs)
elif (
isinstance(self.probe_options, api.options.PIEProbeOptions)
) and (
self.probe_options.experimental.sdl_probe_options.enabled
):
self.probe = probe.SynthesisDictLearnProbe(**kwargs)
else:
self.probe = probe.Probe(**kwargs)

Expand Down
4 changes: 3 additions & 1 deletion src/ptychi/data_structures/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
name: Optional[str] = None,
options: "api.options.base.ParameterOptions" = None,
data_as_parameter: bool = True,
build_optimizer: bool = True,
*args,
**kwargs,
) -> None:
Expand Down Expand Up @@ -161,7 +162,8 @@ def __init__(
else:
self.register_buffer("tensor", tensor)

self.build_optimizer()
if build_optimizer:
self.build_optimizer()

@property
def shape(self) -> Tuple[int, ...]:
Expand Down
69 changes: 69 additions & 0 deletions src/ptychi/data_structures/probe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Union, TYPE_CHECKING
import logging
import enum
import os

import numpy as np
Expand All @@ -22,11 +23,19 @@
logger = logging.getLogger(__name__)


class ProbeRepresentation(enum.StrEnum):
NORMAL = enum.auto()
SPARSE_CODE = enum.auto()
DIP = enum.auto()


class Probe(dsbase.ReconstructParameter):
# TODO: eigenmode_update_relaxation is only used for LSQML. We should create dataclasses
# to contain additional options for ReconstructParameter classes, and subclass them for specific
# reconstruction algorithms - for example, ProbeOptions -> LSQMLProbeOptions.
options: "api.options.base.ProbeOptions"

representation: ProbeRepresentation = ProbeRepresentation.NORMAL

def __init__(
self,
Expand Down Expand Up @@ -436,10 +445,70 @@ def save_tiff(self, path: str):
tifffile.imsave(fname + "_mag.tif", mag_img)
tifffile.imsave(fname + "_phase.tif", phase_img)

class SynthesisDictLearnProbe( Probe ):

representation: ProbeRepresentation = ProbeRepresentation.SPARSE_CODE

def __init__(self, name = "probe", options = None, *args, **kwargs):

super().__init__(name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs)

dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H = self.get_dictionary()
self.register_buffer("dictionary_matrix", dictionary_matrix)
self.register_buffer("dictionary_matrix_pinv", dictionary_matrix_pinv)
self.register_buffer("dictionary_matrix_H", dictionary_matrix_H)

probe_sparse_code_nnz = torch.tensor( self.options.experimental.sdl_probe_options.probe_sparse_code_nnz, dtype=torch.uint32 )
self.register_buffer("probe_sparse_code_nnz", probe_sparse_code_nnz )

sparse_code_probe = self.get_initial_weights()
self.register_parameter("sparse_code_probe", torch.nn.Parameter(sparse_code_probe))

self.build_optimizer()

def get_dictionary(self):
dictionary_matrix = torch.tensor( self.options.experimental.sdl_probe_options.d_mat, dtype=torch.complex64 )
dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_pinv, dtype=torch.complex64 )
dictionary_matrix_H = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_conj_transpose, dtype=torch.complex64 )
return dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H

def get_initial_weights(self):
probe_vec = torch.reshape( self.data, ( self.data.shape[1], self.data.shape[2] * self.data.shape[3] ))
probe_vec = torch.swapaxes( probe_vec, 0, -1)
sparse_code_probe = self.dictionary_matrix_pinv @ probe_vec
return sparse_code_probe

def generate(self):
"""Generate the probe using the sparse code, and set the
generated probe to self.data.

Returns
-------
Tensor
A (n_opr_modes, n_modes, h, w) tensor giving the generated probe.
"""
probe_vec = self.dictionary_matrix @ self.sparse_code_probe
probe_vec = torch.swapaxes( probe_vec, 0, -1)
probe = torch.reshape( probe_vec, ( self.data.shape[1], self.data.shape[2], self.data.shape[3] ))[ None, ... ]
self.tensor.data = torch.stack([probe.real, probe.imag], dim=-1)
return probe

def build_optimizer(self):
if self.optimizable and self.optimizer_class is None:
raise ValueError(
"Parameter {} is optimizable but no optimizer is specified.".format(self.name)
)
if self.optimizable:
self.optimizer = self.optimizer_class([self.sparse_code_probe], **self.optimizer_params)

def set_sparse_code(self, data):
self.sparse_code_probe.data = data


class DIPProbe(Probe):

options: "api.options.ad_ptychography.AutodiffPtychographyProbeOptions"
representation: ProbeRepresentation = ProbeRepresentation.DIP

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/ptychi/forward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ def dip_generate(self):
self.object.generate()
if isinstance(self.probe, ptychi.data_structures.probe.DIPProbe):
self.probe.generate()
elif isinstance(self.probe, ptychi.data_structures.probe.SynthesisDictLearnProbe):
self.probe.generate()

@timer()
def forward(self, indices: Tensor, return_object_patches: bool = False) -> Tensor:
Expand Down
41 changes: 40 additions & 1 deletion src/ptychi/reconstructors/pie.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,46 @@ def compute_updates(
)

delta_p_i = None
if probe.optimization_enabled(self.current_epoch):
if probe.optimization_enabled(self.current_epoch) and self.parameter_group.probe.representation == "sparse_code":
rc = psi_prime.shape[-1] * psi_prime.shape[-2]
n_scpm = psi_prime.shape[-3]
n_pos = psi_prime.shape[-4]

psi_prime_vec = torch.reshape(psi_prime, (n_pos, n_scpm, rc))

probe_vec = torch.reshape(self.parameter_group.probe.data[0, ...], (n_scpm , rc))

obj_patches_vec = torch.reshape(obj_patches, (n_pos, 1, rc ))

conj_obj_patches = torch.conj(obj_patches_vec)
abs2_obj_patches = torch.abs(obj_patches_vec) ** 2

z = torch.sum(abs2_obj_patches, dim = 0)
z_max = torch.max(z)
w = 0.9 * (z_max - z)

sum_spos_conjT_s_psi = torch.sum(conj_obj_patches * psi_prime_vec, 0)
sum_spos_conjT_s_psi = torch.swapaxes(sum_spos_conjT_s_psi, 0, 1)

w_phi = torch.swapaxes(w * probe_vec, 0, 1)
z_plus_w = torch.swapaxes(z + w, 0, 1)

numer = self.parameter_group.probe.dictionary_matrix_H @ (sum_spos_conjT_s_psi + w_phi)
denom = (self.parameter_group.probe.dictionary_matrix_H @ (z_plus_w * self.parameter_group.probe.dictionary_matrix))

sparse_code = torch.linalg.solve(denom, numer)

# Enforce sparsity constraint on sparse code
abs_sparse_code = torch.abs(sparse_code)
sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True)

sel = sparse_code_sorted[0][self.parameter_group.probe.probe_sparse_code_nnz, :]

sparse_code = sparse_code * (abs_sparse_code >= sel)

# Update sparse code in probe object
self.parameter_group.probe.set_sparse_code(sparse_code)
else:
step_weight = self.calculate_probe_step_weight(obj_patches)
delta_p_i = step_weight * (psi_prime - psi) # get delta p at each position
delta_p_i = self.adjoint_shift_probe_update_direction(indices, delta_p_i, first_mode_only=True)
Expand Down
86 changes: 86 additions & 0 deletions tests/test_2d_ptycho_rpie_synthesisdictlearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import argparse
import os

import torch
import numpy as np

import ptychi.api as api
from ptychi.api.task import PtychographyTask
from ptychi.utils import get_suggested_object_size, get_default_complex_dtype

import test_utils as tutils


class Test2DPtychoRPIE_SDL(tutils.TungstenDataTester):
@tutils.TungstenDataTester.wrap_recon_tester(name="test_2d_ptycho_rpie_synthesisdictlearn")
def test_2d_ptycho_rpie_synthesisdictlearn(self):
self.setup_ptychi(cpu_only=False)

data, probe, pixel_size_m, positions_px = self.load_tungsten_data(additional_opr_modes=0)

npz_dict_file = np.load(
os.path.join(
self.get_ci_input_data_dir(), "zernike2D_dictionaries", "testing_sdl_dictionary.npz"
)
)
D = npz_dict_file["a"]
npz_dict_file.close()

options = api.RPIEOptions()
options.data_options.data = data

options.object_options.initial_guess = torch.ones(
[1, *get_suggested_object_size(positions_px, probe.shape[-2:], extra=100)],
dtype=get_default_complex_dtype(),
)
options.object_options.pixel_size_m = pixel_size_m
options.object_options.optimizable = True
options.object_options.optimizer = api.Optimizers.SGD
options.object_options.step_size = 1e-1
options.object_options.alpha = 1e-0

options.probe_options.initial_guess = probe
options.probe_options.optimizable = True
options.probe_options.optimizer = api.Optimizers.SGD
options.probe_options.orthogonalize_incoherent_modes.enabled = True
options.probe_options.step_size = 1e-0
options.probe_options.alpha = 1e-0

options.probe_options.experimental.sdl_probe_options.enabled = True
options.probe_options.experimental.sdl_probe_options.d_mat = np.asarray(
D, dtype=np.complex64
)
options.probe_options.experimental.sdl_probe_options.d_mat_conj_transpose = np.conj(
options.probe_options.experimental.sdl_probe_options.d_mat
).T
options.probe_options.experimental.sdl_probe_options.d_mat_pinv = np.linalg.pinv(
options.probe_options.experimental.sdl_probe_options.d_mat
)
options.probe_options.experimental.sdl_probe_options.probe_sparse_code_nnz = np.round(
0.90 * D.shape[-1]
)

options.probe_position_options.position_x_px = positions_px[:, 1]
options.probe_position_options.position_y_px = positions_px[:, 0]
options.probe_position_options.optimizable = False

options.reconstructor_options.batch_size = round(data.shape[0] * 0.1)
options.reconstructor_options.num_epochs = 50
options.reconstructor_options.allow_nondeterministic_algorithms = False

task = PtychographyTask(options)
task.run()

recon = task.get_data_to_cpu("object", as_numpy=True)[0]

return recon


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--generate-gold", action="store_true")
args = parser.parse_args()

tester = Test2DPtychoRPIE_SDL()
tester.setup_method(name="", generate_data=False, generate_gold=args.generate_gold, debug=True)
tester.test_2d_ptycho_rpie_synthesisdictlearn()