# Medical Image Enhancement GAN - Local Setup

## Project Overview
This notebook implements a Generative Adversarial Network (GAN) for enhancing medical images from the Kvasir-SEG dataset. The model takes degraded medical images and restores them to high quality using a U-Net generator and discriminator architecture.

## Setup Instructions

### 1. Install Required Dependencies
```bash
pip install torch torchvision matplotlib pillow numpy scikit-image
```

### 2. Download and Prepare Dataset
1. Download the Kvasir-SEG dataset from: https://datasets.simula.no/kvasir-seg/
2. Extract the dataset in your project directory
3. Ensure the folder structure is:
   ```
   your_project/
   ├── data/
   │   └── Kvasir-SEG/
   │       └── images/
   │           ├── image_00001.jpg
   │           ├── image_00002.jpg
   │           └── ...
   ├── output/ (will be created automatically)
   └── ganproject (2).ipynb
   ```

### 3. System Requirements
- GPU recommended (CUDA compatible) for faster training
- At least 8GB RAM
- Python 3.7+ with PyTorch installed

### 4. Expected Output
- Training will create enhanced images in `output/generated_images/`
- Model checkpoints will be saved in `output/generated_images/checkpoints/`
- Training progress will be displayed with loss metrics and sample images

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vgg19
import os
from PIL import Image
import numpy as np
import random
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import matplotlib.pyplot as plt
import shutil

## Detailed Code Explanation

### 1. Import Libraries Analysis

Let's break down each import and understand its purpose in our GAN implementation:

**Deep Learning Framework:**
- `torch`: Core PyTorch library for tensor operations and neural networks
- `torch.nn`: Neural network modules (layers, activation functions, loss functions)
- `torch.optim`: Optimization algorithms (Adam, SGD, etc.)

**Data Handling:**
- `Dataset`: Base class for creating custom datasets
- `DataLoader`: Efficient data loading with batching and shuffling
- `transforms`: Image preprocessing and augmentation utilities
- `vgg19`: Pre-trained VGG19 model for perceptual loss calculation

**Image Processing:**
- `PIL.Image`: Loading and basic image operations
- `numpy`: Numerical operations and array manipulations
- `skimage.metrics`: Image quality metrics (PSNR, SSIM)

**Utilities:**
- `matplotlib.pyplot`: Plotting and visualization
- `os`: File system operations
- `random`: Random number generation for data augmentation
- `shutil`: High-level file operations

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### 2. Device Configuration

This line automatically detects the best available hardware for training:

- **CUDA GPU**: If available, uses GPU acceleration (much faster for deep learning)
- **CPU Fallback**: Uses CPU if no GPU is detected (slower but still functional)

**Why GPU is Important:**
- Neural networks involve massive matrix operations
- GPUs have thousands of cores optimized for parallel computation
- Training time can be 10-100x faster on GPU vs CPU
- Essential for practical deep learning model training

In [3]:
import os

class Config:
    # Local paths for running on your device
    DATA_ROOT = 'kvasir-seg\Kvasir-SEG\images'  # Local path to the dataset
    OUTPUT_DIR = 'output/generated_images'  # Local output directory
    CHECKPOINT_DIR = os.path.join(OUTPUT_DIR, 'checkpoints')

    IMAGE_SIZE = 256
    NUM_CHANNELS = 3

    BATCH_SIZE = 4
    NUM_EPOCHS = 50
    LEARNING_RATE_G = 2e-4
    LEARNING_RATE_D = 1e-4
    BETA1 = 0.5
    LAMBDA_L1 = 70.0
    LAMBDA_PERCEPTUAL = 10.0

   
    G_TRAIN_MULTIPLIER = 3 

    NOISE_STD_DEV = 0.15
    NOISE_STD_DEV_VARIATION = 0.05
    
    SALT_VS_PEPPER_RATIO = 0.5
    SP_NOISE_PROB = 0.01
    
    BLUR_KERNEL_SIZE = 5
    BRIGHTNESS_FACTOR = 0.2
    CONTRAST_FACTOR = 0.2
    SATURATION_FACTOR = 0.2
    HUE_FACTOR = 0.05

    
    LR_DECAY_START_EPOCH = 30 
    SAVE_EVERY_N_EPOCHS = 5
    LOG_EVERY_N_BATCHES = 10

    DEVICE = DEVICE


### 3. Configuration Class - Detailed Parameter Analysis

The Config class centralizes all hyperparameters and settings. Let's examine each category:

#### **A. File Paths & Directory Structure**
```python
DATA_ROOT = 'kvasir-seg\Kvasir-SEG\images'  # Input dataset location
OUTPUT_DIR = 'output/generated_images'       # Where results are saved
CHECKPOINT_DIR = ...                         # Model weights storage
```

#### **B. Image Processing Parameters**
```python
IMAGE_SIZE = 256        # All images resized to 256x256 pixels
NUM_CHANNELS = 3        # RGB color channels (Red, Green, Blue)
```
- **Why 256x256?** Good balance between detail preservation and computational efficiency
- **RGB channels:** Standard color representation for medical images

#### **C. Training Hyperparameters**
```python
BATCH_SIZE = 4              # Process 4 images simultaneously
NUM_EPOCHS = 50             # Complete 50 training cycles
LEARNING_RATE_G = 2e-4      # Generator learning rate (0.0002)
LEARNING_RATE_D = 1e-4      # Discriminator learning rate (0.0001)
BETA1 = 0.5                 # Adam optimizer momentum parameter
```

**Critical Details:**
- **Different Learning Rates:** Generator learns faster than discriminator to prevent discriminator from becoming too powerful
- **Small Batch Size:** Accommodates limited GPU memory, especially important for medical imaging
- **Adam Optimizer:** Adaptive learning rate algorithm, excellent for GANs

#### **D. Loss Function Weights**
```python
LAMBDA_L1 = 70.0           # Pixel-wise reconstruction importance
LAMBDA_PERCEPTUAL = 10.0   # High-level feature similarity importance
G_TRAIN_MULTIPLIER = 3     # Train generator 3x more than discriminator
```

**Why These Weights Matter:**
- **High L1 weight (70.0):** Ensures generated images are pixel-accurate
- **Perceptual weight (10.0):** Maintains visual realism and texture quality
- **Generator multiplier:** Prevents discriminator from overpowering generator

In [4]:
def apply_degradation(image_tensor, config):
    image_tensor_cpu = image_tensor.cpu()
    degraded_tensor = image_tensor_cpu.clone() 

    current_noise_std_dev = config.NOISE_STD_DEV + random.uniform(-config.NOISE_STD_DEV_VARIATION, config.NOISE_STD_DEV_VARIATION)
    current_noise_std_dev = max(0, current_noise_std_dev)

    gaussian_noise = torch.randn_like(degraded_tensor) * current_noise_std_dev
    degraded_tensor = degraded_tensor + gaussian_noise
    degraded_tensor = torch.clamp(degraded_tensor, 0, 1)


    num_pixels = degraded_tensor.numel()
    num_sp_noise_pixels = int(num_pixels * config.SP_NOISE_PROB)

    noise_indices = torch.randperm(num_pixels)[:num_sp_noise_pixels]

    salt_mask = torch.rand(num_sp_noise_pixels) < config.SALT_VS_PEPPER_RATIO
    pepper_mask = ~salt_mask

    degraded_tensor_flat = degraded_tensor.view(-1)

    # Apply pepper noise (set to 0)
    degraded_tensor_flat[noise_indices[pepper_mask]] = 0.0

    # Apply salt noise (set to 1)
    degraded_tensor_flat[noise_indices[salt_mask]] = 1.0
    
    # Reshape back to original dimensions
    degraded_tensor = degraded_tensor_flat.view(degraded_tensor.shape)


    if config.BLUR_KERNEL_SIZE > 0:
        blur_transform = transforms.GaussianBlur(
            kernel_size=(config.BLUR_KERNEL_SIZE, config.BLUR_KERNEL_SIZE),
            sigma=(0.1, 2.0)
        )
        degraded_tensor = blur_transform(degraded_tensor)

    jitter_transform = transforms.ColorJitter(
        brightness=config.BRIGHTNESS_FACTOR,
        contrast=config.CONTRAST_FACTOR,
        saturation=config.SATURATION_FACTOR,
        hue=config.HUE_FACTOR
    )
    degraded_tensor = jitter_transform(degraded_tensor)

    return degraded_tensor

### 4. Image Degradation Function - Simulating Real-World Problems

This function artificially degrades high-quality images to create training pairs. Here's the step-by-step process:

#### **Step 1: Gaussian Noise Addition**
```python
current_noise_std_dev = config.NOISE_STD_DEV + random.uniform(...)
gaussian_noise = torch.randn_like(degraded_tensor) * current_noise_std_dev
```
**Purpose:** Simulates sensor noise, electrical interference, or compression artifacts
**Implementation:** Adds random values from normal distribution to each pixel
**Randomization:** Varies noise intensity for each image to increase training diversity

#### **Step 2: Salt & Pepper Noise**
```python
num_sp_noise_pixels = int(num_pixels * config.SP_NOISE_PROB)  # 1% of pixels
salt_mask = torch.rand(num_sp_noise_pixels) < config.SALT_VS_PEPPER_RATIO
```
**Salt Noise:** Random pixels set to maximum value (white spots)
**Pepper Noise:** Random pixels set to minimum value (black spots)
**Medical Relevance:** Simulates dead pixels, dust on lens, or transmission errors

#### **Step 3: Gaussian Blur**
```python
blur_transform = transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0))
```
**Purpose:** Simulates motion blur, focus issues, or optical aberrations
**Kernel Size:** 5x5 pixel neighborhood for blur calculation
**Variable Sigma:** Random blur intensity for training variety

#### **Step 4: Color Jitter**
```python
jitter_transform = transforms.ColorJitter(
    brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
)
```
**Brightness:** ±20% illumination changes
**Contrast:** ±20% contrast variations
**Saturation:** ±20% color intensity changes
**Hue:** ±5% color shift (small to preserve medical accuracy)

**Why This Approach?**
- Creates realistic training data without needing actual degraded medical images
- Allows controlled experimentation with different degradation types
- Ensures the model learns to handle various real-world imaging problems

In [5]:
def save_images(low_res, real_high_res, fake_high_res, epoch, batch_idx, output_dir):
    low_res_np = (low_res * 0.5 + 0.5).cpu().detach().numpy().transpose(0, 2, 3, 1)
    real_high_res_np = (real_high_res * 0.5 + 0.5).cpu().detach().numpy().transpose(0, 2, 3, 1)
    fake_high_res_np = (fake_high_res * 0.5 + 0.5).cpu().detach().numpy().transpose(0, 2, 3, 1)

    fig, axes = plt.subplots(Config.BATCH_SIZE, 3, figsize=(9, 3 * Config.BATCH_SIZE))
    fig.suptitle(f'Epoch {epoch}, Batch {batch_idx}', fontsize=16)

    for i in range(Config.BATCH_SIZE):
        axes[i, 0].imshow(low_res_np[i])
        axes[i, 0].set_title('Low-Quality Input')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(real_high_res_np[i])
        axes[i, 1].set_title('Real High-Quality')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(fake_high_res_np[i])
        axes[i, 2].set_title('Generated High-Quality')
        axes[i, 2].axis('off')
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    
    filename = os.path.join(output_dir, f'epoch_{epoch}_batch_{batch_idx}.png')
    
    try:
        plt.savefig(filename)
        print(f"Successfully saved image to: {filename}")
    except Exception as e:
        print(f"Error saving image {filename}: {e}")
    finally:
        plt.close(fig)

In [6]:
def calculate_metrics(real_images, fake_images):
    real_images_np = (real_images * 0.5 + 0.5).cpu().detach().numpy().transpose(0, 2, 3, 1)
    fake_images_np = (fake_images * 0.5 + 0.5).cpu().detach().numpy().transpose(0, 2, 3, 1)

    psnr_scores = []
    ssim_scores = []

    for i in range(real_images_np.shape[0]):
        psnr = peak_signal_noise_ratio(real_images_np[i], fake_images_np[i], data_range=1.0)
        ssim = structural_similarity(real_images_np[i], fake_images_np[i], data_range=1.0, channel_axis=-1)
        psnr_scores.append(psnr)
        ssim_scores.append(ssim)
    
    return np.mean(psnr_scores), np.mean(ssim_scores)

In [7]:
class KvasirSEGDataset(Dataset):
    def __init__(self, root_dir, image_size, transform=None, degradation_transform=None):
        self.root_dir = root_dir
        self.image_size = image_size
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
        self.transform = transform
        self.degradation_transform = degradation_transform

        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        
        print(f"Dataset initialized with {len(self.image_files)} images from {root_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.root_dir, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
            high_quality_image = self.transform(image)

            high_quality_image_0_1 = (high_quality_image * 0.5 + 0.5)
            
            if self.degradation_transform:
                low_quality_image_0_1 = self.degradation_transform(high_quality_image_0_1)
            else:
                low_quality_image_0_1 = high_quality_image_0_1

            low_quality_image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(low_quality_image_0_1)

            return low_quality_image, high_quality_image
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a dummy tensor if image loading fails
            dummy_tensor = torch.zeros(3, self.image_size, self.image_size)
            return dummy_tensor, dummy_tensor

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


In [8]:
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

### 5. U-Net Building Blocks - Architecture Deep Dive

#### **UNetDown: Encoder Block Analysis**
```python
layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
```

**Convolutional Layer Parameters:**
- **Kernel Size 4:** 4x4 filter for feature detection
- **Stride 2:** Moves filter 2 pixels at a time → halves image dimensions
- **Padding 1:** Adds 1 pixel border to maintain spatial relationships
- **bias=False:** Batch normalization handles bias, so we disable it here

**Layer Sequence:**
1. **Convolution:** Feature extraction and downsampling
2. **Batch Normalization:** Stabilizes training, prevents gradient problems
3. **LeakyReLU:** Activation function allowing small negative gradients
4. **Dropout:** Randomly zeros neurons to prevent overfitting

#### **UNetUp: Decoder Block Analysis**
```python
nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
```

**Transposed Convolution (Upsampling):**
- **Purpose:** Increases image resolution while learning features
- **Kernel Size 4:** Maintains symmetry with encoder
- **Stride 2:** Doubles image dimensions
- **Skip Connection:** `torch.cat((x, skip_input), 1)` preserves fine details

**Why Skip Connections?**
- **Problem:** Deep networks lose fine details during downsampling
- **Solution:** Directly connect encoder features to decoder
- **Result:** Combines high-level understanding with low-level details
- **Medical Importance:** Preserves critical anatomical structures

In [9]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels * 2, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 1, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, 1, 1, bias=False)
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

### 6. Discriminator Architecture - PatchGAN Deep Analysis

#### **Input Processing:**
```python
def forward(self, img_A, img_B):
    img_input = torch.cat((img_A, img_B), 1)  # Concatenate along channel dimension
```
**img_A:** Degraded input image (3 channels)
**img_B:** High-quality target/generated image (3 channels)
**Result:** 6-channel input (concatenated RGB + RGB)

#### **Network Architecture Breakdown:**

**Layer 1:** `Conv2d(6→64, kernel=4, stride=2)` 
- Input: 256×256×6 → Output: 128×128×64
- No batch norm for first layer (common GAN practice)

**Layer 2:** `Conv2d(64→128)` + BatchNorm + LeakyReLU
- Input: 128×128×64 → Output: 64×64×128
- Batch normalization stabilizes training

**Layer 3:** `Conv2d(128→256)` + BatchNorm + LeakyReLU  
- Input: 64×64×128 → Output: 32×32×256
- Increasing feature channels, decreasing spatial resolution

**Layer 4:** `Conv2d(256→512, stride=1)` + BatchNorm + LeakyReLU
- Input: 32×32×256 → Output: 31×31×512
- Stride=1 maintains resolution for final classification

**Layer 5:** `Conv2d(512→1, stride=1)` 
- Input: 31×31×512 → Output: 30×30×1
- Final classification layer (no activation - uses BCEWithLogitsLoss)

#### **PatchGAN Concept:**
- **Traditional Discriminator:** Single "real/fake" decision for entire image
- **PatchGAN:** 30×30 = 900 separate "real/fake" decisions per image
- **Advantage:** Focuses on local texture and detail quality
- **Medical Benefit:** Ensures fine anatomical structures are preserved accurately

In [10]:
class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg19(pretrained=True).features
        for param in vgg.parameters():
            param.requires_grad = False
        
        self.features = nn.Sequential(
            vgg[:2],
            vgg[2:7],
            vgg[7:12],
            vgg[12:21]
        ).to(DEVICE).eval()

        self.criterion = nn.L1Loss()

    def forward(self, fake_img, real_img):
        normalize_vgg = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        fake_img_norm = normalize_vgg((fake_img * 0.5 + 0.5))
        real_img_norm = normalize_vgg((real_img * 0.5 + 0.5))

        fake_features = self.features(fake_img_norm)
        real_features = self.features(real_img_norm)
        
        return self.criterion(fake_features, real_features)

### 7. Perceptual Loss - Advanced Feature Matching

#### **Core Concept:**
Instead of comparing pixels directly, perceptual loss compares high-level features extracted by a pre-trained neural network.

#### **VGG19 Feature Extraction:**
```python
vgg = vgg19(pretrained=True).features  # Load pre-trained ImageNet weights
```

**Why VGG19?**
- Trained on millions of natural images
- Learns hierarchical feature representations
- Lower layers: edges, textures, colors
- Higher layers: shapes, objects, semantic content

#### **Feature Layer Selection:**
```python
self.features = nn.Sequential(
    vgg[:2],    # Early conv layers - basic edges and textures
    vgg[2:7],   # Mid conv layers - simple patterns
    vgg[7:12],  # Higher conv layers - complex patterns  
    vgg[12:21]  # Deep conv layers - semantic features
)
```

#### **Normalization for VGG:**
```python
normalize_vgg = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
```
**Purpose:** VGG was trained with ImageNet normalization
**Process:** Convert from GAN range [-1,1] to ImageNet range [0,1], then apply VGG normalization

#### **Loss Calculation:**
```python
fake_features = self.features(fake_img_norm)
real_features = self.features(real_img_norm)
return self.criterion(fake_features, real_features)  # L1 Loss between features
```

#### **Medical Imaging Benefits:**
1. **Texture Preservation:** Ensures tissue textures look natural
2. **Structural Consistency:** Maintains anatomical relationships
3. **Visual Realism:** Generated images appear more convincing to medical professionals
4. **Detail Enhancement:** Focuses on perceptually important features rather than every pixel

#### **Mathematical Formulation:**
```
Perceptual Loss = ||φ(generated_image) - φ(target_image)||₁
```
Where φ represents the VGG feature extraction function.

In [11]:
def train_model():
    print(f"Using device: {DEVICE}")

    dataset = KvasirSEGDataset(
        root_dir=Config.DATA_ROOT,
        image_size=Config.IMAGE_SIZE,
        degradation_transform=lambda img_tensor: apply_degradation(img_tensor, Config)
    )
    # Reduced num_workers for local execution (use 0 for Windows compatibility)
    dataloader = DataLoader(dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)

    if len(dataloader) == 0:
        print(f"Error: Dataloader is empty. No images found in {Config.DATA_ROOT} or batch size is too large for the dataset.")
        print("Please ensure the Kvasir-SEG dataset is correctly placed and contains images.")
        return # Exit if no data is found

    generator = UnetGenerator(Config.NUM_CHANNELS, Config.NUM_CHANNELS).to(DEVICE)
    discriminator = Discriminator(Config.NUM_CHANNELS).to(DEVICE)

    generator.apply(weights_init)
    discriminator.apply(weights_init)

    criterion_GAN = nn.BCEWithLogitsLoss()
    criterion_L1 = nn.L1Loss()
    perceptual_loss_fn = PerceptualLoss()

    optimizer_G = optim.Adam(generator.parameters(), lr=Config.LEARNING_RATE_G, betas=(Config.BETA1, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=Config.LEARNING_RATE_D, betas=(Config.BETA1, 0.999))

    def lambda_rule(epoch):
        if epoch < Config.LR_DECAY_START_EPOCH:
            return 1.0
        else:
            return 1.0 - (epoch - Config.LR_DECAY_START_EPOCH) / (Config.NUM_EPOCHS - Config.LR_DECAY_START_EPOCH)

    scheduler_G = optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_rule)
    scheduler_D = optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=lambda_rule)

    print("Starting Training Loop...")
    for epoch in range(Config.NUM_EPOCHS):
        for i, (low_res_img, high_res_img) in enumerate(dataloader):
            low_res_img = low_res_img.to(DEVICE)
            high_res_img = high_res_img.to(DEVICE)


            optimizer_D.zero_grad()

            real_labels = torch.ones(low_res_img.size(0), 1, 30, 30, device=DEVICE)
            fake_high_res_for_D = generator(low_res_img)
            
            output_real = discriminator(low_res_img, high_res_img)
            loss_D_real = criterion_GAN(output_real, real_labels)

            fake_labels = torch.zeros(low_res_img.size(0), 1, 30, 30, device=DEVICE)
            output_fake = discriminator(low_res_img, fake_high_res_for_D.detach())
            loss_D_fake = criterion_GAN(output_fake, fake_labels)

            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward()
            optimizer_D.step()


            for _ in range(Config.G_TRAIN_MULTIPLIER):
                optimizer_G.zero_grad()
                
                fake_high_res_img = generator(low_res_img)

                output_fake_for_G = discriminator(low_res_img, fake_high_res_img)
                loss_G_GAN = criterion_GAN(output_fake_for_G, real_labels)

                loss_G_L1 = criterion_L1(fake_high_res_img, high_res_img) * Config.LAMBDA_L1

                loss_G_perceptual = perceptual_loss_fn(fake_high_res_img, high_res_img) * Config.LAMBDA_PERCEPTUAL

                loss_G = loss_G_GAN + loss_G_L1 + loss_G_perceptual
                loss_G.backward() 
                optimizer_G.step()

            if (i + 1) % Config.LOG_EVERY_N_BATCHES == 0:
                print(f"Epoch [{epoch+1}/{Config.NUM_EPOCHS}], Batch [{i+1}/{len(dataloader)}] | "
                      f"D Loss: {loss_D.item():.4f} | G Loss: {loss_G.item():.4f} "
                      f"(GAN: {loss_G_GAN.item():.4f}, L1: {loss_G_L1.item():.4f}, Perceptual: {loss_G_perceptual.item():.4f})")
        
        scheduler_G.step()
        scheduler_D.step()
        print(f"Epoch {epoch+1} completed. Current LR G: {optimizer_G.param_groups[0]['lr']:.6f}, LR D: {optimizer_D.param_groups[0]['lr']:.6f}")


        if (epoch + 1) % Config.SAVE_EVERY_N_EPOCHS == 0 or (epoch + 1) == Config.NUM_EPOCHS:
            print(f"--- Entering save/checkpoint block for epoch {epoch+1} ---")
            print(f"Saving samples and checkpoint for epoch {epoch+1}...")
            with torch.no_grad():
                try:
                    sample_low_res, sample_high_res = next(iter(dataloader))
                    sample_low_res = sample_low_res.to(DEVICE)
                    sample_high_res = sample_high_res.to(DEVICE)
                    
                    generated_high_res = generator(sample_low_res) 
                    save_images(sample_low_res, sample_high_res, generated_high_res, epoch + 1, i + 1, Config.OUTPUT_DIR)
                except StopIteration:
                    print(f"Warning: Dataloader exhausted at epoch {epoch+1} for sample saving. Skipping image save for this epoch.")


            try:
                torch.save(generator.state_dict(), os.path.join(Config.CHECKPOINT_DIR, f'generator_epoch_{epoch+1}.pth'))
                torch.save(discriminator.state_dict(), os.path.join(Config.CHECKPOINT_DIR, f'discriminator_epoch_{epoch+1}.pth'))
                print(f"Models saved to {Config.CHECKPOINT_DIR}")
            except Exception as e:
                print(f"Error saving model checkpoints for epoch {epoch+1}: {e}")

    print("Training complete.")

### 8. Training Process - GAN Training Dynamics Explained

#### **A. Training Loop Structure**

The training alternates between two competing networks:

```python
for epoch in range(Config.NUM_EPOCHS):
    for i, (low_res_img, high_res_img) in enumerate(dataloader):
        # 1. Train Discriminator (1 step)
        # 2. Train Generator (3 steps) - G_TRAIN_MULTIPLIER = 3
```

#### **B. Discriminator Training Phase**

**Step 1: Real Image Classification**
```python
output_real = discriminator(low_res_img, high_res_img)
loss_D_real = criterion_GAN(output_real, real_labels)  # Should output "1" (real)
```

**Step 2: Fake Image Classification**  
```python
fake_high_res_for_D = generator(low_res_img)
output_fake = discriminator(low_res_img, fake_high_res_for_D.detach())
loss_D_fake = criterion_GAN(output_fake, fake_labels)  # Should output "0" (fake)
```

**Key Detail:** `.detach()` prevents gradients from flowing to generator during discriminator training.

**Combined Discriminator Loss:**
```python
loss_D = (loss_D_real + loss_D_fake) * 0.5
```

#### **C. Generator Training Phase (3x per discriminator step)**

**Multi-Component Loss Function:**

**1. Adversarial Loss:**
```python
output_fake_for_G = discriminator(low_res_img, fake_high_res_img)
loss_G_GAN = criterion_GAN(output_fake_for_G, real_labels)  # Fool discriminator
```
*Goal:* Make discriminator classify generated images as "real"

**2. L1 Reconstruction Loss:**
```python
loss_G_L1 = criterion_L1(fake_high_res_img, high_res_img) * Config.LAMBDA_L1
```
*Goal:* Pixel-wise accuracy (weighted by λ=70.0)

**3. Perceptual Loss:**
```python
loss_G_perceptual = perceptual_loss_fn(fake_high_res_img, high_res_img) * Config.LAMBDA_PERCEPTUAL
```
*Goal:* High-level feature similarity (weighted by λ=10.0)

**Total Generator Loss:**
```python
loss_G = loss_G_GAN + loss_G_L1 + loss_G_perceptual
```

#### **D. Learning Rate Scheduling**

```python
def lambda_rule(epoch):
    if epoch < Config.LR_DECAY_START_EPOCH:  # First 30 epochs
        return 1.0  # Keep original learning rate
    else:  # Epochs 30-50
        return 1.0 - (epoch - 30) / (50 - 30)  # Linear decay to 0
```

**Purpose:** Stabilizes training in later epochs, prevents oscillation around optimal solution.

#### **E. Why Multiple Generator Updates?**

**Problem:** Discriminator can become too strong, making generator unable to learn
**Solution:** Train generator 3 times per discriminator update
**Result:** Balanced competition between networks

#### **F. Training Monitoring**

Every 10 batches, the code logs:
- Discriminator loss (should stabilize around 0.5)
- Generator GAN loss (should decrease over time) 
- Generator L1 loss (pixel accuracy)
- Generator perceptual loss (feature similarity)

**Healthy Training Signs:**
- Losses neither increase nor decrease dramatically
- Generated images gradually improve in quality
- No mode collapse (generator producing identical images)

In [12]:
class UnetGenerator(nn.Module):
    """
    U-Net Generator architecture for image-to-image translation.
    """
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        # Encoder
        self.down1 = UNetDown(in_channels, 64, normalize=False) # 256 -> 128
        self.down2 = UNetDown(64, 128) # 128 -> 64
        self.down3 = UNetDown(128, 256) # 64 -> 32
        self.down4 = UNetDown(256, 512, dropout=0.5) # 32 -> 16
        self.down5 = UNetDown(512, 512, dropout=0.5) # 16 -> 8
        self.down6 = UNetDown(512, 512, dropout=0.5) # 8 -> 4
        self.down7 = UNetDown(512, 512, dropout=0.5) # 4 -> 2
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) # 2 -> 1

        # Decoder
        self.up1 = UNetUp(512, 512, dropout=0.5) # 1 -> 2 (concat with down7)
        self.up2 = UNetUp(1024, 512, dropout=0.5) # 2 -> 4 (concat with down6)
        self.up3 = UNetUp(1024, 512, dropout=0.5) # 4 -> 8 (concat with down5)
        self.up4 = UNetUp(1024, 512) # 8 -> 16 (concat with down4)
        self.up5 = UNetUp(1024, 256) # 16 -> 32 (concat with down3)
        self.up6 = UNetUp(512, 128) # 32 -> 64 (concat with down2)
        self.up7 = UNetUp(256, 64) # 64 -> 128 (concat with down1)

        self.final_conv = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, 4, 2, 1), # 128 -> 256
            nn.Tanh() # Output pixels in [-1, 1]
        )

    def forward(self, x):
        # Encoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        # Decoder with skip connections
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final_conv(u7)

### 9. Complete U-Net Generator - Architectural Masterpiece

#### **U-Net Design Philosophy**
- **U-Shape:** Symmetric encoder-decoder with skip connections
- **Medical Heritage:** Originally designed for biomedical image segmentation
- **Key Insight:** Combine global context (from bottleneck) with local details (from skip connections)

#### **Encoder Path (Contracting Path)**
```
Input: 256×256×3 (RGB medical image)
    ↓ down1: Conv 3→64, stride=2
128×128×64
    ↓ down2: Conv 64→128, stride=2  
64×64×128
    ↓ down3: Conv 128→256, stride=2
32×32×256
    ↓ down4: Conv 256→512, stride=2 + Dropout(0.5)
16×16×512
    ↓ down5: Conv 512→512, stride=2 + Dropout(0.5)
8×8×512
    ↓ down6: Conv 512→512, stride=2 + Dropout(0.5)
4×4×512
    ↓ down7: Conv 512→512, stride=2 + Dropout(0.5)
2×2×512
    ↓ down8: Conv 512→512, stride=2 + Dropout(0.5)
1×1×512 (Bottleneck - Global Context)
```

#### **Decoder Path (Expansive Path)**
```
1×1×512 (Bottleneck)
    ↑ up1: TransConv 512→512 + concat(down7) = 1024 channels
2×2×1024
    ↑ up2: TransConv 1024→512 + concat(down6) = 1024 channels  
4×4×1024
    ↑ up3: TransConv 1024→512 + concat(down5) = 1024 channels
8×8×1024
    ↑ up4: TransConv 1024→512 + concat(down4) = 1024 channels
16×16×1024
    ↑ up5: TransConv 1024→256 + concat(down3) = 512 channels
32×32×512
    ↑ up6: TransConv 512→128 + concat(down2) = 256 channels
64×64×256
    ↑ up7: TransConv 256→64 + concat(down1) = 128 channels
128×128×128
    ↑ final_conv: TransConv 128→3 + Tanh()
Output: 256×256×3 (Enhanced medical image)
```

#### **Skip Connection Magic**
```python
u1 = self.up1(d8, d7)  # Concatenate bottleneck with down7 features
```

**Information Flow:**
- **Encoder:** Captures "what" (semantic understanding)
- **Decoder:** Reconstructs "where" (spatial localization) 
- **Skip Connections:** Preserve "how" (fine-grained details)

#### **Dropout Strategy**
- **Deep Layers (down4-down8):** 50% dropout prevents overfitting
- **Shallow Layers:** No dropout preserves important low-level features
- **Medical Rationale:** Ensures model generalizes across different imaging conditions

#### **Channel Progression Logic**
- **Encoder:** 3→64→128→256→512→512... (increasing feature complexity)
- **Decoder:** ...512→256→128→64→3 (decreasing to output channels)
- **Symmetry:** Enables precise reconstruction with learned enhancements

#### **Output Activation: Tanh()**
```python
nn.Tanh()  # Output pixels in [-1, 1]
```
**Purpose:** Matches the normalized input range, enables stable training dynamics

In [13]:
def evaluate_model(generator_path):
    print("\nStarting Evaluation...")
    generator = UnetGenerator(Config.NUM_CHANNELS, Config.NUM_CHANNELS).to(DEVICE)
    generator.load_state_dict(torch.load(generator_path, map_location=DEVICE))
    generator.eval()

    dataset = KvasirSEGDataset(
        root_dir=Config.DATA_ROOT,
        image_size=Config.IMAGE_SIZE,
        degradation_transform=lambda img_tensor: apply_degradation(img_tensor, Config)
    )
    # Reduced num_workers for local execution
    eval_dataloader = DataLoader(dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)

    if len(eval_dataloader) == 0:
        print(f"Error: Evaluation Dataloader is empty. Cannot perform evaluation.")
        return

    all_psnr = []
    all_ssim = []

    with torch.no_grad():
        for i, (low_res_img, high_res_img) in enumerate(eval_dataloader):
            low_res_img = low_res_img.to(DEVICE)
            high_res_img = high_res_img.to(DEVICE)

            fake_high_res_img = generator(low_res_img)
            
            psnr, ssim = calculate_metrics(high_res_img, fake_high_res_img)
            all_psnr.append(psnr)
            all_ssim.append(ssim)

            if (i + 1) % Config.LOG_EVERY_N_BATCHES == 0:
                print(f"Eval Batch [{i+1}/{len(eval_dataloader)}] | Avg PSNR: {np.mean(all_psnr):.2f} | Avg SSIM: {np.mean(all_ssim):.4f}")

    avg_psnr = np.mean(all_psnr)
    avg_ssim = np.mean(all_ssim)
    print(f"\n--- Evaluation Results ---")
    print(f"Average PSNR: {avg_psnr:.2f}")
    print(f"Average SSIM: {avg_ssim:.4f}")
    print("Evaluation complete.")

### 10. Model Evaluation - Quantitative Quality Assessment

#### **Evaluation Process Overview**
1. Load trained generator model
2. Process test images through degradation → enhancement pipeline  
3. Compare enhanced images with original high-quality images
4. Calculate objective quality metrics

#### **PSNR (Peak Signal-to-Noise Ratio)**
```python
psnr = peak_signal_noise_ratio(real_images_np[i], fake_images_np[i], data_range=1.0)
```

**Mathematical Definition:**
```
PSNR = 20 × log₁₀(MAX_I / √MSE)
```
Where:
- MAX_I = maximum possible pixel value (1.0 for normalized images)
- MSE = Mean Squared Error between images

**Interpretation:**
- **Higher = Better** (typically 20-40 dB for good quality)
- Measures pixel-level accuracy
- Medical relevance: Quantifies preservation of fine anatomical details

#### **SSIM (Structural Similarity Index)**
```python
ssim = structural_similarity(real_images_np[i], fake_images_np[i], 
                           data_range=1.0, channel_axis=-1)
```

**SSIM Components:**
1. **Luminance:** `l(x,y) = (2μₓμᵧ + c₁)/(μₓ² + μᵧ² + c₁)`
2. **Contrast:** `c(x,y) = (2σₓσᵧ + c₂)/(σₓ² + σᵧ² + c₂)`  
3. **Structure:** `s(x,y) = (σₓᵧ + c₃)/(σₓσᵧ + c₃)`

**Combined SSIM:**
```
SSIM(x,y) = l(x,y) × c(x,y) × s(x,y)
```

**Interpretation:**
- **Range:** -1 to 1 (1 = identical images)
- **Advantage:** Correlates better with human perception than PSNR
- **Medical relevance:** Ensures structural integrity of anatomical features

#### **Why Both Metrics?**
- **PSNR:** Pixel-perfect accuracy (important for diagnostic details)
- **SSIM:** Perceptual quality (important for visual assessment)
- **Combined:** Comprehensive quality evaluation

#### **Evaluation Best Practices**
```python
generator.eval()  # Disable dropout and batch norm training mode
with torch.no_grad():  # Disable gradient computation for efficiency
```

**No-Gradient Context:**
- Speeds up inference significantly
- Prevents accidental weight updates during evaluation
- Reduces memory consumption

#### **Expected Results for Medical Images:**
- **Good PSNR:** > 25 dB (sharp, detailed reconstruction)
- **Good SSIM:** > 0.85 (perceptually similar to original)
- **Excellent Performance:** PSNR > 30 dB, SSIM > 0.90

#### **Clinical Validation Considerations:**
While PSNR and SSIM provide objective measures, clinical validation would require:
- Radiologist assessment of diagnostic quality
- Comparison with original diagnosis accuracy
- Evaluation across diverse pathological conditions

In [14]:
if __name__ == "__main__":
    if not os.path.exists(Config.DATA_ROOT):
        print(f"Error: Data root directory not found at {Config.DATA_ROOT}")
        print("Please download the Kvasir-SEG dataset and extract it such that")
        print(f"the image files are located under '{Config.DATA_ROOT}'")
        print("For example, if you download 'Kvasir-SEG.zip', extract it,")
        print("you should have a structure like: kvasirseg/Kvasir-SEG/images/image_00001.jpg")
        exit()

    if os.path.exists(Config.OUTPUT_DIR):
        print(f"Clearing existing output directory: {Config.OUTPUT_DIR}")
        shutil.rmtree(Config.OUTPUT_DIR)
    os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
    os.makedirs(Config.CHECKPOINT_DIR, exist_ok=True)
    print(f"Output directory '{Config.OUTPUT_DIR}' created and ready.")

    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    train_model()

    last_generator_checkpoint = os.path.join(Config.CHECKPOINT_DIR, f'generator_epoch_{Config.NUM_EPOCHS}.pth')
    if os.path.exists(last_generator_checkpoint):
        evaluate_model(last_generator_checkpoint)
    else:
        print(f"Could not find the last generator checkpoint at {last_generator_checkpoint} for evaluation.")
        print("Ensure training completed successfully.")

Output directory 'output/generated_images' created and ready.


NameError: name 'np' is not defined

### 11. Complete Execution Pipeline - Putting It All Together

#### **A. Pre-Execution Validation**
```python
if not os.path.exists(Config.DATA_ROOT):
    print(f"Error: Data root directory not found at {Config.DATA_ROOT}")
    exit()
```
**Safety Check:** Prevents training from starting with missing dataset
**User Guidance:** Provides clear instructions for dataset setup

#### **B. Environment Preparation**
```python
if os.path.exists(Config.OUTPUT_DIR):
    shutil.rmtree(Config.OUTPUT_DIR)  # Clean slate for new training
os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
os.makedirs(Config.CHECKPOINT_DIR, exist_ok=True)
```

**Clean Start Strategy:**
- Removes previous training artifacts
- Ensures consistent output organization
- Prevents confusion from mixed training runs

#### **C. Reproducibility Setup**
```python
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
```

**Reproducibility Components:**
- **CPU Random Seeds:** NumPy, Python random, PyTorch CPU operations
- **GPU Random Seeds:** CUDA operations across all devices
- **cuDNN Settings:** 
  - `deterministic=True`: Ensures identical results across runs
  - `benchmark=False`: Disables optimization that introduces randomness

**Trade-off:** Slight performance reduction for exact reproducibility

#### **D. Training Pipeline Execution**
```python
train_model()  # Main training loop (50 epochs, ~several hours)
```

**What Happens During Training:**
1. **Data Loading:** Batch-wise image loading with degradation
2. **Model Updates:** Alternating discriminator/generator training
3. **Monitoring:** Loss logging every 10 batches
4. **Checkpointing:** Model saving every 5 epochs
5. **Visualization:** Sample image generation for progress tracking

#### **E. Automatic Evaluation**
```python
last_generator_checkpoint = os.path.join(Config.CHECKPOINT_DIR, f'generator_epoch_{Config.NUM_EPOCHS}.pth')
if os.path.exists(last_generator_checkpoint):
    evaluate_model(last_generator_checkpoint)
```

**Smart Evaluation:**
- Only runs if training completed successfully
- Uses the final trained model (epoch 50)
- Provides quantitative assessment of enhancement quality

#### **F. Expected Training Timeline**
**Hardware Dependent:**
- **GPU (GTX 1060+):** ~2-4 hours for 50 epochs
- **CPU Only:** ~12-24 hours for 50 epochs
- **High-End GPU (RTX 3080+):** ~30-60 minutes for 50 epochs

#### **G. Output Structure After Completion**
```
output/
└── generated_images/
    ├── epoch_5_batch_X.png     # Training progress images
    ├── epoch_10_batch_X.png
    ├── ...
    └── checkpoints/
        ├── generator_epoch_5.pth     # Model weights
        ├── discriminator_epoch_5.pth
        ├── ...
        ├── generator_epoch_50.pth    # Final trained model
        └── discriminator_epoch_50.pth
```

#### **H. Success Indicators**
**During Training:**
- Steadily decreasing generator L1 loss (pixel accuracy improving)
- Stable discriminator loss around 0.5 (balanced competition)
- Visual improvement in generated sample images

**After Training:**
- PSNR > 25 dB (good pixel-level accuracy)
- SSIM > 0.85 (good perceptual quality)
- Checkpoints saved without errors

#### **I. Troubleshooting Common Issues**
**"CUDA out of memory":** Reduce BATCH_SIZE from 4 to 2 or 1
**"Dataset empty":** Verify Kvasir-SEG folder structure
**"Training very slow":** Ensure GPU is properly detected and utilized

In [None]:
import os

class Config:
    DATA_ROOT = 'data/Kvasir-SEG/images'  # Local path to the dataset

print(f"Checking images in: {Config.DATA_ROOT}")

if not os.path.exists(Config.DATA_ROOT):
    print(f"Error: The specified data root directory does NOT exist: {Config.DATA_ROOT}")
    print("Please ensure your Kvasir-SEG dataset is correctly extracted and placed.")
    print("Expected structure: data/Kvasir-SEG/images/")
    print("\nTo set up the dataset:")
    print("1. Download the Kvasir-SEG dataset")
    print("2. Extract it to create a 'data' folder in your project directory")
    print("3. Ensure the structure is: data/Kvasir-SEG/images/image_*.jpg")
else:
    image_files = [f for f in os.listdir(Config.DATA_ROOT) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    if len(image_files) == 0:
        print(f"Warning: No image files (jpg, jpeg, png) found in {Config.DATA_ROOT}.")
        print("Please verify the contents of this directory.")
    else:
        print(f"Found {len(image_files)} image files. Here are the first 5:")
        for i, filename in enumerate(image_files[:5]):
            print(os.path.join(Config.DATA_ROOT, filename))
        print("Dataset path appears correct and images are found.")


Checking images in: kvasirseg/Kvasir-SEG/images
Error: The specified data root directory does NOT exist: kvasirseg/Kvasir-SEG/images
Please ensure your Kvasir-SEG dataset is correctly extracted and placed.
Expected structure: /kaggle/working/kvasirseg/Kvasir-SEG/images/


### 12. Clinical Applications & Research Impact

#### **Medical Imaging Enhancement Applications**

**A. Endoscopic Imaging Improvements**
- **Challenge:** Poor lighting conditions during procedures
- **Solution:** GAN enhances visibility of tissue structures
- **Benefit:** Improved polyp detection and diagnostic accuracy

**B. Historical Image Archive Enhancement**  
- **Challenge:** Older medical images with lower quality
- **Solution:** Retroactively enhance archived images
- **Benefit:** Improved analysis of longitudinal patient data

**C. Real-Time Clinical Integration**
- **Challenge:** Live imaging quality during procedures
- **Solution:** Real-time enhancement for better visualization
- **Benefit:** Immediate diagnostic improvement

#### **Technical Achievements of This Implementation**

**1. Multi-Loss Architecture:**
- Combines pixel-level accuracy (L1) with perceptual quality (VGG features)
- Adversarial training ensures realistic image generation
- Balanced approach for medical imaging requirements

**2. Robust Degradation Simulation:**
- Gaussian noise: Sensor imperfections
- Salt/pepper noise: Transmission errors  
- Blur: Motion artifacts
- Color variations: Lighting inconsistencies

**3. Medical-Specific Optimizations:**
- Conservative color adjustments (hue ±5% vs ±20% for brightness)
- High L1 loss weight (70.0) for diagnostic accuracy
- U-Net architecture preserves fine anatomical details

#### **Research Extensions**

**A. Multi-Modal Enhancement:**
- Extend to CT, MRI, X-ray modalities
- Modality-specific degradation patterns
- Cross-modal knowledge transfer

**B. Pathology-Aware Enhancement:**
- Condition-specific training datasets
- Pathology preservation during enhancement
- Clinical validation studies

**C. Real-Time Deployment:**
- Model optimization for edge devices
- Integration with endoscopic equipment
- Live enhancement during procedures

#### **Ethical Considerations**

**A. Diagnostic Responsibility:**
- Enhanced images should supplement, not replace, original data
- Clear labeling of AI-enhanced content
- Radiologist training on AI-enhanced interpretation

**B. Validation Requirements:**
- Clinical trials comparing diagnostic accuracy
- Inter-observer agreement studies
- Long-term patient outcome analysis

**C. Regulatory Compliance:**
- FDA/CE marking for clinical deployment
- Quality assurance protocols
- Audit trails for enhanced images

#### **Performance Benchmarks**

**Research Quality Targets:**
- **PSNR > 30 dB:** Research-grade enhancement
- **SSIM > 0.90:** Clinical-grade perceptual quality
- **Processing Time < 100ms:** Real-time feasibility

**Clinical Deployment Requirements:**
- **Accuracy:** No false enhancement of pathological features
- **Consistency:** Reliable performance across imaging conditions  
- **Integration:** Seamless workflow incorporation
- **Training:** Minimal additional staff education required

This implementation provides a solid foundation for medical image enhancement research and potential clinical applications, with careful consideration of both technical performance and medical safety requirements.