# Load an image

In [None]:
from openexr.spectralexr import SpectralEXRFile
from radiometry.cmf import CMF
from radiometry.spectrum import Spectrum

import compress.moments as mnt
import compress.util as mntutil 

import os
import matplotlib.pyplot as plt

import numpy as np

In [None]:
cmf_cie_2006_2 = CMF(os.path.join('..', 'data', 'cmf', 'ciexyz06_2deg.csv'))
illu_d65 = Spectrum(os.path.join('..', 'data', 'spectra', 'illuminant_D65.csv'))

In [None]:
image_path = "/home/afichet/Repositories/spectral-compress/build/macbeth.exr"

image_in = SpectralEXRFile(image_path)

rgb_image = cmf_cie_2006_2.get_sRGB_lin_reflective_img(
    illu_d65.data[:,0], illu_d65.data[:,1],
    image_in.reflective_wavelengths_nm, image_in.reflective_image
)

plt.plot()
plt.imshow(rgb_image**(1./2.2))
plt.show()


# Compression pipeline

## Spectral Image to Moment Image

In [None]:
wavelengths = image_in.reflective_wavelengths_nm
spectral_image = image_in.reflective_image

phases = mnt.wavelengths_to_phase(np.array(wavelengths))
basis  = mnt.get_basis_signal_to_moments(phases)
moment_image = spectral_image @ basis

## Compression

### Current (bounded case)

In [None]:
# compressed_moment_image, mins, max = mnt.bounded_forward(basis, moment_image)

# Full expantion:
w, h, n_moments = moment_image.shape
moments = np.real(moment_image.reshape((w * h, n_moments)))
compressed_moments = np.zeros_like(moments)

for i in range(moments.shape[0]):
    compressed_moments[i, :] = mnt.bounded_compress_real_trigonometric_moments(moments[i, :])

normalized_moments, mins, maxs = mntutil.normalize(compressed_moments)
normalized_moments = normalized_moments.reshape((w, h, n_moments))


### Alternative (bounded and unbounded)

In [None]:
phases = mnt.wavelengths_to_phase(np.array(wavelengths))
basis  = mnt.get_basis_signal_to_moments(phases)
moment_image = spectral_image @ basis
# TODO: Handle case where there is only zeros!

# 1. Get max value of the image
max_v = np.max(spectral_image)

# 2. Get a relative max per pixel
w, h, n_moments = moment_image.shape
spectrum = spectral_image.reshape((w * h, len(wavelengths)))
moments = np.real(moment_image.reshape((w * h, n_moments)))
relative_scale = np.zeros((moments.shape[0], 1), dtype=np.uint8)

for i in range(spectrum.shape[0]):
    local_max = np.max(spectrum[i, :])
    r_scale = local_max / max_v
    r_scale_q = np.ceil(r_scale * 255)
    relative_scale[i] = r_scale_q

relative_scale_image = relative_scale.reshape((w, h, ))

# Now we can rescale the AC components by r_squale_q
scaled_moments = np.zeros_like(moments)

for i in range(moments.shape[0]):
    scale = max_v * (relative_scale[i] / 255.)

    scaled_moments[i]     = moments[i]
    scaled_moments[i, 1:] = moments[i, 1:] / scale

# Apply bounded compression
compressed_scaled_moments = np.zeros_like(scaled_moments)

for i in range(scaled_moments.shape[0]):
    compressed_scaled_moments[i, :] = mnt.bounded_compress_real_trigonometric_moments(scaled_moments[i, :])

normalized_scaled_moments, mins_scaled, maxs_scaled = mntutil.normalize(compressed_scaled_moments)
normalized_scaled_moments = normalized_scaled_moments.reshape((w, h, n_moments))

## Decompression

### Current (bounded case)

In [None]:
# I need compressed_moments, mins, maxs and wavelenghts
phases = mnt.wavelengths_to_phase(np.array(wavelengths))
inv_basis = mnt.get_basis_moment_to_signal(phases)

w, h, n_moments = normalized_moments.shape
d_normalized_moments = normalized_moments.reshape((w * h, n_moments))
d_compressed_moments = mntutil.denormalize(d_normalized_moments, mins, maxs)

d_moments = np.zeros_like(d_compressed_moments)

for i in range(d_moments.shape[0]):
    d_moments[i, :] = mnt.bounded_decompress_real_trigonometric_moments(d_compressed_moments[i, :])

d_signals = d_moments @ inv_basis

d_spectral_image = d_signals.reshape((w, h, len(wavelengths)))

In [None]:
org_rgb_image = cmf_cie_2006_2.get_sRGB_lin_reflective_img(
    illu_d65.data[:,0], illu_d65.data[:,1],
    image_in.reflective_wavelengths_nm, image_in.reflective_image
)

d_rgb_image = cmf_cie_2006_2.get_sRGB_lin_reflective_img(
    illu_d65.data[:,0], illu_d65.data[:,1],
    wavelengths, d_spectral_image
)

fig, ax = plt.subplots(1, 2)

ax[0].set_title('Original image')
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].imshow(org_rgb_image**(1./2.2))

ax[1].set_title('Decompressed image')
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].imshow(d_rgb_image**(1./2.2))

plt.show()

### Alternative (bounded and unbounded)

In [None]:
# I need compressed_moments, mins, maxs, global max, relative scale and wavelengths
# normalized_scaled_moments, mins_scaled, maxs_scaled, max_v, relative_scale_image, wavelengths

phases = mnt.wavelengths_to_phase(np.array(wavelengths))
inv_basis = mnt.get_basis_moment_to_signal(phases)

w, h, n_moments = normalized_scaled_moments.shape
d_normalized_scaled_moments = normalized_scaled_moments.reshape((w * h, n_moments))
d_relative_scale            = relative_scale_image.reshape((w * h, ))
d_compressed_scaled_moments = mntutil.denormalize(d_normalized_scaled_moments, mins_scaled, maxs_scaled)

d_moments = np.zeros_like(d_compressed_scaled_moments)

for i in range(d_moments.shape[0]):
    scale = max_v * (d_relative_scale[i] / 255.)
    d_moments[i, :] = mnt.bounded_decompress_real_trigonometric_moments(d_compressed_scaled_moments[i, :])
    d_moments[i, 1:] *= scale

d_signals = d_moments @ inv_basis

d_spectral_image = d_signals.reshape((w, h, len(wavelengths)))

In [None]:
org_rgb_image = cmf_cie_2006_2.get_sRGB_lin_reflective_img(
    illu_d65.data[:,0], illu_d65.data[:,1],
    image_in.reflective_wavelengths_nm, image_in.reflective_image
)

d_rgb_image = cmf_cie_2006_2.get_sRGB_lin_reflective_img(
    illu_d65.data[:,0], illu_d65.data[:,1],
    wavelengths, d_spectral_image
)

fig, ax = plt.subplots(1, 2)

ax[0].set_title('Original image')
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].imshow(org_rgb_image**(1./2.2))

ax[1].set_title('Decompressed image')
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].imshow(d_rgb_image**(1./2.2))

plt.show()

## Comparison of quantization curves