In [13]:
import cv2
import numpy as np
import pywt
import torch
import time
import os
from collections import deque
import threading
from queue import Queue
import psutil
import matplotlib.pyplot as plt
from IPython import display # Import IPython.display

# --- RealtimeSpecularProcessor class (UNCHANGED from previous correct version) ---
# (Paste the full RealtimeSpecularProcessor class here. For brevity, I'm omitting it,
#  but ensure it's the one that includes the _rgb_to_hsv_torch method)
class RealtimeSpecularProcessor:
    def __init__(self, 
                 wavelet='db4',
                 threshold=0.06,
                 level=2,  # Reduced from 3 for speed to optimize for real-time
                 use_gpu=True,
                 detection_scale=0.5,  # Process detection at half resolution for speed
                 repository_update_rate=5,  # Update repository every N frames
                 repository_alpha=0.05): # Alpha for exponential moving average, lower for more stability
        
        self.wavelet = wavelet
        self.threshold = threshold
        self.level = level
        self.detection_scale = detection_scale
        self.repository_update_rate = repository_update_rate
        self.repository_alpha = repository_alpha
        
        # GPU setup
        self.use_gpu = use_gpu and torch.cuda.is_available()
        self.device = torch.device("cuda:0" if self.use_gpu else "cpu")
        
        # Repository management: Stores the averaged background image and its confidence
        self.current_repository = None # Stores the averaged background image (float32, HWC)
        self.repository_confidence_map = None # Stores confidence/update count for each pixel (float32, HW, range 0-1)
        
        self.frame_count = 0
        
        # Performance monitoring
        self.processing_times = deque(maxlen=30)
        self.fps_history = deque(maxlen=10)
        
        # Threading for background repository updates (non-blocking)
        self.repository_queue = Queue(maxsize=10) 
        self.repository_thread = threading.Thread(target=self._repository_updater, daemon=True)
        self.repository_thread.start()
        
        print(f"GPU Available: {torch.cuda.is_available()}")
        print(f"Using GPU: {self.use_gpu}")
        print(f"Detection scale: {self.detection_scale}")

    def _rgb_to_hsv_torch(self, image_rgb_chw):
        """
        Converts a batch of RGB images to HSV.
        Input: torch.Tensor (C, H, W) normalized to [0.0, 1.0].
               C=3 for R, G, B channels.
        Output: torch.Tensor (C, H, W) with H, S, V channels normalized to [0.0, 1.0].
        This is a vectorized implementation suitable for GPU.
        """
        r, g, b = image_rgb_chw[0], image_rgb_chw[1], image_rgb_chw[2]

        C_max = torch.max(r, torch.max(g, b))
        C_min = torch.min(r, torch.min(g, b))
        delta = C_max - C_min

        h = torch.zeros_like(r, device=self.device)
        s = torch.zeros_like(r, device=self.device)
        v = C_max

        # Saturation
        # Only compute saturation where C_max is not zero to avoid division by zero
        s[C_max != 0] = delta[C_max != 0] / C_max[C_max != 0]

        # Hue
        # Handle cases where delta is zero (grayscale) - h remains 0 as initialized
        # Case 1: C_max == R
        mask_r = (C_max == r) & (delta != 0)
        h[mask_r] = (60 * (((g[mask_r] - b[mask_r]) / delta[mask_r]) % 6))

        # Case 2: C_max == G
        mask_g = (C_max == g) & (delta != 0)
        h[mask_g] = (60 * (((b[mask_g] - r[mask_g]) / delta[mask_g]) + 2))

        # Case 3: C_max == B
        mask_b = (C_max == b) & (delta != 0)
        h[mask_b] = (60 * (((r[mask_b] - g[mask_b]) / delta[mask_b]) + 4))

        # Normalize hue to [0, 1] (0-360 degrees divided by 360)
        h = h / 360.0

        return torch.stack([h, s, v], dim=0)
    
    def _fast_wavelet_features(self, frame_np):
        """
        Optimized wavelet feature extraction on CPU.
        (pywt is CPU-bound, so this part remains on CPU).
        """
        # Convert to grayscale
        if len(frame_np.shape) == 3:
            gray_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2GRAY)
        else:
            gray_np = frame_np.copy()
        
        # Downsample for speed if detection_scale < 1.0
        if self.detection_scale < 1.0:
            new_size = (int(gray_np.shape[1] * self.detection_scale), 
                       int(gray_np.shape[0] * self.detection_scale))
            gray_small_np = cv2.resize(gray_np, new_size, interpolation=cv2.INTER_LINEAR)
        else:
            gray_small_np = gray_np
        
        try:
            # Apply wavelet transform with reduced levels for speed
            coeffs = pywt.wavedec2(gray_small_np, self.wavelet, level=self.level)
            
            # Initialize detail features with zeros, size of the small grayscale frame
            target_shape = gray_small_np.shape
            detail_features_np = np.zeros(target_shape, dtype=np.float32)
            
            # Extract and combine detail coefficients safely.
            # coeffs[0] is the approximation coefficient (cA_n).
            # coeffs[1:] are tuples (cH_i, cV_i, cD_i) for each level (from coarsest to finest).
            for i in range(1, len(coeffs)): 
                cH, cV, cD = coeffs[i] 
                
                # Absolute values of detail coefficients represent high-frequency energy
                h_detail = np.abs(cH)  
                v_detail = np.abs(cV)  
                
                # Resize each detail map to the target shape (downsampled gray frame size)
                # This ensures all detail maps are compatible for combination.
                if h_detail.shape != target_shape:
                    h_detail = cv2.resize(h_detail, (target_shape[1], target_shape[0]), 
                                        interpolation=cv2.INTER_LINEAR)
                if v_detail.shape != target_shape:
                    v_detail = cv2.resize(v_detail, (target_shape[1], target_shape[0]), 
                                        interpolation=cv2.INTER_LINEAR)
                
                # Combine by taking maximum across all levels to highlight strongest features
                detail_features_np = np.maximum(detail_features_np, h_detail)
                detail_features_np = np.maximum(detail_features_np, v_detail)
            
            # Resize the combined features back to the original full frame size if downsampled
            if self.detection_scale < 1.0:
                detail_features_np = cv2.resize(detail_features_np, 
                                           (gray_np.shape[1], gray_np.shape[0]), 
                                           interpolation=cv2.INTER_LINEAR)
                
        except Exception as e:
            print(f"Wavelet processing error: {e}")
            detail_features_np = np.zeros_like(gray_np, dtype=np.float32)
        
        return detail_features_np
    
    def _fast_specular_detection(self, frame_torch):
        """
        Optimized specular detection combining wavelets and HSV.
        Most operations are performed on GPU when `self.use_gpu` is True.
        """
        detection_start_time = time.time() 
        
        # Wavelet features (CPU-bound due to pywt)
        frame_np_cpu = frame_torch.cpu().numpy() # Move to CPU for pywt
        wavelet_features_np = self._fast_wavelet_features(frame_np_cpu)
        
        # Convert wavelet features to torch tensor and move to device (GPU)
        wavelet_features_torch = torch.from_numpy(wavelet_features_np).to(self.device)
        
        # HSV analysis on GPU
        frame_normalized_torch = frame_torch.float() / 255.0 
        
        # Permute from (H, W, C) to (C, H, W) for _rgb_to_hsv_torch
        # OpenCV frames are BGR, _rgb_to_hsv_torch expects RGB, so swap channels too.
        frame_chw_bgr = frame_normalized_torch.permute(2, 0, 1) 
        frame_chw_rgb = frame_chw_bgr[[2, 1, 0], :, :] # Swap B and R channels for RGB
        
        hsv_torch = self._rgb_to_hsv_torch(frame_chw_rgb) # Use the custom HSV conversion
        
        s_channel_torch = hsv_torch[1, :, :] * 255.0 # Scale S back to 0-255 for thresholding consistency
        v_channel_torch = hsv_torch[2, :, :] * 255.0 # Scale V back to 0-255
        
        # Vectorized operations for speed on GPU
        high_brightness_torch = v_channel_torch > 200
        low_saturation_torch = s_channel_torch < 40
        
        # Normalize wavelet features efficiently on GPU
        w_min, w_max = wavelet_features_torch.min(), wavelet_features_torch.max()
        if w_max > w_min:
            normalized_wavelets_torch = (wavelet_features_torch - w_min) / (w_max - w_min)
        else:
            normalized_wavelets_torch = torch.zeros_like(wavelet_features_torch)
        
        # Combine detection methods on GPU
        wavelet_mask_torch = normalized_wavelets_torch > self.threshold
        hsv_mask_torch = high_brightness_torch & low_saturation_torch
        
        combined_mask_torch = wavelet_mask_torch | hsv_mask_torch
        
        # Morphological operations (performed on CPU using OpenCV for optimized performance)
        combined_mask_np = combined_mask_torch.cpu().numpy().astype(np.uint8)
        
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        # Use MORPH_OPEN to remove small noise and break tiny connections
        refined_mask_np = cv2.morphologyEx(combined_mask_np, cv2.MORPH_OPEN, kernel) 
        # Then MORPH_CLOSE to fill small holes and connect nearby regions
        refined_mask_np = cv2.morphologyEx(refined_mask_np, cv2.MORPH_CLOSE, kernel) 
        
        # Convert refined mask back to torch tensor (boolean type) and move to device for inpainting
        refined_mask_torch = torch.from_numpy(refined_mask_np.astype(bool)).to(self.device)
        
        detection_time_taken = time.time() - detection_start_time
        return refined_mask_torch, detection_time_taken
    
    def _repository_updater(self):
        """
        Background thread for updating the repository.
        Processes frames/masks from the queue and updates the EMA repository on CPU.
        """
        while True:
            try:
                frame_data = self.repository_queue.get(timeout=1.0)
                if frame_data is None:  # Shutdown signal
                    break
                
                frame_np, mask_np = frame_data
                self._update_repository_background(frame_np, mask_np)
                self.repository_queue.task_done()
                
            except Exception: # Catch any exception (e.g., Queue empty timeout) and continue
                continue
    
    def _update_repository_background(self, frame_np, mask_np):
        """
        Update repository (EMA) and confidence map on CPU.
        This function runs in a separate thread.
        """
        mask_bool = mask_np.astype(bool)
        
        if self.current_repository is None:
            # Initialize repository with the current frame's non-reflective parts.
            self.current_repository = frame_np.astype(np.float32)
            # For reflective areas in the first frame, set to 0.0 or a neutral value.
            # These areas will only be filled once valid non-specular data is seen.
            self.current_repository[mask_bool] = 0.0 
            
            # Initialize confidence map: 1.0 for non-reflective areas, 0.0 for reflective areas.
            self.repository_confidence_map = (~mask_bool).astype(np.float32)
        else:
            frame_float = frame_np.astype(np.float32)
            non_reflective_mask = ~mask_bool
            
            # Perform Exponential Moving Average (EMA) only on non-reflective pixels.
            if np.any(non_reflective_mask):
                self.current_repository[non_reflective_mask] = (
                    self.repository_alpha * frame_float[non_reflective_mask] + 
                    (1 - self.repository_alpha) * self.current_repository[non_reflective_mask]
                )
                # Increase confidence for updated pixels, capping at 1.0.
                self.repository_confidence_map[non_reflective_mask] = np.minimum(
                    self.repository_confidence_map[non_reflective_mask] + self.repository_alpha, 1.0
                )
                
    def _fast_inpaint(self, frame_torch, mask_torch):
        """
        Fast inpainting using the background repository.
        Performed on GPU when `self.use_gpu` is True.
        """
        
        # If repository is not yet initialized, return the original frame.
        if self.current_repository is None or self.repository_confidence_map is None:
            return frame_torch 
        
        # Convert repository and confidence map to torch tensors and move to device (GPU).
        # These arrays are managed on CPU by the background thread, so they need to be moved to GPU here.
        repo_torch = torch.from_numpy(self.current_repository).to(self.device)
        confidence_torch = torch.from_numpy(self.repository_confidence_map).to(self.device)
        
        processed_frame_torch = frame_torch.clone()
        
        # Select reflective areas where the repository has accumulated sufficient confidence.
        # A confidence threshold (e.g., > 0.5) ensures we only use reliable background data.
        usable_areas_torch = mask_torch & (confidence_torch > 0.5) 
        
        if usable_areas_torch.any():
            # Expand the 2D boolean mask to 3 channels to apply to the 3-channel image.
            usable_areas_expanded = usable_areas_torch.unsqueeze(-1).expand_as(processed_frame_torch)
            
            # Apply repository data to the selected reflective areas on the GPU.
            processed_frame_torch[usable_areas_expanded] = repo_torch[usable_areas_expanded].to(processed_frame_torch.dtype)
        
        return processed_frame_torch
    
    def process_frame_realtime(self, frame_np):
        """
        Main real-time processing function for a single frame.
        Handles detection, inpainting, and performance tracking.
        """
        frame_start = time.time()
        
        # Convert numpy frame (H, W, C) to torch tensor and move to GPU.
        frame_torch = torch.from_numpy(frame_np).to(self.device)
        
        # Detect specular reflections. mask_torch will be on device (GPU).
        mask_torch, detection_time = self._fast_specular_detection(frame_torch)
        
        # Process the frame (inpainting) using the detected mask.
        processed_frame_torch = self._fast_inpaint(frame_torch, mask_torch)
        
        # Convert processed frame and mask back to numpy for OpenCV display/saving.
        processed_frame_np = processed_frame_torch.cpu().numpy()
        mask_np = mask_torch.cpu().numpy()
        
        # Queue repository update (non-blocking).
        # Send copies of frame and mask to the background thread to prevent data races.
        if self.frame_count % self.repository_update_rate == 0:
            try:
                self.repository_queue.put_nowait((frame_np.copy(), mask_np.copy())) 
            except Exception:
                # Silently skip if queue is full to maintain real-time performance.
                pass 
        
        # Performance tracking.
        total_time = time.time() - frame_start
        self.processing_times.append(total_time)
        current_fps = 1.0 / total_time if total_time > 0 else 0
        self.fps_history.append(current_fps)
        
        self.frame_count += 1
        
        return processed_frame_np, mask_np, {
            'fps': current_fps,
            'avg_fps': np.mean(self.fps_history) if self.fps_history else 0,
            'detection_time': detection_time, # Time specifically for detection stage
            'total_time': total_time,        # Total time for the entire frame processing
            'frame_count': self.frame_count
        }
    
    def get_performance_stats(self):
        """Get current performance statistics."""
        if not self.processing_times:
            return {}
        
        recent_times = list(self.processing_times)
        
        repository_size_mb = 0
        if self.current_repository is not None and self.repository_confidence_map is not None:
            # Calculate memory usage of repository data
            repository_size_mb = (self.current_repository.nbytes + self.repository_confidence_map.nbytes) / (1024 * 1024)
            
        return {
            'avg_processing_time': np.mean(recent_times),
            'max_processing_time': np.max(recent_times),
            'min_processing_time': np.min(recent_times),
            'avg_fps': np.mean(self.fps_history) if self.fps_history else 0,
            'memory_usage': psutil.Process().memory_info().rss / 1024 / 1024,  # Total process memory in MB
            'repository_size_mb': repository_size_mb # Memory for repository data itself
        }
    
    def cleanup(self):
        """Clean up resources by stopping the background repository thread."""
        try:
            self.repository_queue.put_nowait(None) # Signal shutdown to the thread
        except Exception:
            pass # Queue might be already closed or full
        
        if self.repository_thread.is_alive():
            self.repository_thread.join(timeout=2.0) # Give the thread some time to terminate
            if self.repository_thread.is_alive():
                print("Warning: Repository thread did not terminate gracefully.")
# --- End RealtimeSpecularProcessor class ---


# --- Modified Demo Functions for Jupyter Notebook using IPython.display ---

def realtime_camera_demo_jupyter(camera_id=0):
    """Demo function for real-time camera processing in Jupyter Notebook using IPython.display."""
    
    processor = RealtimeSpecularProcessor(
        threshold=0.06, detection_scale=0.7, repository_update_rate=3, use_gpu=True
    )
    
    cap = cv2.VideoCapture(camera_id)
    if not cap.isOpened():
        print(f"Error: Could not open camera {camera_id}")
        return
    
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
    cap.set(cv2.CAP_PROP_FPS, 30)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)

    # Matplotlib Setup (still create figure once)
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].set_title('Original')
    axes[1].set_title('Processed')
    axes[2].set_title('Reflection Mask')
    for ax in axes: ax.axis('off')
    
    print("Real-time processing started. Interrupt kernel (Ctrl+C or Stop button) to quit.")
    
    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                print("Error reading from camera. Exiting.")
                break
            
            processed_frame, mask, frame_stats = processor.process_frame_realtime(frame)
            
            # --- Update Matplotlib Plots using IPython.display ---
            axes[0].imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            axes[1].imshow(cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB))
            axes[2].imshow(mask, cmap='gray', vmin=0, vmax=1)
            
            axes[0].set_title(f'Original\nFPS: {frame_stats["fps"]:.1f} | Avg: {frame_stats["avg_fps"]:.1f}')
            axes[1].set_title(f'Processed\nDetection: {frame_stats["detection_time"]*1000:.1f}ms | Total: {frame_stats["total_time"]*1000:.1f}ms')
            
            display.clear_output(wait=True) # Clear previous output
            display.display(fig)            # Display updated figure
            # --- End Matplotlib Update ---
            
    except KeyboardInterrupt:
        print("\nProcessing interrupted by user.")
    
    finally:
        cap.release()
        processor.cleanup()
        plt.close(fig) # Close the matplotlib figure
        display.clear_output() # Clear the final display
        print("Resources cleaned up successfully.")

def realtime_video_demo_jupyter(video_path):
    """Demo function for real-time video file processing in Jupyter Notebook using IPython.display."""
    
    processor = RealtimeSpecularProcessor(
        threshold=0.06, detection_scale=0.8, repository_update_rate=2, use_gpu=True
    )
    
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return
    
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Matplotlib Setup
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].set_title('Original')
    axes[1].set_title('Processed')
    axes[2].set_title('Reflection Mask')
    for ax in axes: ax.axis('off')
        
    print(f"Processing {total_frames} frames from video at {fps} FPS.")
    print("Interrupt kernel (Ctrl+C or Stop button) to quit.")

    try:
        frame_count = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            processed_frame, mask, stats = processor.process_frame_realtime(frame)
            
            # --- Update Matplotlib Plots using IPython.display ---
            axes[0].imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            axes[1].imshow(cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB))
            axes[2].imshow(mask, cmap='gray', vmin=0, vmax=1)
            
            axes[0].set_title(f'Original\nFPS: {stats["fps"]:.1f} | Avg: {stats["avg_fps"]:.1f}')
            axes[1].set_title(f'Processed\nDetection: {stats["detection_time"]*1000:.1f}ms | Total: {stats["total_time"]*1000:.1f}ms')
            
            display.clear_output(wait=True) # Clear previous output
            display.display(fig)            # Display updated figure
            # --- End Matplotlib Update ---

            # Display progress in console (less frequent to avoid flooding)
            if frame_count % 30 == 0: 
                # Console print will be cleared by display.clear_output unless printed after display.display
                # For this reason, console prints for progress are less effective with this method
                pass 
            
            frame_count += 1
            
        final_stats = processor.get_performance_stats()
        print("\n=== Final Performance Stats ===") # This will appear after the loop finishes or is interrupted
        for k, v in final_stats.items():
            print(f"{k}: {v:.2f}")
    
    except KeyboardInterrupt:
        print("\nProcessing interrupted by user.")
    
    finally:
        cap.release()
        processor.cleanup()
        plt.close(fig)
        display.clear_output()
        print("Resources cleaned up successfully.")


# ============= CONFIGURATION (CHANGE THIS) =============
videoname = "video.mp4"  # Set this to your video filename

# ============= PROCESSING PARAMETERS =============
threshold = 0.06
max_frames = 1200
sample_rate = 10

if __name__ == "__main__":
    # This block is typically not run directly in a Jupyter cell in the same way.
    # You would usually call realtime_camera_demo_jupyter() or realtime_video_demo_jupyter(videoname)
    # directly in a Jupyter cell.
    
#     print("To run in Jupyter, call realtime_camera_demo_jupyter() or realtime_video_demo_jupyter('your_video.mp4') directly in a cell.")
    
#     Example for direct calling in a Jupyter cell:
    if os.path.exists(videoname):
        realtime_video_demo_jupyter(videoname)
    else:
        print(f"Video '{videoname}' not found. Trying camera demo.")
        realtime_camera_demo_jupyter()

Resources cleaned up successfully.
