# Medical Image Enhancement GAN - Local Setup

## Project Overview
This notebook implements a Generative Adversarial Network (GAN) for enhancing medical images from the Kvasir-SEG dataset. The model takes degraded medical images and restores them to high quality using a U-Net generator and discriminator architecture.

## Setup Instructions

### 1. Install Required Dependencies
```bash
pip install torch torchvision matplotlib pillow numpy scikit-image
```

### 2. Download and Prepare Dataset
1. Download the Kvasir-SEG dataset from: https://datasets.simula.no/kvasir-seg/
2. Extract the dataset in your project directory
3. Ensure the folder structure is:
   ```
   your_project/
   ├── data/
   │   └── Kvasir-SEG/
   │       └── images/
   │           ├── image_00001.jpg
   │           ├── image_00002.jpg
   │           └── ...
   ├── output/ (will be created automatically)
   └── ganproject (2).ipynb
   ```

### 3. System Requirements
- GPU recommended (CUDA compatible) for faster training
- At least 8GB RAM
- Python 3.7+ with PyTorch installed

### 4. Expected Output
- Training will create enhanced images in `output/generated_images/`
- Model checkpoints will be saved in `output/generated_images/checkpoints/`
- Training progress will be displayed with loss metrics and sample images

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vgg19
import os
from PIL import Image
import numpy as np
import random
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import matplotlib.pyplot as plt
import shutil

In [17]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [18]:
import os

class Config:
    # Local paths for running on your device
    DATA_ROOT = 'kvasir-seg\Kvasir-SEG\images'  # Local path to the dataset
    OUTPUT_DIR = 'output/generated_images'  # Local output directory
    CHECKPOINT_DIR = os.path.join(OUTPUT_DIR, 'checkpoints')

    IMAGE_SIZE = 256
    NUM_CHANNELS = 3

    BATCH_SIZE = 4
    NUM_EPOCHS = 50
    LEARNING_RATE_G = 2e-4
    LEARNING_RATE_D = 1e-4
    BETA1 = 0.5
    LAMBDA_L1 = 70.0
    LAMBDA_PERCEPTUAL = 10.0

   
    G_TRAIN_MULTIPLIER = 3 

    NOISE_STD_DEV = 0.15
    NOISE_STD_DEV_VARIATION = 0.05
    
    SALT_VS_PEPPER_RATIO = 0.5
    SP_NOISE_PROB = 0.01
    
    BLUR_KERNEL_SIZE = 5
    BRIGHTNESS_FACTOR = 0.2
    CONTRAST_FACTOR = 0.2
    SATURATION_FACTOR = 0.2
    HUE_FACTOR = 0.05

    
    LR_DECAY_START_EPOCH = 30 
    SAVE_EVERY_N_EPOCHS = 5
    LOG_EVERY_N_BATCHES = 10

    DEVICE = DEVICE


In [19]:
def apply_degradation(image_tensor, config):
    image_tensor_cpu = image_tensor.cpu()
    degraded_tensor = image_tensor_cpu.clone() 

    current_noise_std_dev = config.NOISE_STD_DEV + random.uniform(-config.NOISE_STD_DEV_VARIATION, config.NOISE_STD_DEV_VARIATION)
    current_noise_std_dev = max(0, current_noise_std_dev)

    gaussian_noise = torch.randn_like(degraded_tensor) * current_noise_std_dev
    degraded_tensor = degraded_tensor + gaussian_noise
    degraded_tensor = torch.clamp(degraded_tensor, 0, 1)


    num_pixels = degraded_tensor.numel()
    num_sp_noise_pixels = int(num_pixels * config.SP_NOISE_PROB)

    noise_indices = torch.randperm(num_pixels)[:num_sp_noise_pixels]

    salt_mask = torch.rand(num_sp_noise_pixels) < config.SALT_VS_PEPPER_RATIO
    pepper_mask = ~salt_mask

    degraded_tensor_flat = degraded_tensor.view(-1)

    # Apply pepper noise (set to 0)
    degraded_tensor_flat[noise_indices[pepper_mask]] = 0.0

    # Apply salt noise (set to 1)
    degraded_tensor_flat[noise_indices[salt_mask]] = 1.0
    
    # Reshape back to original dimensions
    degraded_tensor = degraded_tensor_flat.view(degraded_tensor.shape)


    if config.BLUR_KERNEL_SIZE > 0:
        blur_transform = transforms.GaussianBlur(
            kernel_size=(config.BLUR_KERNEL_SIZE, config.BLUR_KERNEL_SIZE),
            sigma=(0.1, 2.0)
        )
        degraded_tensor = blur_transform(degraded_tensor)

    jitter_transform = transforms.ColorJitter(
        brightness=config.BRIGHTNESS_FACTOR,
        contrast=config.CONTRAST_FACTOR,
        saturation=config.SATURATION_FACTOR,
        hue=config.HUE_FACTOR
    )
    degraded_tensor = jitter_transform(degraded_tensor)

    return degraded_tensor

In [20]:
def save_images(low_res, real_high_res, fake_high_res, epoch, batch_idx, output_dir):
    low_res_np = (low_res * 0.5 + 0.5).cpu().detach().numpy().transpose(0, 2, 3, 1)
    real_high_res_np = (real_high_res * 0.5 + 0.5).cpu().detach().numpy().transpose(0, 2, 3, 1)
    fake_high_res_np = (fake_high_res * 0.5 + 0.5).cpu().detach().numpy().transpose(0, 2, 3, 1)

    fig, axes = plt.subplots(Config.BATCH_SIZE, 3, figsize=(9, 3 * Config.BATCH_SIZE))
    fig.suptitle(f'Epoch {epoch}, Batch {batch_idx}', fontsize=16)

    for i in range(Config.BATCH_SIZE):
        axes[i, 0].imshow(low_res_np[i])
        axes[i, 0].set_title('Low-Quality Input')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(real_high_res_np[i])
        axes[i, 1].set_title('Real High-Quality')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(fake_high_res_np[i])
        axes[i, 2].set_title('Generated High-Quality')
        axes[i, 2].axis('off')
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    
    filename = os.path.join(output_dir, f'epoch_{epoch}_batch_{batch_idx}.png')
    
    try:
        plt.savefig(filename)
        print(f"Successfully saved image to: {filename}")
    except Exception as e:
        print(f"Error saving image {filename}: {e}")
    finally:
        plt.close(fig)

In [21]:
def calculate_metrics(real_images, fake_images):
    real_images_np = (real_images * 0.5 + 0.5).cpu().detach().numpy().transpose(0, 2, 3, 1)
    fake_images_np = (fake_images * 0.5 + 0.5).cpu().detach().numpy().transpose(0, 2, 3, 1)

    psnr_scores = []
    ssim_scores = []

    for i in range(real_images_np.shape[0]):
        psnr = peak_signal_noise_ratio(real_images_np[i], fake_images_np[i], data_range=1.0)
        ssim = structural_similarity(real_images_np[i], fake_images_np[i], data_range=1.0, channel_axis=-1)
        psnr_scores.append(psnr)
        ssim_scores.append(ssim)
    
    return np.mean(psnr_scores), np.mean(ssim_scores)

In [22]:
class KvasirSEGDataset(Dataset):
    def __init__(self, root_dir, image_size, transform=None, degradation_transform=None):
        self.root_dir = root_dir
        self.image_size = image_size
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
        self.transform = transform
        self.degradation_transform = degradation_transform

        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        
        print(f"Dataset initialized with {len(self.image_files)} images from {root_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.root_dir, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
            high_quality_image = self.transform(image)

            high_quality_image_0_1 = (high_quality_image * 0.5 + 0.5)
            
            if self.degradation_transform:
                low_quality_image_0_1 = self.degradation_transform(high_quality_image_0_1)
            else:
                low_quality_image_0_1 = high_quality_image_0_1

            low_quality_image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(low_quality_image_0_1)

            return low_quality_image, high_quality_image
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a dummy tensor if image loading fails
            dummy_tensor = torch.zeros(3, self.image_size, self.image_size)
            return dummy_tensor, dummy_tensor

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


In [23]:
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

In [24]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels * 2, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 1, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, 1, 1, bias=False)
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

In [25]:
class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg19(pretrained=True).features
        for param in vgg.parameters():
            param.requires_grad = False
        
        self.features = nn.Sequential(
            vgg[:2],
            vgg[2:7],
            vgg[7:12],
            vgg[12:21]
        ).to(DEVICE).eval()

        self.criterion = nn.L1Loss()

    def forward(self, fake_img, real_img):
        normalize_vgg = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        fake_img_norm = normalize_vgg((fake_img * 0.5 + 0.5))
        real_img_norm = normalize_vgg((real_img * 0.5 + 0.5))

        fake_features = self.features(fake_img_norm)
        real_features = self.features(real_img_norm)
        
        return self.criterion(fake_features, real_features)

In [26]:
def train_model():
    print(f"Using device: {DEVICE}")

    dataset = KvasirSEGDataset(
        root_dir=Config.DATA_ROOT,
        image_size=Config.IMAGE_SIZE,
        degradation_transform=lambda img_tensor: apply_degradation(img_tensor, Config)
    )
    # Reduced num_workers for local execution (use 0 for Windows compatibility)
    dataloader = DataLoader(dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)

    if len(dataloader) == 0:
        print(f"Error: Dataloader is empty. No images found in {Config.DATA_ROOT} or batch size is too large for the dataset.")
        print("Please ensure the Kvasir-SEG dataset is correctly placed and contains images.")
        return # Exit if no data is found

    generator = UnetGenerator(Config.NUM_CHANNELS, Config.NUM_CHANNELS).to(DEVICE)
    discriminator = Discriminator(Config.NUM_CHANNELS).to(DEVICE)

    generator.apply(weights_init)
    discriminator.apply(weights_init)

    criterion_GAN = nn.BCEWithLogitsLoss()
    criterion_L1 = nn.L1Loss()
    perceptual_loss_fn = PerceptualLoss()

    optimizer_G = optim.Adam(generator.parameters(), lr=Config.LEARNING_RATE_G, betas=(Config.BETA1, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=Config.LEARNING_RATE_D, betas=(Config.BETA1, 0.999))

    def lambda_rule(epoch):
        if epoch < Config.LR_DECAY_START_EPOCH:
            return 1.0
        else:
            return 1.0 - (epoch - Config.LR_DECAY_START_EPOCH) / (Config.NUM_EPOCHS - Config.LR_DECAY_START_EPOCH)

    scheduler_G = optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_rule)
    scheduler_D = optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=lambda_rule)

    print("Starting Training Loop...")
    for epoch in range(Config.NUM_EPOCHS):
        for i, (low_res_img, high_res_img) in enumerate(dataloader):
            low_res_img = low_res_img.to(DEVICE)
            high_res_img = high_res_img.to(DEVICE)


            optimizer_D.zero_grad()

            real_labels = torch.ones(low_res_img.size(0), 1, 30, 30, device=DEVICE)
            fake_high_res_for_D = generator(low_res_img)
            
            output_real = discriminator(low_res_img, high_res_img)
            loss_D_real = criterion_GAN(output_real, real_labels)

            fake_labels = torch.zeros(low_res_img.size(0), 1, 30, 30, device=DEVICE)
            output_fake = discriminator(low_res_img, fake_high_res_for_D.detach())
            loss_D_fake = criterion_GAN(output_fake, fake_labels)

            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward()
            optimizer_D.step()


            for _ in range(Config.G_TRAIN_MULTIPLIER):
                optimizer_G.zero_grad()
                
                fake_high_res_img = generator(low_res_img)

                output_fake_for_G = discriminator(low_res_img, fake_high_res_img)
                loss_G_GAN = criterion_GAN(output_fake_for_G, real_labels)

                loss_G_L1 = criterion_L1(fake_high_res_img, high_res_img) * Config.LAMBDA_L1

                loss_G_perceptual = perceptual_loss_fn(fake_high_res_img, high_res_img) * Config.LAMBDA_PERCEPTUAL

                loss_G = loss_G_GAN + loss_G_L1 + loss_G_perceptual
                loss_G.backward() 
                optimizer_G.step()

            if (i + 1) % Config.LOG_EVERY_N_BATCHES == 0:
                print(f"Epoch [{epoch+1}/{Config.NUM_EPOCHS}], Batch [{i+1}/{len(dataloader)}] | "
                      f"D Loss: {loss_D.item():.4f} | G Loss: {loss_G.item():.4f} "
                      f"(GAN: {loss_G_GAN.item():.4f}, L1: {loss_G_L1.item():.4f}, Perceptual: {loss_G_perceptual.item():.4f})")
        
        scheduler_G.step()
        scheduler_D.step()
        print(f"Epoch {epoch+1} completed. Current LR G: {optimizer_G.param_groups[0]['lr']:.6f}, LR D: {optimizer_D.param_groups[0]['lr']:.6f}")


        if (epoch + 1) % Config.SAVE_EVERY_N_EPOCHS == 0 or (epoch + 1) == Config.NUM_EPOCHS:
            print(f"--- Entering save/checkpoint block for epoch {epoch+1} ---")
            print(f"Saving samples and checkpoint for epoch {epoch+1}...")
            with torch.no_grad():
                try:
                    sample_low_res, sample_high_res = next(iter(dataloader))
                    sample_low_res = sample_low_res.to(DEVICE)
                    sample_high_res = sample_high_res.to(DEVICE)
                    
                    generated_high_res = generator(sample_low_res) 
                    save_images(sample_low_res, sample_high_res, generated_high_res, epoch + 1, i + 1, Config.OUTPUT_DIR)
                except StopIteration:
                    print(f"Warning: Dataloader exhausted at epoch {epoch+1} for sample saving. Skipping image save for this epoch.")


            try:
                torch.save(generator.state_dict(), os.path.join(Config.CHECKPOINT_DIR, f'generator_epoch_{epoch+1}.pth'))
                torch.save(discriminator.state_dict(), os.path.join(Config.CHECKPOINT_DIR, f'discriminator_epoch_{epoch+1}.pth'))
                print(f"Models saved to {Config.CHECKPOINT_DIR}")
            except Exception as e:
                print(f"Error saving model checkpoints for epoch {epoch+1}: {e}")

    print("Training complete.")

In [27]:
class UnetGenerator(nn.Module):
    """
    U-Net Generator architecture for image-to-image translation.
    """
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        # Encoder
        self.down1 = UNetDown(in_channels, 64, normalize=False) # 256 -> 128
        self.down2 = UNetDown(64, 128) # 128 -> 64
        self.down3 = UNetDown(128, 256) # 64 -> 32
        self.down4 = UNetDown(256, 512, dropout=0.5) # 32 -> 16
        self.down5 = UNetDown(512, 512, dropout=0.5) # 16 -> 8
        self.down6 = UNetDown(512, 512, dropout=0.5) # 8 -> 4
        self.down7 = UNetDown(512, 512, dropout=0.5) # 4 -> 2
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) # 2 -> 1

        # Decoder
        self.up1 = UNetUp(512, 512, dropout=0.5) # 1 -> 2 (concat with down7)
        self.up2 = UNetUp(1024, 512, dropout=0.5) # 2 -> 4 (concat with down6)
        self.up3 = UNetUp(1024, 512, dropout=0.5) # 4 -> 8 (concat with down5)
        self.up4 = UNetUp(1024, 512) # 8 -> 16 (concat with down4)
        self.up5 = UNetUp(1024, 256) # 16 -> 32 (concat with down3)
        self.up6 = UNetUp(512, 128) # 32 -> 64 (concat with down2)
        self.up7 = UNetUp(256, 64) # 64 -> 128 (concat with down1)

        self.final_conv = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, 4, 2, 1), # 128 -> 256
            nn.Tanh() # Output pixels in [-1, 1]
        )

    def forward(self, x):
        # Encoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        # Decoder with skip connections
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final_conv(u7)

In [28]:
def evaluate_model(generator_path):
    print("\nStarting Evaluation...")
    generator = UnetGenerator(Config.NUM_CHANNELS, Config.NUM_CHANNELS).to(DEVICE)
    generator.load_state_dict(torch.load(generator_path, map_location=DEVICE))
    generator.eval()

    dataset = KvasirSEGDataset(
        root_dir=Config.DATA_ROOT,
        image_size=Config.IMAGE_SIZE,
        degradation_transform=lambda img_tensor: apply_degradation(img_tensor, Config)
    )
    # Reduced num_workers for local execution
    eval_dataloader = DataLoader(dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)

    if len(eval_dataloader) == 0:
        print(f"Error: Evaluation Dataloader is empty. Cannot perform evaluation.")
        return

    all_psnr = []
    all_ssim = []

    with torch.no_grad():
        for i, (low_res_img, high_res_img) in enumerate(eval_dataloader):
            low_res_img = low_res_img.to(DEVICE)
            high_res_img = high_res_img.to(DEVICE)

            fake_high_res_img = generator(low_res_img)
            
            psnr, ssim = calculate_metrics(high_res_img, fake_high_res_img)
            all_psnr.append(psnr)
            all_ssim.append(ssim)

            if (i + 1) % Config.LOG_EVERY_N_BATCHES == 0:
                print(f"Eval Batch [{i+1}/{len(eval_dataloader)}] | Avg PSNR: {np.mean(all_psnr):.2f} | Avg SSIM: {np.mean(all_ssim):.4f}")

    avg_psnr = np.mean(all_psnr)
    avg_ssim = np.mean(all_ssim)
    print(f"\n--- Evaluation Results ---")
    print(f"Average PSNR: {avg_psnr:.2f}")
    print(f"Average SSIM: {avg_ssim:.4f}")
    print("Evaluation complete.")

In [29]:
if __name__ == "__main__":
    if not os.path.exists(Config.DATA_ROOT):
        print(f"Error: Data root directory not found at {Config.DATA_ROOT}")
        print("Please download the Kvasir-SEG dataset and extract it such that")
        print(f"the image files are located under '{Config.DATA_ROOT}'")
        print("For example, if you download 'Kvasir-SEG.zip', extract it,")
        print("you should have a structure like: kvasirseg/Kvasir-SEG/images/image_00001.jpg")
        exit()

    if os.path.exists(Config.OUTPUT_DIR):
        print(f"Clearing existing output directory: {Config.OUTPUT_DIR}")
        shutil.rmtree(Config.OUTPUT_DIR)
    os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
    os.makedirs(Config.CHECKPOINT_DIR, exist_ok=True)
    print(f"Output directory '{Config.OUTPUT_DIR}' created and ready.")

    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    train_model()

    last_generator_checkpoint = os.path.join(Config.CHECKPOINT_DIR, f'generator_epoch_{Config.NUM_EPOCHS}.pth')
    if os.path.exists(last_generator_checkpoint):
        evaluate_model(last_generator_checkpoint)
    else:
        print(f"Could not find the last generator checkpoint at {last_generator_checkpoint} for evaluation.")
        print("Ensure training completed successfully.")

Clearing existing output directory: output/generated_images
Output directory 'output/generated_images' created and ready.
Using device: cpu
Dataset initialized with 1000 images from kvasir-seg\Kvasir-SEG\images


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to C:\Users\ATHARVA KHOLLAM/.cache\torch\hub\checkpoints\vgg19-dcbb9e9d.pth
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to C:\Users\ATHARVA KHOLLAM/.cache\torch\hub\checkpoints\vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [04:05<00:00, 2.34MB/s] 



Starting Training Loop...
Epoch [1/50], Batch [10/250] | D Loss: 0.7300 | G Loss: 25.0315 (GAN: 0.8251, L1: 11.4368, Perceptual: 12.7696)
Epoch [1/50], Batch [10/250] | D Loss: 0.7300 | G Loss: 25.0315 (GAN: 0.8251, L1: 11.4368, Perceptual: 12.7696)
Epoch [1/50], Batch [20/250] | D Loss: 0.7201 | G Loss: 20.8651 (GAN: 0.8032, L1: 9.3744, Perceptual: 10.6875)
Epoch [1/50], Batch [20/250] | D Loss: 0.7201 | G Loss: 20.8651 (GAN: 0.8032, L1: 9.3744, Perceptual: 10.6875)
Epoch [1/50], Batch [30/250] | D Loss: 0.7109 | G Loss: 17.9303 (GAN: 0.7601, L1: 6.8540, Perceptual: 10.3162)
Epoch [1/50], Batch [30/250] | D Loss: 0.7109 | G Loss: 17.9303 (GAN: 0.7601, L1: 6.8540, Perceptual: 10.3162)
Epoch [1/50], Batch [40/250] | D Loss: 0.6940 | G Loss: 15.7866 (GAN: 0.7347, L1: 5.7114, Perceptual: 9.3405)
Epoch [1/50], Batch [40/250] | D Loss: 0.6940 | G Loss: 15.7866 (GAN: 0.7347, L1: 5.7114, Perceptual: 9.3405)
Epoch [1/50], Batch [50/250] | D Loss: 0.7171 | G Loss: 17.4491 (GAN: 0.7770, L1: 6.51

  generator.load_state_dict(torch.load(generator_path, map_location=DEVICE))


Dataset initialized with 1000 images from kvasir-seg\Kvasir-SEG\images
Eval Batch [10/250] | Avg PSNR: 28.70 | Avg SSIM: 0.7827
Eval Batch [10/250] | Avg PSNR: 28.70 | Avg SSIM: 0.7827
Eval Batch [20/250] | Avg PSNR: 28.77 | Avg SSIM: 0.7851
Eval Batch [20/250] | Avg PSNR: 28.77 | Avg SSIM: 0.7851
Eval Batch [30/250] | Avg PSNR: 28.66 | Avg SSIM: 0.7854
Eval Batch [30/250] | Avg PSNR: 28.66 | Avg SSIM: 0.7854
Eval Batch [40/250] | Avg PSNR: 28.64 | Avg SSIM: 0.7848
Eval Batch [40/250] | Avg PSNR: 28.64 | Avg SSIM: 0.7848
Eval Batch [50/250] | Avg PSNR: 28.73 | Avg SSIM: 0.7854
Eval Batch [50/250] | Avg PSNR: 28.73 | Avg SSIM: 0.7854
Eval Batch [60/250] | Avg PSNR: 28.78 | Avg SSIM: 0.7859
Eval Batch [60/250] | Avg PSNR: 28.78 | Avg SSIM: 0.7859
Eval Batch [70/250] | Avg PSNR: 28.73 | Avg SSIM: 0.7825
Eval Batch [70/250] | Avg PSNR: 28.73 | Avg SSIM: 0.7825
Eval Batch [80/250] | Avg PSNR: 28.76 | Avg SSIM: 0.7840
Eval Batch [80/250] | Avg PSNR: 28.76 | Avg SSIM: 0.7840
Eval Batch [90/25

In [30]:
import os

class Config:
    DATA_ROOT = 'data/Kvasir-SEG/images'  # Local path to the dataset

print(f"Checking images in: {Config.DATA_ROOT}")

if not os.path.exists(Config.DATA_ROOT):
    print(f"Error: The specified data root directory does NOT exist: {Config.DATA_ROOT}")
    print("Please ensure your Kvasir-SEG dataset is correctly extracted and placed.")
    print("Expected structure: data/Kvasir-SEG/images/")
    print("\nTo set up the dataset:")
    print("1. Download the Kvasir-SEG dataset")
    print("2. Extract it to create a 'data' folder in your project directory")
    print("3. Ensure the structure is: data/Kvasir-SEG/images/image_*.jpg")
else:
    image_files = [f for f in os.listdir(Config.DATA_ROOT) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    if len(image_files) == 0:
        print(f"Warning: No image files (jpg, jpeg, png) found in {Config.DATA_ROOT}.")
        print("Please verify the contents of this directory.")
    else:
        print(f"Found {len(image_files)} image files. Here are the first 5:")
        for i, filename in enumerate(image_files[:5]):
            print(os.path.join(Config.DATA_ROOT, filename))
        print("Dataset path appears correct and images are found.")


Checking images in: data/Kvasir-SEG/images
Error: The specified data root directory does NOT exist: data/Kvasir-SEG/images
Please ensure your Kvasir-SEG dataset is correctly extracted and placed.
Expected structure: data/Kvasir-SEG/images/

To set up the dataset:
1. Download the Kvasir-SEG dataset
2. Extract it to create a 'data' folder in your project directory
3. Ensure the structure is: data/Kvasir-SEG/images/image_*.jpg
