In [7]:
!pip install pywavelets opencv-python matplotlib tqdm ipywidgets

[0m

In [13]:
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import random
import os
import time
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration class
class Config:
    # Video parameters
    VIDEO_WIDTH = 512
    VIDEO_HEIGHT = 512
    CHANNELS = 3  # RGB

    # Training parameters
    BATCH_SIZE = 4
    LEARNING_RATE = 2e-4
    NUM_EPOCHS = 30

    # Device configuration
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Paths
    OUTPUT_DIR = "output"
    MODEL_DIR = "models"

    # Visualization parameters
    VISUALIZE_EVERY = 20  # Save sample images every N batches

    def __init__(self):
        os.makedirs(self.OUTPUT_DIR, exist_ok=True)
        os.makedirs(self.MODEL_DIR, exist_ok=True)

# Enhanced frame extraction with original resolution and color
def extract_frames(video_path, output_dir="frames"):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    print(f"Total frames: {frame_count}, FPS: {fps}, Resolution: {width}x{height}")
    
    frames = []
    for i in range(frame_count):
        ret, frame = cap.read()
        if not ret:
            break
        frame_path = f"{output_dir}/frame_{i:04d}.png"
        cv2.imwrite(frame_path, frame)
        frames.append((frame_path, (height, width)))
    cap.release()
    return frames, fps

# Specialized dataset for reflection removal
class ReflectionDataset(Dataset):
    def __init__(self, frame_data, config, validation=False):
        self.frame_data = frame_data
        self.config = config
        self.validation = validation
        
        # Use a subset for validation
        if validation:
            self.frame_data = random.sample(frame_data, min(50, len(frame_data)))
    
    def __len__(self):
        return len(self.frame_data)
    
    def detect_reflections(self, frame):
        """Advanced reflection detection specialized for endoscopy"""
        # Convert to multiple color spaces
        hsv = cv2.cvtColor(frame, cv2.COLOR_RGB2HSV)
        gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        
        # Extract channels
        s_channel = hsv[:,:,1]
        v_channel = hsv[:,:,2]
        
        # Calculate adaptive thresholds
        v_thresh = np.percentile(v_channel, 97)  # Very bright
        s_thresh = np.percentile(s_channel, 15)  # Low saturation
        
        # Main reflection mask - high brightness and low saturation
        reflection_mask = ((v_channel > v_thresh) & (s_channel < s_thresh)).astype(np.uint8)
        
        # Use local contrast to refine (reflections have high local contrast)
        kernel_size = 9
        local_mean = cv2.GaussianBlur(gray, (kernel_size, kernel_size), 0)
        local_contrast = cv2.absdiff(gray, local_mean)
        contrast_thresh = np.percentile(local_contrast, 90)
        contrast_mask = (local_contrast > contrast_thresh).astype(np.uint8)
        
        # Combine masks
        combined_mask = cv2.bitwise_or(reflection_mask, contrast_mask)
        
        # Clean up with morphological operations
        kernel = np.ones((5, 5), np.uint8)
        cleaned_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel)
        cleaned_mask = cv2.morphologyEx(cleaned_mask, cv2.MORPH_OPEN, kernel)
        
        # Connected component filtering to remove small noise regions
        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(cleaned_mask, 4)
        min_size = 20
        for i in range(1, num_labels):
            if stats[i, cv2.CC_STAT_AREA] < min_size:
                cleaned_mask[labels == i] = 0
                
        # Dilate slightly to ensure we cover the full reflection
        final_mask = cv2.dilate(cleaned_mask, kernel, iterations=1)
        
        return final_mask
    
    def __getitem__(self, idx):
        frame_path, _ = self.frame_data[idx]
        frame = cv2.imread(frame_path)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Resize for consistent processing
        frame = cv2.resize(frame, (self.config.VIDEO_WIDTH, self.config.VIDEO_HEIGHT))
        
        # Detect reflections
        mask = self.detect_reflections(frame)
        mask = mask.astype(np.float32) / 255.0
        
        # Convert to tensors
        frame_tensor = torch.from_numpy(frame.transpose(2, 0, 1)).float() / 255.0
        mask_tensor = torch.from_numpy(mask).unsqueeze(0)  # Add channel dimension
        
        return frame_tensor, mask_tensor

# Define U-Net model
class UnetDoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2),
            UnetDoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.down(x)

class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = UnetDoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # Adjust dimensions if needed
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class ReflectionUNet(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        
        # Encoder
        self.inc = UnetDoubleConv(in_channels, 64)
        self.down1 = UnetDown(64, 128)
        self.down2 = UnetDown(128, 256)
        self.down3 = UnetDown(256, 512)
        self.down4 = UnetDown(512, 1024)
        
        # Decoder
        self.up1 = UnetUp(1024, 512)
        self.up2 = UnetUp(512, 256)
        self.up3 = UnetUp(256, 128)
        self.up4 = UnetUp(128, 64)
        
        # Final output
        self.outc = nn.Conv2d(64, in_channels, kernel_size=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        # Encoder path
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder path
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        # Final output
        logits = self.outc(x)
        return self.tanh(logits)

# Loss function for reflection removal
class ReflectionRemovalLoss(nn.Module):
    def __init__(self):
        super(ReflectionRemovalLoss, self).__init__()
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
    
    def forward(self, outputs, targets, masks):
        # Reconstruction loss for non-reflection areas (should be identical to input)
        non_reflection_area = 1 - masks
        reconstruction_loss = self.l1_loss(outputs * non_reflection_area, targets * non_reflection_area) * 10.0
        
        # Coherence loss for reflection areas
        reflection_area = masks
        coherence_loss = self.l1_loss(outputs * reflection_area, targets * reflection_area) * 5.0
        
        # Edge preservation loss
        # Create edge maps using Sobel filters
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).to(outputs.device)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).to(outputs.device)
        
        sobel_x = sobel_x.reshape(1, 1, 3, 3).repeat(3, 1, 1, 1)
        sobel_y = sobel_y.reshape(1, 1, 3, 3).repeat(3, 1, 1, 1)
        
        # Calculate gradients for outputs and targets
        edge_x_output = F.conv2d(outputs, sobel_x, padding=1, groups=3)
        edge_y_output = F.conv2d(outputs, sobel_y, padding=1, groups=3)
        
        edge_x_target = F.conv2d(targets, sobel_x, padding=1, groups=3)
        edge_y_target = F.conv2d(targets, sobel_y, padding=1, groups=3)
        
        # Calculate edge consistency loss
        edge_loss = (self.l1_loss(edge_x_output, edge_x_target) + 
                    self.l1_loss(edge_y_output, edge_y_target)) * 2.0
        
        # Calculate total loss
        total_loss = reconstruction_loss + coherence_loss + edge_loss
        
        return total_loss, (reconstruction_loss, coherence_loss, edge_loss)

# Function to train the model
def train_model(model, frame_data, config):
    """Train the model for reflection removal"""
    print(f"Training model on {len(frame_data)} frames...")
    
    # Create datasets for training and validation
    train_frames = random.sample(frame_data, min(500, len(frame_data)))
    train_dataset = ReflectionDataset(train_frames, config)
    val_dataset = ReflectionDataset(frame_data, config, validation=True)
    
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=1)
    
    # Loss and optimizer
    criterion = ReflectionRemovalLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    
    # Create output directories
    samples_dir = os.path.join(config.OUTPUT_DIR, 'samples')
    os.makedirs(samples_dir, exist_ok=True)
    
    # Training loop
    best_val_loss = float('inf')
    start_time = time.time()
    
    for epoch in range(config.NUM_EPOCHS):
        # Training phase
        model.train()
        train_loss = 0.0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS} [Train]")
        for batch_idx, (frames, masks) in enumerate(progress_bar):
            frames = frames.to(config.DEVICE)
            masks = masks.to(config.DEVICE)
            
            # Forward pass
            outputs = model(frames)
            
            # Calculate loss
            loss, (recon_loss, coherence_loss, edge_loss) = criterion(outputs, frames, masks)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'recon': f"{recon_loss.item():.4f}",
                'edge': f"{edge_loss.item():.4f}"
            })
            
            # Save sample images periodically
            if batch_idx % config.VISUALIZE_EVERY == 0:
                save_sample_images(frames, outputs, masks, epoch, batch_idx, samples_dir)
        
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS} [Val]")
            for batch_idx, (frames, masks) in enumerate(progress_bar):
                frames = frames.to(config.DEVICE)
                masks = masks.to(config.DEVICE)
                
                # Forward pass
                outputs = model(frames)
                
                # Calculate loss
                loss, _ = criterion(outputs, frames, masks)
                val_loss += loss.item()
                
                # Update progress bar
                progress_bar.set_postfix({'val_loss': f"{loss.item():.4f}"})
        
        avg_val_loss = val_loss / len(val_loader)
        
        # Update learning rate
        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Print epoch stats
        elapsed = time.time() - start_time
        print(f"Epoch {epoch+1}/{config.NUM_EPOCHS} - "
              f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
              f"Time: {elapsed:.1f}s, LR: {current_lr:.6f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss
            }, os.path.join(config.MODEL_DIR, 'best_model.pth'))
            print(f"New best model saved with val loss: {best_val_loss:.4f}")
        
        # Save latest model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_val_loss
        }, os.path.join(config.MODEL_DIR, 'latest_model.pth'))
    
    print("Training completed!")
    return model

# Function to save sample images
def save_sample_images(inputs, outputs, masks, epoch, batch_idx, output_dir):
    """Save sample images showing original, processed, and mask"""
    # Select the first image from the batch
    input_img = inputs[0].detach().cpu().numpy().transpose(1, 2, 0)
    output_img = outputs[0].detach().cpu().numpy().transpose(1, 2, 0)
    mask_img = masks[0][0].detach().cpu().numpy()
    
    # Clip to valid range [0, 1]
    input_img = np.clip(input_img, 0, 1)
    output_img = np.clip(output_img, 0, 1)
    
    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Plot original image
    axes[0].imshow(input_img)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    # Plot processed image
    axes[1].imshow(output_img)
    axes[1].set_title('Reflection Removed')
    axes[1].axis('off')
    
    # Plot mask
    axes[2].imshow(mask_img, cmap='gray')
    axes[2].set_title('Reflection Mask')
    axes[2].axis('off')
    
    # Save figure
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'epoch_{epoch+1}_batch_{batch_idx}.png'))
    plt.close(fig)

# Function to process a video with a trained model
def process_video(model, frame_data, output_path, config, fps=30):
    """Process a video with a trained model"""
    output_dir = os.path.join(config.OUTPUT_DIR, 'processed_frames')
    os.makedirs(output_dir, exist_ok=True)
    
    # Create dataset for full video
    full_dataset = ReflectionDataset(frame_data, config)
    
    model.eval()
    with torch.no_grad():
        for i, (frame_path, original_size) in enumerate(tqdm(frame_data, desc="Processing video frames")):
            # Load and process frame
            orig_frame = cv2.imread(frame_path)
            orig_height, orig_width = orig_frame.shape[:2]
            
            # Get processed frame from dataset
            frame_tensor, mask_tensor = full_dataset[i]
            
            # Skip processing if no reflections detected
            if torch.sum(mask_tensor) == 0:
                print(f"No reflections in frame {i}, skipping")
                cv2.imwrite(f"{output_dir}/frame_{i:04d}.png", orig_frame)
                continue
            
            # Process the frame
            frame_tensor = frame_tensor.unsqueeze(0).to(config.DEVICE)
            output = model(frame_tensor)
            
            # Apply mask to blend original and generated content
            # Create a smoothed mask for better blending
            mask_smooth = F.avg_pool2d(F.avg_pool2d(mask_tensor, 3, stride=1, padding=1), 3, stride=1, padding=1)
            mask_smooth = mask_smooth.unsqueeze(0).to(config.DEVICE)
            
            # Blend
            blended = frame_tensor * (1 - mask_smooth) + output * mask_smooth
            
            # Convert output tensor to image
            output_img = blended[0].cpu().numpy().transpose(1, 2, 0) * 255
            output_img = np.clip(output_img, 0, 255).astype(np.uint8)
            output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR)
            
            # Resize back to original dimensions
            output_img = cv2.resize(output_img, (orig_width, orig_height))
            
            # Save processed frame
            cv2.imwrite(f"{output_dir}/frame_{i:04d}.png", output_img)
    
    # Combine frames back into video
    os.system(f"ffmpeg -framerate {fps} -i {output_dir}/frame_%04d.png -c:v libx264 -pix_fmt yuv420p -crf 18 {output_path}")
    print(f"Processed video saved as {output_path}")

# Main execution function
def main(video_path, output_path, mode='train'):
    """Main execution function"""
    # Initialize configuration
    config = Config()
    
    # Extract frames from video
    frame_data, video_fps = extract_frames(video_path)
    
    # Initialize model
    model = ReflectionUNet(in_channels=config.CHANNELS).to(config.DEVICE)
    
    if mode == 'train':
        # Train model
        model = train_model(model, frame_data, config)
        
        # Process video with trained model
        process_video(model, frame_data, output_path, config, fps=video_fps)
    
    elif mode == 'inference':
        # Load pre-trained model if available
        model_path = os.path.join(config.MODEL_DIR, 'best_model.pth')
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path, map_location=config.DEVICE)
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Loaded pre-trained model from {model_path}")
        else:
            print("No pre-trained model found, using untrained model")
        
        # Process video with model
        process_video(model, frame_data, output_path, config, fps=video_fps)
    
    else:
        raise ValueError(f"Unknown mode: {mode}. Use 'train' or 'inference'")

if __name__ == "__main__":
    video_path = "video.mp4"
    output_path = "output/processed_video.mp4"
    main(video_path, output_path, mode='train')

Using device: cuda
Total frames: 1548, FPS: 30, Resolution: 1280x720
Training model on 1548 frames...


Epoch 1/30 [Train]:   0%|          | 0/125 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacty of 15.72 GiB of which 181.06 MiB is free. Process 2807253 has 8.38 GiB memory in use. Process 2805731 has 7.16 GiB memory in use. Of the allocated memory 6.54 GiB is allocated by PyTorch, and 419.38 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF