In [1]:
import numpy as np, torch, scipy.io as io, os, matplotlib.pyplot as plt, h5py, PIL.Image as Image, pathlib
import sys
sys.path.insert(0, "../..")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

import utils.helper_functions as helper
import utils.diffuser_utils as diffuser_utils
import dataset.preprocess_data as prep_data
import train

%load_ext autoreload
%autoreload 2

%matplotlib inline

SAVE_GT_PATH = "/home/cfoley/defocuscamdata/recons/sim_comparison_figure/model_input_gts"

## Notebook to reproduce Figure 6: simulated methods comparison collage

We select 4 images from available datasets and visually benchmark the performance of our method against comparable methods, and the ground truth images.
All image outputs are false-color projections of 3d hyperspectral volumes

### Setup

Let's start by getting and standardizing the image samples. You will need to have run the data fetching script in the `/studies` [README.md](../README.md) first, so that this notebook has local access to the sample data.

In [2]:
device = "cuda:0" # Change this to your desired gpu, or "cpu" for CPU processing

SAMPLE_DATA_DIRPATH = os.path.join(pathlib.Path().resolve(), "data") # TODO: script to get this from google drive. These have some preprocessing applied already
OUTPUTS_DIRPATH = os.path.join(pathlib.Path().resolve(), "outputs")
CONFIGURATION_FILE_DIRPATH = os.path.join(pathlib.Path().resolve(), "configs")

harvard_bushes = os.path.join(SAMPLE_DATA_DIRPATH, "imgf8_patch_0.mat")
kaist_img = os.path.join(SAMPLE_DATA_DIRPATH, "scene03_reflectance.mat")
fruit_artichoke = os.path.join(SAMPLE_DATA_DIRPATH, "internals_artichoke_SegmentedCroppedCompressed.mat")
icvl_color_checker = os.path.join(SAMPLE_DATA_DIRPATH, "IDS_COLORCHECK_1020-1223.mat")

icvl_color_checker = prep_data.project_spectral(np.asarray(h5py.File(icvl_color_checker)['rad']).transpose(1,2,0)[::-1, ::-1], 30)[300:820, 200:820]
kaist_img = prep_data.project_spectral(io.loadmat(kaist_img)['ref'][300:300+420*5, 200:200+620*5], 30)
harvard_bushes = io.loadmat(harvard_bushes)['image']
fruit_artichoke = prep_data.project_spectral(prep_data.read_compressed(io.loadmat(fruit_artichoke)), 30).transpose(1,0,2)

icvl_color_checker.shape, kaist_img.shape, harvard_bushes.shape, fruit_artichoke.shape

In [3]:
def prep_image(image, crop_shape, patch_shape):
    """ 
    Our sample images are all (H X W X C), but of different sizes. 
    This helper function stantardizes these image shapes, normalizes them, and prepares them to be used as input 
    to the model.
    """
    image = np.stack(
        [diffuser_utils.pyramid_down(image[:crop_shape[0],:crop_shape[1],i],patch_shape) for i in range(image.shape[-1])], 0
    )

    image = (image - max(0., np.min(image)))
    image = image / np.max(image)
    image = torch.tensor(image)[None, None,...]
    return image


def save_image_fc_npy(image, savename, fc_range=(420,720)):
    """ 
    Helper to save the raw numpy image alongside a false-color projection with a configurable (to the data)
    spectral range.
    """
    if not os.path.exists(os.path.dirname(savename)):
        os.makedirs(os.path.dirname(savename))

    print("Saving: ", savename + ".npy")
    np.save(savename + ".npy", image)

    print("Saving fc: ", savename + ".png")
    fc_img = helper.select_and_average_bands(image, fc_range=fc_range)
    fc_img = Image.fromarray(((fc_img / fc_img.max())*255).astype(np.uint8))
    fc_img.save(savename + ".png")
    return fc_img

def show_fc(img, fc_range=(420,720)):
    """ 
    Helper to visualize the false-color projection of an image with a configurable (to the data) spectral
    range.
    """

    plt.figure(dpi=100)
    rgbimg = helper.select_and_average_bands(img, fc_range=fc_range)
    plt.imshow(rgbimg / np.max(rgbimg))
    plt.show()

In [4]:
# Unstandardized images in false color. Let's visualize them here.
show_fc(icvl_color_checker)
show_fc(fruit_artichoke, fc_range=(400,780))
show_fc(harvard_bushes)
show_fc(kaist_img)

In [5]:
# Next, we standardize the images to a common shape. Let's save these standardized inputs off below, 
# and visualize them as we do.
harvard_bushes_gt = prep_image(harvard_bushes, harvard_bushes.shape[:2], (420,620))
fruit_artichoke_gt = prep_image(fruit_artichoke, fruit_artichoke.shape[:2], (420,620))
icvl_color_checker_gt = prep_image(icvl_color_checker, icvl_color_checker.shape[:2], (420,620))
kaist_img_gt = prep_image(kaist_img, kaist_img.shape[:2], (420,620))

In [6]:
harvard_bushes_gt_name = os.path.join(OUTPUTS_DIRPATH, "ground_truth", "harvard_bushes")
save_image_fc_npy(harvard_bushes_gt[0,0].numpy().transpose(1,2,0), harvard_bushes_gt_name)

In [7]:
fruit_artichoke_gt_name = os.path.join(OUTPUTS_DIRPATH, "ground_truth", "fruit_artichoke")
save_image_fc_npy(fruit_artichoke_gt[0,0].numpy().transpose(1,2,0), fruit_artichoke_gt_name, fc_range=(400,780))

In [8]:
icvl_color_checker_gt_name = os.path.join(OUTPUTS_DIRPATH, "ground_truth", "icvl_color_checker")
save_image_fc_npy(icvl_color_checker_gt[0,0].numpy().transpose(1,2,0), icvl_color_checker_gt_name)

In [9]:
kaist_img_gt_name = os.path.join(OUTPUTS_DIRPATH, "ground_truth", "kaist_scene03")
save_image_fc_npy(kaist_img_gt[0,0].numpy().transpose(1,2,0), kaist_img_gt_name)

### Reconstructions

Below, in each section, we run the code to reproduce the reconstructed images in each column of the figure. Since some of these reconstructions run many iterations of FISTA, these cells may take a while. It will help if you have a GPU!

In [10]:
from models.ensemble import SSLSimulationModel
# Some hyperparameters for visualization in this notebook.

FISTA_ITERATIONS_BETWEEN_PRINT = 50
FISTA_PLOT_WITH_PRINT = False # set to True if you want to see intermediate FISTA reconstructions

handshake_fista_config = os.path.join(CONFIGURATION_FILE_DIRPATH, "handshake_fista.yml")
diffusercam_fista_config = os.path.join(CONFIGURATION_FILE_DIRPATH, "diffusercam_fista.yml")
defocuscam_fista_config = os.path.join(CONFIGURATION_FILE_DIRPATH, "defocuscam_fista.yml")
defocuscam_learned_config = os.path.join(CONFIGURATION_FILE_DIRPATH, "defocuscam_learned.yml")

In [11]:
def get_model(config_path) -> SSLSimulationModel:
    """
    Get a SSLSimulationModel instance configured with a forward model that simulates our
    hyperspectral imager of choice (as per the config file), and a learned or iterative
    reconstruction model to solve the inverse problem.
    """
    config = helper.read_config(config_path)
    model = train.get_model(config, device=device)

    print("Simulation model loaded: ", model.model1.operations)
    print(f"Reconstruction model loaded: {model.model2}")
    return model

def get_fista_model(config_path: str, fista_lr_mult = 1.0) -> SSLSimulationModel:
    """ Boilerplate additional steps when loading a model with FISTA recons"""
    model = get_model(config_path)
    rm = model.model2

    rm.L = rm.L * fista_lr_mult
    rm.print_every = FISTA_ITERATIONS_BETWEEN_PRINT
    rm.plot = FISTA_PLOT_WITH_PRINT
    print(f"FISTA params: prior={rm.prox_method}, L={rm.L}, tau={rm.tau}, tv_lambda="
          f"{rm.tv_lambda}, tv_lambdaw={rm.tv_lambdaw}, tv_lambdax={rm.tv_lambdax}")
    return model

def run_fista_single_image(model, image):
    """" Run the hyperspectral image through our simulation and reconstruction models."""
    forward_model, recon_model = model.model1, model.model2
    simulated_measurement = forward_model(image.to(recon_model.device))
    recon_model(simulated_measurement.squeeze(dim=(0,2)).to(recon_model.device))
    return recon_model.out_img

def save_reconstruction(recon: np.ndarray, save_namekey: str, fc_range=(420,720)):
    """ Boilerplate for saving off the fist model reconstructions with interpretable names."""
    out_path_stem = os.path.join(OUTPUTS_DIRPATH,  "reconstructions",  save_namekey)
    saved_fc_image = save_image_fc_npy(recon, out_path_stem, fc_range)
    return saved_fc_image  

#### HSI DiffuserCam with FISTA

In [None]:
diffusercam_fista_model = get_fista_model(diffusercam_fista_config, fista_lr_mult=0.2)

diffusercam_fista_harvard_bushes_savename = "harvard_bushes_diffusercam_fista_recon"
diffusercam_fista_icvl_savename = "icvl_color_checker_diffusercam_fista_recon"
diffusercam_fista_fruit_savename = "fruit_artichoke_diffusercam_fista_recon"
diffusercam_fista_kaist_savename = "kaist_scene03_diffusercam_fista_recon"

harvard_bushes_diffusercam_fista_recon = run_fista_single_image(diffusercam_fista_model, harvard_bushes_gt)
display(save_reconstruction(harvard_bushes_diffusercam_fista_recon, diffusercam_fista_harvard_bushes_savename))

icvl_color_checker_diffusercam_fista_recon = run_fista_single_image(diffusercam_fista_model, icvl_color_checker_gt) 
display(save_reconstruction(icvl_color_checker_diffusercam_fista_recon, diffusercam_fista_icvl_savename))

fruit_artichoke_diffusercam_fista_recon = run_fista_single_image(diffusercam_fista_model, fruit_artichoke_gt)
display(save_reconstruction(fruit_artichoke_diffusercam_fista_recon, diffusercam_fista_fruit_savename))

kaist_scene03_diffusercam_fista_recon = run_fista_single_image(diffusercam_fista_model, kaist_img_gt)
display(save_reconstruction(kaist_scene03_diffusercam_fista_recon, diffusercam_fista_kaist_savename))

del diffusercam_fista_model

#### Handshake Camera with FISTA

In [None]:
handshake_fista_model = get_fista_model(handshake_fista_config, fista_lr_mult=0.1)

handshake_fista_harvard_bushes_savename = "harvard_bushes_handshake_fista_recon"
handshake_fista_icvl_savename = "icvl_color_checker_handshake_fista_recon"
handshake_fista_fruit_savename = "fruit_artichoke_handshake_fista_recon"
handshake_fista_kaist_savename = "kaist_scene03_handshake_fista_recon"

# Run reconstructions and save
harvard_bushes_handshake_fista_recon = run_fista_single_image(handshake_fista_model, harvard_bushes_gt)
display(save_reconstruction(harvard_bushes_handshake_fista_recon, handshake_fista_harvard_bushes_savename))

icvl_color_checker_handshake_fista_recon = run_fista_single_image(handshake_fista_model, icvl_color_checker_gt)
display(save_reconstruction(icvl_color_checker_handshake_fista_recon, handshake_fista_icvl_savename))

fruit_artichoke_handshake_fista_recon = run_fista_single_image(handshake_fista_model, fruit_artichoke_gt)
display(save_reconstruction(fruit_artichoke_handshake_fista_recon, handshake_fista_fruit_savename))

kaist_scene03_handshake_fista_recon = run_fista_single_image(handshake_fista_model, kaist_img_gt)
display(save_reconstruction(kaist_scene03_handshake_fista_recon, handshake_fista_kaist_savename))

del handshake_fista_model

#### DefocusCam with FISTA

In [14]:
defocuscam_fista_model = get_fista_model(defocuscam_fista_config, fista_lr_mult=1)

defocuscam_fista_harvard_bushes_savename = "harvard_bushes_defocuscam_fista_recon"
defocuscam_fista_icvl_savename = "icvl_color_checker_defocuscam_fista_recon"
defocuscam_fista_fruit_savename = "fruit_artichoke_defocuscam_fista_recon"
defocuscam_fista_kaist_savename = "kaist_scene03_defocuscam_fista_recon"

harvard_bushes_defocuscam_fista_recon = run_fista_single_image(defocuscam_fista_model, harvard_bushes_gt)
display(save_reconstruction(harvard_bushes_defocuscam_fista_recon, defocuscam_fista_harvard_bushes_savename))

icvl_color_checker_defocuscam_fista_recon = run_fista_single_image(defocuscam_fista_model, icvl_color_checker_gt)
display(save_reconstruction(icvl_color_checker_defocuscam_fista_recon, defocuscam_fista_icvl_savename))

fruit_artichoke_defocuscam_fista_recon = run_fista_single_image(defocuscam_fista_model, fruit_artichoke_gt)
display(save_reconstruction(fruit_artichoke_defocuscam_fista_recon, defocuscam_fista_fruit_savename))

kaist_scene03_defocuscam_fista_recon = run_fista_single_image(defocuscam_fista_model, kaist_img_gt)
display(save_reconstruction(kaist_scene03_defocuscam_fista_recon, defocuscam_fista_kaist_savename))

del defocuscam_fista_model

#### DefocusCam with a learned CondUNet

Since our learned model's weights are learned on a square patch of the spectral filter mask, with square psfs,
we need to predict in patches, and blend the results together.

In [15]:
from patch_predict_utils import patchwise_predict_image_learned

defocuscam_learned_model = get_model(defocuscam_learned_config)

defocuscam_learned_harvard_bushes_savename = "harvard_bushes_defocuscam_learned_recon"
defocuscam_learned_icvl_savename = "icvl_color_checker_defocuscam_learned_recon"
defocuscam_learned_fruit_savename = "fruit_artichoke_defocuscam_learned_recon"
defocuscam_learned_kaist_savename = "kaist_scene03_defocuscam_learned_recon"

# Run reconstructions and save
harvard_bushes_defocuscam_learned_recon = patchwise_predict_image_learned(defocuscam_learned_model, harvard_bushes_gt)
display(save_reconstruction(harvard_bushes_defocuscam_learned_recon, defocuscam_learned_harvard_bushes_savename))

icvl_color_checker_defocuscam_learned_recon = patchwise_predict_image_learned(defocuscam_learned_model, icvl_color_checker_gt)
display(save_reconstruction(icvl_color_checker_defocuscam_learned_recon, defocuscam_learned_icvl_savename))

fruit_artichoke_defocuscam_learned_recon = patchwise_predict_image_learned(defocuscam_learned_model, fruit_artichoke_gt)
display(save_reconstruction(fruit_artichoke_defocuscam_learned_recon, defocuscam_learned_fruit_savename))

kaist_scene03_defocuscam_learned_recon = patchwise_predict_image_learned(defocuscam_learned_model, kaist_img_gt)
display(save_reconstruction(kaist_scene03_defocuscam_learned_recon, defocuscam_learned_kaist_savename))

del defocuscam_learned_model

# Spectral Profile Plots

In [87]:
from cleanplots import *
from matplotlib.ticker import FuncFormatter
import cv2
import seaborn as sns

def draw_plot_marker(image_path, point, fc_range, radius=5, savename=""):
    """
    Draw a marker of the y, x coordinate used to generate a pixel and return the image in false
    color
    """
    im_hsi = np.load(image_path)
    im_fc = helper.select_and_average_bands(im_hsi, fc_range)
    im = np.copy(helper.value_norm(im_fc) * 255)
    px, py = int(point[1]), int(point[0])

    cv2.circle(im, (px, py), radius, (255, 255, 255), -1)

    pil_im = Image.fromarray(im.astype(np.uint8))
    if savename:
        with open(savename, "wb") as f:
            pil_im.save(f)
        print(f"Saved marked gt to {savename}")
    return pil_im


def plot_vectors(
    npy_files: list, 
    pointyx: tuple[int, int], 
    fc_range : tuple, # use appropriate visualization range of each image's dataset
    model_names: list = ["Reference", "Defocuscam (Learned)", "Defocuscam (Fista)", "DiffuserCam", "Handshake"],
    color_idcs: list = [3, 0, 7, 9, 1],
    savename: str = "",
    show_legend: bool = True,
    average_over_surrounding: int = 6
) -> plt.Figure:
    """
    Function to generate spectral line plots for each of our reconstructions.
    """
    linestyles = ["-", "--", "-.", ":", (0, (10, 3))]

    # Initialize empty lists to store data
    py, px = pointyx
    data, maxvals = [], []
    colors = sns.husl_palette(n_colors=10, l=0.5)
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=1.2)

    # Load data from npy files and append the specific point to the data list
    for i, file in enumerate(npy_files):
        hsimage = np.load(file)
        maxvals.append(np.max(hsimage))

        b = average_over_surrounding //2
        point = np.mean(helper.value_norm(hsimage)[py-b:py+b+1, px-b:px+b+1,:-1], axis=(0,1))
        data.append(point)

    wavs = np.linspace(fc_range[0], fc_range[1], len(data[0]))

    # Plotting
    fig = plt.figure(dpi=100, figsize=(17,7))
    plt.rcParams['font.family'] = 'Arial'
    for i, d in enumerate(data):
        plt.plot(
            wavs,
            d / (max(maxvals) / maxvals[i]),
            color=colors[color_idcs[i]],
            label=model_names[i],
            linewidth=12,
            linestyle=linestyles[i]
        )

    def format_y_tick(value, pos):
        return '{:.1e}'.format(value)
    plt.gca().yaxis.set_major_formatter(FuncFormatter(format_y_tick))
    plt.gca().spines["top"].set_visible(False)
    plt.gca().spines["right"].set_visible(False)
    plt.gca().spines["left"].set_linewidth(7)
    plt.gca().spines["bottom"].set_linewidth(7)
    plt.gca().spines["left"].set_color("black")
    plt.gca().spines["bottom"].set_color("black")
    plt.xticks(fontsize=40)
    plt.yticks()
    if show_legend:
        plt.legend(fontsize=50, framealpha=1, loc='lower center', bbox_to_anchor=(0.5, 1.02))
    if savename:
        os.makedirs(os.path.dirname(savename), exist_ok=True)
        with open(savename, "wb") as f:
            plt.savefig(f)
        print(f"Saved plot to {savename}")
    return fig

model_names = ["Reference", "Defocuscam (Learned)", "Defocuscam (Fista)", "DiffuserCam", "Handshake"]
color_idcs = [3, 0, 7, 9, 1]

In [None]:
# ICVL color checker
icvl_files = [
    icvl_color_checker_gt_name + ".npy"
    ] + [
    os.path.join(OUTPUTS_DIRPATH, "reconstructions", name + ".npy")
    for name in [
        defocuscam_learned_icvl_savename,
        defocuscam_fista_icvl_savename,
        diffusercam_fista_icvl_savename,
        handshake_fista_icvl_savename
    ]
]
icvl_range = (420, 720)
icvl_point = (219,249)

icvl_plot = plot_vectors(
    icvl_files, 
    icvl_point, 
    icvl_range, 
    savename=os.path.join(OUTPUTS_DIRPATH, "plots", os.path.basename(icvl_color_checker_gt_name) + ".png"),
    average_over_surrounding=0
)
icvl_marked_gt = draw_plot_marker(
    icvl_files[0], 
    icvl_point, 
    icvl_range, 
    savename=icvl_color_checker_gt_name + "_marked.png",
)

display(icvl_marked_gt)
icvl_plot.show()

In [None]:
# Harvard bushes
harvard_bushes_files = [
    harvard_bushes_gt_name + ".npy"
    ] + [
    os.path.join(OUTPUTS_DIRPATH, "reconstructions", name + ".npy")
    for name in [
        defocuscam_learned_harvard_bushes_savename,
        defocuscam_fista_harvard_bushes_savename,
        diffusercam_fista_harvard_bushes_savename,
        handshake_fista_harvard_bushes_savename
    ]
]
harvard_bushes_range = (420, 720)
harvard_bushes_point = (216,522)

harvard_bushes_plot = plot_vectors(
    harvard_bushes_files, 
    harvard_bushes_point, 
    harvard_bushes_range,
    savename=os.path.join(OUTPUTS_DIRPATH, "plots", os.path.basename(harvard_bushes_gt_name) + ".png"),
    average_over_surrounding=0
)
harvard_bushes_marked_gt = draw_plot_marker(
    harvard_bushes_files[0],
    harvard_bushes_point, 
    harvard_bushes_range, 
    savename=harvard_bushes_gt_name + "_marked.png"
)

display(harvard_bushes_marked_gt)
harvard_bushes_plot.show()

In [None]:
# Geissen artichoke
fruit_files = [
    fruit_artichoke_gt_name + ".npy"
    ] + [
    os.path.join(OUTPUTS_DIRPATH, "reconstructions", name + ".npy")
    for name in [
        defocuscam_learned_fruit_savename,
        defocuscam_fista_fruit_savename,
        diffusercam_fista_fruit_savename,
        handshake_fista_fruit_savename
    ]
]
fruit_range = (400, 780)
fruit_point = (116,488)

fruit_plot = plot_vectors(
    fruit_files, 
    fruit_point, 
    fruit_range,
    savename=os.path.join(OUTPUTS_DIRPATH, "plots", os.path.basename(fruit_artichoke_gt_name) + ".png"),
    average_over_surrounding=0
)
fruit_marked_gt = draw_plot_marker(
    fruit_files[0],
    fruit_point, 
    fruit_range, 
    savename=fruit_artichoke_gt_name + "_marked.png"
)

display(fruit_marked_gt)
fruit_plot.show()

In [None]:
# KAIST scene
kaist_files = [
    kaist_img_gt_name + ".npy"
    ] + [
    os.path.join(OUTPUTS_DIRPATH, "reconstructions", name + ".npy")
    for name in [
        defocuscam_learned_kaist_savename,
        defocuscam_fista_kaist_savename,
        diffusercam_fista_kaist_savename,
        handshake_fista_kaist_savename
    ]
]
kaist_range = (420, 720)
kaist_point = (201, 141)


kaist_plot = plot_vectors(
    kaist_files, 
    kaist_point, 
    kaist_range,
    savename=os.path.join(OUTPUTS_DIRPATH, "plots", os.path.basename(kaist_img_gt_name) + ".png"),
    average_over_surrounding=0
)
kaist_marked_gt = draw_plot_marker(
    kaist_files[0],
    kaist_point, 
    kaist_range, 
    savename=kaist_img_gt_name + "_marked.png"
)

display(kaist_marked_gt)
kaist_plot.show()