In [3]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import cv2
import time

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Modified Dataset class that focuses on teaching tissue emulation
class TissueEmulationDataset(Dataset):
    def __init__(self, originals_dir, masks_dir, img_size=(512, 512), 
                 darkness_threshold=0.15, context_radius=3):
        self.originals_dir = originals_dir
        self.masks_dir = masks_dir
        self.img_size = img_size
        self.darkness_threshold = darkness_threshold  # Threshold to identify aperture
        self.context_radius = context_radius  # Radius for sampling surrounding tissue
        
        # Get all files from originals directory
        self.image_filenames = [f for f in os.listdir(originals_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        # Load original image
        img_path = os.path.join(self.originals_dir, self.image_filenames[idx])
        original_img = Image.open(img_path).convert('RGB')
        
        # Find corresponding mask
        mask_filename = self.image_filenames[idx].replace('.png', '_mask.png')
        mask_path = os.path.join(self.masks_dir, mask_filename)
        
        # If mask doesn't exist with _mask suffix, try without it
        if not os.path.exists(mask_path):
            mask_path = os.path.join(self.masks_dir, self.image_filenames[idx])
        
        # Load and convert mask to binary
        if os.path.exists(mask_path):
            mask_img = Image.open(mask_path).convert('L')
        else:
            # Create an empty mask if the mask file doesn't exist
            print(f"Warning: Mask not found for {self.image_filenames[idx]}")
            mask_img = Image.new('L', original_img.size, 0)
        
        # Resize to target dimensions
        original_img = original_img.resize(self.img_size, Image.BILINEAR)
        mask_img = mask_img.resize(self.img_size, Image.NEAREST)
        
        # Convert to numpy arrays
        original_array = np.array(original_img) / 255.0
        mask_array = np.array(mask_img) / 255.0
        
        # Convert mask to binary (threshold)
        mask_array = (mask_array > 0.5).astype(np.float32)
        
        # Create aperture mask (identify dark regions to ignore)
        # Average RGB channels to get brightness
        brightness = np.mean(original_array, axis=2)
        aperture_mask = (brightness < self.darkness_threshold).astype(np.float32)
        
        # Create valid tissue mask (neither reflection nor aperture)
        valid_tissue_mask = (1.0 - mask_array) * (1.0 - aperture_mask)
        
        # For each reflection pixel, find the nearest valid tissue pixels
        # This is a simplified version - in practice you might want a more efficient algorithm
        h, w = mask_array.shape
        context_map = np.zeros_like(original_array)
        
        # For pixels in the reflection (mask_array == 1)
        mask_indices = np.where(mask_array == 1)
        for i, j in zip(mask_indices[0], mask_indices[1]):
            # Define region around the pixel
            min_i = max(0, i - self.context_radius)
            max_i = min(h, i + self.context_radius + 1)
            min_j = max(0, j - self.context_radius)
            max_j = min(w, j + self.context_radius + 1)
            
            # Extract region
            region_valid = valid_tissue_mask[min_i:max_i, min_j:max_j]
            region_orig = original_array[min_i:max_i, min_j:max_j, :]
            
            # If there are valid tissue pixels in the region
            if np.sum(region_valid) > 0:
                # Average color of valid tissue
                for c in range(3):
                    valid_colors = region_orig[:, :, c][region_valid > 0]
                    if len(valid_colors) > 0:
                        context_map[i, j, c] = np.mean(valid_colors)
            else:
                # If no valid tissue nearby, use global average of valid tissue
                for c in range(3):
                    valid_colors = original_array[:, :, c][valid_tissue_mask > 0]
                    if len(valid_colors) > 0:
                        context_map[i, j, c] = np.mean(valid_colors)
        
        # Convert to tensors
        original_tensor = torch.from_numpy(original_array.transpose(2, 0, 1)).float()
        mask_tensor = torch.from_numpy(mask_array).float().unsqueeze(0)
        aperture_tensor = torch.from_numpy(aperture_mask).float().unsqueeze(0)
        context_tensor = torch.from_numpy(context_map.transpose(2, 0, 1)).float()
        
        # Create input tensor by concatenating image, mask, and aperture mask
        input_tensor = torch.cat([original_tensor, mask_tensor, aperture_tensor], dim=0)
        
        # Model should learn to use context_map values in masked areas
        # Target is original image for non-masked areas, and context_map for masked areas
        target_tensor = original_tensor * (1.0 - mask_tensor) + context_tensor * mask_tensor
        
        return {
            'input': input_tensor,           # [C+2, H, W] - RGB + Reflection mask + Aperture mask
            'target': target_tensor,         # [C, H, W] - RGB with reflections replaced by context
            'mask': mask_tensor,             # [1, H, W] - Reflection mask
            'aperture_mask': aperture_tensor, # [1, H, W] - Aperture mask
            'filename': self.image_filenames[idx]
        }

# Modified U-Net with awareness of aperture and focus on surrounding tissue emulation
class TissueEmulationUNet(nn.Module):
    def __init__(self, in_channels=5, out_channels=3):
        super(TissueEmulationUNet, self).__init__()
        
        # Encoder
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
        self.down2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.down3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        self.down4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Pooling
        self.maxpool = nn.MaxPool2d(kernel_size=2)
        
        # Bottleneck with attention
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        
        # Attention mechanism to focus on relevant tissue patterns
        self.attention = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
        # Decoder
        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upconv1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upconv2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.upconv3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.up4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.upconv4 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
        # Output layers - RGB channels
        self.outconv = nn.Conv2d(32, out_channels, kernel_size=1)
        
    def forward(self, x):
        # Extract components of the input
        img = x[:, :3]           # RGB channels
        mask = x[:, 3:4]         # Reflection mask
        aperture_mask = x[:, 4:5] # Aperture mask (dark areas)
        
        # Create a combined invalid area mask (reflections + aperture)
        invalid_mask = torch.clamp(mask + aperture_mask, 0, 1)
        
        # Encoder path
        x1 = self.down1(x)
        x2 = self.maxpool(x1)
        
        x2 = self.down2(x2)
        x3 = self.maxpool(x2)
        
        x3 = self.down3(x3)
        x4 = self.maxpool(x3)
        
        x4 = self.down4(x4)
        x5 = self.maxpool(x4)
        
        # Bottleneck with attention
        x5 = self.bottleneck(x5)
        attention_map = self.attention(x5)
        x5 = x5 * attention_map
        
        # Decoder path with skip connections
        x = self.up1(x5)
        x = torch.cat([x, x4], dim=1)
        x = self.upconv1(x)
        
        x = self.up2(x)
        x = torch.cat([x, x3], dim=1)
        x = self.upconv2(x)
        
        x = self.up3(x)
        x = torch.cat([x, x2], dim=1)
        x = self.upconv3(x)
        
        x = self.up4(x)
        x = torch.cat([x, x1], dim=1)
        x = self.upconv4(x)
        
        # Output
        x = self.outconv(x)
        x = torch.sigmoid(x)  # Scale to [0,1]
        
        # Only replace the reflection areas, keep the original elsewhere
        output = img * (1.0 - mask) + x * mask
        
        # Ensure we don't modify the aperture areas
        output = output * (1.0 - aperture_mask) + img * aperture_mask
        
        return output

# Custom loss function focused on tissue emulation
class TissueEmulationLoss(nn.Module):
    def __init__(self, texture_weight=2.0, edge_weight=1.0):
        super(TissueEmulationLoss, self).__init__()
        self.l1_loss = nn.L1Loss()
        self.texture_weight = texture_weight
        self.edge_weight = edge_weight
        
    def edge_loss(self, pred, target, mask):
        # Sobel operators for edge detection
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=pred.device).float().view(1, 1, 3, 3).repeat(3, 1, 1, 1)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=pred.device).float().view(1, 1, 3, 3).repeat(3, 1, 1, 1)
        
        # Expand mask to match number of channels
        expanded_mask = mask.repeat(1, 3, 1, 1)
        
        # Detect edges
        pred_edges_x = F.conv2d(pred, sobel_x, padding=1, groups=3)
        pred_edges_y = F.conv2d(pred, sobel_y, padding=1, groups=3)
        target_edges_x = F.conv2d(target, sobel_x, padding=1, groups=3)
        target_edges_y = F.conv2d(target, sobel_y, padding=1, groups=3)
        
        # Calculate edge magnitude
        pred_edges = torch.sqrt(pred_edges_x**2 + pred_edges_y**2)
        target_edges = torch.sqrt(target_edges_x**2 + target_edges_y**2)
        
        # Only consider edges in masked areas
        edge_loss = self.l1_loss(pred_edges * expanded_mask, target_edges * expanded_mask)
        
        return edge_loss
    
    def texture_loss(self, pred, target, mask):
        # Simple texture loss using local patches
        # For a more advanced implementation, consider using Gram matrices
        # or other texture descriptors
        
        # Apply average pooling to get local features
        pred_pooled = F.avg_pool2d(pred, kernel_size=3, stride=1, padding=1)
        target_pooled = F.avg_pool2d(target, kernel_size=3, stride=1, padding=1)
        
        # Only consider texture in masked areas
        expanded_mask = mask.repeat(1, 3, 1, 1)
        texture_loss = self.l1_loss(pred_pooled * expanded_mask, target_pooled * expanded_mask)
        
        return texture_loss
    
    def forward(self, pred, target, mask, aperture_mask):
        # Calculate valid mask (not aperture)
        valid_mask = 1.0 - aperture_mask
        
        # Only evaluate loss in non-aperture regions
        pred_valid = pred * valid_mask
        target_valid = target * valid_mask
        mask_valid = mask * valid_mask
        
        # Pixel-wise L1 loss
        pixel_loss = self.l1_loss(pred_valid * mask_valid, target_valid * mask_valid)
        
        # Edge consistency loss
        edge = self.edge_loss(pred_valid, target_valid, mask_valid) * self.edge_weight
        
        # Texture matching loss
        texture = self.texture_loss(pred_valid, target_valid, mask_valid) * self.texture_weight
        
        # Total loss
        total_loss = pixel_loss + edge + texture
        
        return total_loss, {
            'pixel': pixel_loss.item(),
            'edge': edge.item(),
            'texture': texture.item()
        }

# Function to visualize results
def visualize_sample(model, batch, epoch, device):
    model.eval()
    with torch.no_grad():
        # Get input, target and masks
        inputs = batch['input'].to(device)
        targets = batch['target'].to(device)
        masks = batch['mask'].to(device)
        aperture_masks = batch['aperture_mask'].to(device)
        
        # Forward pass
        outputs = model(inputs)
        
        # Convert tensors to numpy for visualization
        input_img = inputs[0, :3].cpu().numpy().transpose(1, 2, 0)  # RGB channels only
        mask_img = masks[0, 0].cpu().numpy()
        aperture_img = aperture_masks[0, 0].cpu().numpy()
        target_img = targets[0].cpu().numpy().transpose(1, 2, 0)
        output_img = outputs[0].cpu().numpy().transpose(1, 2, 0)
        
        # Create figure
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # Plot input image
        axes[0, 0].imshow(input_img)
        axes[0, 0].set_title('Input Image')
        axes[0, 0].axis('off')
        
        # Plot reflection mask
        axes[0, 1].imshow(mask_img, cmap='gray')
        axes[0, 1].set_title('Reflection Mask')
        axes[0, 1].axis('off')
        
        # Plot aperture mask
        axes[0, 2].imshow(aperture_img, cmap='gray')
        axes[0, 2].set_title('Aperture Mask')
        axes[0, 2].axis('off')
        
        # Plot target image
        axes[1, 0].imshow(target_img)
        axes[1, 0].set_title('Target Image')
        axes[1, 0].axis('off')
        
        # Plot output image
        axes[1, 1].imshow(output_img)
        axes[1, 1].set_title('Output Image')
        axes[1, 1].axis('off')
        
        # Plot difference image
        diff_img = np.abs(output_img - target_img)
        diff_img = diff_img / np.max(diff_img) if np.max(diff_img) > 0 else diff_img
        axes[1, 2].imshow(diff_img)
        axes[1, 2].set_title('Difference')
        axes[1, 2].axis('off')
        
        plt.tight_layout()
        plt.savefig(f'results/sample_epoch_{epoch+1}.png')
        plt.close()

# Training function
def train_model(model, train_loader, val_loader, num_epochs=30, lr=0.001):
    # Define optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Loss function
    criterion = TissueEmulationLoss()
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'epoch_times': []
    }
    
    # Best model tracking
    best_val_loss = float('inf')
    
    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        epoch_start_time = time.time()
        train_loss = 0.0
        
        # Progress bar for training
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        
        for batch in train_bar:
            # Get input, target, masks
            inputs = batch['input'].to(device)
            targets = batch['target'].to(device)
            masks = batch['mask'].to(device)
            aperture_masks = batch['aperture_mask'].to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)
            
            # Compute loss
            loss, loss_components = criterion(outputs, targets, masks, aperture_masks)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            # Update training loss
            train_loss += loss.item()
            
            # Update progress bar
            train_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'pixel': f"{loss_components['pixel']:.4f}",
                'texture': f"{loss_components['texture']:.4f}"
            })
        
        # Calculate average training loss
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            # Progress bar for validation
            val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
            
            for batch in val_bar:
                inputs = batch['input'].to(device)
                targets = batch['target'].to(device)
                masks = batch['mask'].to(device)
                aperture_masks = batch['aperture_mask'].to(device)
                
                # Forward pass
                outputs = model(inputs)
                
                # Compute loss
                loss, _ = criterion(outputs, targets, masks, aperture_masks)
                
                # Update validation loss
                val_loss += loss.item()
                
                # Update progress bar
                val_bar.set_postfix({'val_loss': f"{loss.item():.4f}"})
        
        # Calculate average validation loss
        avg_val_loss = val_loss / len(val_loader)
        
        # Update learning rate
        scheduler.step(avg_val_loss)
        
        # Update history
        epoch_time = time.time() - epoch_start_time
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['epoch_times'].append(epoch_time)
        
        # Print epoch summary
        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
              f"Time: {epoch_time:.1f}s, "
              f"LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss
            }, 'best_tissue_emulation_model.pth')
            print(f"Saved new best model with val loss: {best_val_loss:.4f}")
        
        # Visualize a sample after each epoch
        if val_loader:
            visualize_sample(model, next(iter(val_loader)), epoch, device)
    
    return model, history

# Video processing function
def process_video(model, video_path, output_path, darkness_threshold=0.15):
    # Make sure the output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Open the video
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return
    
    # Get video properties
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Create output video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    # Load the model
    model.eval()
    
    # Create transform for preprocessing
    preprocess = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ])
    
    # Function to detect specular reflections and aperture
    def detect_features(frame):
        # Convert to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Convert to HSV for better reflection detection
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        
        # Extract value channel (brightness)
        v_channel = hsv[:,:,2]
        
        # Create a mask for very bright areas (specular reflections)
        _, reflection_mask = cv2.threshold(v_channel, 220, 255, cv2.THRESH_BINARY)
        
        # Create a mask for dark areas (outside the endoscope view)
        brightness = np.mean(frame_rgb / 255.0, axis=2)
        aperture_mask = (brightness < darkness_threshold).astype(np.float32)
        
        # Clean up reflection mask with morphological operations
        kernel = np.ones((5, 5), np.uint8)
        reflection_mask = cv2.morphologyEx(reflection_mask, cv2.MORPH_CLOSE, kernel)
        reflection_mask = cv2.morphologyEx(reflection_mask, cv2.MORPH_OPEN, kernel)
        
        return frame_rgb, reflection_mask / 255.0, aperture_mask
    
    # Process each frame
    with torch.no_grad():
        for i in tqdm(range(frame_count), desc="Processing video"):
            ret, frame = cap.read()
            if not ret:
                break
            
            # Detect reflection and aperture
            frame_rgb, reflection_mask, aperture_mask = detect_features(frame)
            
            # Check if reflections found
            if np.max(reflection_mask) > 0:
                # Preprocess
                frame_tensor = preprocess(frame_rgb).unsqueeze(0)
                reflection_tensor = torch.from_numpy(reflection_mask).float().unsqueeze(0).unsqueeze(0)
                reflection_tensor = F.interpolate(reflection_tensor, size=(512, 512), mode='nearest')
                aperture_tensor = torch.from_numpy(aperture_mask).float().unsqueeze(0).unsqueeze(0)
                aperture_tensor = F.interpolate(aperture_tensor, size=(512, 512), mode='nearest')
                
                # Create input by concatenating frame, reflection mask, and aperture mask
                input_tensor = torch.cat([frame_tensor, reflection_tensor, aperture_tensor], dim=1).to(device)
                
                # Process with model
                output = model(input_tensor)
                
                # Convert output to numpy
                output_np = output[0].cpu().numpy().transpose(1, 2, 0)
                output_np = (output_np * 255).astype(np.uint8)
                
                # Resize back to original dimensions
                output_np = cv2.resize(output_np, (width, height))
                
                # Convert RGB to BGR for OpenCV
                output_bgr = cv2.cvtColor(output_np, cv2.COLOR_RGB2BGR)
                
                # Create the mask at original resolution
                orig_mask = cv2.resize(reflection_mask, (width, height))
                orig_mask = np.expand_dims(orig_mask, axis=2).repeat(3, axis=2)
                
                # Blend original and processed frames
                blended = frame * (1 - orig_mask) + output_bgr * orig_mask
                blended = blended.astype(np.uint8)
                
                # Write frame
                out.write(blended)
            else:
                # No reflections, use original frame
                out.write(frame)
    
    # Release resources
    cap.release()
    out.release()
    print(f"Processed video saved to {output_path}")



def main():
    # Create results directory
    os.makedirs('results', exist_ok=True)
    
    # Parameters
    BATCH_SIZE = 4
    NUM_EPOCHS = 30
    LEARNING_RATE = 0.001
    IMG_SIZE = (512, 512)
    DARKNESS_THRESHOLD = 0.15  # Adjust based on your images
    
    # Path to your data directories
    originals_dir = 'originals'  # Directory with original images
    masks_dir = 'masks'         # Directory with reflection masks
    
    # Create the dataset
    full_dataset = TissueEmulationDataset(
        originals_dir=originals_dir,
        masks_dir=masks_dir,
        img_size=IMG_SIZE,
        darkness_threshold=DARKNESS_THRESHOLD
    )
    
    # Split into train and validation sets (80% train, 20% validation)
    dataset_size = len(full_dataset)
    train_size = int(0.8 * dataset_size)
    val_size = dataset_size - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size]
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=2
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=2
    )
    
    print(f"Dataset loaded - Total: {dataset_size}, Train: {train_size}, Val: {val_size}")
    
    # Initialize the model
    model = TissueEmulationUNet(in_channels=5, out_channels=3).to(device)
    
    # Print model summary
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model initialized with {total_params:,} trainable parameters")
    
    # Train the model
    print("Starting training...")
    model, history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=NUM_EPOCHS,
        lr=LEARNING_RATE
    )
    
    # Save the final model
    torch.save(model.state_dict(), 'final_tissue_emulation_model.pth')
    
    # Plot training history
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.legend()
    plt.title('Loss History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(history['epoch_times'])
    plt.title('Epoch Times')
    plt.xlabel('Epoch')
    plt.ylabel('Time (s)')
    
    plt.tight_layout()
    plt.savefig('results/training_history.png')
    
    print("Training completed!")
    
    # Process a video if available
    video_path = input("Enter the path to a video to process (or press Enter to skip): ")
    if video_path and os.path.exists(video_path):
        output_path = 'results/processed_video.mp4'
        print(f"Processing video {video_path}...")
        process_video(model, video_path, output_path, darkness_threshold=DARKNESS_THRESHOLD)
        print(f"Video processing completed!")

# Additionally, add functions to use a pre-trained model for inference

def load_trained_model(model_path, device=None):
    """Load a pre-trained model for inference"""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize the model
    model = TissueEmulationUNet(in_channels=5, out_channels=3).to(device)
    
    # Load the saved weights
    checkpoint = torch.load(model_path, map_location=device)
    
    # Check if we have a full checkpoint or just state_dict
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    model.eval()
    print(f"Model loaded from {model_path}")
    
    return model

def process_single_image(model, image_path, output_path=None, darkness_threshold=0.15):
    """Process a single image to remove reflections"""
    # Load image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Error: Could not open image {image_path}")
        return None
    
    # Detect reflection and aperture
    frame_rgb, reflection_mask, aperture_mask = detect_features(image, darkness_threshold)
    
    # Check if reflections found
    if np.max(reflection_mask) > 0:
        # Create preprocessing transform
        preprocess = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((512, 512)),
            transforms.ToTensor()
        ])
        
        # Preprocess
        frame_tensor = preprocess(frame_rgb).unsqueeze(0)
        reflection_tensor = torch.from_numpy(reflection_mask).float().unsqueeze(0).unsqueeze(0)
        reflection_tensor = F.interpolate(reflection_tensor, size=(512, 512), mode='nearest')
        aperture_tensor = torch.from_numpy(aperture_mask).float().unsqueeze(0).unsqueeze(0)
        aperture_tensor = F.interpolate(aperture_tensor, size=(512, 512), mode='nearest')
        
        # Create input by concatenating frame, reflection mask, and aperture mask
        input_tensor = torch.cat([frame_tensor, reflection_tensor, aperture_tensor], dim=1).to(device)
        
        # Process with model
        with torch.no_grad():
            output = model(input_tensor)
        
        # Convert output to numpy
        output_np = output[0].cpu().numpy().transpose(1, 2, 0)
        output_np = (output_np * 255).astype(np.uint8)
        
        # Resize back to original dimensions
        h, w = image.shape[:2]
        output_np = cv2.resize(output_np, (w, h))
        
        # Convert RGB to BGR for OpenCV
        output_bgr = cv2.cvtColor(output_np, cv2.COLOR_RGB2BGR)
        
        # Create the mask at original resolution
        orig_mask = cv2.resize(reflection_mask, (w, h))
        orig_mask = np.expand_dims(orig_mask, axis=2).repeat(3, axis=2)
        
        # Blend original and processed frames
        blended = image * (1 - orig_mask) + output_bgr * orig_mask
        processed_image = blended.astype(np.uint8)
        
        # Save if output path provided
        if output_path:
            cv2.imwrite(output_path, processed_image)
            print(f"Processed image saved to {output_path}")
        
        return processed_image
    else:
        print("No reflections detected in the image.")
        if output_path:
            cv2.imwrite(output_path, image)
        return image

def detect_features(frame, darkness_threshold=0.15):
    """Helper function to detect specular reflections and aperture in an image"""
    # Convert to RGB
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    # Convert to HSV for better reflection detection
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    
    # Extract value channel (brightness)
    v_channel = hsv[:,:,2]
    
    # Create a mask for very bright areas (specular reflections)
    _, reflection_mask = cv2.threshold(v_channel, 220, 255, cv2.THRESH_BINARY)
    
    # Create a mask for dark areas (outside the endoscope view)
    brightness = np.mean(frame_rgb / 255.0, axis=2)
    aperture_mask = (brightness < darkness_threshold).astype(np.float32)
    
    # Clean up reflection mask with morphological operations
    kernel = np.ones((5, 5), np.uint8)
    reflection_mask = cv2.morphologyEx(reflection_mask, cv2.MORPH_CLOSE, kernel)
    reflection_mask = cv2.morphologyEx(reflection_mask, cv2.MORPH_OPEN, kernel)
    
    return frame_rgb, reflection_mask / 255.0, aperture_mask

def evaluate_model(model, test_loader):
    """Evaluate model performance on a test dataset"""
    model.eval()
    criterion = TissueEmulationLoss()
    
    total_loss = 0.0
    component_losses = {'pixel': 0.0, 'edge': 0.0, 'texture': 0.0}
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating model"):
            inputs = batch['input'].to(device)
            targets = batch['target'].to(device)
            masks = batch['mask'].to(device)
            aperture_masks = batch['aperture_mask'].to(device)
            
            # Forward pass
            outputs = model(inputs)
            
            # Compute loss
            loss, loss_components = criterion(outputs, targets, masks, aperture_masks)
            
            # Update losses
            total_loss += loss.item()
            for key, value in loss_components.items():
                component_losses[key] += value
    
    # Average losses
    num_batches = len(test_loader)
    avg_total_loss = total_loss / num_batches
    avg_component_losses = {k: v / num_batches for k, v in component_losses.items()}
    
    print(f"Evaluation results:")
    print(f"  Total loss: {avg_total_loss:.4f}")
    for key, value in avg_component_losses.items():
        print(f"  {key.capitalize()} loss: {value:.4f}")
    
    return avg_total_loss, avg_component_losses

# Run the main function when script is executed directly
if __name__ == "__main__":
    main()

Using device: cuda
Dataset loaded - Total: 48, Train: 38, Val: 10
Model initialized with 7,767,140 trainable parameters
Starting training...


Epoch 1/30 [Train]:   0%|          | 0/10 [00:00<?, ?it/s]



Epoch 1/30 [Val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 1/30 - Train Loss: nan, Val Loss: nan, Time: 2.9s, LR: 0.001000


  xx = (xx * 255).astype(np.uint8)


Epoch 2/30 [Train]:   0%|          | 0/10 [00:00<?, ?it/s]



Epoch 2/30 [Val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 2/30 - Train Loss: nan, Val Loss: nan, Time: 2.3s, LR: 0.001000





Epoch 3/30 [Train]:   0%|          | 0/10 [00:00<?, ?it/s]



Epoch 3/30 [Val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 3/30 - Train Loss: nan, Val Loss: nan, Time: 2.3s, LR: 0.001000


Epoch 4/30 [Train]:   0%|          | 0/10 [00:00<?, ?it/s]




Epoch 4/30 [Val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 4/30 - Train Loss: nan, Val Loss: nan, Time: 2.3s, LR: 0.001000


Epoch 5/30 [Train]:   0%|          | 0/10 [00:00<?, ?it/s]




Epoch 5/30 [Val]:   0%|          | 0/3 [00:00<?, ?it/s]


Epoch 5/30 - Train Loss: nan, Val Loss: nan, Time: 2.4s, LR: 0.001000


Epoch 6/30 [Train]:   0%|          | 0/10 [00:00<?, ?it/s]



Epoch 6/30 [Val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 00006: reducing learning rate of group 0 to 5.0000e-04.
Epoch 6/30 - Train Loss: nan, Val Loss: nan, Time: 2.4s, LR: 0.000500


Epoch 7/30 [Train]:   0%|          | 0/10 [00:00<?, ?it/s]




Epoch 7/30 [Val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 7/30 - Train Loss: nan, Val Loss: nan, Time: 2.4s, LR: 0.000500


Epoch 8/30 [Train]:   0%|          | 0/10 [00:00<?, ?it/s]



Epoch 8/30 [Val]:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch 8/30 - Train Loss: nan, Val Loss: nan, Time: 2.4s, LR: 0.000500


KeyboardInterrupt: 