In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
# set gpu to be pci bus id
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# set gpu memory usage and turnoff pre-allocated memory
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'


import optax
import equinox as eqx
import jax.numpy as jnp
test = jnp.zeros((10,10,10))

import sys

sys.path.append('/home/nmsingh/dev/EncodingInformation/src')
sys.path.append('/home/nmsingh/dev/EncodingInformation/ideal')

import wandb
wandb.login()

from ideal.losses import PixelCNNLoss, GaussianLoss, GaussianEntropyLoss
from ideal.optimizers import IDEALOptimizer, param_labels
from ideal.imaging_systems.xray_ptychography.xray_ptychography import XRayPtychography
from ideal.imaging_systems.xray_ptychography.data_generator import CellDataGenerator


# Define Imaging System Parameters

In [30]:
# define parameters for IDEAL optimization
patch_size = 16
num_patches = 1024
patching_strategy = 'random'
num_steps = 1000
loss_type = 'gaussian'
refit_every = 50
gaussian_sigma = None # if none poisson noise is used, else gaussian noise with standard deviation sigma is used

# wandb parameters
use_wandb=True
project_name='ideal_development'
run_name='gaussian_loss'
log_every = 100
validate_every = 500

# Create Your Imaging System

In [31]:
# Define the imaging system
imaging_system = XRayPtychography()

# Choose Your Learning Rate

In [None]:
labels = param_labels(imaging_system)

In [33]:
optimizer = optax.multi_transform(
    {
        'mask': optax.adam(learning_rate=1e-3),
    },
    param_labels = labels 
)
batch_size = 4

# Create a Dataset

In [34]:
# Create a Data Generator
data_generator = CellDataGenerator()

# Create training dataset with sparsity
train_dataset = data_generator.create_dataset(
    batch_size=batch_size
)

# Define the Loss Function

In [35]:
if loss_type == 'pixelcnn':
    loss_fn = PixelCNNLoss(refit_every=refit_every)
elif loss_type == 'gaussian_entropy':
    loss_fn = GaussianEntropyLoss()
elif loss_type == 'gaussian':
    loss_fn = GaussianLoss()
else:
    raise ValueError(f"Loss type {loss_type} not supported")

# Create the Optimizer

In [36]:
ideal_optimizer = IDEALOptimizer(
    imaging_system, 
    optimizer, 
    loss_fn,
    patch_size = patch_size,
    num_patches= num_patches,
    patching_strategy=patching_strategy,
    gaussian_sigma=gaussian_sigma,
    use_wandb=use_wandb,
    project_name=project_name,
    run_name=run_name
)

# Optimize!!!!!

In [None]:
optimized_imaging_system = ideal_optimizer.optimize(
    train_dataset,
    num_steps,
    log_every=log_every,
    validate_every=validate_every
)

# Save the optimized imaging system
eqx.tree_serialise_leaves(run_name + "_optimized_imaging_system.eqx", ideal_optimizer.imaging_system)