In [1]:
import sys
sys.path.insert(0, "../")

import jax.numpy as jnp
from jax import lax
from jax import random
from flax import linen as nn
import jax
import jax.numpy as jnp
test = jnp.ones((3,3))
import optax

import numpy as np
import torch
import matplotlib.pyplot as plt
import torchvision
import os
from PIL import Image
from scipy import optimize
import scipy.signal as sp
import models.deep_image_prior.sdc_config3 as sdc
import wandb
from ipywidgets import IntProgress
from IPython.display import display
import time


class SkipConnection(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        # Apply the main convolutional layer
        y = nn.Conv(features=self.features, kernel_size=(3, 3), strides=(1, 1))(x)
        y = nn.leaky_relu(y, negative_slope=0.01)
        # Apply the skip connection, summing the input to the output of the convolutional layer
        return x + y

class Upsample(nn.Module):
    features: int

    def setup(self):
        self.upsample_layer = nn.ConvTranspose(features=self.features, kernel_size=(3, 3), strides=(2, 2))

    def __call__(self, x):
        return self.upsample_layer(x)

class HyperspectralNetwork(nn.Module):
    @nn.compact
    def __call__(self, x):
        # Initial convolution layer
        x = nn.Conv(features=45, kernel_size=(3, 3))(x)
        x = nn.leaky_relu(x, negative_slope=0.01)
        # 5 layers of skip connections
        for _ in range(5):
            x = SkipConnection(features=45)(x)
        # Upsampling layer
        x = Upsample(features=1)(x)
        return x
    
def color_visualize(image, wavelengths, title='', figsize=(10,10)):
    # Create false color filter
    HSI_data = jnp.transpose(image, (1, 2, 0))
    HSI_data = jnp.reshape(HSI_data, [-1, image.shape[0]])
    false_color_image = sdc.HSI2RGB_jax(wavelengths, HSI_data , image.shape[1], image.shape[2], 65, False)
    
    plt.figure(figsize=figsize)
    plt.imshow(false_color_image**.6)
    if title!='':
        plt.title(title)
    plt.axis('off')
    plt.show()

def wandb_log_meas(wandb_log, meas):
    # plot the lenslet positions
    fig, ax = plt.subplots(figsize=(10,10))
    # plot the measurement
    ax.imshow(meas, cmap='gray', vmin=0, vmax=1)
    # turn off the axis
    ax.axis('off')
    # add plot to log dictionary
    wandb_log['experimental_measurement'] = fig
    # close the figure
    plt.close()
    return wandb_log

def wandb_log_sim_meas(wandb_log, meas):
    # make the graph
    fig, ax = plt.subplots(figsize=(10,10))
    # plot the measurement
    ax.imshow(meas, cmap='gray', vmin=0, vmax=1)
    # turn off the axis
    ax.axis('off')
    # add plot to log dictionary
    wandb_log['simulated_measurement'] = fig
    # close the figure
    plt.close()
    return wandb_log

def wandb_log_false_color_recon(wandb_log, recon, wavelengths):
    # Create false color filter
    HSI_data = jnp.transpose(recon, (1, 2, 0))
    HSI_data = jnp.reshape(HSI_data, [-1, recon.shape[0]])
    false_color_image = sdc.HSI2RGB_jax(wavelengths, HSI_data , recon.shape[1], recon.shape[2], 65, False)

    # make the graph
    fig, ax = plt.subplots(figsize=(10,10))
    # plot the false color recon
    ax.imshow(false_color_image**.6)
    # turn off the axis
    ax.axis('off')
    # add plot to log dictionary
    wandb_log['false_color_recon'] = fig
    # close the figure
    plt.close()
    return wandb_log

def wandb_log_psf(wandb_log, psf):
    # make the graph
    fig, ax = plt.subplots(figsize=(10,10))
    # plot the psf
    ax.imshow(psf, cmap='gray')
    # turn off the axis
    ax.axis('off')
    # add plot to log dictionary
    wandb_log['psf'] = fig
    # close the figure
    plt.close()
    return wandb_log

def wandb_log_ground_truth(wandb_log, gt):
    
    # make the graph
    fig, ax = plt.subplots(figsize=(10,10))
    # plot the psf
    if gt.shape[-1]!=3:
        ax.imshow(gt, cmap='gray')
    else:
        ax.imshow(gt)
    # turn off the axis
    ax.axis('off')
    # add plot to log dictionary
    wandb_log['ground_truth'] = fig
    # close the figure
    plt.close()
    return wandb_log

def crop2D(target_shape, mat):
    y_margin = (mat.shape[-2] - target_shape[-2]) // 2
    x_margin = (mat.shape[-1] - target_shape[-1]) // 2
    if mat.ndim == 2:
        return mat[y_margin:-y_margin or None, x_margin:-x_margin or None]
    elif mat.ndim == 3:
        return mat[:, y_margin:-y_margin or None, x_margin:-x_margin or None]
    else:
        raise ValueError('crop2D only supports 2D and 3D arrays')

def process_image(object, spectral_filter, psf):

    expanded_psf = jnp.asarray(psf)[None, ...]
    padded_psf = jnp.pad(expanded_psf, ((0,0),(jnp.ceil(expanded_psf.shape[1]/2).astype(int),jnp.floor(expanded_psf.shape[1]/2).astype(int)),(jnp.ceil(expanded_psf.shape[2]/2).astype(int),jnp.floor(expanded_psf.shape[2]/2).astype(int))), mode='constant', constant_values=0)
    padded_object = jnp.pad(object, ((0,0),(jnp.ceil(object.shape[1]/2).astype(int),jnp.floor(object.shape[1]/2).astype(int)),(jnp.ceil(object.shape[2]/2).astype(int),jnp.floor(object.shape[2]/2).astype(int))), mode='constant', constant_values=0)

    fft_psf = jnp.fft.fft2(padded_psf)
    fft_object = jnp.fft.fft2(padded_object)
    fft_product = fft_psf * fft_object

    ifft_product = jnp.fft.ifftshift(jnp.fft.ifft2(fft_product), axes=(1,2))
    ifft_product = abs(crop2D(object.shape, ifft_product))
    ifft_product = ifft_product * jnp.asarray(spectral_filter)

    return jnp.sum(ifft_product, axis=0)




In [4]:
calibration_location = 'recon_materials_organized'
psf_name = 'psf_2023-11-16.pt'
calibration_wavelengths_file = 'connie_cal_waves.pt'
filter_cube_file = 'connie_normalized_filter_cube.pt'
datafolder = '/media/azuldata/neerja/2023-11-17/Bear10x'
ground_truth_spectra_locs = ['/media/azuldata/neerja/2022-04-08/greenbeads_oldslide_emissionspectra.csv', 
                            '/media/azuldata/neerja/2022-04-08/redbeads_fresh_emissionspectra.csv']
bits = 16
gpu = 1
crop_indices = [420,1752,1150,2926]
wvmin = 450
wvmax = 800
wvstep = 8
downsample_factor = 1
sample = datafolder.split('/')[-1]
max_iters = 10000
step_size = 1e-5
loss_type = 'l2'
kprint = 10

run_name = sample + 'loss_type='+loss_type+'_kmax={}_stepsize={}_downsample={}'.format(max_iters, step_size, downsample_factor)

# load the measurement and display it to make sure your crop indices are correct

# Uncomment this line to load from a folder of measurements to average over
# Uncomment this line to load from a single measurement 
try:
    sample_meas = sdc.importTiff(datafolder,'meas.tiff')/2**bits
except:
    sample_meas = torch.mean(sdc.tif_loader(os.path.join(datafolder,'measurements'))/2**bits,0)
try:
    background = sdc.importTiff(datafolder,'bg.tiff')/2**bits
except:
    print('No background image found, continuing without background subtraction')
    background = torch.zeros(sample_meas.shape)
meas = sdc.cropci((sample_meas-background).clip(0,1),crop_indices)

sdc.bw_visualize(meas)

IntProgress(value=0, max=1)

ValueError: <COMPRESSION.LZW: 5> requires the 'imagecodecs' package

In [5]:
# ACTION: CHOOSE WHICH GPU TO USE (0-3)
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'

# ACTION: CHOOSE WHICH GPU TO USE (0-3)
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)

# Set up waits and biases run
wandb.init(
    # Set the project where this run will be logged
    project='HyperSpectralDiffuserScope_UnlearnedRecons', 
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=run_name,
    # Track hyperparameters and run metadata
    config={"wvmin": wvmin,
            "wvmax": wvmax,
            "wvstep": wvstep,
            "max_iters": max_iters,
            "step_size": step_size, 
            "downsample_factor": downsample_factor,
            "sample": sample,
            "gpu": gpu}
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33memarkley[0m ([33mwallerlab[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
device = 'cpu'
# Load wavelength calibration and downsample to spectral resolution of filter cube
wv = torch.load(os.path.join(calibration_location, calibration_wavelengths_file),map_location='cpu')
wavelengths = np.arange(wvmin,wvmax+wvstep,wvstep)

# Load and crop filter cube
normalized_filter_cube = torch.load(os.path.join(calibration_location,filter_cube_file),map_location='cpu')
filterstack = sdc.cropci(normalized_filter_cube, crop_indices)
msum = sdc.sumFilterArray(filterstack,wv,wvmin,wvmax,wvstep)
spectral_filter = msum/torch.amax(msum)

# Load and crop PSF
sensor_psf = torch.load(os.path.join(calibration_location, psf_name),map_location='cpu')
ccrop = torchvision.transforms.CenterCrop(spectral_filter.shape[1:])
psf = ccrop(sensor_psf)
psf = psf/torch.sum(psf)
psf = psf.clip(0)

# Load measurement
# Uncomment this line to load from a folder of measurements to average over
# sample_meas = torch.mean(sdc.tif_loader(os.path.join(datafolder,'measurements'))/2**bits,0)
# Uncomment this line to load from a single measurement 
sample_meas = sdc.importTiff(datafolder,'meas.tiff')/2**bits
try:
    background = sdc.importTiff(datafolder,'bg.tiff')/2**bits
except:
    print('No background image found, continuing without background subtraction')
    background = torch.zeros(sample_meas.shape)
meas = sdc.cropci((sample_meas-background).clip(0,1),crop_indices)


# Load ground truth image
try:
    gt = sdc.importTiff(datafolder,'gt.tiff')
    if gt.shape[-1]!=3:
        gt = gt/2**bits
        gt = torchvision.transforms.functional.rotate(gt.unsqueeze(0),-90).squeeze()
    gt = gt/torch.max(gt)
except:
    print('No ground truth image found, continuing without ground truth')
    gt = torch.zeros(meas.shape)

# Load ground truth spectra
legend = []
spectra = []
for loc in ground_truth_spectra_locs:
    gt_wv,gt_int = sdc.loadspectrum(loc)
    gt_int = np.interp(wavelengths,gt_wv,gt_int)
    gt_int = gt_int/np.max(gt_int)
    gt_int = torch.tensor(gt_int).to(device)
    spectra.append(gt_int)
    legend.append(loc.split('/')[-1].split('.')[0].split('_')[0])



# Downsample and move everything to GPU
spectral_filter = spectral_filter[:,::downsample_factor,::downsample_factor].to(device)
psf = psf[::downsample_factor,::downsample_factor].to(device)
meas = meas[::downsample_factor,::downsample_factor].to(device)
gt = gt.to(device)

808


In [7]:
# Define the input shape based on the experimental data
input_shape = (1, 45, 450, 450, 1)  # Batch size, depth, height, width, channels

# Randomly initialize the network and the input
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
init_input = jax.random.normal(init_rng, input_shape)

# Create the network
hyperspectral_network = HyperspectralNetwork()

# Initialize parameters
params = hyperspectral_network.init(rng, init_input)

# Apply the network to the input
output = hyperspectral_network.apply(params, init_input)

# Print the output shape
target_shape = output.squeeze().shape[1:]

# crop psf and spectral filter to target shape
sim_psf = crop2D(target_shape, psf)
sim_spectral_filter = crop2D(target_shape, spectral_filter)
meas = jnp.asarray(meas)
target_meas = crop2D(target_shape, meas)

recon = hyperspectral_network.apply(params, init_input).squeeze()
sim_meas = process_image(recon, sim_spectral_filter, sim_psf)

# define loss function
def l1_loss(params, init_input):
    recon = hyperspectral_network.apply(params, init_input).squeeze().clip(0)
    sim_meas = process_image(recon, sim_spectral_filter, sim_psf)
    return jnp.sum(jnp.abs(sim_meas - target_meas))

def l2_loss(params, init_input):
    recon = hyperspectral_network.apply(params, init_input).squeeze().clip(0)
    sim_meas = process_image(recon, sim_spectral_filter, sim_psf)
    return jnp.sum((sim_meas - target_meas)**2)

if loss_type == 'l1':
    loss_and_grad = jax.value_and_grad(l1_loss, (0))
elif loss_type == 'l2':
    loss_and_grad = jax.value_and_grad(l2_loss, (0))

# define optimizer
network_optimizer = optax.adam(step_size)
network_optimizer_state = network_optimizer.init(params)

for ii in range(max_iters):
    loss, grad = loss_and_grad(params, init_input)

    wandb_log = {}
    wandb_log['loss'] = loss
    if ii == 0:
        wandb_log = wandb_log_meas(wandb_log, target_meas)
        wandb_log = wandb_log_psf(wandb_log, psf)
        wandb_log = wandb_log_ground_truth(wandb_log, gt)

    if ii % kprint == 0:
        recon = hyperspectral_network.apply(params, init_input).squeeze().clip(0)
        current_meas = process_image(recon, sim_spectral_filter, sim_psf)
        wandb_log = wandb_log_sim_meas(wandb_log, current_meas)
        wandb_log = wandb_log_false_color_recon(wandb_log, recon/jnp.max(recon)*jnp.sum(recon,0)[None,...], wavelengths)

    updates, network_optimizer_state = network_optimizer.update(grad, network_optimizer_state, params)
    params = optax.apply_updates(params, updates)
    wandb.log(wandb_log)



2024-01-08 11:03:50.077598: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng35{k2=1,k5=3,k14=2} for conv (f32[1,45,3,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,45,901,901]{3,2,1,0}, f32[45,45,899,899]{3,2,1,0}), window={size=899x899 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0} is taking a while...
2024-01-08 11:03:54.756379: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 5.678864012s
Trying algorithm eng35{k2=1,k5=3,k14=2} for conv (f32[1,45,3,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,45,901,901]{3,2,1,0}, f32[45,45,899,899]{3,2,1,0}), window={size=899x899 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0} is taking a while...
2024-01-08 11:03:55

KeyboardInterrupt: 