In [4]:
!pip install rawpy

Collecting rawpy
  Downloading rawpy-0.24.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.2 kB)
Downloading rawpy-0.24.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m603.1 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: rawpy
Successfully installed rawpy-0.24.0


In [5]:
import numpy as np
import cv2
from scipy import ndimage
import rawpy
import imageio
import os
import glob
from scipy.signal import convolve2d

class HDRPlus:
    def __init__(self, tile_size=16, search_region=4, pyramid_levels=3, num_frames=8):
        """
        Initialize the HDR+ pipeline
        
        Parameters:
        -----------
        tile_size: int
            Size of tiles for alignment and merging (default: 16)
        search_region: int
            Size of the search region around each tile (default: 4)
        pyramid_levels: int
            Number of levels in the Gaussian pyramid for alignment (default: 3)
        num_frames: int
            Number of frames to process in a burst (default: 8)
        """
        self.tile_size = tile_size
        self.search_region = search_region
        self.pyramid_levels = pyramid_levels
        self.num_frames = num_frames

    def load_raw_burst(self, file_pattern):
        """
        Load a burst of raw images
        
        Parameters:
        -----------
        file_pattern: str
            Pattern to glob for raw files
            
        Returns:
        --------
        list of raw images
        """
        files = sorted(glob.glob(file_pattern))[:self.num_frames]
        if len(files) == 0:
            raise ValueError(f"No files found with pattern: {file_pattern}")
        
        raw_images = []
        for file in files:
            with rawpy.imread(file) as raw:
                # Get raw data without processing
                raw_data = raw.raw_image.copy()
                # Store raw object metadata
                bayer_pattern = raw.raw_pattern
                black_level = raw.black_level_per_channel[0]  # Assuming same black level for all channels
                white_level = raw.white_level
                
                raw_images.append({
                    'data': raw_data,
                    'bayer_pattern': bayer_pattern,
                    'black_level': black_level,
                    'white_level': white_level,
                    'camera_white_balance': raw.camera_white_balance,
                    'color_matrix': raw.color_matrix
                })
        
        print(f"Loaded {len(raw_images)} raw frames")
        return raw_images

    def preprocess_raw(self, raw_image):
        """
        Preprocess raw image for alignment
        
        Parameters:
        -----------
        raw_image: dict
            Raw image data and metadata
            
        Returns:
        --------
        Preprocessed image for alignment
        """
        # Subtract black level
        data = raw_image['data'].astype(np.float32)
        data -= raw_image['black_level']
        
        # Simple demosaic for alignment purpose only
        # This is a simplified version - real implementation would be more sophisticated
        h, w = data.shape
        processed = np.zeros((h, w, 3), dtype=np.float32)
        
        # Simple bilinear demosaicking based on Bayer pattern
        # This is oversimplified and would need to be replaced with proper demosaicking
        # for the actual implementation
        
        # Normalize
        processed /= raw_image['white_level'] - raw_image['black_level']
        
        # Convert to grayscale for alignment
        gray = np.mean(processed, axis=2)
        
        return gray

    def build_pyramid(self, image):
        """
        Build a Gaussian pyramid for alignment
        
        Parameters:
        -----------
        image: ndarray
            Input image
            
        Returns:
        --------
        List of pyramid levels
        """
        pyramid = [image]
        current = image
        
        for _ in range(self.pyramid_levels - 1):
            # Downsample by factor of 4 as described in the document
            current = cv2.resize(current, (current.shape[1] // 4, current.shape[0] // 4))
            pyramid.append(current)
        
        # Reverse to start with coarsest level
        return pyramid[::-1]

    def align_frames(self, frames):
        """
        Align frames hierarchically using a Gaussian pyramid
        
        Parameters:
        -----------
        frames: list
            List of preprocessed frames
            
        Returns:
        --------
        List of alignment vectors for each tile in the reference frame
        """
        # Preprocess frames for alignment
        preprocessed_frames = [self.preprocess_raw(frame) for frame in frames]
        
        # Choose the middle frame as reference
        reference_idx = len(frames) // 2
        reference = preprocessed_frames[reference_idx]
        
        # Build pyramids for all frames
        pyramids = [self.build_pyramid(frame) for frame in preprocessed_frames]
        reference_pyramid = pyramids[reference_idx]
        
        # Initialize alignment vectors (will be refined at each level)
        h, w = frames[0]['data'].shape
        tiles_y = h // self.tile_size
        tiles_x = w // self.tile_size
        
        # Initialize alignment vectors as zeros
        alignments = [np.zeros((tiles_y, tiles_x, 2), dtype=np.int32) for _ in range(len(frames))]
        
        # Process pyramid levels from coarse to fine
        for level in range(len(reference_pyramid)):
            level_ref = reference_pyramid[level]
            
            # Current scale factor
            scale = 4 ** (self.pyramid_levels - level - 1)
            
            # Process each frame
            for f, frame_pyramid in enumerate(pyramids):
                if f == reference_idx:
                    continue  # Skip reference frame
                
                level_frame = frame_pyramid[level]
                
                # Adjust tile size and search region for this pyramid level
                level_tile_size = max(self.tile_size // scale, 4)
                level_search = max(self.search_region // scale, 2)
                
                # Calculate alignments for this level
                h_level, w_level = level_ref.shape
                tiles_y_level = h_level // level_tile_size
                tiles_x_level = w_level // level_tile_size
                
                for ty in range(tiles_y_level):
                    for tx in range(tiles_x_level):
                        # Get reference tile
                        y_start = ty * level_tile_size
                        x_start = tx * level_tile_size
                        ref_tile = level_ref[y_start:y_start + level_tile_size, 
                                             x_start:x_start + level_tile_size]
                        
                        # Get previous alignment (if not at coarsest level)
                        if level > 0:
                            prev_ty = ty // 4
                            prev_tx = tx // 4
                            offset_y, offset_x = alignments[f][prev_ty, prev_tx] * 4
                        else:
                            offset_y, offset_x = 0, 0
                        
                        # Search around previous alignment
                        best_dist = float('inf')
                        best_dy, best_dx = 0, 0
                        
                        for dy in range(-level_search, level_search + 1):
                            for dx in range(-level_search, level_search + 1):
                                # Candidate position to check
                                cand_y = y_start + offset_y + dy
                                cand_x = x_start + offset_x + dx
                                
                                # Skip if out of bounds
                                if (cand_y < 0 or cand_x < 0 or 
                                    cand_y + level_tile_size > h_level or 
                                    cand_x + level_tile_size > w_level):
                                    continue
                                
                                # Get candidate tile
                                cand_tile = level_frame[cand_y:cand_y + level_tile_size, 
                                                       cand_x:cand_x + level_tile_size]
                                
                                # Compute L1 distance as described in the document
                                dist = np.sum(np.abs(ref_tile - cand_tile))
                                
                                if dist < best_dist:
                                    best_dist = dist
                                    best_dy = offset_y + dy
                                    best_dx = offset_x + dx
                        
                        # Store alignment for this level
                        if level == len(reference_pyramid) - 1:  # Finest level
                            alignments[f][ty, tx] = [best_dy, best_dx]
                        else:
                            # Scale up for next level
                            scaled_ty = min(ty * 4, tiles_y - 1)
                            scaled_tx = min(tx * 4, tiles_x - 1)
                            alignments[f][scaled_ty, scaled_tx] = [best_dy, best_dx]
        
        return alignments, reference_idx

    def merge_frames(self, raw_frames, alignments, reference_idx):
        """
        Merge aligned frames using a patch-based approach
        
        Parameters:
        -----------
        raw_frames: list
            List of raw frames
        alignments: list
            List of alignment vectors for each tile in each frame
        reference_idx: int
            Index of the reference frame
            
        Returns:
        --------
        Merged raw frame
        """
        reference = raw_frames[reference_idx]['data']
        h, w = reference.shape
        tiles_y = h // self.tile_size
        tiles_x = w // self.tile_size
        
        # Create output merged frame
        merged = np.zeros_like(reference, dtype=np.float32)
        weights = np.zeros((h, w), dtype=np.float32)
        
        # Create raised cosine window for blending tiles
        window = self.create_raised_cosine_window(self.tile_size)
        
        # Half overlap of tiles
        half_tile = self.tile_size // 2
        
        # Process each tile with half-tile overlap
        for ty in range(0, tiles_y * 2 - 1):
            for tx in range(0, tiles_x * 2 - 1):
                # Tile coordinates with half-tile steps
                y_start = ty * half_tile
                x_start = tx * half_tile
                
                # Ensure we don't go out of bounds
                if y_start + self.tile_size > h or x_start + self.tile_size > w:
                    continue
                
                # Get reference tile
                ref_tile = reference[y_start:y_start + self.tile_size, 
                                    x_start:x_start + self.tile_size]
                
                # Collect aligned tiles from all frames
                aligned_tiles = []
                tile_weights = []
                
                for f, frame in enumerate(raw_frames):
                    if f == reference_idx:
                        aligned_tiles.append(ref_tile)
                        tile_weights.append(1.0)  # Reference frame gets full weight
                        continue
                    
                    # Get alignment for the nearest full tile
                    nearest_ty = min(ty // 2, tiles_y - 1)
                    nearest_tx = min(tx // 2, tiles_x - 1)
                    dy, dx = alignments[f][nearest_ty, nearest_tx]
                    
                    # Extract aligned tile
                    aligned_y = y_start + dy
                    aligned_x = x_start + dx
                    
                    # Skip if out of bounds
                    if (aligned_y < 0 or aligned_x < 0 or 
                        aligned_y + self.tile_size > h or 
                        aligned_x + self.tile_size > w):
                        continue
                    
                    aligned_tile = frame['data'][aligned_y:aligned_y + self.tile_size, 
                                                aligned_x:aligned_x + self.tile_size]
                    
                    # Calculate similarity to reference tile (simplified patch similarity)
                    diff = np.abs(ref_tile.astype(np.float32) - aligned_tile.astype(np.float32))
                    similarity = np.exp(-np.mean(diff) / 30.0)  # Adjust this parameter as needed
                    
                    aligned_tiles.append(aligned_tile)
                    tile_weights.append(similarity)
                
                # Normalize weights
                tile_weights = np.array(tile_weights)
                tile_weights /= np.sum(tile_weights)
                
                # Compute weighted average
                merged_tile = np.zeros_like(ref_tile, dtype=np.float32)
                for i, tile in enumerate(aligned_tiles):
                    merged_tile += tile * tile_weights[i]
                
                # Apply window and add to output
                merged[y_start:y_start + self.tile_size, 
                       x_start:x_start + self.tile_size] += merged_tile * window
                weights[y_start:y_start + self.tile_size, 
                        x_start:x_start + self.tile_size] += window
        
        # Normalize by weights
        merged /= np.maximum(weights, 1e-6)
        
        # Create merged raw frame with metadata from reference
        merged_frame = raw_frames[reference_idx].copy()
        merged_frame['data'] = merged
        
        return merged_frame

    def create_raised_cosine_window(self, size):
        """
        Create a 2D raised cosine window for tile blending
        
        Parameters:
        -----------
        size: int
            Size of the window
            
        Returns:
        --------
        2D raised cosine window
        """
        x = np.linspace(-1, 1, size)
        X, Y = np.meshgrid(x, x)
        R = np.sqrt(X**2 + Y**2)
        R = np.clip(R, 0, 1)
        
        # Raised cosine function: 0.5 * (1 + cos(π*r))
        window = 0.5 * (1 + np.cos(np.pi * R))
        
        return window

    def demosaic(self, raw_data, bayer_pattern):
        """
        Simplified demosaicking with gradient correction
        
        Parameters:
        -----------
        raw_data: ndarray
            Raw image data
        bayer_pattern: ndarray
            Bayer pattern
            
        Returns:
        --------
        Demosaicked RGB image
        """
        # For simplicity, we'll use OpenCV's demosaicking function
        # In a real implementation, you'd want to implement the advanced
        # demosaicking algorithm described in the paper
        
        # Convert to OpenCV's Bayer pattern format
        # This is a placeholder - real implementation would need proper pattern detection
        cv_bayer_pattern = cv2.COLOR_BAYER_RG2RGB
        
        # Ensure data is in correct format
        raw_data = raw_data.astype(np.uint16)
        
        # Demosaic
        rgb = cv2.cvtColor(raw_data, cv_bayer_pattern)
        
        # Normalize to [0, 1]
        rgb = rgb.astype(np.float32) / 65535.0
        
        return rgb

    def chroma_denoise(self, rgb):
        """
        Apply bilinear chroma denoising
        
        Parameters:
        -----------
        rgb: ndarray
            RGB image
            
        Returns:
        --------
        Denoised RGB image
        """
        # Convert to YUV
        y = 0.299 * rgb[:,:,0] + 0.587 * rgb[:,:,1] + 0.114 * rgb[:,:,2]
        u = -0.14713 * rgb[:,:,0] - 0.28886 * rgb[:,:,1] + 0.436 * rgb[:,:,2]
        v = 0.615 * rgb[:,:,0] - 0.51499 * rgb[:,:,1] - 0.10001 * rgb[:,:,2]
        
        # Apply bilateral filter to chroma channels
        u_denoised = cv2.bilateralFilter(u, 9, 75, 75)
        v_denoised = cv2.bilateralFilter(v, 9, 75, 75)
        
        # Convert back to RGB
        r = y + 1.13983 * v_denoised
        g = y - 0.39465 * u_denoised - 0.58060 * v_denoised
        b = y + 2.03211 * u_denoised
        
        # Clip values to [0, 1]
        denoised_rgb = np.stack([r, g, b], axis=2)
        denoised_rgb = np.clip(denoised_rgb, 0, 1)
        
        return denoised_rgb

    def apply_color_correction(self, rgb, color_matrix):
        """
        Apply sRGB color correction
        
        Parameters:
        -----------
        rgb: ndarray
            RGB image
        color_matrix: ndarray
            Color matrix from metadata
            
        Returns:
        --------
        Color corrected image
        """
        # Reshape color matrix to 3x3
        if color_matrix.shape[0] > 9:  # Some cameras have a larger matrix
            color_matrix = color_matrix[:9].reshape(3, 3)
        
        # Apply color correction
        h, w, _ = rgb.shape
        flat_rgb = rgb.reshape(-1, 3).T  # Reshape to 3 x (h*w)
        corrected_flat = np.dot(color_matrix, flat_rgb).T
        corrected = corrected_flat.reshape(h, w, 3)
        
        # Clip values to [0, 1]
        corrected = np.clip(corrected, 0, 1)
        
        return corrected

    def tone_map(self, image, strength=1.0, iterations=1):
        """
        Apply tone mapping using a Laplacian pyramid
        
        Parameters:
        -----------
        image: ndarray
            Input image
        strength: float
            Strength of tone mapping
        iterations: int
            Number of iterations for high contrast scenes
            
        Returns:
        --------
        Tone mapped image
        """
        # Create a simulated brighter exposure
        bright_exposure = np.power(image, 0.5) # Simulating brighter exposure
        
        # For each iteration
        result = image.copy()
        
        for _ in range(iterations):
            # Build Laplacian pyramids for both images
            levels = 5
            gaussian_pyr_orig = [result]
            gaussian_pyr_bright = [bright_exposure]
            
            current_orig = result
            current_bright = bright_exposure
            
            for _ in range(levels-1):
                current_orig = cv2.pyrDown(current_orig)
                gaussian_pyr_orig.append(current_orig)
                
                current_bright = cv2.pyrDown(current_bright)
                gaussian_pyr_bright.append(current_bright)
            
            laplacian_pyr_orig = []
            laplacian_pyr_bright = []
            
            for i in range(levels-1):
                laplacian_orig = gaussian_pyr_orig[i] - cv2.pyrUp(gaussian_pyr_orig[i+1])
                laplacian_pyr_orig.append(laplacian_orig)
                
                laplacian_bright = gaussian_pyr_bright[i] - cv2.pyrUp(gaussian_pyr_bright[i+1])
                laplacian_pyr_bright.append(laplacian_bright)
            
            # Add the last levels
            laplacian_pyr_orig.append(gaussian_pyr_orig[-1])
            laplacian_pyr_bright.append(gaussian_pyr_bright[-1])
            
            # Calculate weights for blending based on normal distribution
            # Ideal pixel value distribution is centered around 0.5
            mean = 0.5
            std = 0.2
            
            def weight_function(x):
                return np.exp(-((x - mean) ** 2) / (2 * std ** 2))
            
            # Blend the Laplacian pyramids
            blended_pyr = []
            
            for i in range(levels):
                # Calculate weights
                orig_weight = weight_function(gaussian_pyr_orig[i])
                bright_weight = weight_function(gaussian_pyr_bright[i])
                
                # Normalize weights
                total_weight = orig_weight + bright_weight
                orig_weight = orig_weight / total_weight
                bright_weight = bright_weight / total_weight
                
                # Blend
                blended = laplacian_pyr_orig[i] * orig_weight + laplacian_pyr_bright[i] * bright_weight
                blended_pyr.append(blended)
            
            # Reconstruct the image
            result = blended_pyr[-1]
            for i in range(levels-2, -1, -1):
                result = cv2.pyrUp(result) + blended_pyr[i]
        
        # Adjust strength
        if strength != 1.0:
            result = image * (1 - strength) + result * strength
        
        return result

    def gamma_correction(self, image, gamma=2.2):
        """
        Apply gamma correction
        
        Parameters:
        -----------
        image: ndarray
            Input image
        gamma: float
            Gamma value
            
        Returns:
        --------
        Gamma corrected image
        """
        return np.power(image, 1.0/gamma)

    def unsharp_mask(self, image, strength=0.5):
        """
        Apply unsharp mask sharpening
        
        Parameters:
        -----------
        image: ndarray
            Input image
        strength: float
            Strength of sharpening
            
        Returns:
        --------
        Sharpened image
        """
        blurred = cv2.GaussianBlur(image, (0, 0), 3)
        sharpen = image + strength * (image - blurred)
        sharpen = np.clip(sharpen, 0, 1)
        
        return sharpen

    def adjust_contrast(self, image, strength=1.1):
        """
        Apply global contrast adjustment
        
        Parameters:
        -----------
        image: ndarray
            Input image
        strength: float
            Contrast adjustment strength
            
        Returns:
        --------
        Contrast adjusted image
        """
        mean = np.mean(image)
        contrast = mean + (image - mean) * strength
        contrast = np.clip(contrast, 0, 1)
        
        return contrast

    def process_pipeline(self, raw_burst):
        """
        Apply full HDR+ pipeline to a burst of raw images
        
        Parameters:
        -----------
        raw_burst: list
            List of raw images
            
        Returns:
        --------
        Processed RGB image
        """
        print("Aligning frames...")
        alignments, reference_idx = self.align_frames(raw_burst)
        
        print("Merging frames...")
        merged_raw = self.merge_frames(raw_burst, alignments, reference_idx)
        
        print("Applying finishing steps...")
        # Black level subtraction already done during merging
        
        # White balance
        wb = merged_raw['camera_white_balance'][:3]  # Get first 3 channels
        wb = wb / wb[1]  # Normalize to green channel
        
        # Demosaic
        rgb = self.demosaic(merged_raw['data'], merged_raw['bayer_pattern'])
        
        # Apply white balance
        rgb[:,:,0] *= wb[0]
        rgb[:,:,2] *= wb[2]
        
        # Chroma denoising
        rgb = self.chroma_denoise(rgb)
        
        # Color correction
        rgb = self.apply_color_correction(rgb, merged_raw['color_matrix'])
        
        # Tone mapping
        rgb = self.tone_map(rgb, strength=1.0, iterations=2)
        
        # Gamma correction
        rgb = self.gamma_correction(rgb)
        
        # Contrast adjustment
        rgb = self.adjust_contrast(rgb)
        
        # Sharpening
        rgb = self.unsharp_mask(rgb)
        
        return rgb

    def save_image(self, image, output_path):
        """
        Save processed image
        
        Parameters:
        -----------
        image: ndarray
            Processed image
        output_path: str
            Path to save the image
        """
        # Convert to 8-bit
        image_8bit = (image * 255).astype(np.uint8)
        
        # Save using imageio
        imageio.imwrite(output_path, image_8bit)
        print(f"Saved output to {output_path}")

def main():
    """
    Example usage of the HDR+ pipeline
    """
    # Initialize HDR+ pipeline
    hdr_plus = HDRPlus(tile_size=16, search_region=4, pyramid_levels=3, num_frames=8)
    
    # Load raw burst (this would be your raw files)
    raw_burst = hdr_plus.load_raw_burst("path/to/your/raw/files/*.CR2")
    
    # Process pipeline
    output_image = hdr_plus.process_pipeline(raw_burst)
    
    # Save result
    hdr_plus.save_image(output_image, "hdr_plus_output.jpg")

if __name__ == "__main__":
    main()

ValueError: No files found with pattern: path/to/your/raw/files/*.CR2