In [None]:
import torch
import numpy as np
import os, sys, glob, copy, json
from scipy.interpolate import interp1d
from scipy.signal import convolve

import matplotlib.pyplot as plt
import PIL.Image as Image
import pathlib
sys.path.insert(0, "/home/cfoley_waller/defocam/SpectralDefocusCam")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"


sys.path.insert(0, "../..")
import train
import utils.helper_functions as helper
import utils.diffuser_utils as diffuser_utils
import dataset.precomp_dataset as ds

%load_ext autoreload
%autoreload 2

%matplotlib inline

# Load model

In [3]:
def load_model(trained_weights_path, device):
    config_path = os.path.join(pathlib.Path(trained_weights_path).parent, "training_config.yml")

    config = helper.read_config(config_path)
    config["device"] = device
    config["forward_model_params"]["operations"]['adj_mask_noise'] = False
    config["forward_model_params"]["operations"]['fwd_mask_noise'] = False
    config["data_precomputed"] = False
    config["preload_weights"] = True
    config["checkpoint_dir"] = trained_weights_path

    device = torch.device(config['device'])
    model = train.get_model(config=config, device=device)   
    model.eval()

    print(f"Model using: {device}")
    return model, config

In [None]:
device='cuda:1'
trained_weights_path = "/home/cfoley/defocuscamdata/models/checkpoint_results_learned_largecrop_firstlast_3_config.yml/2024_03_21_21_45_46/saved_model_ep60_testloss_0.053416458687380604.pt"
model_stack_depth = 3
model, config = load_model(trained_weights_path, device)

# Expermental prediction indoors

In [66]:
def preprocess_exp_meas(meas, config):
    # read
    center = config["image_center"]
    dimy, dimx = config["patch_crop"]
    crop = lambda x: x[center[0] -dimy//2:center[0]+dimy//2, center[1]-dimx//2: center[1]+dimx//2]
    meas =  crop(np.array(Image.open(meas), dtype=float))

    # downsample
    meas = diffuser_utils.pyramid_down(meas, config["patch_size"])

    return meas

def get_plot_meas(exp_meas_path, config, stack_depth = model_stack_depth):
    exp_meas = [preprocess_exp_meas(m, config) for m in sorted(glob.glob(os.path.join(exp_meas_path, "*.bmp")))]#[0::4]

    # sample defocus
    if stack_depth == 2:
        exp_meas = exp_meas[0::4]
    elif stack_depth == 3:
        exp_meas = exp_meas[0::2]
    elif stack_depth == 5:
        exp_meas = exp_meas
    else:
        raise NotImplementedError(f"Only 2, 3, and 5 meas models supported: {len(exp_meas)}")
    
    fig, ax = plt.subplots(1, len(exp_meas), figsize = (4*len(exp_meas), 4))
    fig.set_dpi(180)
    for i, meas in enumerate(exp_meas):
        ax[i].imshow(exp_meas[i], cmap='gray')
        ax[i].set_title(f"Focus level: {i}")
        ax[i].axis('off')
    plt.show()

    return exp_meas

def predict(exp_meas, recon_model):
    norm = ds.Normalize(0,1)
    exp_meas_stack = norm(torch.tensor(np.stack(exp_meas)).to(device)[None, :, None, ...])
    print(f"mean {exp_meas_stack.mean()} - std {exp_meas_stack.std()} - shape {exp_meas_stack.shape}")
    
    pred = recon_model(exp_meas_stack)[0].detach().cpu().numpy().transpose(1,2,0)
    return helper.value_norm(pred)

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/03_07/exp_meas/usaf_negative"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [None]:
scaling = (1.1, 0.9, 0.7)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/03_07/exp_meas/color_palette"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [None]:
scaling = (1.1, 0.8, 0.65)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/03_07/exp_meas/mushroom_knife"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [121]:
white = np.maximum(0.4, np.mean(p[291-3:291+3, 463-3:463+3, :], axis=(0,1)))
white_balanced = p / white

In [None]:
scaling = (1.1, 0.8, 0.85)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(white_balanced, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), white_balanced)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/11_21/exp_meas/duckincar"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [58]:
p = np.clip(p, 0, np.quantile(p, 0.99))

In [None]:
scaling = (1.15, 0.9, 0.8)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/03_07/exp_meas/rubberband_cards"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [None]:
scaling = (1.2, 0.9, 0.75)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/03_07/exp_meas/origami_stars_colorful"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [63]:
clipped = np.clip(p, 0, np.quantile(p, 0.9999))

In [None]:
scaling = (1.2, 0.9, 0.7)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(clipped, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

# Outside scenes success cases

In [None]:
trained_weights_path_2meas = "/home/cfoley/defocuscamdata/models/checkpoint_results_learned_largecrop_config.yml/2024_03_16_21_09_51/saved_model_ep49_testloss_0.05782177185882693.pt"
model_2meas, config_2meas = load_model(trained_weights_path_2meas, device)

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/3_19/outside_one"
p = predict(get_plot_meas(exp_meas_path, config_2meas, stack_depth=5), model_2meas.model2)

In [None]:
scaling = (1.3,0.85,1)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path_2meas)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path_2meas)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/3_19/outside_six"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [None]:
scaling = (1.1, 0.75, 0.75)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/3_19/outside_eight2"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [None]:
scaling = (1, 0.9, 0.95)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/3_19/outside_nine2"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [None]:
scaling = (1, 0.8, 0.75)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

# Failure cases - unstable camera position or movement in the scene

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/3_19/outside_ten"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [None]:
scaling = (1, 0.8, 0.75)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/3_19/outside_seven"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [None]:
scaling = (1, 0.8, 0.75)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/3_19/outside_eleven"
p = predict(get_plot_meas(exp_meas_path, config), model.model2)

In [105]:
bias = p[371, 231, :]

In [None]:
scaling = (1, 0.75, 0.9)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p - bias, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/3_19/outside_three"
p = predict(get_plot_meas(exp_meas_path, config_2meas), model_2meas.model2)

In [None]:
scaling = (1, 0.7, 0.7)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im

In [None]:
exp_meas_path = "/home/cfoley/defocuscamdata/calibration_data/DMM_37UX178_ML_calib_data/3_19/outside_two"
p = predict(get_plot_meas(exp_meas_path, config_2meas), model_2meas.model2)

In [None]:
scaling = (1, 0.7, 0.7)
im = Image.fromarray((helper.value_norm(helper.select_and_average_bands(p, fc_range=(390,870), scaling=scaling))*255).astype(np.uint8))

np.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path)])), p)
im.save(os.path.join("/home/cfoley/defocuscamdata/recons/exp_results_figure/", "_".join([os.path.basename(trained_weights_path)[:-3], os.path.basename(exp_meas_path), f"largecrop_scaling-{scaling}" + ".png"])))

im