# Illuminant Chromaticity Analysis for LSMI Dataset (Nikon D810)

This notebook analyzes the distribution of ground truth illuminants in the LSMI dataset.

In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
import rawpy
from tqdm import tqdm

plt.rcParams["figure.figsize"] = (14, 10)

# Constants
LSMI_ROOT = "Data/LSMI/nikon"
META_FILE = os.path.join(LSMI_ROOT, "meta.json")


In [None]:
if os.path.exists(META_FILE):
    with open(META_FILE, "r") as f:
        meta_data = json.load(f)
    print(f"Loaded meta.json with {len(meta_data)} entries.")
else:
    print("meta.json not found!")
    meta_data = {}

In [None]:
# Full Macbeth Color Checker Chart Coordinates (Source)
FULL_CELLCHART = np.float32([
    # Row 1
    [0.25, 0.25],   [2.75, 0.25],   [2.75, 2.75],   [0.25, 2.75],
    [3.00, 0.25],   [5.50, 0.25],   [5.50, 2.75],   [3.00, 2.75], 
    [5.75, 0.25],   [8.25, 0.25],   [8.25, 2.75],   [5.75, 2.75],
    [8.50, 0.25],   [11.00, 0.25],  [11.00, 2.75],  [8.50, 2.75],
    [11.25, 0.25],  [13.75, 0.25],  [13.75, 2.75],  [11.25, 2.75],
    [14.00, 0.25],  [16.50, 0.25],  [16.50, 2.75],  [14.00, 2.75],
    # Row 2  
    [0.25, 3.00],   [2.75, 3.00],   [2.75, 5.50],   [0.25, 5.50],
    [3.00, 3.00],   [5.50, 3.00],   [5.50, 5.50],   [3.00, 5.50],
    [5.75, 3.00],   [8.25, 3.00],   [8.25, 5.50],   [5.75, 5.50],
    [8.50, 3.00],   [11.00, 3.00],  [11.00, 5.50],  [8.50, 5.50],
    [11.25, 3.00],  [13.75, 3.00],  [13.75, 5.50],  [11.25, 5.50],
    [14.00, 3.00],  [16.50, 3.00],  [16.50, 5.50],  [14.00, 5.50],
    # Row 3
    [0.25, 5.75],   [2.75, 5.75],   [2.75, 8.25],   [0.25, 8.25],
    [3.00, 5.75],   [5.50, 5.75],   [5.50, 8.25],   [3.00, 8.25],
    [5.75, 5.75],   [8.25, 5.75],   [8.25, 8.25],   [5.75, 8.25],
    [8.50, 5.75],   [11.00, 5.75],  [11.00, 8.25],  [8.50, 8.25],
    [11.25, 5.75],  [13.75, 5.75],  [13.75, 8.25],  [11.25, 8.25],
    [14.00, 5.75],  [16.50, 5.75],  [16.50, 8.25],  [14.00, 8.25],
    # Row 4
    [0.25, 8.50],   [2.75, 8.50],   [2.75, 11.00],  [0.25, 11.00],
    [3.00, 8.50],   [5.50, 8.50],   [5.50, 11.00],  [3.00, 11.00],
    [5.75, 8.50],   [8.25, 8.50],   [8.25, 11.00],  [5.75, 11.00],
    [8.50, 8.50],   [11.00, 8.50],  [11.00, 11.00], [8.50, 11.00],
    [11.25, 8.50],  [13.75, 8.50],  [13.75, 11.00], [11.25, 11.00],
    [14.00, 8.50],  [16.50, 8.50],  [16.50, 11.00], [14.00, 11.00]
])
MCCBOX = np.float32([[0.00, 0.00], [16.75, 0.00], [16.75, 11.25], [0.00, 11.25]])

def manual_perspective_transform(points, h):
    points = np.array(points)
    if len(points.shape) != 2:
        points = points.reshape(-1, 2)
    points_homo = np.hstack([points, np.ones((points.shape[0], 1))])
    transformed = points_homo @ h.T
    transformed /= transformed[:, 2:3]
    return transformed[:, :2]

def get_patch_chroma(img, mcc_coord):
    # Check if coords need scaling
    h, w = img.shape[:2]
    # If coords are outside image bounds, assume they are for full resolution and scale down
    # Note: mcc_coord is 4 points. Check if any point is outside.
    if np.any(mcc_coord > np.array([w, h])):
        # print("Scaling MCC coords by 0.5")
        mcc_coord = mcc_coord * 0.5
        
    h_matrix = cv2.getPerspectiveTransform(MCCBOX, mcc_coord)
    if h_matrix is None: return None
    
    gray_patches_indices = [18, 19, 20, 21, 22, 23]
    patch_colors = []
    
    for idx in gray_patches_indices:
        corners_src = FULL_CELLCHART[idx*4 : (idx+1)*4]
        try:
            corners_dst = manual_perspective_transform(corners_src, h_matrix)
            mask = np.zeros(img.shape[:2], dtype=np.uint8)
            cv2.fillConvexPoly(mask, corners_dst.astype(np.int32), 1)
            mean_val = cv2.mean(img, mask=mask)[:3]
            # Only append if mean_val is not black (0,0,0) which implies empty mask
            if np.sum(mean_val) > 0:
                patch_colors.append(mean_val)
        except Exception:
            continue
            
    if not patch_colors: return None
    avg_color = np.mean(patch_colors, axis=0)
    s = np.sum(avg_color)
    if s > 0:
        return avg_color / s
    return None

def extract_illuminants(meta_data, max_samples=200):
    illuminants = []
    places = list(meta_data.keys())
    
    # Sample places
    if len(places) > max_samples:
        import random
        places = random.sample(places, max_samples)
        
    print(f"Extracting illuminants from {len(places)} samples...")
    
    for place in tqdm(places):
        place_path = os.path.join(LSMI_ROOT, place)
        if place not in meta_data: continue
            
        meta = meta_data[place]
        mcc_coords = meta.get("MCCCoord", {})
        
        try:
            # Load Image 1
            raw_path = os.path.join(place_path, f"{place}_1.nef")
            if not os.path.exists(raw_path): continue
            
            with rawpy.imread(raw_path) as raw:
                # Use half_size=True for speed
                img_1 = raw.postprocess(half_size=True, use_camera_wb=False, user_wb=[1,1,1,1], no_auto_bright=True, output_color=rawpy.ColorSpace.raw)
            
            # Extract Light 1
            if "mcc1" in mcc_coords:
                chroma = get_patch_chroma(img_1, np.float32(mcc_coords["mcc1"]))
                if chroma is not None:
                    illuminants.append({'r': chroma[0], 'g': chroma[1], 'b': chroma[2], 'place': place, 'light': 1})
            
        except Exception as e:
            print(f"Error processing {place}: {e}")
            continue
            
    return pd.DataFrame(illuminants)

In [None]:
# Run extraction
df = extract_illuminants(meta_data, max_samples=200)
print(f"Extracted {len(df)} illuminants.")
df.head()

## 1. Chromaticity Distribution Plot

We plot the R vs G chromaticity of the extracted illuminants.

In [None]:
if not df.empty:
    plt.figure(figsize=(10, 10))
    # Plot points colored by their RGB value
    # Normalize RGB for display (scale brightness)
    colors = df[['r', 'g', 'b']].values
    # Simple normalization: max channel = 1
    colors = colors / np.max(colors, axis=1, keepdims=True)
    
    plt.scatter(df['r'], df['g'], c=colors, s=50, edgecolors='k', alpha=0.8)
    plt.xlabel("Red Chromaticity (r)", fontsize=14)
    plt.ylabel("Green Chromaticity (g)", fontsize=14)
    plt.title("LSMI Illuminant Distribution (Sample)", fontsize=16)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    
    # Load cluster centers
    if os.path.exists("cluster_centers.npy"):
        centers = np.load("cluster_centers.npy", allow_pickle=True)
        if centers.shape == (): centers = centers.item()
        
        # Convert dictionary to list if needed
        if isinstance(centers, dict):
            center_points = np.array(list(centers.values()))
            labels = list(centers.keys())
        else:
            center_points = centers
            labels = [f"C{i}" for i in range(len(centers))]
            
        plt.scatter(center_points[:, 0], center_points[:, 1], c='red', marker='x', s=200, linewidths=3, label='Cluster Centers')
        for i, txt in enumerate(labels):
            plt.annotate(txt, (center_points[i, 0], center_points[i, 1]), xytext=(5, 5), textcoords='offset points', fontsize=12, fontweight='bold')
            
    plt.legend()
    plt.show()
else:
    print("No data to plot.")

## 2. Color Map Visualization

Similar to the reference notebook, we create a binned color map.

In [None]:
if not df.empty:
    r = df['r'].values
    g = df['g'].values
    b = df['b'].values
    
    n_bins = 20
    r_bins = np.linspace(r.min(), r.max(), n_bins + 1)
    g_bins = np.linspace(g.min(), g.max(), n_bins + 1)
    
    fig, ax = plt.subplots(figsize=(12, 10))
    
    for i in range(n_bins):
        for j in range(n_bins):
            mask = ((r >= r_bins[i]) & (r < r_bins[i+1]) & 
                    (g >= g_bins[j]) & (g < g_bins[j+1]))
            
            if mask.sum() > 0:
                avg_r = r[mask].mean()
                avg_g = g[mask].mean()
                avg_b = b[mask].mean()
                
                total = avg_r + avg_g + avg_b
                color = [avg_r/total, avg_g/total, avg_b/total]
                
                rect = patches.Rectangle((r_bins[i], g_bins[j]), 
                                         r_bins[i+1] - r_bins[i],
                                         g_bins[j+1] - g_bins[j],
                                         linewidth=0.5, edgecolor="black",
                                         facecolor=color, alpha=0.9)
                ax.add_patch(rect)
                
    ax.set_xlim(r.min(), r.max())
    ax.set_ylim(g.min(), g.max())
    ax.set_xlabel("R", fontsize=14)
    ax.set_ylabel("G", fontsize=14)
    ax.set_title("Illuminant Color Map", fontsize=16)
    plt.show()