diff --git a/src/axiomatic/pic_helpers.py b/src/axiomatic/pic_helpers.py index 3f74acd..b72ba55 100644 --- a/src/axiomatic/pic_helpers.py +++ b/src/axiomatic/pic_helpers.py @@ -452,7 +452,7 @@ def print_statements( print("\n-----------------------------------\n") -def _str_units_to_float(str_units: str) -> float: +def _str_units_to_float(str_units: str) -> Optional[float]: unit_conversions = { "nm": 1e-3, "um": 1, @@ -460,9 +460,9 @@ def _str_units_to_float(str_units: str) -> float: "m": 1e6, } match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str_units) - numeric_value = float(match.group(1) if match else 1.55) - unit = match.group(2) if match else "um" - return float(numeric_value * unit_conversions[unit]) + numeric_value = float(match.group(1)) if match else None + unit = match.group(2) if match else None + return float(numeric_value * unit_conversions[unit]) if unit in unit_conversions and numeric_value is not None else None def get_wavelengths_to_plot(statements: StatementDictionary, num_samples: int = 100) -> Tuple[List[float], List[float]]: @@ -484,7 +484,7 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float, vlines = vlines | { _str_units_to_float(wl) for wl in (comp.arguments["wavelengths"] if isinstance(comp.arguments["wavelengths"], list) else []) - if isinstance(wl, str) + if isinstance(wl, str) and _str_units_to_float(wl) is not None } if "wavelength_range" in comp.arguments: if ( @@ -492,8 +492,11 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float, and len(comp.arguments["wavelength_range"]) == 2 and all(isinstance(wl, str) for wl in comp.arguments["wavelength_range"]) ): - min_wl = min(min_wl, _str_units_to_float(comp.arguments["wavelength_range"][0])) - max_wl = max(max_wl, _str_units_to_float(comp.arguments["wavelength_range"][1])) + mi = _str_units_to_float(comp.arguments["wavelength_range"][0]) + ma = _str_units_to_float(comp.arguments["wavelength_range"][1]) + if mi is not None and ma is not None: + min_wl = min(min_wl, mi) + max_wl = max(max_wl, ma) return min_wl, max_wl, vlines for cost_stmt in statements.cost_functions or []: @@ -508,8 +511,8 @@ def update_wavelengths(mapping: Dict[str, Optional[Computation]], min_wl: float, min_wl = min(min_wl, min(vlines)) max_wl = max(max_wl, max(vlines)) if min_wl >= max_wl: - avg_wl = sum(vlines) / len(vlines) if vlines else 1550 - min_wl, max_wl = avg_wl - 10, avg_wl + 10 + avg_wl = sum(vlines) / len(vlines) if vlines else 1.55 + min_wl, max_wl = avg_wl - 0.01, avg_wl + 0.01 else: range_size = max_wl - min_wl min_wl -= 0.2 * range_size