In [2]:
import torch
import numpy as np
import cv2
import os, sys, glob, copy, json
from scipy.interpolate import interp1d
from scipy.signal import convolve
import seaborn as sns

from matplotlib.ticker import FuncFormatter
import matplotlib.pyplot as plt
import PIL.Image as Image
import pathlib
sys.path.append("/home/cfoley_waller/defocam/SpectralDefocusCam")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"


sys.path.insert(0, "../..")
import utils.helper_functions as helper
import utils.diffuser_utils as diffuser_utils
import dataset.precomp_dataset as ds
from models.get_model import get_model

%load_ext autoreload
%autoreload 2

# Plot resolution comparison

In [None]:
res_target_fista = np.load("/home/cfoley/defocuscamdata/recons/exp_results_figure/usaf_negative_fista_.npy")
res_target_learned = np.load("/home/cfoley/defocuscamdata/recons/exp_results_figure/saved_model_ep60_testloss_0.053416458687380604_usaf_negative.npy")
helper.plot_cube_interactive(res_target_learned)

In [None]:
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.2)
colors = sns.husl_palette(n_colors=10, l=0.5)

vec_fista = helper.value_norm(np.mean(res_target_fista, axis=-1)[177:177+25, 242])
vec_learned = helper.value_norm(np.mean(res_target_learned, axis=-1)[177:177+25, 242])

plt.figure(dpi=100, figsize=(17,7))
plt.plot(vec_fista, color=colors[0], label = "FISTA (5)", linewidth = 12)
plt.plot(vec_learned, color=colors[7], label = "Learned (2)", linewidth = 12)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["left"].set_linewidth(4)
plt.gca().spines["bottom"].set_linewidth(4)
plt.gca().spines["left"].set_color("black")
plt.gca().spines["bottom"].set_color("black")
plt.legend(fontsize=20, loc=(0.5,0.1), framealpha=1)

# Plot Spectral Comparison

In [3]:
def read_gt_csv(file):
    with open(file, 'r') as f:
        lines = f.readlines()
        waves, intensity = [],[]
        for line in lines:
            if line[0] == "[" or line[0] == "#":
                continue
            wave, intens = line.split(";")
            waves.append(float(wave))
            intensity.append(float(intens))
    return np.array(waves), np.array(intensity)

def read_gt_rmn(file, max_wave=935.358):
    with open(file, 'r') as f:
        data = json.load(f)[0]
    min_wave = data['FirstWavenumber']
    intensities = data['Intensities']
    waves = np.linspace(min_wave, max_wave, len(intensities))

    return waves, np.array(intensities)

def mov_avg_intensities(intensities, old_waves, new_waves, index_width):
    """
    Given a set of intensities at wavelengths "oldwaves", resamples these intensities
    at the points "newwaves", taking the average of the nearest "numavg" wavelength intensities
    around the new wavelength sample.
    """
    # Apply moving average filtering to the original intensities
    filter_weights = np.ones(index_width) / index_width
    smoothed_intensities = convolve(intensities, filter_weights, mode='same')
    
    # Interpolate intensities at new wavelengths
    f = interp1d(old_waves, smoothed_intensities, kind='linear', fill_value="extrapolate")
    new_intensities = f(new_waves)

    return new_intensities


def combine_thor_ocean_waves(
    thorfile, 
    oceanfile = None, 
    minwave=390, 
    maxwave=870, 
    channels=30, 
    thor_smooth_idx=400, 
    ocean_smooth_idx=100, 
    combine_thor_ocean = True
):
    """
    Combined function for stitching together ground truth spectra from different calibrations
    """
    min_ocean_wave = 500
    new_waves = np.linspace(minwave, maxwave, channels)

    # resample measurements into desired range
    (thorwaves, thorintens), = read_gt_csv(thorfile),  
    thorintens = mov_avg_intensities(thorintens/np.max(thorintens), thorwaves, new_waves, thor_smooth_idx)
    
    if oceanfile is not None:
        (oceanwaves, oceanintens) =  read_gt_rmn(oceanfile)
        oceanintens = mov_avg_intensities(oceanintens/np.max(oceanintens), oceanwaves, new_waves, ocean_smooth_idx)

    #create new waves by concatenating all intensities below "min_ocean_wave" from thorintens and above from oceanintens
    if oceanfile is None:
        return new_waves, thorintens
    
    min_ocean_idx = np.searchsorted(new_waves, min_ocean_wave)
    new_intensities = np.concatenate((thorintens[:min_ocean_idx], oceanintens[min_ocean_idx:]))

    #average around discontinuity to smooth transition
    new_intensities[min_ocean_idx:min_ocean_idx+2] = np.mean(np.stack((thorintens[min_ocean_idx:min_ocean_idx+2], oceanintens[min_ocean_idx:min_ocean_idx+2])),axis=0)

    return new_waves, new_intensities

def draw_plot_marker(image, point, radius):
    im = np.copy(image)
    px, py = int(point[1]), int(point[0])

    cv2.circle(im, (px, py), radius, (255, 255, 255), -1)
    return im
    
def plot_vectors(vectors: list, model_names: list, spectral_range : tuple, colormaps=[3,0,7], legend=False):
    """"""
    linestyles = ["-", "--", "-.", ":", (0, (10, 3))]

    # Initialize empty lists to store data'
    n = len(model_names)
    data, maxvals = vectors, []
    colors = sns.husl_palette(n_colors=10, l=0.5)
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=1.2)

    # Load data from npy files and append the specific point to the data list

    wavs = np.linspace(spectral_range[0], spectral_range[1], len(data[0]))

    # Plotting
    plt.figure(dpi=100, figsize=(17,7))
    for i, d in enumerate(data):
        plt.plot(wavs, d / np.max(d), color=colors[colormaps[i%n]], label=model_names[i%n], linewidth=12, linestyle=linestyles[i%n])

    def format_y_tick(value, pos):
        return '{:.1e}'.format(value)
    plt.gca().yaxis.set_major_formatter(FuncFormatter(format_y_tick))
    plt.gca().spines["top"].set_visible(False)
    plt.gca().spines["right"].set_visible(False)
    plt.gca().spines["left"].set_linewidth(4)
    plt.gca().spines["bottom"].set_linewidth(4)
    plt.gca().spines["left"].set_color("black")
    plt.gca().spines["bottom"].set_color("black")
    # plt.xticks(fontsize=40)
    # plt.yticks()
    if legend:
        plt.legend(fontsize=50, framealpha=1)
    plt.show()


### Color palette

In [None]:
# Interactive cell
reference = np.load("/home/cfoley/defocuscamdata/recons/exp_results_figure/saved_model_ep60_testloss_0.053416458687380604_color_palette.npy")
white = np.maximum(0.4, np.mean(reference[291-3:291+3, 463-3:463+3, :], axis=(0,1)))
white_balanced = reference / white
helper.plot_cube_interactive(helper.value_norm(white_balanced), fc_range=(390,870), fc_scaling=(1.1,0.8, 0.85))

In [None]:

reference = np.load("/home/cfoley/defocuscamdata/recons/exp_results_figure/color_palette_fista_.npy")
white = np.maximum(0.4, np.mean(reference[291-3:291+3, 463-3:463+3, :], axis=(0,1)))
white_balanced = reference / white
helper.plot_cube_interactive(helper.value_norm(white_balanced), fc_range=(390,870), fc_scaling=(1.1,0.8, 0.85))

In [None]:
import tqdm


vectors, names = [], []
for i in tqdm.tqdm(range(1, 25)):
    plt.figure(figsize = (8, 3))
    if i in [20, 21, 23]:
        continue # thorlabs data missing    
    thor_data = f"/home/cfoley/defocuscamdata/recons/spectrometer_gts_thorlabs_css/color_palette/{i}.csv"
    ocean_data = f"/home/cfoley/defocuscamdata/recons/spectrometer_gts_ocean_optics_hr2000/color_palette/{i}.rmn"
    waves, intensity = combine_thor_ocean_waves(thor_data, ocean_data, 370, 890, 30)

    vectors.append(intensity[:-4])
    names.append(i)

    plt.plot(waves, intensity, label=str(i))
    plt.legend()
    plt.show() 

### fista comparison figure

In [27]:
point_dict = {
    "palette_0": (97, 304),
    "palette_1": (255,205),
    "palette_2": (259,289),
    #"palette_2": (184, 193), # square 8
    "stars_red":(),
    "stars_green":(351, 456),
    "stars_blue":(247, 278),
    "cards_blue": (129, 98),
    "cards_yellow": (277,263),
    "cards_red": (41,463),
    "cards_orange": (138, 304),
    "cards_green": (229, 162),
    "cards_bias": (381, 41),
    "stars_bias": (70, 533),
    "mushroom_knife_green_hackysack":(),
    "mushroom_knife_mushroom_red":(209,493),
    "mushroom_knife_bias":(375,326),
}

In [None]:
# PLOTS
learned_recon = f"/home/cfoley/defocuscamdata/recons/exp_results_figure/saved_model_ep60_testloss_0.053416458687380604_color_palette.npy"
fista_recon = f"/home/cfoley/defocuscamdata/recons/exp_results_figure/color_palette_fista_.npy"
gt_squares = [3,14,8]

for name, point in point_dict.items():
    if "palette" not in name:
        continue

    (py, px), idx = point, int(name[-1])
    fname = f"color_palette/{gt_squares[idx]}"

    thor_data = f"/home/cfoley/defocuscamdata/recons/spectrometer_gts_thorlabs_css/{fname}.csv"
    ocean_data = f"/home/cfoley/defocuscamdata/recons/spectrometer_gts_ocean_optics_hr2000/{fname}.rmn"


    waves, intensity = combine_thor_ocean_waves(thor_data, ocean_data, 370, 890, 30)

    learned_vec = np.mean(np.load(learned_recon)[py-3:py+3, px-3:px+3], axis=(0,1))[:-4]
    fista_vec = np.mean(np.load(fista_recon)[py-3:py+3, px-3:px+3], axis=(0,1))[:-4]
    gt_vec = intensity[:-4]

    vectors = [gt_vec, helper.value_norm(learned_vec), fista_vec]
    names = ["Reference", "Learned (5)", "FISTA (2)"]

    plot_vectors(vectors, names, (390, 870 - 16*4))

In [21]:
#REFERENCE IMAGE
bias_vec = np.load(learned_recon)[400,524]
learned_im = helper.value_norm(helper.select_and_average_bands(np.load(learned_recon) - bias_vec, fc_range=(390,870), scaling=(1.1,0.8,0.8)))*255 # (1.1,0.8,0.65)
fista_im = helper.value_norm(helper.select_and_average_bands(np.load(fista_recon), fc_range=(390,870), scaling=(1.1,0.6,0.65)))*255

for name, point in point_dict.items():
    if "palette" not in name:
        continue
    
    learned_im = draw_plot_marker(learned_im, point, 5)
    fista_im = draw_plot_marker(fista_im, point, 5)

In [None]:
Image.fromarray(learned_im.astype(np.uint8))

In [None]:
Image.fromarray(fista_im.astype(np.uint8))

### Results figure

In [24]:
fname = "stars/green"
recon_name = "origami_stars_colorful"
pybias, pxbias = point_dict["stars_bias"]
py, px = point_dict['stars_green']

thor_data = f"/home/cfoley/defocuscamdata/recons/spectrometer_gts_thorlabs_css/{fname}.csv"
ocean_data = f"/home/cfoley/defocuscamdata/recons/spectrometer_gts_ocean_optics_hr2000/{fname}.rmn"
learned_recon = f"/home/cfoley/defocuscamdata/recons/exp_results_figure/{recon_name}.npy"

waves, intensity = combine_thor_ocean_waves(thor_data, ocean_data, 370, 890, 30)

In [None]:
v_width, b_width = 3, 5
learned_vec = np.mean(np.load(learned_recon)[py-v_width:py+v_width, px-v_width:px+v_width], axis=(0,1))[:-4]
bias = np.mean(np.load(learned_recon)[pybias-b_width:pybias+b_width, pxbias-b_width:pxbias+b_width],axis=(0,1))[:-4]
learned_vec -= bias
gt_vec = intensity[:-4]

vectors = [gt_vec, helper.value_norm(learned_vec)]
names = ["Reference", "Learned (2)"]

plot_vectors(vectors, names, (390, 870 - 16*4))

In [None]:
helper.plot_cube_interactive(palette)

In [None]:
#White balance calibration using experimental color palette
palette = np.load("/home/cfoley/defocuscamdata/recons/exp_results_figure/saved_model_ep60_testloss_0.053416458687380604_color_palette.npy")
bias_vec = np.mean(palette[400-3:400+3,524-3:524+3], axis=(0,1))

plt.figure(figsize=(8,4), dpi=100)
plt.plot(bias_vec, linewidth=6, color = "red")
def format_y_tick(value, pos):
    return '{:.1e}'.format(value)
plt.gca().yaxis.set_major_formatter(FuncFormatter(format_y_tick))
plt.gca().spines["top"].set_visible(False)
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["left"].set_linewidth(4)
plt.gca().spines["bottom"].set_linewidth(4)
plt.gca().spines["left"].set_color("black")
plt.gca().spines["bottom"].set_color("black")
plt.show()

palette_im = (helper.value_norm(helper.select_and_average_bands(palette - bias_vec, fc_range=(390,870), scaling=(1.1, 0.8, 0.8)))*255).astype(np.uint8)
palette_im = Image.fromarray(palette_im)
palette_im

In [108]:
spectral_gt_names = [
    "/home/cfoley/defocuscamdata/recons/spectrometer_gts_thorlabs_css/mushroom_knife/mushroom_red.csv",
    "/home/cfoley/defocuscamdata/recons/spectrometer_gts_thorlabs_css/outside/six_brownwood_sun.csv",
    "/home/cfoley/defocuscamdata/recons/spectrometer_gts_thorlabs_css/outside/six_red_sun.csv",
    "/home/cfoley/defocuscamdata/recons/spectrometer_gts_thorlabs_css/outside/yellowumbrella_shade.csv"
]
reconstruction_files = [
    "/home/cfoley/defocuscamdata/recons/exp_results_figure/saved_model_ep60_testloss_0.053416458687380604_mushroom_knife.npy",
    "/home/cfoley/defocuscamdata/recons/exp_results_figure/saved_model_ep60_testloss_0.053416458687380604_outside_six.npy",
    "/home/cfoley/defocuscamdata/recons/exp_results_figure/saved_model_ep60_testloss_0.053416458687380604_outside_six.npy",
    "/home/cfoley/defocuscamdata/recons/exp_results_figure/saved_model_ep60_testloss_0.053416458687380604_outside_eight2.npy",
]
image_points = [(221, 467), (304, 281), (220, 412), (314, 448)]
fc_scalings = [
    (1.1, 0.7, 0.85),
    (1.1, 0.5, 0.8),
    (1.1, 0.5, 0.8),
    (0.85, 0.95, 1.45),
]

In [None]:
#Get false color images with markers
images = []
for i, _ in enumerate(reconstruction_files):
    ref_fc = (helper.value_norm(helper.select_and_average_bands(np.load(reconstruction_files[i]) - bias_vec, fc_range=(390,870), scaling=fc_scalings[i]))*255).astype(np.uint8)
    ref_fc = draw_plot_marker(ref_fc, image_points[i], 5)
    images.append(Image.fromarray(ref_fc))
    plt.figure(dpi=70)
    plt.imshow(ref_fc)
    plt.show()

In [None]:
images[0]

In [None]:
# Generate plots

for i, _ in enumerate(spectral_gt_names):
    gt_file, (py, px), (by, bx), recon_file = spectral_gt_names[i], image_points[i], bias_points[i], reconstruction_files[i]
    
    _, gt_vec = combine_thor_ocean_waves(gt_file, None, 370, 890, 30)
    recon_vec = np.mean(np.load(recon_file)[py-3:py+3, px-3:px+3], axis=(0,1))

    plot_vectors([gt_vec[:-4], helper.value_norm(recon_vec[:-4] - bias_vec[:-4]*0.6)], ["Reference", "Learned (3)"], (390, 870 - 16*4))

### Intro Figure

In [None]:
ref = np.load("/home/cfoley/defocuscamdata/recons/exp_results_figure/saved_model_ep60_testloss_0.053416458687380604_outside_nine2.npy")
bias = ref[364,323]
helper.plot_cube_interactive(ref-bias)

In [None]:
points = [(80, 53),(227,554)]
names = ["1.", "2."]
vectors = []
for (py, px) in points:
    vectors.append(ref[py, px,:-4])

plot_vectors(vectors, names, (390, 870 - 16*4), colormaps=[0,7], legend=False)

In [None]:
ref_fc = helper.value_norm(helper.select_and_average_bands(ref, fc_range=(390,870), scaling=(1,0.8,0.75)))*255
for (py, px) in points:
    print(py, px)
    ref_fc = draw_plot_marker(ref_fc, (py, px), 5)

Image.fromarray(ref_fc.astype(np.uint8))