In [10]:
import sys
sys.path.append('/home/jaekim/ws/git/myHDR/source')

import os
import h5py
import torch
import numpy as np
from sklearn.cluster import KMeans
from tqdm import tqdm
import shutil
import imageio
import matplotlib.pyplot as plt

from utils import *

## utils

In [11]:
def show_tensor(tensor, title=''):
    tensor = tensor.squeeze().T
    plt.imshow(tensor)
    plt.title(title)
    plt.colorbar()
    plt.show()

In [12]:
def show_hist(tensor, boundary, bin=256):
    np = tensor.squeeze().cpu().numpy().flatten()
    plt.hist(np, bins=bin)
    plt.grid(True)
    
    for b in boundary:
        plt.axvline(x=b, color='red', linestyle='--', label=f'Boundary: {b:.4f}')
    
    plt.show()

## Map generation

In [13]:
def get_map_by_hist_kmean(diff_tensor, k=2):
    diff = diff_tensor.squeeze().cpu().numpy()  # shape: (H, W)
    flat_diff = diff.reshape(-1, 1)

    # K-means clustering
    kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
    labels = kmeans.fit_predict(flat_diff)  # shape: (H*W,)
    centers = kmeans.cluster_centers_.squeeze()  # shape: (k,)
    sorted_centers = np.sort(centers)  # e.g., [low, mid, high]
    boundary = [(sorted_centers[i] + sorted_centers[i + 1]) / 2 for i in range(k - 1)]

    # centers를 오름차순 정렬하여 [low, mid, high] → 0, 1, 2로 매핑
    sorted_indices = np.argsort(centers)           # e.g., [2, 0, 1]
    mask = np.zeros_like(labels)

    for new_label, old_label in enumerate(sorted_indices):
        mask[labels == old_label] = new_label

    mask = mask.reshape(diff_tensor.shape).astype(np.uint8)

    return mask, boundary

In [14]:
def get_sat_map(label, mid_exp):
    label_after_clamp = (label * mid_exp).clamp(0, 1) / mid_exp
    saturated = torch.abs(label - label_after_clamp).mean(dim=0, keepdim=True)
    sat_map = torch.where(saturated > 0, torch.ones_like(label), torch.zeros_like(label))
    
    return sat_map


## Execution

In [16]:
h5_dir = '/home/jaekim/ws/data/Kalantari/HDF/aligned/Training'

h5_paths = [os.path.join(h5_dir, f) for f in os.listdir(h5_dir) if f.endswith('.h5')]

for i, path in enumerate(h5_paths):
    file_name = os.path.basename(path)
    print(f'Processing {i + 1}/{len(h5_paths)}: {file_name}')
    with h5py.File(path, 'r+') as f:
        data1 = f['IN'][3*3:4*3, :, :]  # short after gain adjustment
        data2 = f['IN'][4*3:5*3, :, :]  # mid after gain adjustment
        data3 = f['IN'][5*3:6*3, :, :]  # long after gain adjustment
        label = f['GT'][   :   , :, :]
        exp  = f['EXP'][:]
        
        data1 = torch.from_numpy(data1).float()
        data2 = torch.from_numpy(data2).float()
        data3 = torch.from_numpy(data3).float()
        label = torch.from_numpy(label).float()
        
        # TODO: put map function here (before tonemapping)
        sat_map = get_sat_map(label, exp[1])
        
        
        # TODO: put map function here (after tonemapping)
        #data1 = tonemap(data1, 'mu')
        #data2 = tonemap(data2, 'mu')
        #data3 = tonemap(data3, 'mu')
        #label = tonemap(label, 'mu')
        
        
        # TODO : write map to h5
        write_map_key = 'MAP_sat'
        write_map_data = sat_map
        
        
        if write_map_key in f:
            del f[write_map_key]
        f.create_dataset(write_map_key, data=sat_map)

Processing 1/74: 001.h5
Processing 2/74: 002.h5
Processing 3/74: 003.h5
Processing 4/74: 004.h5
Processing 5/74: 005.h5
Processing 6/74: 006.h5
Processing 7/74: 007.h5
Processing 8/74: 008.h5
Processing 9/74: 009.h5
Processing 10/74: 010.h5
Processing 11/74: 011.h5
Processing 12/74: 012.h5
Processing 13/74: 013.h5
Processing 14/74: 014.h5
Processing 15/74: 015.h5
Processing 16/74: 016.h5
Processing 17/74: 017.h5
Processing 18/74: 018.h5
Processing 19/74: 019.h5
Processing 20/74: 020.h5
Processing 21/74: 021.h5
Processing 22/74: 022.h5
Processing 23/74: 023.h5
Processing 24/74: 024.h5
Processing 25/74: 025.h5
Processing 26/74: 026.h5
Processing 27/74: 027.h5
Processing 28/74: 028.h5
Processing 29/74: 029.h5
Processing 30/74: 030.h5
Processing 31/74: 031.h5
Processing 32/74: 032.h5
Processing 33/74: 033.h5
Processing 34/74: 034.h5
Processing 35/74: 035.h5
Processing 36/74: 036.h5
Processing 37/74: 037.h5
Processing 38/74: 038.h5
Processing 39/74: 039.h5
Processing 40/74: 040.h5
Processin