Skip to content

Commit

Permalink
Update JAX Model Based Reconstruction code
Browse files Browse the repository at this point in the history
  • Loading branch information
tomelse committed May 16, 2024
1 parent ab5ca6f commit 4a8eea7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions patato/convenience_scripts/process_msot.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def main():
description = pipeline.children[0].get_algorithm_name()

# Setup the name
if args.highpass is not None:
description += "_CUSTOM_HP_FILTER"
if args.lowpass is not None:
description += "_CUSTOM_LP_FILTER"
#if args.highpass is not None:
# description += "_CUSTOM_HP_FILTER"
#if args.lowpass is not None:
# description += "_CUSTOM_LP_FILTER"
if args.run is not None:
description += "_run_" + str(args.run)
if args.wavelength is not None:
Expand Down
2 changes: 1 addition & 1 deletion patato/processing/jax_preprocessing_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def pre_compute_filter(self, n_samples: int, fs: float, irf: Array = None):
fs : float
irf : Array
"""
self.filter = jnp.array(make_filter(n_samples, fs, irf, self.hilbert, self.lp_filter, self.hp_filter))
self.filter = jnp.array(make_filter(n_samples, fs, irf if self.irf_correct else None, self.hilbert, self.lp_filter, self.hp_filter))

def _run(self, time_series: Array, detectors: Array, overall_correction_factor, **kwargs):
"""
Expand Down
6 changes: 3 additions & 3 deletions patato/recon/jax_model_based/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def __init__(self, n_pixels, field_of_view, pa_example: "PAData" = None, **kwarg
from jax.experimental.sparse import CSR

self._model_regulariser = (kwargs.get("model_regulariser", None), kwargs.get("model_regulariser_lambda", None))
if self._model_regulariser[0] is None:
self._model_regulariser = None
#if self._model_regulariser[0] is None:
# self._model_regulariser = None

self._model_matrix = CSR((self._model_matrix.data, self._model_matrix.indices, self._model_matrix.indptr),
shape=self._model_matrix.shape)
Expand Down Expand Up @@ -153,7 +153,7 @@ def reconstruct(self, raw_data: np.ndarray,
raise ValueError("Constraint must either be 'positive' or 'none'.")

M = self._model_matrix
if self._model_regulariser is not None:
if self._model_regulariser is not None or all([x is None for x in self._model_regulariser]):
method, lambda_reg = self._model_regulariser
else:
method, lambda_reg = None, None
Expand Down

0 comments on commit 4a8eea7

Please sign in to comment.