In [1]:
!pip install opencv-python numpy matplotlib PyWavelets scikit-image tqdm torch

[0m

In [13]:
import cv2
import numpy as np
import os
import pywt  # This imports from the PyWavelets package
import matplotlib.pyplot as plt
import time
from tqdm.notebook import tqdm
import torch
import glob
from pathlib import Path
import random

class SpecularReflectionDatasetGenerator:
    def __init__(self, 
                 wavelet='db4',         # Wavelet type
                 threshold=0.15,        # Threshold for specular reflection detection (increased for better isolation)
                 level=3,               # Wavelet decomposition level
                 use_gpu=True,          # Whether to use GPU for processing
                 highlight_color=(0, 255, 0),  # Green color for highlighting
                 debug=True):           # Whether to save debug images
        
        self.wavelet = wavelet
        self.threshold = threshold
        self.level = level
        self.use_gpu = use_gpu and torch.cuda.is_available()
        self.highlight_color = highlight_color
        self.debug = debug
        
        # Create directories for dataset
        self.output_dir = "reflection_dataset"
        self.original_dir = os.path.join(self.output_dir, "original")
        self.highlighted_dir = os.path.join(self.output_dir, "highlighted")
        self.mask_dir = os.path.join(self.output_dir, "masks")
        self.debug_dir = os.path.join(self.output_dir, "debug")
        
        os.makedirs(self.original_dir, exist_ok=True)
        os.makedirs(self.highlighted_dir, exist_ok=True)
        os.makedirs(self.mask_dir, exist_ok=True)
        os.makedirs(self.debug_dir, exist_ok=True)
        
        print(f"GPU Available: {torch.cuda.is_available()}")
        print(f"Using GPU: {self.use_gpu}")
        
        if self.use_gpu:
            # Set up GPU device
            self.device = torch.device("cuda:0")
        else:
            self.device = torch.device("cpu")
    
    def _calculate_wavelet_features(self, frame):
        """Calculate wavelet decomposition and extract features from the frame."""
        # Convert to grayscale for wavelet analysis
        if len(frame.shape) == 3:
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        else:
            gray = frame.copy()
            
        # Apply wavelet transform
        coeffs = pywt.wavedec2(gray, self.wavelet, level=self.level)
        
        # Extract high-frequency components (details)
        details = []
        for detail_coeffs in coeffs[1:]:
            details.extend([np.abs(detail_coeffs[i]) for i in range(3)])
        
        # Normalize and stack details
        detail_features = np.stack([cv2.resize(d, (gray.shape[1], gray.shape[0])) for d in details])
        detail_features = np.max(detail_features, axis=0)
        
        return detail_features
    
    def _detect_specular_regions(self, frame):
        """Detect specular reflection regions in the frame using wavelet features and HSV color space."""
        # Calculate wavelet features
        features = self._calculate_wavelet_features(frame)
        
        # Apply additional preprocessing for specular reflection detection
        # Convert to HSV for better highlight detection
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(hsv)
        
        # High Value and low Saturation indicates specular reflection
        # Adjusted thresholds based on the images provided
        high_v = v > 200
        low_s = s < 40
        
        # Combine with wavelet features
        normalized_features = (features - np.min(features)) / (np.max(features) - np.min(features) + 1e-6)
        wavelet_mask = normalized_features > self.threshold
        
        # Combine the masks with more weight on color-based detection for these types of reflections
        combined_mask = np.logical_or(np.logical_and(high_v, low_s), wavelet_mask)
        
        # Clean up the mask
        kernel = np.ones((3, 3), np.uint8)
        mask = cv2.morphologyEx(combined_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        
        # Further refine the mask
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        refined_mask = np.zeros_like(mask)
        
        # Only keep contours of meaningful size
        for contour in contours:
            area = cv2.contourArea(contour)
            if area > 10:  # Minimum area threshold
                cv2.drawContours(refined_mask, [contour], -1, 1, -1)
        
        return refined_mask
    
    def _create_highlighted_image(self, frame, mask):
        """Create a version of the frame with specular reflections highlighted in green."""
        highlighted = frame.copy()
        
        # Convert mask to boolean for indexing
        mask_bool = mask.astype(bool)
        
        # Apply highlight color to the detected regions
        highlighted[mask_bool] = self.highlight_color
        
        return highlighted
    
    def process_video(self, input_path, output_prefix, sample_rate=1, max_frames=None):
        """
        Process the video to generate dataset of original and highlighted frames.
        
        Args:
            input_path: Path to the input video file
            output_prefix: Prefix for output filenames
            sample_rate: Save a frame every N frames
            max_frames: Maximum number of frames to process (None for all)
        
        Returns:
            Number of frames processed
        """
        # Open the video file
        video = cv2.VideoCapture(input_path)
        if not video.isOpened():
            raise ValueError(f"Could not open video file: {input_path}")
        
        # Get video properties
        frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Limit frames if specified
        if max_frames is not None:
            frame_count = min(frame_count, max_frames)
        
        # Process each frame
        processed_frames = 0
        saved_frames = 0
        
        start_time = time.time()
        
        for i in tqdm(range(frame_count)):
            ret, frame = video.read()
            if not ret:
                break
                
            # Only process every Nth frame
            if i % sample_rate != 0:
                continue
                
            # Detect specular reflections
            specular_mask = self._detect_specular_regions(frame)
            
            # Create highlighted version
            highlighted_frame = self._create_highlighted_image(frame, specular_mask)
            
            # Save the pair of images
            frame_filename = f"{output_prefix}_{saved_frames:04d}"
            
            cv2.imwrite(os.path.join(self.original_dir, f"{frame_filename}.jpg"), frame)
            cv2.imwrite(os.path.join(self.highlighted_dir, f"{frame_filename}.jpg"), highlighted_frame)
            cv2.imwrite(os.path.join(self.mask_dir, f"{frame_filename}.png"), specular_mask * 255)
            
            if self.debug and saved_frames % 10 == 0:
                # Create visualization of the detection
                wavelet_features = self._calculate_wavelet_features(frame)
                normalized_features = (wavelet_features - np.min(wavelet_features)) / (np.max(wavelet_features) - np.min(wavelet_features) + 1e-6)
                wavelet_vis = cv2.applyColorMap((normalized_features * 255).astype(np.uint8), cv2.COLORMAP_JET)
                
                # Create side-by-side comparison for debugging
                comparison = np.hstack((frame, highlighted_frame))
                cv2.imwrite(os.path.join(self.debug_dir, f"{frame_filename}_comparison.jpg"), comparison)
                cv2.imwrite(os.path.join(self.debug_dir, f"{frame_filename}_wavelet.jpg"), wavelet_vis)
            
            processed_frames += 1
            saved_frames += 1
        
        # Calculate processing stats
        end_time = time.time()
        processing_time = end_time - start_time
        
        # Clean up
        video.release()
        
        print(f"Processed {processed_frames} frames and saved {saved_frames} frame pairs in {processing_time:.2f} seconds")
        
        return saved_frames
    
    def process_image_directory(self, input_dir, output_prefix, file_pattern='*.jpg'):
        """
        Process all images in a directory to generate dataset of original and highlighted images.
        
        Args:
            input_dir: Directory containing input images
            output_prefix: Prefix for output filenames
            file_pattern: Pattern to match image files
        
        Returns:
            Number of images processed
        """
        # Get list of image files
        image_files = glob.glob(os.path.join(input_dir, file_pattern))
        
        # Process each image
        processed_images = 0
        
        start_time = time.time()
        
        for i, image_path in enumerate(tqdm(image_files)):
            # Read the image
            frame = cv2.imread(image_path)
            if frame is None:
                print(f"Could not read image: {image_path}")
                continue
                
            # Detect specular reflections
            specular_mask = self._detect_specular_regions(frame)
            
            # Create highlighted version
            highlighted_frame = self._create_highlighted_image(frame, specular_mask)
            
            # Save the pair of images
            frame_filename = f"{output_prefix}_{i:04d}"
            
            cv2.imwrite(os.path.join(self.original_dir, f"{frame_filename}.jpg"), frame)
            cv2.imwrite(os.path.join(self.highlighted_dir, f"{frame_filename}.jpg"), highlighted_frame)
            cv2.imwrite(os.path.join(self.mask_dir, f"{frame_filename}.png"), specular_mask * 255)
            
            if self.debug and i % 10 == 0:
                # Create visualization of the detection
                wavelet_features = self._calculate_wavelet_features(frame)
                normalized_features = (wavelet_features - np.min(wavelet_features)) / (np.max(wavelet_features) - np.min(wavelet_features) + 1e-6)
                wavelet_vis = cv2.applyColorMap((normalized_features * 255).astype(np.uint8), cv2.COLORMAP_JET)
                
                # Create side-by-side comparison for debugging
                comparison = np.hstack((frame, highlighted_frame))
                cv2.imwrite(os.path.join(self.debug_dir, f"{frame_filename}_comparison.jpg"), comparison)
                cv2.imwrite(os.path.join(self.debug_dir, f"{frame_filename}_wavelet.jpg"), wavelet_vis)
            
            processed_images += 1
        
        # Calculate processing stats
        end_time = time.time()
        processing_time = end_time - start_time
        
        print(f"Processed {processed_images} images in {processing_time:.2f} seconds")
        
        return processed_images
    
    def process_single_image(self, image_path, output_prefix):
        """
        Process a single image and generate dataset entries.
        
        Args:
            image_path: Path to the input image
            output_prefix: Prefix for output filenames
        
        Returns:
            Tuple of (original, highlighted, mask) images
        """
        # Read the image
        frame = cv2.imread(image_path)
        if frame is None:
            raise ValueError(f"Could not read image: {image_path}")
                
        # Detect specular reflections
        specular_mask = self._detect_specular_regions(frame)
        
        # Create highlighted version
        highlighted_frame = self._create_highlighted_image(frame, specular_mask)
        
        # Save the pair of images
        cv2.imwrite(os.path.join(self.original_dir, f"{output_prefix}.jpg"), frame)
        cv2.imwrite(os.path.join(self.highlighted_dir, f"{output_prefix}.jpg"), highlighted_frame)
        cv2.imwrite(os.path.join(self.mask_dir, f"{output_prefix}.png"), specular_mask * 255)
        
        if self.debug:
            # Create visualization of the detection
            wavelet_features = self._calculate_wavelet_features(frame)
            normalized_features = (wavelet_features - np.min(wavelet_features)) / (np.max(wavelet_features) - np.min(wavelet_features) + 1e-6)
            wavelet_vis = cv2.applyColorMap((normalized_features * 255).astype(np.uint8), cv2.COLORMAP_JET)
            
            # Create side-by-side comparison for debugging
            comparison = np.hstack((frame, highlighted_frame))
            cv2.imwrite(os.path.join(self.debug_dir, f"{output_prefix}_comparison.jpg"), comparison)
            cv2.imwrite(os.path.join(self.debug_dir, f"{output_prefix}_wavelet.jpg"), wavelet_vis)
        
        return frame, highlighted_frame, specular_mask
    
    def create_train_val_split(self, val_ratio=0.2):
        """
        Create train/validation split from generated dataset
        
        Args:
            val_ratio: Portion of data to use for validation
        
        Returns:
            Dictionary with train/val split information
        """
        # Get all image files
        all_files = [os.path.basename(f) for f in glob.glob(os.path.join(self.original_dir, "*.jpg"))]
        
        # Shuffle the files
        random.shuffle(all_files)
        
        # Split into train and validation
        split_idx = int(len(all_files) * (1 - val_ratio))
        train_files = all_files[:split_idx]
        val_files = all_files[split_idx:]
        
        # Create train/val directories
        train_dir = os.path.join(self.output_dir, "train")
        val_dir = os.path.join(self.output_dir, "val")
        
        for directory in [train_dir, val_dir]:
            os.makedirs(os.path.join(directory, "original"), exist_ok=True)
            os.makedirs(os.path.join(directory, "highlighted"), exist_ok=True)
            os.makedirs(os.path.join(directory, "masks"), exist_ok=True)
        
        # Copy files to train/val directories
        for file in train_files:
            base_name = os.path.splitext(file)[0]
            os.system(f"cp {os.path.join(self.original_dir, file)} {os.path.join(train_dir, 'original', file)}")
            os.system(f"cp {os.path.join(self.highlighted_dir, file)} {os.path.join(train_dir, 'highlighted', file)}")
            os.system(f"cp {os.path.join(self.mask_dir, base_name + '.png')} {os.path.join(train_dir, 'masks', base_name + '.png')}")
        
        for file in val_files:
            base_name = os.path.splitext(file)[0]
            os.system(f"cp {os.path.join(self.original_dir, file)} {os.path.join(val_dir, 'original', file)}")
            os.system(f"cp {os.path.join(self.highlighted_dir, file)} {os.path.join(val_dir, 'highlighted', file)}")
            os.system(f"cp {os.path.join(self.mask_dir, base_name + '.png')} {os.path.join(val_dir, 'masks', base_name + '.png')}")
        
        split_info = {
            "total_files": len(all_files),
            "train_files": len(train_files),
            "val_files": len(val_files),
            "train_ratio": 1 - val_ratio,
            "val_ratio": val_ratio
        }
        
        print(f"Created dataset split: {split_info['train_files']} training samples, {split_info['val_files']} validation samples")
        
        return split_info
    
    def show_sample_detections(self, num_samples=5):
        """
        Display sample detections from the generated dataset
        
        Args:
            num_samples: Number of samples to display
        """
        # Get random samples
        all_files = glob.glob(os.path.join(self.original_dir, "*.jpg"))
        if len(all_files) == 0:
            print("No samples available. Generate dataset first.")
            return
        
        samples = random.sample(all_files, min(num_samples, len(all_files)))
        
        # Display samples
        fig, axes = plt.subplots(len(samples), 3, figsize=(15, 5 * len(samples)))
        if len(samples) == 1:
            axes = [axes]
        
        for i, sample_path in enumerate(samples):
            base_name = os.path.basename(sample_path)
            base_name_no_ext = os.path.splitext(base_name)[0]
            
            # Load images
            original = cv2.imread(sample_path)
            highlighted = cv2.imread(os.path.join(self.highlighted_dir, base_name))
            mask = cv2.imread(os.path.join(self.mask_dir, base_name_no_ext + ".png"), cv2.IMREAD_GRAYSCALE)
            
            # Convert to RGB for display
            original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
            highlighted = cv2.cvtColor(highlighted, cv2.COLOR_BGR2RGB)
            
            # Display
            axes[i][0].imshow(original)
            axes[i][0].set_title('Original')
            axes[i][0].axis('off')
            
            axes[i][1].imshow(highlighted)
            axes[i][1].set_title('Highlighted')
            axes[i][1].axis('off')
            
            axes[i][2].imshow(mask, cmap='gray')
            axes[i][2].set_title('Mask')
            axes[i][2].axis('off')
        
        plt.tight_layout()
        plt.show()

# Function to generate dataset from video for self-supervised learning
def generate_reflection_dataset(input_video_path,
                               threshold=0.15,
                               wavelet='db4',
                               level=3,
                               highlight_color=(0, 255, 0),
                               sample_rate=5,
                               max_frames=None,
                               use_gpu=True,
                               debug=True,
                               create_split=True):
    """
    Generate a dataset of original and highlighted images for self-supervised learning.
    
    Parameters:
    -----------
    input_video_path : str
        Path to the input video file
    threshold : float
        Threshold for detecting specular reflections (default: 0.15)
    wavelet : str
        Wavelet type to use for the transform (default: 'db4')
    level : int
        Level of wavelet decomposition (default: 3)
    highlight_color : tuple
        RGB color for highlighting reflections (default: green (0, 255, 0))
    sample_rate : int
        Process every Nth frame (default: 5)
    max_frames : int or None
        Maximum number of frames to process (default: None, process all)
    use_gpu : bool
        Whether to use GPU acceleration if available (default: True)
    debug : bool
        Whether to save debug visualizations (default: True)
    create_split : bool
        Whether to create train/validation split (default: True)
    
    Returns:
    --------
    generator : SpecularReflectionDatasetGenerator
        The dataset generator object for further use
    """
    # Create the generator
    generator = SpecularReflectionDatasetGenerator(
        wavelet=wavelet,
        threshold=threshold,
        level=level,
        use_gpu=use_gpu,
        highlight_color=highlight_color,
        debug=debug
    )
    
    # Process the video
    video_name = os.path.splitext(os.path.basename(input_video_path))[0]
    num_frames = generator.process_video(
        input_video_path, 
        video_name, 
        sample_rate=sample_rate,
        max_frames=max_frames
    )
    
    print(f"Generated {num_frames} samples from video")
    
    # Create train/val split if requested
    if create_split and num_frames > 0:
        split_info = generator.create_train_val_split()
        print(f"Train/val split: {split_info}")
    
    # Show samples
    generator.show_sample_detections(num_samples=3)
    
    return generator

# Process an image where specular reflections are already identified
def process_sample_images(image1_path, image2_path, threshold=0.15):
    """
    Process sample images to see how the reflection detection works.
    
    Parameters:
    -----------
    image1_path : str
        Path to the first sample image
    image2_path : str
        Path to the second sample image
    threshold : float
        Threshold for detecting specular reflections (default: 0.15)
    
    Returns:
    --------
    Tuple of processed images
    """
    # Create a generator with custom threshold
    generator = SpecularReflectionDatasetGenerator(threshold=threshold)
    
    # Process each image
    img1_orig, img1_highlighted, img1_mask = generator.process_single_image(image1_path, "sample1")
    img2_orig, img2_highlighted, img2_mask = generator.process_single_image(image2_path, "sample2")
    
    # Display the results
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Convert to RGB for display
    img1_orig_rgb = cv2.cvtColor(img1_orig, cv2.COLOR_BGR2RGB)
    img1_highlighted_rgb = cv2.cvtColor(img1_highlighted, cv2.COLOR_BGR2RGB)
    img2_orig_rgb = cv2.cvtColor(img2_orig, cv2.COLOR_BGR2RGB)
    img2_highlighted_rgb = cv2.cvtColor(img2_highlighted, cv2.COLOR_BGR2RGB)
    
    # First image
    axes[0, 0].imshow(img1_orig_rgb)
    axes[0, 0].set_title('Original 1')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(img1_highlighted_rgb)
    axes[0, 1].set_title('Highlighted 1')
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(img1_mask, cmap='gray')
    axes[0, 2].set_title('Mask 1')
    axes[0, 2].axis('off')
    
    # Second image
    axes[1, 0].imshow(img2_orig_rgb)
    axes[1, 0].set_title('Original 2')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(img2_highlighted_rgb)
    axes[1, 1].set_title('Highlighted 2')
    axes[1, 1].axis('off')
    
    axes[1, 2].imshow(img2_mask, cmap='gray')
    axes[1, 2].set_title('Mask 2')
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return (img1_orig, img1_highlighted, img1_mask), (img2_orig, img2_highlighted, img2_mask)

# Example usage:
# 1. Process sample images
# process_sample_images('image1.jpg', 'image2.jpg', threshold=0.15)

# Generate a complete dataset from video
generator = generate_reflection_dataset(
    'video.mp4',
    threshold=0.08,               # Adjust sensitivity
    highlight_color=(0, 255, 0),  # Green highlighting
    sample_rate=5,                # Process every 5th frame
    max_frames=5000,              # Limit number of frames (optional)
    debug=True                    # Save debug visualizations
)

GPU Available: True
Using GPU: True


  0%|          | 0/1548 [00:00<?, ?it/s]

KeyboardInterrupt: 