## Introduction

In this lab, we will manually reconstruct an HDR image from an SDR source and a Gain Map.
This version uses **tile-based processing** to handle large images without running out of memory.

We will:
1. Load the SDR HEIC image.
2. Load the Gain Map PNG.
3. Apply the Apple Gain Map formula (in memory-efficient tiles).
4. Save the result as a 16-bit PNG with PQ encoding and correct metadata.

In [None]:
import colour
import io
import numpy as np
import os
import png
import struct
import zlib
import xml.etree.ElementTree as ET
import gc

from PIL import Image
from pillow_heif import register_heif_opener

# Set up data directory - works in mybinder and locally
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath('__file__')), 'data')
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)
    print(f"Created data directory: {DATA_DIR}")
else:
    print(f"Data directory found: {DATA_DIR}")

# List files in data directory
if os.path.exists(DATA_DIR):
    files = os.listdir(DATA_DIR)
    if files:
        print(f"Files in data directory: {files}")
    else:
        print("Data directory is empty. Please upload your HEIC and gain map PNG files.")

## PNG Binary Helpers

Standard Python libraries often struggle to save 16-bit PNGs with specific HDR metadata (the cICP chunk). We will write a helper to manually inject this chunk, which tells the operating system (Windows/macOS) that the image uses the BT.2020 color primaries and the PQ transfer function.

In [None]:
# Helper Functions (PNG & Metadata)
def create_png_chunk(chunk_type: bytes, data: bytes) -> bytes:
    """Creates a valid PNG chunk: [Length][Type][Data][CRC]"""
    chunk = struct.pack(">I", len(data)) + chunk_type + data
    crc = zlib.crc32(chunk_type + data) & 0xffffffff
    chunk += struct.pack(">I", crc)
    return chunk

def save_pq_png(filename, width, height, data_16bit):
    """Saves a 16-bit PNG with a cICP chunk for PQ HDR."""
    writer = png.Writer(width=width, height=height, bitdepth=16, greyscale=False)
    buffer = io.BytesIO()
    writer.write(buffer, data_16bit)
    png_bytes = buffer.getvalue()
    
    # cICP Chunk: Primaries=9 (BT.2020), Transfer=16 (PQ), Matrix=0, Range=1
    cicp_chunk = create_png_chunk(b'cICP', b'\x09\x10\x00\x01')
    
    with open(filename, 'wb') as f:
        f.write(png_bytes[:33])    # Signature and IHDR
        f.write(cicp_chunk)        # cICP Metadata
        f.write(png_bytes[33:])    # IDAT and IEND
    print(f"✓ Saved: {filename}")

print("PNG helper functions loaded successfully.")

## Color Science & Metadata Logic

Here we handle the core logic. We need to:

- **Parse XMP**: Apple stores the "Headroom" (how much brighter the HDR is compared to SDR) in XML metadata inside the image.
- **Adapt Color Space**: Move from the phone's display space (P3) to the standard HDR space (BT.2020).

In [None]:
# Color Science Logic
def adaptation_to_bt2020(image_array, input_colorspace):
    """Convert from input colorspace to BT.2020 using Bradford chromatic adaptation."""
    return colour.RGB_to_RGB(
        image_array,
        input_colourspace=input_colorspace,
        output_colourspace=colour.models.RGB_COLOURSPACE_BT2020,
        chromatic_adaptation_transform="Bradford",
        apply_cctf_decoding=False,
        apply_cctf_encoding=False,
    )

def parse_headroom_from_xmp(pil_image):
    """Extracts the HDRGainMapHeadroom value from XMP metadata."""
    xmp_data = pil_image.info.get("xmp") or pil_image.info.get("XML:com.adobe.xmp")
    if not xmp_data:
        print("Warning: No XMP metadata found. Using default headroom of 1.0")
        return 1.0 # Default

    try:
        namespaces = {
            "rdf": "http://www.w3.org/1999/02/22-rdf-syntax-ns#",
            "HDRGainMap": "http://ns.apple.com/HDRGainMap/1.0/"
        }
        root = ET.fromstring(xmp_data)
        target_tag = root.find(".//HDRGainMap:HDRGainMapHeadroom", namespaces)
        if target_tag is not None and target_tag.text:
            return float(target_tag.text)
    except Exception as e:
        print(f"Warning: Error parsing XMP metadata: {e}")
        pass
    return 1.0

print("Color science functions loaded successfully.")

## The Reconstruction Pipeline

This is the heart of the script. We implement the Apple Gain Map formula:

$$L_{hdr} = L_{sdr} \times (1 + (Headroom - 1) \times GainMap)$$

This formula allows us to recover the HDR brightness values from the standard dynamic range image and the gain map. We apply it to smaller chunks at a time.

In [None]:
def process_tile(sdr_tile, gain_tile, headroom, sdr_colorspace, ref_white, peak_nits):
    """
    Process a single tile through the HDR pipeline.
    """
    # Linearize
    sdr_linear = sdr_colorspace.cctf_decoding(sdr_tile)
    gain_linear = colour.models.oetf_inverse_BT709(gain_tile)
    
    # Apply Gain
    hdr_linear_p3 = sdr_linear * (1.0 + (headroom - 1.0) * gain_linear)
    
    # Convert to BT.2020 & PQ
    hdr_linear_2020 = adaptation_to_bt2020(hdr_linear_p3, sdr_colorspace)
    hdr_abs_nits = hdr_linear_2020 * ref_white
    hdr_pq = colour.models.eotf_inverse_BT2100_PQ(np.clip(hdr_abs_nits, 0, peak_nits))
    
    return hdr_pq

def apply_apple_gain_map_tiled(input_path, gain_map_path, ref_white=203, peak_nits=1000, tile_size=512):
    """
    Apply Apple's gain map formula using tile-based processing to reduce memory usage.
    
    Args:
        input_path: Path to SDR HEIC image
        gain_map_path: Path to gain map PNG
        ref_white: Reference white level in nits (default: 203)
        peak_nits: Peak brightness in nits (default: 1000)
        tile_size: Size of tiles to process (default: 512x512)
    
    Returns:
        HDR image in PQ encoding as uint16 array
    """
    sdr_colorspace = colour.models.RGB_COLOURSPACE_DISPLAY_P3
    register_heif_opener()
    
    print(f"Loading SDR image from: {input_path}")
    # 1. Load Images
    sdr_img = Image.open(input_path)
    width, height = sdr_img.size
    print(f"  SDR image size: {width}x{height}")
    
    print(f"Loading gain map from: {gain_map_path}")
    gain_map = Image.open(gain_map_path)
    print(f"  Gain map size: {gain_map.size}")
    
    # Resize gain map to match SDR
    print(f"Resizing gain map to match SDR image...")
    gain_map_resized = gain_map.resize(sdr_img.size, resample=Image.BICUBIC)
    
    # Get headroom from gain map metadata
    headroom = parse_headroom_from_xmp(gain_map)
    print(f"Headroom detected: {headroom}")
    
    # Calculate number of tiles
    tiles_x = (width + tile_size - 1) // tile_size
    tiles_y = (height + tile_size - 1) // tile_size
    total_tiles = tiles_x * tiles_y
    
    print(f"\nProcessing image in {tiles_x}x{tiles_y} = {total_tiles} tiles of size {tile_size}x{tile_size}")
    print(f"This will reduce peak memory usage.\n")
    
    # Create output array
    output = np.zeros((height, width, 3), dtype=np.uint16)
    
    # Process each tile
    tile_count = 0
    for y in range(0, height, tile_size):
        for x in range(0, width, tile_size):
            tile_count += 1
            
            # Calculate tile boundaries
            x_end = min(x + tile_size, width)
            y_end = min(y + tile_size, height)
            
            if tile_count % 10 == 0 or tile_count == total_tiles:
                print(f"  Processing tile {tile_count}/{total_tiles} ({100*tile_count/total_tiles:.1f}%)...", end='\r')
            
            # Extract tiles
            sdr_tile = np.array(sdr_img.crop((x, y, x_end, y_end))).astype(np.float32) / 255.0
            gain_tile = np.array(gain_map_resized.crop((x, y, x_end, y_end))).astype(np.float32) / 255.0
            
            # Handle grayscale gain maps
            if gain_tile.ndim == 2:
                gain_tile = np.stack([gain_tile]*3, axis=-1)
            elif gain_tile.shape[2] == 1:
                gain_tile = np.concatenate([gain_tile]*3, axis=-1)
            
            # Process tile
            hdr_tile = process_tile(sdr_tile, gain_tile, headroom, sdr_colorspace, ref_white, peak_nits)
            
            # Convert to uint16 and store
            output[y:y_end, x:x_end] = (np.clip(hdr_tile, 0, 1) * 65535).astype(np.uint16)
            
            # Force garbage collection every 20 tiles to free memory
            if tile_count % 20 == 0:
                gc.collect()
    
    print(f"\n  Processing tile {total_tiles}/{total_tiles} (100.0%)... Done!")
    print("HDR reconstruction complete!")
    
    # Clean up
    sdr_img.close()
    gain_map.close()
    gain_map_resized.close()
    gc.collect()
    
    return output

print("Tiled gain map application function loaded successfully.")

## Execute the Pipeline

Now we'll run the complete pipeline with memory-efficient tile processing.

Make sure you have:
1. `112_115.HEIC` - The source SDR image
2. `112_115-urn_com_apple_photo_2020_aux_hdrgainmap.png` - The gain map

Both files should be in the `data/` folder.

In [None]:
# Define file paths
input_image_path = os.path.join(DATA_DIR, "112_115.HEIC")
gain_map_path = os.path.join(DATA_DIR, "112_115-urn_com_apple_photo_2020_aux_hdrgainmap.png")
output_filename = os.path.join(DATA_DIR, "output_112_115_PQ.png")

print("=" * 60)
print("HDR GAIN MAP RECONSTRUCTION PIPELINE")
print("=" * 60)

# Check if files exist
if not os.path.exists(input_image_path):
    print(f"❌ ERROR: SDR image not found at: {input_image_path}")
    print("   Please upload '112_115.HEIC' to the data/ folder.")
elif not os.path.exists(gain_map_path):
    print(f"❌ ERROR: Gain map not found at: {gain_map_path}")
    print("   Please upload '112_115-urn_com_apple_photo_2020_aux_hdrgainmap.png' to the data/ folder.")
else:
    print("✓ Input files found!\n")
    
    try:
        # Run the tiled pipeline (adjust tile_size if still running out of memory)
        # Smaller tile_size = less memory but slower processing
        # 512 is a good balance, but you can try 256 if needed
        hdr_uint16 = apply_apple_gain_map_tiled(
            input_image_path, 
            gain_map_path, 
            ref_white=203, 
            peak_nits=1000,
            tile_size=512  # Reduce to 256 or 128 if still having memory issues
        )
        
        # Prepare for saving
        print(f"\nPreparing output for saving...")
        h, w, c = hdr_uint16.shape
        flat_rows = hdr_uint16.reshape(h, w * c)
        
        # Save the output
        print(f"Saving HDR output to: {output_filename}")
        save_pq_png(output_filename, w, h, flat_rows)
        
        # Clean up memory
        del hdr_uint16
        del flat_rows
        gc.collect()
        
        print("\n" + "=" * 60)
        print("✓ SUCCESS! HDR reconstruction complete.")
        print(f"✓ Output saved to: {output_filename}")
        print("=" * 60)
        
        # Display some statistics
        print(f"\nOutput Statistics:")
        print(f"  Resolution: {w}x{h}")
        print(f"  Bit depth: 16-bit")
        print(f"  Color space: BT.2020")
        print(f"  Transfer function: PQ (Perceptual Quantizer)")
        print(f"  File size: {os.path.getsize(output_filename) / (1024*1024):.2f} MB")
        
    except Exception as e:
        print(f"\n❌ ERROR during processing: {e}")
        import traceback
        traceback.print_exc()

## Download the Output

If you're running this on mybinder.org, you can download the output file by:
1. Looking in the file browser on the left
2. Navigating to the `data/` folder
3. Right-clicking on `output_112_115_PQ.png` and selecting "Download"
