In [None]:
from openexr.spectralexr import SpectralEXRFile
from radiometry.cmf import CMF
from radiometry.spectrum import Spectrum

import compress.moments as mnt
import compress.util as mntutil 

import os
import matplotlib.pyplot as plt

import numpy as np

# Load an image and misc. assets

In [None]:
image_path = "/home/afichet/Repositories/spectral-compress/build/macbeth.exr"

cmf_cie_2006_2 = CMF(os.path.join('..', 'data', 'cmf', 'ciexyz06_2deg.csv'))
illu_d65 = Spectrum(os.path.join('..', 'data', 'spectra', 'illuminant_D65.csv'))

In [None]:
image_in = SpectralEXRFile(image_path)

wavelengths = image_in.reflective_wavelengths_nm
spectral_image = image_in.reflective_image

In [None]:
rgb_image = cmf_cie_2006_2.get_sRGB_lin_reflective_img(
    illu_d65.data[:,0], illu_d65.data[:,1],
    wavelengths, spectral_image
)

plt.plot()
plt.xticks([])
plt.yticks([])
plt.imshow(rgb_image**(1./2.2))
plt.tight_layout()
plt.show()

# Compression pipeline

## Compression

### Current (bounded case)

In [None]:
phases = mnt.wavelengths_to_phase(np.array(wavelengths))
basis  = mnt.get_basis_signal_to_moments(phases)

compressed_moment_image, mins, maxs = mnt.bounded_forward(basis, spectral_image)

### Alternative (bounded and unbounded)

In [None]:
# TODO: Handle case where there is only zeros!

def bounded_w_upper_forward(basis, spectral_image):
    w, h, n_moments = spectral_image.shape
    spectrum = spectral_image.reshape((w * h, n_moments))
    moments = np.real(spectrum @ basis)

    # 1. Get the max value over all wavelengths & all pixels
    global_max = np.max(spectral_image)

    # 2. Get a scaling for each pixel relative to the global max
    relative_scale = np.zeros((w * h, 1), dtype=np.uint8)

    for i in range(w * h):
        local_max = np.max(spectrum[i])
        rel_scale = local_max / global_max

        # Quantize the relative scale
        relative_scale[i] = np.ceil(rel_scale * 255)
    
    # 3. Now we can rescale AC components
    scaled_moments = np.zeros_like(moments)

    for i in range(w * h):
        scale = global_max * (relative_scale[i] / 255.)
        scaled_moments[i, 0]  = moments[i, 0]
        scaled_moments[i, 1:] = moments[i, 1:] / scale

    # 4. Apply the bounded compression
    compressed_scaled_moments = np.zeros_like(moments)

    for i in range(w * h):
        compressed_scaled_moments[i] = mnt.bounded_compress_real_trigonometric_moments(scaled_moments[i])
    
    # 5. Scaling to make AC components fit in [0..1]
    normalized_scaled_moments, mins, maxs = mntutil.normalize(compressed_scaled_moments)
    
    normalized_scaled_moments = normalized_scaled_moments.reshape((w, h, n_moments))
    relative_scale = relative_scale.reshape((w, h, 1))

    return normalized_scaled_moments, relative_scale, mins, maxs, global_max

In [None]:
phases = mnt.wavelengths_to_phase(np.array(wavelengths))
basis  = mnt.get_basis_signal_to_moments(phases)

(compressed_scaled_moment_image,
 relative_scale,
 mins_scaled, maxs_scaled,
 global_max) = bounded_w_upper_forward(basis, spectral_image)

## Decompression

### Current (bounded case)

In [None]:
# I need compressed_moments, mins, maxs and wavelenghts
phases = mnt.wavelengths_to_phase(np.array(wavelengths))
inv_basis = mnt.get_basis_moment_to_signal(phases)

d_bounded_spectral_image = mnt.bounded_backward(inv_basis, compressed_moment_image, mins, maxs)

In [None]:
org_rgb_image = cmf_cie_2006_2.get_sRGB_lin_reflective_img(
    illu_d65.data[:,0], illu_d65.data[:,1],
    image_in.reflective_wavelengths_nm, image_in.reflective_image
)

d_rgb_image = cmf_cie_2006_2.get_sRGB_lin_reflective_img(
    illu_d65.data[:,0], illu_d65.data[:,1],
    wavelengths, d_bounded_spectral_image
)

fig, ax = plt.subplots(1, 2)

ax[0].set_title('Original image')
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].imshow(org_rgb_image**(1./2.2))

ax[1].set_title('Decompressed image')
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].imshow(d_rgb_image**(1./2.2))

plt.show()

### Alternative (bounded and unbounded)

In [None]:
# I need compressed_moments, mins, maxs, global max, relative scale and wavelengths
# normalized_scaled_moments, mins_scaled, maxs_scaled, max_v, relative_scale_image, wavelengths
def bounded_w_upper_backward(
    inv_basis,
    compressed_moments: np.array,
    relative_scale: np.array,
    mins:np.array, maxs: np.array, global_max: float):

    w, h, n_moments = compressed_moments.shape

    compressed_scaled_moments = compressed_moments.reshape((w * h, n_moments))
    compressed_scaled_moments = mntutil.denormalize(compressed_scaled_moments, mins, maxs)

    relative_scale_l = relative_scale.reshape((w * h, 1))
    moments = np.zeros((w * h, n_moments))

    for i in range(w * h):
        scale = global_max * (relative_scale_l[i] / 255.)
        moments[i] = mnt.bounded_decompress_real_trigonometric_moments(compressed_scaled_moments[i])
        moments[i, 1:] *= scale

    signal = moments @ inv_basis

    spectral_image = signal.reshape((w, h, n_moments))

    return spectral_image

In [None]:
phases = mnt.wavelengths_to_phase(np.array(wavelengths))
inv_basis = mnt.get_basis_moment_to_signal(phases)

d_bounded_w_upper_spectral_image = bounded_w_upper_backward(
    inv_basis,
    compressed_scaled_moment_image, relative_scale,
    mins_scaled, maxs_scaled,
    global_max)

In [None]:
org_rgb_image = cmf_cie_2006_2.get_sRGB_lin_reflective_img(
    illu_d65.data[:,0], illu_d65.data[:,1],
    image_in.reflective_wavelengths_nm, image_in.reflective_image
)

d_rgb_image = cmf_cie_2006_2.get_sRGB_lin_reflective_img(
    illu_d65.data[:,0], illu_d65.data[:,1],
    wavelengths, d_bounded_w_upper_spectral_image
)

fig, ax = plt.subplots(1, 2)

ax[0].set_title('Original image')
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].imshow(org_rgb_image**(1./2.2))

ax[1].set_title('Decompressed image')
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].imshow(d_rgb_image**(1./2.2))

plt.show()

## Comparison of quantization curves

In [None]:
from compress.quantization import rrmse, quantize_dequantize

def bounded_generate_quantization_curve(wavelengths: np.array, ref: np.array, n_bits: int, err_fun=rrmse) -> np.array:
    phases = mnt.wavelengths_to_phase(wavelengths)

    signal_to_moments = mnt.get_basis_signal_to_moments(phases)
    moments_to_signal = np.linalg.inv(signal_to_moments)

    norm_moments, mins, maxs = mnt.bounded_forward(signal_to_moments, ref)

    bits = np.zeros((norm_moments.shape[2],), dtype=np.uint8)

    bits[0] = 16
    bits[1] = n_bits

    # Determine error baseline
    q_norm_moments  = quantize_dequantize(norm_moments, 1, n_bits)
    backward_signal = mnt.bounded_backward(moments_to_signal, q_norm_moments, mins, maxs)

    sz = backward_signal.shape[0] * backward_signal.shape[2]
    err_init = err_fun(wavelengths, ref, backward_signal)

    for i in range(2, norm_moments.shape[2]):
        bits[i] = bits[i - 1]
        
        for b in range(bits[i], 0, -1):
            # Decrease bitrate while keeping the error bellow our threshold
            q_norm_moments = quantize_dequantize(norm_moments, i, b)
            backward_signal = mnt.bounded_backward(moments_to_signal, q_norm_moments, mins, maxs)

            err = err_fun(wavelengths, ref, backward_signal)

            if err >= err_init:
                break
                
            bits[i] = b
    
    return bits

In [None]:
def bounded_w_upper_generate_quantization_curve(wavelengths: np.array, ref: np.array, n_bits: int, err_fun=rrmse) -> np.array:
    w, h, n_moments = ref.shape

    phases = mnt.wavelengths_to_phase(wavelengths)

    signal_to_moments = mnt.get_basis_signal_to_moments(phases)
    moments_to_signal = np.linalg.inv(signal_to_moments)

    norm_moments, relative_scale, mins, maxs, global_max = bounded_w_upper_forward(signal_to_moments, ref)

    bits = np.zeros((n_moments,), dtype=np.uint8)

    bits[0] = 16
    bits[1] = n_bits

    # Determine error baseline
    q_norm_moments  = quantize_dequantize(norm_moments, 1, n_bits)
    backward_signal = bounded_w_upper_backward(moments_to_signal, q_norm_moments, relative_scale, mins, maxs, global_max)

    sz = backward_signal.shape[0] * n_moments
    err_init = err_fun(wavelengths, ref, backward_signal)

    for i in range(2, n_moments):
        bits[i] = bits[i - 1]
        
        for b in range(bits[i], 0, -1):
            # Decrease bitrate while keeping the error bellow our threshold
            q_norm_moments = quantize_dequantize(norm_moments, i, b)
            backward_signal = bounded_w_upper_backward(moments_to_signal, q_norm_moments, relative_scale, mins, maxs, global_max)

            err = err_fun(wavelengths, ref, backward_signal)
            print(err_init, err, b, i)
            if err >= err_init:
                break
                
            bits[i] = b
    
    return bits

In [None]:
q_curve_bounded         =         bounded_generate_quantization_curve(np.array(wavelengths), spectral_image, 12)
q_curve_bounded_w_upper = bounded_w_upper_generate_quantization_curve(np.array(wavelengths), spectral_image, 12)

In [None]:
fig, ax = plt.subplots(1, 1)

a = np.arange(len(wavelengths))
ax.plot(a, q_curve_bounded)
ax.plot(a, q_curve_bounded_w_upper)

plt.show()