In [1]:
import numpy as np
from dataclasses import dataclass, field
from typing import Dict, List, Any, Callable, Optional, Tuple

# What is a scene?
Similar to normal rendering, a scene is a representation of all involved objects and volumes. There are some differences in our case:

- A scene can consist out of multiple pointings, e.g. it can be split over time and observation direction, but consists out of the same atmosphere/instrument/emitters. I thus represents a view of the same world (instrument, atmosphere, sky) for different times and positions.
- There are no reflective surfaces, essentially it is a rasterizer which only accounts for light transfer.
- All objects and volumes have parameters that can be changed. Calling the .update method then updates the scene.

In [2]:
class Parameterized:
    """Mixin for objects that expose parameters."""
    
    def get_parameters(self) -> Dict[str, Any]:
        """Return current parameters of this object."""
        if hasattr(self, '_parameters'):
            return self._parameters
        return {}
    
    def set_parameter(self, name: str, value: Any) -> bool:
        """Set a parameter value. Returns True if parameter exists."""
        if hasattr(self, '_parameters') and name in self._parameters:
            self._parameters[name] = value
            return True
        return False

@dataclass
class SceneComponents:
    """Container for all scene components."""
    
    # Diffuse components
    healpix_maps: Dict[str, np.ndarray] = field(default_factory=dict)
    
    # Catalogs
    catalogs: List[Catalog] = field(default_factory=list)
    
    # Models
    instrument: Optional[Instrument] = None
    atmosphere: Optional[Atmosphere] = None

class Scene:
    """Scene containing all components and parameters for rendering."""
    
    def __init__(self, components: SceneComponents):
        self.components = components
        
        # Build dynamic parameter table from all components
        self._build_parameter_table()

    def update(self, params: Dict[str, Any]):
        """Update scene parameters.
        
        Args:
            params: Dictionary of parameter_path: value pairs
        """
        for param_name, value in params.items():
            if param_name in self._parameter_map:
                # Update in the source object
                obj, key = self._parameter_map[param_name]
                if obj.set_parameter(key, value):
                    self.parameters[param_name] = value
                    print(f"Updated {param_name} = {value}")
                else:
                    print(f"Warning: Failed to update {param_name}")
            elif param_name in self.parameters:
                # Direct parameter (like obs.*)
                self.parameters[param_name] = value
                print(f"Updated {param_name} = {value}")
            else:
                print(f"Warning: Parameter {param_name} not found in scene")
    
    def step(self):
        """Update internal states of all stateful catalogs."""
        for i, catalog in enumerate(self.components.catalogs):
            catalog.sample_state()
    
    def _build_parameter_table(self):
        """Dynamically build parameter table from all components."""
        self.parameters = {}
        self._parameter_map = {}  # Maps parameter names to objects
        
        # Collect from instrument
        if self.components.instrument:
            for key, value in self.components.instrument.get_parameters().items():
                param_name = f"instrument.{key}"
                self.parameters[param_name] = value
                self._parameter_map[param_name] = (self.components.instrument, key)
        
        # Collect from atmosphere
        if self.components.atmosphere:
            for key, value in self.components.atmosphere.get_parameters().items():
                param_name = f"atmosphere.{key}"
                self.parameters[param_name] = value
                self._parameter_map[param_name] = (self.components.atmosphere, key)
        
        # Collect from catalogs
        for i, catalog in enumerate(self.components.catalogs):
            catalog_name = catalog.__class__.__name__
            for key, value in catalog.get_parameters().items():
                param_name = f"catalog.{catalog_name}_{i}.{key}"
                self.parameters[param_name] = value
                self._parameter_map[param_name] = (catalog, key)
    
    def get_parameter_table(self) -> Dict[str, Any]:
        """Get full parameter table."""
        # Rebuild to capture any changes
        self._build_parameter_table()
        return self.parameters.copy()
    
    def print_parameters(self):
        """Pretty print all parameters."""
        table = self.get_parameter_table()
        
        print("\n" + "="*60)
        print("SCENE PARAMETERS")
        print("="*60)
        
        # Group by category
        categories = {}
        for key, value in sorted(table.items()):
            category = key.split('.')[0]
            if category not in categories:
                categories[category] = []
            categories[category].append((key, value))
        
        for category, params in categories.items():
            print(f"\n[{category.upper()}]")
            for key, value in params:
                if isinstance(value, bool):
                    print(f"  {key:<40} = {str(value):>10}")
                elif isinstance(value, (int, float)):
                    print(f"  {key:<40} = {value:>10.4f}")
                else:
                    print(f"  {key:<40} = {str(value):>10}")

In [16]:
class Atmosphere(Parameterized):
    """Atmospheric model with dynamic parameters."""
    
    def __init__(self, model_type: str = "standard"):
        self.model_type = model_type
        
        # Different atmosphere models have different parameters
        if model_type == "standard":
            self._parameters = {
                'aod': 0.1,  # Aerosol optical depth
                'water_vapor': 10.0,  # mm
                'pressure': 1013.25,  # mbar
            }
        elif model_type == "tropical":
            self._parameters = {
                'aod': 0.15,
                'water_vapor': 30.0,
                'pressure': 1013.25,
                'ozone': 260.0,  # Dobson units - only in tropical model
                'humidity': 0.8,  # Only in tropical
            }
        elif model_type == "arctic":
            self._parameters = {
                'aod': 0.05,
                'water_vapor': 2.0,
                'pressure': 1013.25,
                'ozone': 350.0,  # Different ozone
                'temperature': -30.0,  # Only in arctic
            }
    
    def extinction(self, wavelengths: np.ndarray, airmass: float) -> np.ndarray:
        """Compute atmospheric extinction."""
        aod = self._parameters.get('aod', 0.1)
        # Simplified extinction model
        return np.exp(-aod * airmass * (550e-9 / wavelengths)**4)
    
    def scattering(self, wavelengths: np.ndarray) -> np.ndarray:
        """Compute atmospheric scattering."""
        return 0.1 * (550e-9 / wavelengths)**4

class Instrument(Parameterized):
    """Instrument model with dynamic parameters."""
    def __init__(self, instrument_type: str = "ccd"):
        self.instrument_type = instrument_type
        
        if instrument_type == "ccd":
            self._parameters = {
                'xshift': 0.0,
                'yshift': 0.0,
                'rotation': 0.0,
                'gain': 1.0,
                'read_noise': 5.0,
                'dark_current': 0.1,
            }
        elif instrument_type == "cmos":
            self._parameters = {
                'xshift': 0.0,
                'yshift': 0.0,
                'rotation': 0.0,
                'gain': 1.5,
                'read_noise': 2.0,
                # CMOS specific
                'rolling_shutter': True,
                'pixel_variation': 0.02,
            }
        elif instrument_type == "emccd":
            self._parameters = {
                'xshift': 0.0,
                'yshift': 0.0,
                'rotation': 0.0,
                'gain': 1.0,
                'em_gain': 100.0,  # Electron multiplication
                'read_noise': 0.1,
                'clock_voltage': 10.0,  # EMCCD specific
            }
        
        self.pixel_response = np.ones((1024, 1024))
    
    def bandpass(self, wavelengths: np.ndarray) -> np.ndarray:
        """Instrument bandpass function."""
        # Simple Gaussian bandpass
        return np.exp(-((wavelengths - 550e-9) / 100e-9)**2)

In [17]:
inst = Instrument()
atmo = Atmosphere()
cat = Catalog()

import healpy as hp
components = SceneComponents(
    healpix_maps={
        'airglow': np.random.randn(hp.nside2npix(64)) * 0.1 + 1.0,
        'zodiacal': np.random.randn(hp.nside2npix(64)) * 0.05 + 0.5,
    },
    catalogs=[cat],
    instrument=inst,
    atmosphere=atmo,
)

scene = Scene(components)

In [25]:
scene.print_parameters()


SCENE PARAMETERS

[ATMOSPHERE]
  atmosphere.aod                           =     0.1000
  atmosphere.pressure                      =     1.2000
  atmosphere.water_vapor                   =    10.0000

[INSTRUMENT]
  instrument.dark_current                  =     1.2000
  instrument.gain                          =     1.0000
  instrument.read_noise                    =     5.0000
  instrument.rotation                      =     0.0000
  instrument.xshift                        =     0.0000
  instrument.yshift                        =     0.0000


In [26]:
scene.update({'instrument.rotation':0.1})

Updated instrument.rotation = 0.1


In [27]:
scene.print_parameters()


SCENE PARAMETERS

[ATMOSPHERE]
  atmosphere.aod                           =     0.1000
  atmosphere.pressure                      =     1.2000
  atmosphere.water_vapor                   =    10.0000

[INSTRUMENT]
  instrument.dark_current                  =     1.2000
  instrument.gain                          =     1.0000
  instrument.read_noise                    =     5.0000
  instrument.rotation                      =     0.1000
  instrument.xshift                        =     0.0000
  instrument.yshift                        =     0.0000
