In [1]:
import torch
from torch.nn import functional as F
import cv2
import numpy as np
from pathlib import Path
from natsort import natsorted
import matplotlib.pyplot as plt
from typing import Literal

#### Develop RAW images
```bash
dcraw -q 3 -4 -T -o 1 data/door_stack/*.nef
```

#### Linearize rendered images

In [28]:
image_paths = natsorted(Path('data/door_stack/door_stack').glob('*.jpg'))
num_images = len(image_paths)
height, width = cv2.imread(str(image_paths[0]), cv2.IMREAD_UNCHANGED).shape[:2]
channel = 3
selected_index = torch.randperm(height * width)[:200].numpy().tolist()
selected_pixels = [cv2.cvtColor(cv2.imread(str(p), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB).reshape(-1, channel)[selected_index].flatten() for p in image_paths]
selected_pixels = np.stack(selected_pixels, axis=0).transpose()

In [29]:
exposure_times = np.arange(1, num_images + 1)
exposure_times = 2 ** (exposure_times - 1)/2048

In [30]:
z_min = 0.05
z_max = 0.95
def w_uniform(z, *_):
    return torch.where(torch.logical_and(z >= z_min, z <= z_max), 1, 0)
def w_tent(z, *_):
    return torch.where(torch.logical_and(z >= z_min, z <= z_max), torch.min(z, 1 - z), 0)
def w_gaussian(z, *_):
    return torch.where(torch.logical_and(z >= z_min, z <= z_max), torch.exp(-4*(z - 0.5)**2/0.5**2), 0)
def w_photon(z,t):
    return torch.where(torch.logical_and(z >= z_min, z <= z_max), t, 0)

In [67]:

from typing import Callable


w_functions = {
    'uniform': w_uniform,
    'tent': w_tent,
    'gaussian': w_gaussian,
    'photon': w_photon
}


def gsolve(Z, B, l, weighting_function: Callable):
    """
    Z: Tensor of shape (num_pixels, num_images)
    B: Tensor of shape (num_images,)
    l: Lambda, smoothness weight
    weighting_function: Function to compute the weighting for each pixel value
    Returns:
    g: Tensor of shape (256,) - log exposure corresponding to pixel value z
    lE: Tensor of shape (num_pixels,) - log irradiance at each pixel location
    """
    Z = Z.long()  # Ensure pixel values are indices
    n = 256
    num_pixels, num_images = Z.shape
    A_rows = num_pixels * num_images + n + 1
    A_cols = n + num_pixels

    A = torch.zeros((A_rows, A_cols), dtype=torch.float32)
    b = torch.zeros((A_rows), dtype=torch.float32)

    k = 0
    for i in range(num_pixels):
        for j in range(num_images):
            z_ij = Z[i, j]
            wij = weighting_function(z_ij.float()/255, torch.exp(B[j]))
            A[k, z_ij] = wij
            A[k, n + i] = -wij
            b[k] = wij * B[j]
            k += 1

    # Fix the curve by setting its middle value to 0
    A[k, 128] = 1
    k += 1

    # Smoothness equations
    for i in range(0, n - 2):
        w_i = weighting_function(torch.tensor(i / 255), None) if weighting_function != w_photon else 1
        A[k, i] = l * w_i
        A[k, i + 1] = -2 * l * w_i
        A[k, i + 2] = l * w_i
        k += 1

    # Solve the system using SVD
    b = b.view(-1, 1)
    print(f"Solving system of size A: {A.shape}, b: {b.shape}")
    x = torch.linalg.lstsq(A, b, driver = "gelsd").solution
    print(x.shape)
    x = x[:A_cols]

    g = x[:n].squeeze()
    lE = x[n:].squeeze()
    
    return g, lE

In [68]:
method = 'gaussian'
lamb = 0

g, lE = gsolve(torch.tensor(selected_pixels), torch.log(torch.tensor(exposure_times).float()), lamb, w_functions[method])
g

Solving system of size A: torch.Size([9857, 856]), b: torch.Size([9857, 1])
torch.Size([856, 1])


tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.7826e-08,
         3.2187e-05,  9.1158e-06,  3.0875e-05,  4.8608e-05, -2.7180e-05,
        -3.6851e-05,  2.6584e-05,  2.9802e-06, -2.9768e+00, -2.8815e+00,
        -2.8117e+00, -2.7179e+00, -2.6056e+00, -2.5550e+00, -2.5446e+00,
        -2.4420e+00, -2.3435e+00, -2.3265e+00, -2.2960e+00, -2.2154e+00,
        -2.1233e+00, -2.1100e+00, -2.0936e+00, -2.0419e+00, -1.9957e+00,
        -1.9719e+00, -1.9580e+00, -1.9166e+00, -1.9043e+00, -1.8516e+00,
        -1.8055e+00, -1.7841e+00, -1.7667e+00, -1.7449e+00, -1.7311e+00,
        -1.7008e+00, -1.6481e+00, -1.5968e+00, -1.5401e+00, -1.5495e+00,
        -1.5001e+00, -1.4992e+00, -1.4797e+00, -1.4236e+00, -1.4200e+00,
        -1.4026e+00, -1.3569e+00, -1.3647e+00, -1.3413e+00, -1.2937e+00,
        -1.2825e+00, -1.2741e+00, -1.2778e+00, -1.2215e+00, -1.1990e+00,
        -1.2587e+00, -1.2418e+00, -1.1585e+00, -1.1325e+00, -1.1279e+00,
        -1.1128e+00, -1.0888e+00, -1.0727e+00, -1.0