In [1]:
import torch
from torch.nn import functional as F
import cv2
import numpy as np
from pathlib import Path
from natsort import natsorted
import matplotlib.pyplot as plt
from typing import Literal

#### Develop RAW images
```bash
dcraw -q 3 -4 -T -o 1 data/door_stack/*.nef
```

#### Linearize rendered images

In [3]:
image_paths = natsorted(Path('data/door_stack').glob('*.jpg'))
num_images = len(image_paths)
height, width = cv2.imread(str(image_paths[0]), cv2.IMREAD_UNCHANGED).shape[:2]
channel = 3
selected_index = torch.randperm(height * width)[:50].numpy().tolist()
selected_pixels = [cv2.cvtColor(cv2.imread(str(p), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB).reshape(-1, channel)[selected_index].flatten() for p in image_paths]
selected_pixels = np.stack(selected_pixels, axis=0).flatten()

In [4]:
exposure_times = np.arange(1, num_images + 1)
exposure_times = 2 ** (exposure_times - 1)/2048

In [5]:
z_min = 0.05
z_max = 0.95
def w_uniform(z, *_):
    return torch.where(torch.logical_and(z >= z_min, z <= z_max), 1, 0)
def w_tent(z, *_):
    return torch.where(torch.logical_and(z >= z_min, z <= z_max), torch.min(z, 1 - z), 0)
def w_gaussian(z, *_):
    return torch.where(torch.logical_and(z >= z_min, z <= z_max), torch.exp(-4*(z - 0.5)**2/0.5**2), 0)
def w_photon(z,t):
    return torch.where(torch.logical_and(z >= z_min, z <= z_max), t, 0)

In [6]:

w_functions = {
    'uniform': w_uniform,
    'tent': w_tent,
    'gaussian': w_gaussian,
    'photon': w_photon
}


def construct_Ab(image_tensor_flatten:torch.Tensor, # K*N*C
                 exposure_times:torch.Tensor, # K
                 lamb:float, 
                 w_method=Literal['uniform', 'tent', 'gaussian', 'photon']):
    num_elements = image_tensor_flatten.numel()
    k = exposure_times.numel()
    nxc = num_elements // k
    image_tensor_one_hot = F.one_hot(image_tensor_flatten, num_classes=256) # (K*N*C)x256

    image_indexs = torch.arange(nxc).long() # N*C
    image_indexs = image_indexs.unsqueeze(0).repeat(k, 1).view(-1) # K*N*C
    image_indexs_one_hot = F.one_hot(image_indexs, num_classes=nxc) # (K*N*C)x(N*C)
    w_pixels = w_functions[w_method](image_tensor_flatten.float()/255, exposure_times.unsqueeze(1).repeat(1, nxc).view(-1)) # K*N*C

    A1 = image_tensor_one_hot * w_pixels.unsqueeze(1) # (K*N*C)x256
    A2 = image_indexs_one_hot * -w_pixels.unsqueeze(1) # (K*N*C)x(N*C)
    A12 = torch.cat((A1, A2), dim=1) # (K*N*C)x(256+N*C)

    b = w_pixels * exposure_times.unsqueeze(1).repeat(1, nxc).view(-1) # K*N*C
    b = torch.cat((b, torch.zeros(256)), dim=0) # (K*N*C+256)


    rolling_matrix = torch.zeros((256, 256+2), dtype=torch.float32) # 256x258

    # Pattern to roll
    pattern = torch.tensor([1.0, -2.0, 1.0])

    # Fill the matrix with the rolling pattern
    for i in range(256):
        rolling_matrix[i, i:i+3] = pattern

    rolling_matrix = rolling_matrix[:, 1:-1] # 256x256

    zeros_matrix = torch.zeros((256, nxc), dtype=torch.float32) # 256x(N*C)

    z = torch.arange(256).float() # 256
    w_z = w_functions[w_method](z/255) if w_method != 'photon' else 1 # 256
    w_z = w_z * lamb # 256
    
    A3 = rolling_matrix * w_z.unsqueeze(1) # 256x256
    A4 = zeros_matrix
    A34 = torch.cat((A3, A4), dim=1) # 256x(256+N*C)

    A = torch.cat((A12, A34), dim=0) # A((K*N*C+256)x(256+N*C)) @v(256+N*C) + b(K*N*C+256)

    return A, b

In [12]:
method = 'gaussian'
lamb = 0.1

A, b = construct_Ab(
    torch.tensor(selected_pixels).long(), # K*N*C
    torch.tensor(exposure_times).float(), # K
    lamb, 
    w_method=method
)

In [13]:
v = torch.linalg.lstsq(A, b, rcond=None)[0] # (256+N*C)