In [1]:
import numpy as np, torch, scipy.io as io, glob, os, OpenEXR, matplotlib.pyplot as plt, h5py, PIL.Image as Image, pathlib
import sys
sys.path.insert(0, "../..")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = "cuda:1"

import utils.helper_functions as helper
import utils.diffuser_utils as diffuser_utils
import dataset.preprocess_data as prep_data
import train

%load_ext autoreload
%autoreload 2

%matplotlib inline

SAVE_GT_PATH = "/home/cfoley/defocuscamdata/recons/sim_comparison_figure/model_input_gts"

# DATA

In [2]:
#harvard_bookshelf = "/home/cfoley/defocuscamdata/recons/sim_comparison_figure/sample_data_preprocessed/imgh2_patch_0.mat"
harvard_bushes = "/home/cfoley/defocuscamdata/recons/sim_comparison_figure/sample_data_preprocessed/imgf8_patch_0.mat"
kaist_img = "/home/cfoley/defocuscamdata/sample_data/kaistdata/scene03_reflectance.mat"
fruit_artichoke = "/home/cfoley/defocuscamdata/sample_data/fruitdata/pca/internals_artichoke_SegmentedCroppedCompressed.mat"
icvl_color_checker = "/home/cfoley/defocuscamdata/sample_data/icvldata/IDS_COLORCHECK_1020-1223.mat"

icvl_color_checker = prep_data.project_spectral(np.asarray(h5py.File(icvl_color_checker)['rad']).transpose(1,2,0)[::-1, ::-1], 30)[300:820, 200:820]
kaist_img = prep_data.project_spectral(io.loadmat(kaist_img)['ref'][300:300+420*5, 200:200+620*5], 30)
harvard_bushes = io.loadmat(harvard_bushes)['image']
fruit_artichoke = prep_data.project_spectral(prep_data.read_compressed(io.loadmat(fruit_artichoke)), 30).transpose(1,0,2)

In [3]:
def show_fc(img, fc_range=(420,720)):
    plt.figure(dpi=100)
    rgbimg = helper.select_and_average_bands(img, fc_range=fc_range)
    plt.imshow(rgbimg / np.max(rgbimg))
    plt.show()

In [None]:
show_fc(icvl_color_checker), show_fc(fruit_artichoke, fc_range=(400,780)), show_fc(harvard_bushes), show_fc(kaist_img)

In [5]:
def prep_image(image, crop_shape, patch_shape):
    # 0 1 normalize
    image = np.stack(
        [diffuser_utils.pyramid_down(image[:crop_shape[0],:crop_shape[1],i],patch_shape) for i in range(image.shape[-1])], 0
    )

    image = (image - max(0., np.min(image)))
    image = image / np.max(image)
    image = torch.tensor(image)[None, None,...]
    return image

def save_image_fc_npy(image, savename, fc_range=(420,720)):
    print("Saving: ", savename + ".npy")
    np.save(savename + ".npy", image)

    print("Saving fc: ", savename + ".png")
    fc_img = helper.select_and_average_bands(image, fc_range=fc_range)
    fc_img = Image.fromarray(((fc_img / fc_img.max())*255).astype(np.uint8))
    fc_img.save(savename + ".png")
    return fc_img

In [6]:
# save all the ground truths
harvard_bushes_gt_name = os.path.join(SAVE_GT_PATH, "harvard_bushes_gt")
fruit_artichoke_gt_name = os.path.join(SAVE_GT_PATH, "fruit_artichoke_gt")
icvl_color_checker_gt_name = os.path.join(SAVE_GT_PATH, "icvl_color_checker_gt")
kaist_img_gt_name = os.path.join(SAVE_GT_PATH, "kaist_scene03_gt")

harvard_bushes_gt = prep_image(harvard_bushes, harvard_bushes.shape[:2], (420,620))
fruit_artichoke_gt = prep_image(fruit_artichoke, fruit_artichoke.shape[:2], (420,620))
icvl_color_checker_gt = prep_image(icvl_color_checker, icvl_color_checker.shape[:2], (420,620))
kaist_img_gt = prep_image(kaist_img, kaist_img.shape[:2], (420,620))


In [None]:
save_image_fc_npy(harvard_bushes_gt[0,0].numpy().transpose(1,2,0), harvard_bushes_gt_name)

In [None]:
save_image_fc_npy(fruit_artichoke_gt[0,0].numpy().transpose(1,2,0), fruit_artichoke_gt_name, fc_range=(400,780))

In [None]:
save_image_fc_npy(icvl_color_checker_gt[0,0].numpy().transpose(1,2,0), icvl_color_checker_gt_name)

In [None]:
save_image_fc_npy(kaist_img_gt[0,0].numpy().transpose(1,2,0), kaist_img_gt_name)

# FISTA RECONS

In [None]:
fista_config = "/home/cfoley/SpectralDefocusCam/notebooks/recons_sim_fista/fista_config_static.yml"
config = helper.read_config(fista_config)
model = train.get_model(config, device=device)
fm, rm = model.model1, model.model2

In [None]:
sim = fm(harvard_bushes_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"harvard_bushes_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename)

In [None]:
sim = fm(fruit_artichoke_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"fruit_artichoke_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename, fc_range=(400,780))

In [None]:
sim = fm(icvl_color_checker_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"icvl_colorpalette_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename)

In [None]:
sim = fm(kaist_img_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"kaist_scene03_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename)

# LEARNED RECONS

In [10]:
import overlap_stitch as stitch_utils

In [None]:
#trained_weights_path = "/home/cfoley/defocuscamdata/models/checkpoint_train_02_07_2024_lsi_adjoint_condunet_L1psf_2meas.yml/2024_03_11_14_44_55/saved_model_ep28_testloss_0.04970918206568315.pt"
trained_weights_path = "/home/cfoley/defocuscamdata/models/checkpoint_train_02_07_2024_lsi_adjoint_condunet_firstlastonly_L1psf_3meas.yml/2024_03_20_00_38_33/saved_model_ep68_testloss_0.039786975884608014.pt"
learned_config = os.path.join(pathlib.Path(trained_weights_path).parent, "training_config.yml")

config = helper.read_config(learned_config)
config["save_recon_path"] = "/home/cfoley/defocuscamdata/recons/sim_comparison_figure/learned_recons"
config['data_precomputed'] = False
config['forward_model_params']['operations']['fwd_mask_noise'] = False
model = train.get_model(config, device=device)
model.eval()
fm, rm = model.model1, model.model2

print(fm.passthrough)
print(fm.operations)

In [11]:
def prep_image_learned(image, crop_shape, patch_shape):
    image = np.stack(
        [diffuser_utils.pyramid_down(image[:crop_shape[0],:crop_shape[1],i],patch_shape) for i in range(image.shape[-1])], 0
    )

    image = (image - max(0., np.min(image)))
    image = image / np.max(image)
    image = torch.tensor(image)[None, None,...]
    return image

def patchwise_predict_image_learned(image : torch.Tensor, model):
    patchy, patchx  = model.model1.psfs.shape[-2:] 
    patch_centers = stitch_utils.get_overlapping_positions(
        (image.shape[-2]//2, image.shape[-1]//2), 
        image.shape[-2:],
        (patchy, patchx),
        min_overlap=64 # The higher this is, the less edge artifacts may show up
    )


    prediction = np.zeros(image.squeeze().shape)
    contributions_mask = np.zeros(image.shape[-2:])
    for i, (ceny, cenx) in enumerate(patch_centers):
        reg = [ceny - patchy//2, ceny + patchy//2, cenx - patchx//2, cenx + patchx//2]
        patch_gt = image[..., reg[0]:reg[1], reg[2]:reg[3]]
        sim = model.model1(patch_gt.to(device))
        pred = model.model2((sim- sim.mean()) / sim.std()).detach().cpu().numpy()
        pred = pred*patch_gt.std().numpy() + patch_gt.mean().numpy()

        # ------------ REMOVE NON IMAGE-BORDERING PATCH EDGE ARTIFACTS ----------- #
        crop_width = pred.shape[-1]//10 # assuming patch is square

        # Crop patch edges that are not bording an image edge
        bordering_top = (ceny - patchy // 2 == 0)
        bordering_bottom = (ceny + patchy // 2 == image.shape[-2])
        bordering_right = (cenx + patchx // 2 == image.shape[-1])
        bordering_left = (cenx - patchx // 2 == 0)
        if not bordering_top:
            pred, reg[0] = pred[..., crop_width:, :], reg[0] + crop_width
        if not bordering_bottom:
            pred, reg[1] = pred[..., :-crop_width, :], reg[1] - crop_width
        if not bordering_left:
            pred, reg[2] = pred[..., :, crop_width:], reg[2] + crop_width
        if not bordering_right:
            pred, reg[3] = pred[..., :, :-crop_width], reg[3] - crop_width


        # Insert the cropped patch into the prediction array
        prediction[..., reg[0]:reg[1], reg[2]:reg[3]] += pred.squeeze()
        contributions_mask[reg[0]:reg[1], reg[2]:reg[3]] += 1
    prediction = prediction / contributions_mask
    return np.maximum(0, prediction).transpose(1,2,0)

In [12]:
recon = patchwise_predict_image_learned(harvard_bushes_gt, model)

In [None]:
savename = os.path.join(config["save_recon_path"], f"harvard_bushes_learned_recon_learned_condunet_L1psf_L1mask_3meas")
save_image_fc_npy(recon, savename)

In [14]:
recon = patchwise_predict_image_learned(fruit_artichoke_gt, model)

In [None]:
savename = os.path.join(config["save_recon_path"], f"fruit_artichoke_learned_recon_learned_condunet_L1psf_L1mask_3meas")
save_image_fc_npy(recon, savename, fc_range=(400,780))

In [16]:
recon = patchwise_predict_image_learned(icvl_color_checker_gt, model)

In [None]:
savename = os.path.join(config["save_recon_path"], f"icvl_color_checker_learned_recon_learned_condunet_L1psf_L1mask_3meas")
save_image_fc_npy(recon, savename)

In [18]:
recon = patchwise_predict_image_learned(kaist_img_gt, model)

In [None]:
savename = os.path.join(config["save_recon_path"], f"kaist_scene03_learned_recon_learned_condunet_L1psf_L1mask_3meas")
save_image_fc_npy(recon, savename)

# HANDSHAKE RECONS

In [None]:
handshake_config = "/home/cfoley/SpectralDefocusCam/notebooks/figure_generation/handshake_random_config.yml"
config = helper.read_config(handshake_config)
model = train.get_model(config, device=device)
fm, rm = model.model1, model.model2

rm.L /= 10
rm.iters = 151
rm.print_every = 50

In [None]:
sim = fm(harvard_bushes_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"harvard_bushes_handshake_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename)

In [None]:
sim = fm(fruit_artichoke_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"fruit_artichoke_handshake_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename, fc_range=(400,780))

In [None]:
sim = fm(icvl_color_checker_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"icvl_colorpalette_handshake_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename)

In [None]:
sim = fm(kaist_img_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"kaist_scene03_handshake_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename)

# DIFFUSER RECONS

In [None]:
handshake_config = "/home/cfoley/SpectralDefocusCam/notebooks/figure_generation/diffuser_config.yml"
config = helper.read_config(handshake_config)
model = train.get_model(config, device=device)
fm, rm = model.model1, model.model2

rm.L /= 5
rm.print_every = 80
rm.plot = True
rm.iters = 250

In [None]:
sim = fm(harvard_bushes_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"harvard_bushes_diffuser_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename)

In [None]:
sim = fm(fruit_artichoke_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"fruit_artichoke_diffuser_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename, fc_range=(400,780))

In [None]:
sim = fm(icvl_color_checker_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"icvl_colorpalette_diffuser_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename)

In [None]:
sim = fm(kaist_img_gt.to(device))
rm(sim.squeeze(dim=(0,2)))
recon = rm.out_img

In [None]:
savename = os.path.join(config["save_recon_path"], f"kaist_scene03_diffuser_fista_recon_{rm.psfs.shape[0]}_{rm.tv_lambda}_{rm.tv_lambdaw}_{rm.tv_lambdax}")
save_image_fc_npy(recon, savename)