In [None]:
import numpy as np
import matplotlib.pyplot as plt
import compress.moments as moments
import compress.quantization as quantization

import csv, os

# Load dataset

In [None]:
def load_macbeth(path):
    macbeth_wl  = []
    macbeth_refl = []

    with open(path) as csvXYZ:
        csvReader = csv.reader(csvXYZ)
        
        i = 0
        for row in csvReader:
            if i == 0:
                macbeth_wl += [float(r) for r in row]
            else:
                macbeth_refl.append([float(r) for r in row])
            i += 1

    return np.array(macbeth_wl), np.array(macbeth_refl)

In [None]:
mb_wl, mb_rfl = load_macbeth(os.path.join('data', 'spectra', 'macbeth_patches.csv'))
mb_phases = moments.wavelengths_to_phase(mb_wl)
mb_rfl = mb_rfl.reshape((6, 4, mb_rfl.shape[1]))

n_moments = mb_rfl.shape[2]

# Error introduced when quantizing a moment

In [None]:
quantization_curves_n_bits = [8, 9, 10, 11]


bounded_error_progressive = []

for n_bits in quantization_curves_n_bits:
    error = quantization.bounded_get_error_progressive_quantization(mb_phases, mb_rfl, n_bits)
    bounded_error_progressive.append(error)


# unbounded_error_progressive = []

# for n_bits in quantization_curves_n_bits:
#     error = quantization.unbounded_get_error_progressive_quantization(mb_phases, mb_rfl, n_bits)
#     unbounded_error_progressive.append(error)

In [None]:
fig, ax = plt.subplots(1, 1)

ax.set_title('Error after quantizing a moment')

x = np.arange(1, n_moments)

for n_bits, err in zip(quantization_curves_n_bits, bounded_error_progressive):
    ax.bar(x, err[1:], label='{} bits'.format(n_bits))

ax.set_xlabel('Moment index')
ax.set_ylabel('rRMSE')
ax.set_yscale('log')
ax.legend()

fig.tight_layout()

plt.show()

# Generation of quantization curves

In [None]:
bounded_quantization_curves = []

for n_bits in quantization_curves_n_bits:
    curve = quantization.bounded_generate_quantization_curve(mb_phases, mb_rfl, n_bits)
    bounded_quantization_curves.append(curve)

In [None]:
fig, ax = plt.subplots(1, 1)

ax.set_title('Quantization curves')

a = np.arange(1, n_moments)

for curve, n_bits in zip(bounded_quantization_curves, quantization_curves_n_bits):
    ax.plot(a, curve[1:], label=n_bits)

ax.set_xlabel('Moment index')
ax.set_ylabel('Number of bits')
ax.legend()
fig.tight_layout()

plt.show()


# Evaluate Pareto optimility

Error for generated quantization curves

In [None]:
bounded_errors = np.zeros((len(bounded_quantization_curves), 2))

for c, i in zip(bounded_quantization_curves, range(len(bounded_quantization_curves))):
    bounded_errors[i, :] = quantization.bounded_err_for_curve(mb_wl, mb_rfl, c)

Error for random quantization curves

In [None]:
# random curves
n_iter = 300

max_bits = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

w = len(max_bits)

bounded_rand_errors = np.zeros((n_iter * w, 2))

for i in range(n_iter):
    for n_bits, j in zip(max_bits, range(w)):
        curve = np.random.randint(1, high=n_bits, size=mb_phases.shape[0])
        bounded_rand_errors[i * w + j, :] = quantization.bounded_err_for_curve(mb_wl, mb_rfl, curve)

In [None]:
fig, ax = plt.subplots(1, 1)

ax.scatter(bounded_errors[:, 0], bounded_errors[:, 1], s=8)
ax.scatter(bounded_rand_errors[:, 0], bounded_rand_errors[:, 1], s=0.5)

# ax.set_ylim(0, np.max(errors[:, 1]) + 1)
ax.set_xlabel('Total number of bits')
ax.set_ylabel('rRMSE')
ax.set_xlim(np.min(bounded_errors[:, 0]) - 10, np.max(bounded_errors[:, 0]) + 10)
ax.set_yscale('log')

plt.plot()