In [None]:
"""
Specular Reflection Processor for Endoscopic Video

Just change the videoname variable at the top of this file and run the entire script.
The program will process the video and save samples to their respective folders.
"""

# ============= CONFIGURATION (CHANGE THIS) =============
videoname = "video.mp4"  # Set this to your video filename

# ============= PROCESSING PARAMETERS =============
# You can adjust these parameters as needed
threshold = 0.06        # Threshold for specular reflection detection (0.05-0.20)
sample_rate = 0         # Save a frame every N frames (0 to disable saving individual frames)
max_frames = 1200       # Set to a number to limit frames processed, or None for all
highlight_color = (0, 255, 0)  # Green color for highlighting reflections
use_debug = True        # Whether to save debug visualizations
use_gpu = True          # Whether to use GPU for processing if available
fill_radius = 10        # Radius (in pixels) to average for filling when no data is available

# Color deviation filter parameters
max_color_deviation = 150     # Maximum allowed color difference for repository updates (0-255)
min_initial_frames = 20      # Minimum frames to build initial repository before applying deviation filter
strict_update_factor = 0.5   # Factor to make threshold stricter for previously reflective areas (0.0-1.0)
max_brightness = 220         # Maximum pixel brightness to add to repository (0-255)

# ============= IMPORTS =============
import cv2
import numpy as np
import os
import pywt
import matplotlib.pyplot as plt
import time
from tqdm import tqdm  # Using tqdm instead of tqdm.notebook for broader compatibility
import torch
from pathlib import Path
import glob

# ============= MAIN CLASS =============
class SpecularReflectionProcessor:
    def __init__(self, 
                 wavelet='db4',         
                 threshold=0.15,        
                 level=3,               
                 use_gpu=True,          
                 highlight_color=(0, 255, 0),  
                 debug=True):           
        
        self.wavelet = wavelet
        self.threshold = threshold
        self.level = level
        self.use_gpu = use_gpu and torch.cuda.is_available()
        self.highlight_color = highlight_color
        self.debug = debug
        
        # Create output directories
        self.output_dir = "reflection_processing"
        self.original_dir = os.path.join(self.output_dir, "original")
        self.highlighted_dir = os.path.join(self.output_dir, "highlighted")
        self.processed_dir = os.path.join(self.output_dir, "processed")  
        self.mask_dir = os.path.join(self.output_dir, "masks")
        self.debug_dir = os.path.join(self.output_dir, "debug")
        
        # Create directories if they don't exist
        for directory in [self.output_dir, self.original_dir, self.highlighted_dir, 
                          self.processed_dir, self.mask_dir, self.debug_dir]:
            os.makedirs(directory, exist_ok=True)
        
        print(f"GPU Available: {torch.cuda.is_available()}")
        print(f"Using GPU: {self.use_gpu}")
        
        if self.use_gpu:
            self.device = torch.device("cuda:0")
        else:
            self.device = torch.device("cpu")
            
        # Initialize repository of non-reflective pixels
        self.baseline_repository = None
        self.last_valid_frame = None
        self.processed_frame_count = 0
    
    def _calculate_wavelet_features(self, frame):
        """Calculate wavelet decomposition features from the frame."""
        # Convert to grayscale for wavelet analysis
        if len(frame.shape) == 3:
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        else:
            gray = frame.copy()
            
        # Apply wavelet transform
        coeffs = pywt.wavedec2(gray, self.wavelet, level=self.level)
        
        # Extract high-frequency components (details)
        details = []
        for detail_coeffs in coeffs[1:]:
            details.extend([np.abs(detail_coeffs[i]) for i in range(3)])
        
        # Normalize and stack details
        detail_features = np.stack([cv2.resize(d, (gray.shape[1], gray.shape[0])) for d in details])
        detail_features = np.max(detail_features, axis=0)
        
        return detail_features
    
    def _detect_specular_regions(self, frame):
        """Detect specular reflection regions using wavelets and HSV analysis."""
        # Calculate wavelet features
        features = self._calculate_wavelet_features(frame)
        
        # Convert to HSV for better highlight detection
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        
        # High Value and low Saturation indicates specular reflection
        high_v = v > 200
        low_s = s < 40
        
        # Combine with wavelet features
        normalized_features = (features - np.min(features)) / (np.max(features) - np.min(features) + 1e-6)
        wavelet_mask = normalized_features > self.threshold
        
        # Combine masks
        combined_mask = np.logical_or(np.logical_and(high_v, low_s), wavelet_mask)
        
        # Clean up the mask
        kernel = np.ones((3, 3), np.uint8)
        mask = cv2.morphologyEx(combined_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        
        # Further refine the mask
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        refined_mask = np.zeros_like(mask)
        
        # Only keep contours of meaningful size
        for contour in contours:
            area = cv2.contourArea(contour)
            if area > 10:  # Minimum area threshold
                cv2.drawContours(refined_mask, [contour], -1, 1, -1)
        
        return refined_mask
    
    def _create_highlighted_image(self, frame, mask):
        """Create version of frame with reflections highlighted in green."""
        highlighted = frame.copy()
        mask_bool = mask.astype(bool)
        highlighted[mask_bool] = self.highlight_color
        return highlighted
    
    def _update_baseline_repository(self, frame, mask, frame_count=0):
        """
        Update repository of non-reflective pixels with color deviation and brightness filtering.
        
        Args:
            frame: The input frame
            mask: Binary mask of reflective areas (1=reflection, 0=valid tissue)
            frame_count: Current frame count for initial repository building
        """
        # Always use the specular mask to identify reflective regions
        reflective_areas = mask.astype(bool)
        non_reflective_areas = ~reflective_areas
        
        # Calculate brightness for brightness filtering
        # Convert to grayscale as a simple brightness measure
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        
        # Create a mask for areas that are too bright (potentially reflective)
        too_bright = gray > max_brightness
        
        # Combined safety mask: exclude reflective areas AND too bright areas
        safe_areas = non_reflective_areas & ~too_bright
        
        if self.baseline_repository is None:
            # Initialize the repository with the current frame
            self.baseline_repository = frame.copy()
            # Set reflective areas and too bright areas to zeros
            self.baseline_repository[reflective_areas | too_bright] = 0
            self.last_valid_frame = frame.copy()
            # Initialize frame counter for color deviation filtering
            self.processed_frame_count = 1
        else:
            # Update strategy depends on how many frames we've processed
            if self.processed_frame_count < min_initial_frames:
                # During initial repository building, update ONLY safe areas
                self.baseline_repository[safe_areas] = frame[safe_areas]
            else:
                # After initial phase, apply color deviation filter
                # But ONLY for areas that are safe (non-reflective and not too bright)
                # Calculate color deviation between current frame and repository
                color_diff = np.abs(frame.astype(np.float32) - self.baseline_repository.astype(np.float32))
                # Sum differences across color channels
                total_diff = np.sum(color_diff, axis=2)
                # Identify areas with acceptable color deviation
                acceptable_deviation = total_diff < max_color_deviation
                
                # Only consider safe areas for updates
                valid_updates = safe_areas & acceptable_deviation
                
                # We still need to handle previously reflective areas with special care
                valid_repository = (self.baseline_repository != 0).all(axis=2)
                previously_reflective = ~valid_repository
                
                # For areas that were previously reflective but are now safe,
                # we apply a stricter threshold
                if np.any(previously_reflective & valid_updates):
                    # Apply stricter threshold for previously reflective areas
                    strictest_threshold = max_color_deviation * strict_update_factor
                    stricter_acceptable = total_diff < strictest_threshold
                    
                    # Only update previously reflective areas if they meet the stricter threshold
                    # AND are currently safe
                    valid_updates = valid_updates & (~previously_reflective | (stricter_acceptable & safe_areas))
                
                # Final safety check: NEVER update unsafe areas
                valid_updates = valid_updates & safe_areas
                
                # Apply the validated updates
                self.baseline_repository[valid_updates] = frame[valid_updates]
            
            # Keep track of the most recent valid frame
            self.last_valid_frame = frame.copy()
            self.processed_frame_count += 1
            
            # For debugging: count how many pixels are in the repository
            if self.debug and self.processed_frame_count % 10 == 0:
                valid_pixels = (self.baseline_repository != 0).all(axis=2).sum()
                total_pixels = self.baseline_repository.shape[0] * self.baseline_repository.shape[1]
                print(f"Repository coverage: {valid_pixels}/{total_pixels} pixels ({valid_pixels/total_pixels*100:.2f}%)")
    
    def _inpaint_reflections(self, frame, mask, radius=10):
        """
        Replace reflective areas using baseline repository only. 
        If no repository data exists, keep original pixels.
        
        Args:
            frame: The input frame
            mask: Binary mask of reflective areas
            radius: Unused parameter (kept for compatibility)
            
        Returns:
            Processed frame with reflections removed where repository data exists
        """
        if self.baseline_repository is None:
            # If no repository exists yet, just return the original frame
            return frame.copy()
        
        # Create a processed frame
        processed = frame.copy()
        
        # Get reflective areas
        reflective_areas = mask.astype(bool)
        
        # Use the baseline repository where valid data exists
        valid_repository = (self.baseline_repository != 0).all(axis=2) if len(self.baseline_repository.shape) == 3 else (self.baseline_repository != 0)
        usable_repository = valid_repository & reflective_areas
        
        if np.any(usable_repository):
            # Use repository data where available
            processed[usable_repository] = self.baseline_repository[usable_repository]
        
        # For areas not in repository, keep original pixels
        # (This happens automatically since we're using a copy of the original frame)
        
        return processed
    
    def process_video(self, input_path, output_prefix, sample_rate=10, max_frames=None, fill_radius=10):
        """
        Process video to detect, highlight and remove specular reflections.
        Creates a multi-view output video showing original, processed, and highlighted views.
        
        Args:
            input_path: Path to input video file
            output_prefix: Prefix for output filenames
            sample_rate: Save a frame every N frames (0 to disable)
            max_frames: Maximum number of frames to process
            fill_radius: Radius parameter (unused but kept for compatibility)
        """
        # Open the video file
        video = cv2.VideoCapture(input_path)
        if not video.isOpened():
            raise ValueError(f"Could not open video file: {input_path}")
        
        # Get video properties
        frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = int(video.get(cv2.CAP_PROP_FPS))
        width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        # Create composite video dimensions
        # Layout: Top row: Original and Processed side by side
        #         Bottom row: Highlighted reflections
        comp_width = width * 2
        comp_height = height * 2
        
        # For creating output video
        output_video_path = os.path.join(self.output_dir, f"{output_prefix}_tri_view.mp4")
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_video_path, fourcc, fps, (comp_width, comp_height))
        
        # Limit frames if specified
        if max_frames is not None:
            frame_count = min(frame_count, max_frames)
        
        # Process each frame
        processed_frames = 0
        saved_frames = 0
        
        start_time = time.time()
        processing_times = []  # Store processing times for FPS calculation
        
        self.baseline_repository = None  # Reset repository
        self.processed_frame_count = 0  # Reset frame counter for color deviation filtering
        
        for i in tqdm(range(frame_count), desc="Processing video"):
            frame_start_time = time.time()
            
            ret, frame = video.read()
            if not ret:
                break
            
            # Process every frame for repository updates
            # Detect specular reflections
            specular_mask = self._detect_specular_regions(frame)
            
            # Update the baseline repository with non-reflective areas
            self._update_baseline_repository(frame, specular_mask, i)
            
            # Create highlighted version
            highlighted_frame = self._create_highlighted_image(frame, specular_mask)
            
            # Create processed version (with reflections removed)
            processed_frame = self._inpaint_reflections(frame, specular_mask)
            
            # Calculate processing time for this frame
            frame_processing_time = time.time() - frame_start_time
            processing_times.append(frame_processing_time)
            
            # Calculate FPS based on recent frames
            recent_times = processing_times[-20:] if len(processing_times) > 20 else processing_times
            avg_time = sum(recent_times) / len(recent_times)
            fps_text = f"Processing: {1.0/avg_time:.2f} frames/sec"
            
            # Create the composite frame
            # Create a black canvas
            composite = np.zeros((comp_height, comp_width, 3), dtype=np.uint8)
            
            # Add original frame (top left)
            composite[0:height, 0:width] = frame
            
            # Add processed frame (top right)
            composite[0:height, width:width*2] = processed_frame
            
            # Add highlighted frame (bottom, centered)
            composite[height:height*2, width//2:width//2+width] = highlighted_frame
            
            # Add labels to each view
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.8
            font_color = (255, 255, 255)
            line_thickness = 2
            
            cv2.putText(composite, "Original", (10, 30), font, font_scale, font_color, line_thickness)
            cv2.putText(composite, "Processed", (width + 10, 30), font, font_scale, font_color, line_thickness)
            cv2.putText(composite, "Highlighted", (width//2 + 10, height + 30), font, font_scale, font_color, line_thickness)
            
            # Add FPS counter and frame info
            cv2.putText(composite, fps_text, (10, comp_height - 20), font, font_scale, font_color, line_thickness)
            frame_info = f"Frame: {i+1}/{frame_count} | Filters: {'On' if self.processed_frame_count >= min_initial_frames else 'Off'} | Max bright: {max_brightness}"
            cv2.putText(composite, frame_info, (comp_width - 500, comp_height - 20), font, font_scale, font_color, line_thickness)
            
            # Write to video
            out.write(composite)
            
            # Save individual frames if sample_rate is greater than 0
            if sample_rate > 0 and i % sample_rate == 0:
                frame_filename = f"{output_prefix}_{saved_frames:04d}"
                
                cv2.imwrite(os.path.join(self.original_dir, f"{frame_filename}.jpg"), frame)
                cv2.imwrite(os.path.join(self.highlighted_dir, f"{frame_filename}.jpg"), highlighted_frame)
                cv2.imwrite(os.path.join(self.processed_dir, f"{frame_filename}.jpg"), processed_frame)
                cv2.imwrite(os.path.join(self.mask_dir, f"{frame_filename}.png"), specular_mask * 255)
                
                # Save the composite view
                cv2.imwrite(os.path.join(self.debug_dir, f"{frame_filename}_composite.jpg"), composite)
                
                saved_frames += 1
            
            processed_frames += 1
        
        # Clean up
        video.release()
        out.release()
        
        # Calculate overall processing stats
        end_time = time.time()
        total_processing_time = end_time - start_time
        
        print(f"Processed {processed_frames} frames in {total_processing_time:.2f} seconds")
        print(f"Average processing speed: {processed_frames/total_processing_time:.2f} frames/second")
        print(f"Output composite video saved to: {output_video_path}")
        
        return saved_frames
    
    def display_sample_results(self, num_samples=3):
        """Display sample results using matplotlib."""
        original_files = sorted(glob.glob(os.path.join(self.original_dir, "*.jpg")))
        if not original_files:
            print("No processed files found.")
            return
            
        # Select a few samples (every 10th saved frame)
        samples = []
        for i, path in enumerate(original_files):
            if i % 10 == 0 and len(samples) < num_samples:
                samples.append(path)
        
        if not samples:
            samples = original_files[:num_samples]
            
        # Create figure
        fig, axes = plt.subplots(len(samples), 3, figsize=(15, 5 * len(samples)))
        if len(samples) == 1:
            axes = [axes]
            
        for i, orig_path in enumerate(samples):
            base_name = os.path.basename(orig_path)
            
            # Load images
            original = cv2.imread(orig_path)
            highlighted = cv2.imread(os.path.join(self.highlighted_dir, base_name))
            processed = cv2.imread(os.path.join(self.processed_dir, base_name))
            
            # Convert to RGB for display
            original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
            highlighted = cv2.cvtColor(highlighted, cv2.COLOR_BGR2RGB)
            processed = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)
            
            # Display
            axes[i][0].imshow(original)
            axes[i][0].set_title('Original')
            axes[i][0].axis('off')
            
            axes[i][1].imshow(highlighted)
            axes[i][1].set_title('Reflections Highlighted')
            axes[i][1].axis('off')
            
            axes[i][2].imshow(processed)
            axes[i][2].set_title('Processed (Reflections Removed)')
            axes[i][2].axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, "sample_results.png"))
        plt.show()

# ============= MAIN FUNCTION =============
if __name__ == "__main__":
    # Extract video name without extension
    video_base_name = os.path.splitext(os.path.basename(videoname))[0]
    
    print(f"Processing video: {videoname}")
    print(f"Output will be saved to the 'reflection_processing' directory")
    print(f"Maximum brightness threshold: {max_brightness}")
    
    # Create processor with specified parameters
    processor = SpecularReflectionProcessor(
        threshold=threshold,
        highlight_color=highlight_color,
        use_gpu=use_gpu,
        debug=use_debug
    )
    
    # Process the video
    processor.process_video(
        videoname,
        video_base_name,
        sample_rate=sample_rate,
        max_frames=max_frames,
        fill_radius=fill_radius
    )
    
    # Display sample results if any samples were saved
    if sample_rate > 0:
        print("\nDisplaying sample results...")
        processor.display_sample_results(num_samples=3)
    
    print("\nProcessing complete!")
    print("Results saved in the following directories:")
    print(f"  - Original frames: {processor.original_dir}")
    print(f"  - Highlighted frames: {processor.highlighted_dir}")
    print(f"  - Processed frames: {processor.processed_dir}")
    print(f"  - Masks: {processor.mask_dir}")
    if use_debug:
        print(f"  - Debug visualizations: {processor.debug_dir}")
    print(f"  - Processed video: {os.path.join(processor.output_dir, f'{video_base_name}_tri_view.mp4')}")

Processing video: video.mp4
Output will be saved to the 'reflection_processing' directory
Maximum brightness threshold: 220
GPU Available: True
Using GPU: True


Processing video:   1%|          | 10/1200 [00:01<03:22,  5.88it/s]

Repository coverage: 222121/921600 pixels (24.10%)


Processing video:   2%|▏         | 20/1200 [00:03<03:47,  5.19it/s]

Repository coverage: 225519/921600 pixels (24.47%)


Processing video:   2%|▎         | 30/1200 [00:05<03:52,  5.02it/s]

Repository coverage: 224416/921600 pixels (24.35%)


Processing video:   3%|▎         | 40/1200 [00:07<03:52,  5.00it/s]

Repository coverage: 225801/921600 pixels (24.50%)


Processing video:   4%|▍         | 50/1200 [00:09<03:58,  4.83it/s]

Repository coverage: 226716/921600 pixels (24.60%)


Processing video:   5%|▌         | 60/1200 [00:11<03:53,  4.87it/s]

Repository coverage: 225482/921600 pixels (24.47%)


Processing video:   6%|▌         | 70/1200 [00:13<03:49,  4.92it/s]

Repository coverage: 226943/921600 pixels (24.62%)


Processing video:   7%|▋         | 80/1200 [00:15<03:45,  4.96it/s]

Repository coverage: 225442/921600 pixels (24.46%)


Processing video:   8%|▊         | 90/1200 [00:17<03:36,  5.12it/s]

Repository coverage: 219770/921600 pixels (23.85%)


Processing video:   8%|▊         | 100/1200 [00:19<03:31,  5.19it/s]

Repository coverage: 216465/921600 pixels (23.49%)


Processing video:   9%|▉         | 110/1200 [00:21<03:30,  5.17it/s]

Repository coverage: 222304/921600 pixels (24.12%)


Processing video:  10%|█         | 120/1200 [00:23<03:35,  5.02it/s]

Repository coverage: 226686/921600 pixels (24.60%)


Processing video:  11%|█         | 130/1200 [00:25<03:26,  5.19it/s]

Repository coverage: 229379/921600 pixels (24.89%)


Processing video:  12%|█▏        | 140/1200 [00:27<03:26,  5.14it/s]

Repository coverage: 232038/921600 pixels (25.18%)


Processing video:  12%|█▎        | 150/1200 [00:29<03:25,  5.10it/s]

Repository coverage: 234470/921600 pixels (25.44%)


Processing video:  13%|█▎        | 160/1200 [00:30<03:26,  5.03it/s]