In [None]:
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
%matplotlib widget
import numpy as np

from color_model.base_color_model import BaseColorModel
from util.colorspace import XYZ2RGB, LMS2XYZ, iDKL2LMS, RGB2sRGB, sRGB2RGB
iDKL2RGB = XYZ2RGB @ LMS2XYZ @ iDKL2LMS
RGB2iDKL = np.linalg.inv(iDKL2RGB)

In [None]:
model = BaseColorModel()
model.load("io/color_model/model.pth")

def get_ellipse_mesh(srgb, ecc, resolution):
    wh_idkl = model.compute_ellipses(srgb[None, :], ecc[None, :])

    def ellipse_cov(w, h):
        # We fake an ellipse as an ellipsoid with epsilon depth
        return np.array([
            [1 / w**2,        0,             0],
            [0,        1 / h**2,             0],
            [0,               0, 1 / (1e-5)**2],
        ])

    def _get_ellipsoid_mesh(cov, centre, nu, nv):
        u, v = np.linspace(0, 1, nu), np.linspace(0, np.pi * 2, nv)
        x = np.outer(u, np.cos(v))
        y = np.outer(u, np.sin(v))
        z = np.outer(np.zeros_like(u), np.zeros_like(v))
        xyz = np.stack([x, y, z], 0).reshape(3, -1)

        eigval, eigvec = np.linalg.eigh(cov)
        # Ensure orientation of last eigvec is on same semi-sphere for all eigvecs
        if eigvec[:, 2] @ np.array([1, 0, 0]) < 0:
            eigvec[:, 2] = -eigvec[:, 2]
        # Ensure right-handed ellipse principle axes
        if np.cross(eigvec[:, 0], eigvec[:, 1]) @ eigvec[:, 2] > 0:
            a = 1 / np.sqrt(eigval[0]) * eigvec[:, 0]
            b = 1 / np.sqrt(eigval[1]) * eigvec[:, 1]
        else:
            a = 1 / np.sqrt(eigval[1]) * eigvec[:, 1]
            b = 1 / np.sqrt(eigval[0]) * eigvec[:, 0]
        c = 1 / np.sqrt(eigval[2]) * eigvec[:, 2]
        basis = np.stack([a, b, c], axis=-1)
        ellipse = basis @ xyz + centre[:, None]
        return ellipse.reshape(3, *x.shape)

    cov_idkl = ellipse_cov(wh_idkl[0, 0], wh_idkl[0, 1])
    
    cov_rgb = RGB2iDKL.T @ cov_idkl @ RGB2iDKL
    mesh = _get_ellipsoid_mesh(cov_rgb, sRGB2RGB(srgb), resolution, resolution)
    return mesh

def gen_figure():
    fig = plt.figure()
    ax = plt.axes(projection="3d")
    ax.set_xlabel("Red")
    ax.set_ylabel("Green")
    ax.set_zlabel("Blue")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_zlim(0, 1)
    ax.view_init(30, 60)
        
    return fig, ax

In [None]:
rgb = np.array([0.5, 0.5, 0.5])  # linear sRGB
ecc = np.array([10.])  # in degrees
resolution = 60  # mesh granularity

_, ax = gen_figure()
srgb = RGB2sRGB(rgb)
mesh = get_ellipse_mesh(srgb, ecc, resolution)
ax.plot_surface(*mesh, rstride=10, cstride=10, color=srgb, alpha=0.75)