In [7]:
"""
Specular Reflection Processor for Endoscopic Video

Just change the videoname variable at the top of this file and run the entire script.
The program will process the video and save samples to their respective folders.
"""

# ============= CONFIGURATION (CHANGE THIS) =============
videoname = "video.mp4"  # Set this to your video filename

# ============= PROCESSING PARAMETERS =============
# You can adjust these parameters as needed
threshold = 0.08        # Threshold for specular reflection detection (0.05-0.20)
sample_rate = 10        # Save a frame every N frames
max_frames = None       # Set to a number to limit frames processed, or None for all
highlight_color = (0, 255, 0)  # Green color for highlighting reflections
use_debug = True        # Whether to save debug visualizations
use_gpu = True          # Whether to use GPU for processing if available
fill_radius = 10        # Radius (in pixels) to average for filling when no data is available

# ============= IMPORTS =============
import cv2
import numpy as np
import os
import pywt
import matplotlib.pyplot as plt
import time
from tqdm import tqdm  # Using tqdm instead of tqdm.notebook for broader compatibility
import torch
from pathlib import Path
import glob

# ============= MAIN CLASS =============
class SpecularReflectionProcessor:
    def __init__(self, 
                 wavelet='db4',         
                 threshold=0.15,        
                 level=3,               
                 use_gpu=True,          
                 highlight_color=(0, 255, 0),  
                 debug=True):           
        
        self.wavelet = wavelet
        self.threshold = threshold
        self.level = level
        self.use_gpu = use_gpu and torch.cuda.is_available()
        self.highlight_color = highlight_color
        self.debug = debug
        
        # Create output directories
        self.output_dir = "reflection_processing"
        self.original_dir = os.path.join(self.output_dir, "original")
        self.highlighted_dir = os.path.join(self.output_dir, "highlighted")
        self.processed_dir = os.path.join(self.output_dir, "processed")  
        self.mask_dir = os.path.join(self.output_dir, "masks")
        self.debug_dir = os.path.join(self.output_dir, "debug")
        
        # Create directories if they don't exist
        for directory in [self.output_dir, self.original_dir, self.highlighted_dir, 
                          self.processed_dir, self.mask_dir, self.debug_dir]:
            os.makedirs(directory, exist_ok=True)
        
        print(f"GPU Available: {torch.cuda.is_available()}")
        print(f"Using GPU: {self.use_gpu}")
        
        if self.use_gpu:
            self.device = torch.device("cuda:0")
        else:
            self.device = torch.device("cpu")
            
        # Initialize repository of non-reflective pixels
        self.baseline_repository = None
        self.last_valid_frame = None
    
    def _calculate_wavelet_features(self, frame):
        """Calculate wavelet decomposition features from the frame."""
        # Convert to grayscale for wavelet analysis
        if len(frame.shape) == 3:
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        else:
            gray = frame.copy()
            
        # Apply wavelet transform
        coeffs = pywt.wavedec2(gray, self.wavelet, level=self.level)
        
        # Extract high-frequency components (details)
        details = []
        for detail_coeffs in coeffs[1:]:
            details.extend([np.abs(detail_coeffs[i]) for i in range(3)])
        
        # Normalize and stack details
        detail_features = np.stack([cv2.resize(d, (gray.shape[1], gray.shape[0])) for d in details])
        detail_features = np.max(detail_features, axis=0)
        
        return detail_features
    
    def _detect_specular_regions(self, frame):
        """Detect specular reflection regions using wavelets and HSV analysis."""
        # Calculate wavelet features
        features = self._calculate_wavelet_features(frame)
        
        # Convert to HSV for better highlight detection
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        
        # High Value and low Saturation indicates specular reflection
        high_v = v > 200
        low_s = s < 40
        
        # Combine with wavelet features
        normalized_features = (features - np.min(features)) / (np.max(features) - np.min(features) + 1e-6)
        wavelet_mask = normalized_features > self.threshold
        
        # Combine masks
        combined_mask = np.logical_or(np.logical_and(high_v, low_s), wavelet_mask)
        
        # Clean up the mask
        kernel = np.ones((3, 3), np.uint8)
        mask = cv2.morphologyEx(combined_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        
        # Further refine the mask
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        refined_mask = np.zeros_like(mask)
        
        # Only keep contours of meaningful size
        for contour in contours:
            area = cv2.contourArea(contour)
            if area > 10:  # Minimum area threshold
                cv2.drawContours(refined_mask, [contour], -1, 1, -1)
        
        return refined_mask
    
    def _create_highlighted_image(self, frame, mask):
        """Create version of frame with reflections highlighted in green."""
        highlighted = frame.copy()
        mask_bool = mask.astype(bool)
        highlighted[mask_bool] = self.highlight_color
        return highlighted
    
    def _update_baseline_repository(self, frame, mask):
        """Update repository of non-reflective pixels."""
        if self.baseline_repository is None:
            # Initialize the repository with the current frame
            self.baseline_repository = frame.copy()
            # Set reflective areas to zeros (to be filled later)
            self.baseline_repository[mask.astype(bool)] = 0
            self.last_valid_frame = frame.copy()
        else:
            # Update only the non-reflective areas
            non_reflective_areas = ~mask.astype(bool)
            self.baseline_repository[non_reflective_areas] = frame[non_reflective_areas]
            
            # Keep track of the most recent valid frame
            self.last_valid_frame = frame.copy()
    
    def _inpaint_reflections(self, frame, mask, radius=10):
        """
        Replace reflective areas using baseline repository only. 
        If no repository data exists, keep original pixels.
        
        Args:
            frame: The input frame
            mask: Binary mask of reflective areas
            radius: Unused parameter (kept for compatibility)
            
        Returns:
            Processed frame with reflections removed where repository data exists
        """
        if self.baseline_repository is None:
            # If no repository exists yet, just return the original frame
            return frame.copy()
        
        # Create a processed frame
        processed = frame.copy()
        
        # Get reflective areas
        reflective_areas = mask.astype(bool)
        
        # Use the baseline repository where valid data exists
        valid_repository = (self.baseline_repository != 0).all(axis=2) if len(self.baseline_repository.shape) == 3 else (self.baseline_repository != 0)
        usable_repository = valid_repository & reflective_areas
        
        if np.any(usable_repository):
            # Use repository data where available
            processed[usable_repository] = self.baseline_repository[usable_repository]
        
        # For areas not in repository, keep original pixels
        # (This happens automatically since we're using a copy of the original frame)
        
        return processed
    
    def process_video(self, input_path, output_prefix, sample_rate=10, max_frames=None, fill_radius=10):
        """
        Process video to detect, highlight and remove specular reflections.
        
        Args:
            input_path: Path to input video file
            output_prefix: Prefix for output filenames
            sample_rate: Save a frame every N frames
            max_frames: Maximum number of frames to process
            fill_radius: Radius for neighborhood averaging when no data is available
        """
        # Open the video file
        video = cv2.VideoCapture(input_path)
        if not video.isOpened():
            raise ValueError(f"Could not open video file: {input_path}")
        
        # Get video properties
        frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = int(video.get(cv2.CAP_PROP_FPS))
        width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        # For creating output video
        output_video_path = os.path.join(self.output_dir, f"{output_prefix}_processed.mp4")
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_video_path, fourcc, fps // sample_rate, (width, height))
        
        # Limit frames if specified
        if max_frames is not None:
            frame_count = min(frame_count, max_frames)
        
        # Process each frame
        processed_frames = 0
        saved_frames = 0
        
        start_time = time.time()
        
        self.baseline_repository = None  # Reset repository
        
        for i in tqdm(range(frame_count), desc="Processing video"):
            ret, frame = video.read()
            if not ret:
                break
            
            # Process every frame for repository updates
            # Detect specular reflections
            specular_mask = self._detect_specular_regions(frame)
            
            # Update the baseline repository with non-reflective areas
            self._update_baseline_repository(frame, specular_mask)
            
            # Process the frame if it's time to save it
            if i % sample_rate == 0:
                # Create highlighted version
                highlighted_frame = self._create_highlighted_image(frame, specular_mask)
                
                # Create processed version (with reflections removed)
                processed_frame = self._inpaint_reflections(frame, specular_mask, radius=fill_radius)
                
                # Save the frames
                frame_filename = f"{output_prefix}_{saved_frames:04d}"
                
                cv2.imwrite(os.path.join(self.original_dir, f"{frame_filename}.jpg"), frame)
                cv2.imwrite(os.path.join(self.highlighted_dir, f"{frame_filename}.jpg"), highlighted_frame)
                cv2.imwrite(os.path.join(self.processed_dir, f"{frame_filename}.jpg"), processed_frame)
                cv2.imwrite(os.path.join(self.mask_dir, f"{frame_filename}.png"), specular_mask * 255)
                
                # Write to video
                out.write(processed_frame)
                
                if self.debug and saved_frames % 10 == 0:
                    # Create visualization of the detection
                    wavelet_features = self._calculate_wavelet_features(frame)
                    normalized_features = (wavelet_features - np.min(wavelet_features)) / (np.max(wavelet_features) - np.min(wavelet_features) + 1e-6)
                    wavelet_vis = cv2.applyColorMap((normalized_features * 255).astype(np.uint8), cv2.COLORMAP_JET)
                    
                    # Create side-by-side comparison for debugging
                    top_row = np.hstack((frame, highlighted_frame))
                    bottom_row = np.hstack((processed_frame, wavelet_vis))
                    comparison = np.vstack((top_row, bottom_row))
                    
                    cv2.imwrite(os.path.join(self.debug_dir, f"{frame_filename}_comparison.jpg"), comparison)
                
                saved_frames += 1
            
            processed_frames += 1
        
        # Clean up
        video.release()
        out.release()
        
        # Calculate processing stats
        end_time = time.time()
        processing_time = end_time - start_time
        
        print(f"Processed {processed_frames} frames and saved {saved_frames} frame sets in {processing_time:.2f} seconds")
        print(f"Output video saved to: {output_video_path}")
        
        return saved_frames
    
    def display_sample_results(self, num_samples=3):
        """Display sample results using matplotlib."""
        original_files = sorted(glob.glob(os.path.join(self.original_dir, "*.jpg")))
        if not original_files:
            print("No processed files found.")
            return
            
        # Select a few samples (every 10th saved frame)
        samples = []
        for i, path in enumerate(original_files):
            if i % 10 == 0 and len(samples) < num_samples:
                samples.append(path)
        
        if not samples:
            samples = original_files[:num_samples]
            
        # Create figure
        fig, axes = plt.subplots(len(samples), 3, figsize=(15, 5 * len(samples)))
        if len(samples) == 1:
            axes = [axes]
            
        for i, orig_path in enumerate(samples):
            base_name = os.path.basename(orig_path)
            
            # Load images
            original = cv2.imread(orig_path)
            highlighted = cv2.imread(os.path.join(self.highlighted_dir, base_name))
            processed = cv2.imread(os.path.join(self.processed_dir, base_name))
            
            # Convert to RGB for display
            original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
            highlighted = cv2.cvtColor(highlighted, cv2.COLOR_BGR2RGB)
            processed = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)
            
            # Display
            axes[i][0].imshow(original)
            axes[i][0].set_title('Original')
            axes[i][0].axis('off')
            
            axes[i][1].imshow(highlighted)
            axes[i][1].set_title('Reflections Highlighted')
            axes[i][1].axis('off')
            
            axes[i][2].imshow(processed)
            axes[i][2].set_title('Processed (Reflections Removed)')
            axes[i][2].axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, "sample_results.png"))
        plt.show()

# ============= MAIN FUNCTION =============
if __name__ == "__main__":
    # Extract video name without extension
    video_base_name = os.path.splitext(os.path.basename(videoname))[0]
    
    print(f"Processing video: {videoname}")
    print(f"Output will be saved to the 'reflection_processing' directory")
    print(f"Using fill radius of {fill_radius} pixels for neighborhood averaging")
    
    # Create processor with specified parameters
    processor = SpecularReflectionProcessor(
        threshold=threshold,
        highlight_color=highlight_color,
        use_gpu=use_gpu,
        debug=use_debug
    )
    
    # Process the video
    processor.process_video(
        videoname,
        video_base_name,
        sample_rate=sample_rate,
        max_frames=max_frames,
        fill_radius=fill_radius
    )
    
    # Display sample results
    print("\nDisplaying sample results...")
    processor.display_sample_results(num_samples=3)
    
    print("\nProcessing complete!")
    print("Results saved in the following directories:")
    print(f"  - Original frames: {processor.original_dir}")
    print(f"  - Highlighted frames: {processor.highlighted_dir}")
    print(f"  - Processed frames: {processor.processed_dir}")
    print(f"  - Masks: {processor.mask_dir}")
    if use_debug:
        print(f"  - Debug visualizations: {processor.debug_dir}")
    print(f"  - Processed video: {os.path.join(processor.output_dir, f'{video_base_name}_processed.mp4')}")

Processing video: video.mp4
Output will be saved to the 'reflection_processing' directory
Using fill radius of 10 pixels for neighborhood averaging
GPU Available: True
Using GPU: True


Processing video:  72%|███████▏  | 1114/1548 [02:31<00:58,  7.38it/s]


KeyboardInterrupt: 