In [18]:
%load_ext autoreload
%autoreload 2
import os
import sys
# set gpu to be pci bus id
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
# set gpu memory usage and turnoff pre-allocated memory
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'


import optax
import equinox as eqx
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
test = jnp.zeros((10,10,10))
import jax.random as random
from imaging_system import ImagingSystem, ImagingSystemProtocol

import sys
sys.path.append('/home/emarkley/Workspace/PYTHON/EncodingInformation')
from encoding_information.models.pixel_cnn import PixelCNN
from encoding_information.models.gaussian_process import FullGaussianProcess
from encoding_information.information_estimation import *

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from IPython.display import display
import wandb
wandb.login()

import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from losses import PixelCNNLoss, GaussianLoss, GaussianEntropyLoss
from optimizers import IDEALOptimizer, param_labels
from imaging_systems.spectral_diffuser_scope.imaging_system import GaussianPSFLayer, GaussianSensorLayer
from imaging_systems.spectral_diffuser_scope.data_generator import SpectralDataGenerator


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Define an Imaging System

In [19]:
class HyperspectralImagingSystem(ImagingSystem):
    psf_layer: GaussianPSFLayer 
    gaussian_sensor_layer: GaussianSensorLayer
    seed: int = eqx.field(static=True)
    rng_key: jax.random.PRNGKey = eqx.field(static=True)

    def __init__(self, psf_layer, gaussian_sensor_layer, seed: int = 0):
        super().__init__(seed)
        self.psf_layer = psf_layer
        self.gaussian_sensor_layer = gaussian_sensor_layer
        self.seed = seed
        self.rng_key = random.PRNGKey(seed)

    @eqx.filter_jit
    def __call__(self, objects: jnp.ndarray) -> jnp.ndarray:
        """JIT-compiled forward pass"""
        return self.forward_model(objects)

    def forward_model(self, objects: jnp.ndarray) -> jnp.ndarray:
        """
        Runs the forward model using the hyperspectral imaging system.

        Args:
            objects: Input objects of shape (H, W, C).

        Returns:
            measurements: Output measurements of shape (H, W, C).
        """
        key = self.next_rng_key()
        x = self.psf_layer(objects, key=key)
        x = self.gaussian_sensor_layer(x, key=key)
        # clip the output to be nonnegative
        x = jnp.where(x < 1e-8, 1e-8, x)
        return x

    def reconstruct(self, measurements: jnp.ndarray) -> jnp.ndarray:
        """
        Performs reconstruction from the measurements.

        Args:
            measurements: Input measurements of shape (H, W, C).

        Returns:
            reconstructions: Reconstructed objects of shape (H, W, C).
        """
        # Placeholder: Implement reconstruction logic if available.
        return measurements

    def toy_images(self, batch_size: int, height: int, width: int, channels: int) -> jnp.ndarray:
        """
        Generates toy images for testing the system.

        Args:
            batch_size: Number of images to generate.
            height: Height of each image.
            width: Width of each image.
            channels: Number of channels in each image.

        Returns:
            Toy images of shape (batch_size, height, width, channels).
        """
        key = self.next_rng_key()
        return random.uniform(key, shape=(batch_size, height, width, channels), minval=0, maxval=1)

    def display_measurement(self, measurement: jnp.ndarray) -> None:
        fig, ax = plt.subplots(figsize=(5, 5))
        im = ax.imshow(measurement, cmap='gray')
        fig.colorbar(im, ax=ax)
        plt.close()

        return fig
    
    def display_object(self, object: jnp.ndarray) -> None:
        fig, ax = plt.subplots(figsize=(5, 5))
        im = ax.imshow(jnp.max(object, axis=0), cmap='gray')
        plt.close()

        return fig
    
    def display_optics(self) -> None:
        # Create figure with 3 subplots
        fig = plt.figure(figsize=(15, 5))
        
        # 1. PSF Plot (left)
        ax1 = fig.add_subplot(131)
        im1 = ax1.imshow(self.psf_layer.compute_psf(), cmap='gray')
        fig.colorbar(im1, ax=ax1)
        ax1.set_title('PSF')
        ax1.axis('off')
        
        # 2. Learned Pattern Plot (middle)
        ax2 = fig.add_subplot(132)
        model_sensor = self.gaussian_sensor_layer
        wavelengths = self.gaussian_sensor_layer.wavelengths
        norm_wavelengths = (wavelengths - np.min(wavelengths)) / (np.max(wavelengths) - np.min(wavelengths))
        colors = plt.cm.nipy_spectral(norm_wavelengths)
        custom_cmap = LinearSegmentedColormap.from_list("custom_colormap", 
                                                    list(zip(norm_wavelengths, colors)), 
                                                    N=self.gaussian_sensor_layer.num_waves)
        
        color_mask = self.gaussian_sensor_layer.means
        color_mask = (color_mask-np.min(wavelengths))/(np.max(wavelengths)-np.min(wavelengths))*model_sensor.num_waves
        
        im2 = ax2.imshow(color_mask, cmap=custom_cmap, vmin=0, vmax=model_sensor.num_waves)
        cbar = fig.colorbar(im2, ax=ax2, ticks=np.linspace(0, self.gaussian_sensor_layer.num_waves, 5))
        cbar.set_label('Wavelength (nm)')
        cbar.set_ticklabels(np.linspace(self.gaussian_sensor_layer.min_wave, self.gaussian_sensor_layer.max_wave, 5).astype(int))
        ax2.set_title('Learned Pattern')
        ax2.axis('off')
        
        # 3. Filter Distribution Plot (right)
        ax3 = fig.add_subplot(133)
        sensor = self.gaussian_sensor_layer.get_sensor()
        for ii in range(self.gaussian_sensor_layer.means.shape[0]):
            for jj in range(self.gaussian_sensor_layer.means.shape[1]):
                ax3.plot(self.gaussian_sensor_layer.wavelengths, sensor[...,ii,jj])
        ax3.set_title('Sensor Distribution')
        ax3.set_xlabel('Wavelength (nm)')
        ax3.set_ylabel('Response')
        
        # Adjust layout
        plt.tight_layout()
        plt.close()
        
        return fig
    
    def normalize_psf(self):
        new_psf_layer = self.psf_layer.normalize_psf()
        return eqx.tree_at(lambda m: m.psf_layer, self, new_psf_layer)
    
    def update_means(self):
        new_sensor_layer = self.gaussian_sensor_layer.update_means()
        return eqx.tree_at(lambda m: m.gaussian_sensor_layer, self, new_sensor_layer)
    
    def update_stds(self):
        new_sensor_layer = self.gaussian_sensor_layer.update_stds()
        return eqx.tree_at(lambda m: m.gaussian_sensor_layer, self, new_sensor_layer)
    
    def normalize(self):
        """Run all normalization and update steps.
        
        Returns:
            Updated imaging system with normalized PSF and updated sensor parameters.
        """
        system = self.normalize_psf()
        system = system.update_means()
        system = system.update_stds()
        return system

# Define Imaging System Parameters

In [20]:
# general parameters
key = jax.random.PRNGKey(42)

# dataset parameters
spectra_folder = '/home/emarkley/Workspace/PYTHON/HyperspectralIdeal/Spectra'
subset_fraction = 1.0
sparsity_factor = 2
photon_scale = 100.0
mosaic_rows = 19
mosaic_cols = 19
batch_size = 1

# gaussian psf layer parameters
object_size = mosaic_rows * 28
num_gaussian = 10
psf_size = (32, 32)

# define the variables for the Gaussian sensor layer
min_wave = 400
max_wave = 800
num_waves = 101
min_std = 4
max_std = 100
sensor_size = object_size
super_pixel_size = 4

# define parameters for IDEAL optimization
patch_size = 16
num_patches = 1024
patching_strategy = 'random'
num_steps = 1000
loss_type = 'gaussian'
refit_every = 50
gaussian_sigma = None # if none poisson noise is used, else gaussian noise with standard deviation sigma is used

# wandb parameters
use_wandb=True
project_name='ideal_development'
run_name='gaussian_loss'
log_every = 100
validate_every = 500

# Create Your Imaging System

In [21]:
# define the PSF layer
psf_layer = GaussianPSFLayer(object_size, num_gaussian, psf_size, key=key)

# define the gaussian sensor layer
gaussian_sensor_layer = GaussianSensorLayer(min_wave, max_wave, num_waves, min_std, max_std, sensor_size, super_pixel_size)

# Define the imaging system
imaging_system = HyperspectralImagingSystem(psf_layer, gaussian_sensor_layer)

# Choose Your Learning Rate

In [22]:
labels = param_labels(imaging_system)

Learnable parameters:
psf_layer.means
psf_layer.covs
psf_layer.weights
gaussian_sensor_layer.means
gaussian_sensor_layer.stds


In [23]:
pl_means_lr = 1e-2
pl_covs_lr = 1e-3
pl_weights_lr = 1e-4
gs_means_lr = 8e-2
gs_stds_lr = 8e-2

optimizer = optax.multi_transform(
    {
        'psf_layer.means': optax.adam(learning_rate=pl_means_lr),
        'psf_layer.covs': optax.adam(learning_rate=pl_covs_lr),
        'psf_layer.weights': optax.adam(learning_rate=pl_weights_lr),
        'gaussian_sensor_layer.means': optax.adam(learning_rate=gs_means_lr),
        'gaussian_sensor_layer.stds': optax.adam(learning_rate=gs_stds_lr),
    },

    param_labels = labels 
)

# Create a Dataset

In [24]:
# Create a Data Generator
data_generator = SpectralDataGenerator(spectra_folder, subset_fraction)

# Load MNIST data
x_train, x_test = data_generator.load_mnist_data()

# Create training dataset with sparsity
train_dataset = data_generator.create_sparse_dataset(
    x_train,
    sparsity_factor=sparsity_factor,
    scale=photon_scale,
    mosaic_rows=mosaic_rows,
    mosaic_cols=mosaic_cols,
    batch_size=batch_size
)

# Define the Loss Function

In [25]:
if loss_type == 'pixelcnn':
    loss_fn = PixelCNNLoss(refit_every=refit_every)
elif loss_type == 'gaussian_entropy':
    loss_fn = GaussianEntropyLoss()
elif loss_type == 'gaussian':
    loss_fn = GaussianLoss()
else:
    raise ValueError(f"Loss type {loss_type} not supported")

# Create the Optimizer

In [26]:
ideal_optimizer = IDEALOptimizer(
    imaging_system, 
    optimizer, 
    loss_fn,
    patch_size = patch_size,
    num_patches= num_patches,
    patching_strategy=patching_strategy,
    gaussian_sigma=gaussian_sigma,
    use_wandb=use_wandb,
    project_name=project_name,
    run_name=run_name
)

# Optimize!!!!!

In [None]:
optimized_imaging_system = ideal_optimizer.optimize(
    train_dataset,
    num_steps,
    log_every=log_every,
    validate_every=validate_every
)

# Save the optimized imaging system
eqx.tree_serialise_leaves(run_name + "_optimized_imaging_system.eqx", ideal_optimizer.imaging_system)

  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


 10%|█         | 100/1000 [00:58<08:06,  1.85it/s]