# Medical Image Dataset Exploration

This notebook explores the three medical imaging datasets:
- Brain Tumor MRI Classification
- Chest X-Ray Pneumonia Detection
- Colorectal Cancer Histopathology

In [None]:
# Import required libraries
import sys
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torchvision import transforms
from PIL import Image
import pandas as pd

# Add parent directory to path
sys.path.append(str(Path.cwd().parent / 'src'))

# Import project modules
from config import get_config
from datasets.brain_tumor import BrainTumorDataset
from datasets.chest_xray import ChestXRayDataset
from datasets.colorectal import ColorectalDataset
from utils.visualization import visualize_batch

# Set style
plt.style.use('seaborn-v0_8')
%matplotlib inline

## 1. Brain Tumor Dataset

In [None]:
# Load Brain Tumor dataset
config_brain = get_config('brain_tumor')
brain_dataset_path = config_brain['dataset']['data_path']

# Check if dataset exists
if brain_dataset_path.exists():
    # Load train dataset
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    
    brain_train = BrainTumorDataset(
        root_dir=brain_dataset_path,
        split='train',
        transform=transform
    )
    
    print(f"Dataset: {config_brain['dataset']['name']}")
    print(f"Number of classes: {config_brain['dataset']['num_classes']}")
    print(f"Classes: {', '.join(brain_train.classes)}")
    print(f"Training samples: {len(brain_train)}")
    
    # Get class distribution
    distribution = brain_train.get_class_distribution()
    print("\nClass distribution:")
    for class_name, count in distribution.items():
        print(f"  {class_name}: {count}")
else:
    print(f"Brain tumor dataset not found at {brain_dataset_path}")
    print("Run 'python src/datasets/download_datasets.py' to download the dataset")

In [None]:
# Visualize sample images from brain tumor dataset
if 'brain_train' in locals():
    from torch.utils.data import DataLoader
    
    brain_loader = DataLoader(brain_train, batch_size=16, shuffle=True)
    visualize_batch(brain_loader, brain_train.classes, num_images=16)

In [None]:
# Plot class distribution for brain tumor dataset
if 'brain_train' in locals():
    distribution = brain_train.get_class_distribution()
    
    plt.figure(figsize=(10, 6))
    classes = list(distribution.keys())
    counts = list(distribution.values())
    
    bars = plt.bar(classes, counts, color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4'])
    plt.xlabel('Tumor Type')
    plt.ylabel('Number of Samples')
    plt.title('Brain Tumor Dataset - Class Distribution')
    plt.xticks(rotation=45, ha='right')
    
    # Add value labels on bars
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{count}',
                ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

## 2. Chest X-Ray Dataset

In [None]:
# Load Chest X-Ray dataset
config_chest = get_config('chest_xray')
chest_dataset_path = config_chest['dataset']['data_path']

if chest_dataset_path.exists():
    chest_train = ChestXRayDataset(
        root_dir=chest_dataset_path,
        split='train',
        transform=transform
    )
    
    print(f"Dataset: {config_chest['dataset']['name']}")
    print(f"Number of classes: {config_chest['dataset']['num_classes']}")
    print(f"Classes: {', '.join(chest_train.classes)}")
    print(f"Training samples: {len(chest_train)}")
    
    # Get class distribution
    distribution = chest_train.get_class_distribution()
    print("\nClass distribution:")
    for class_name, count in distribution.items():
        print(f"  {class_name}: {count}")
    
    # Check for pneumonia subtypes
    subtypes = chest_train.get_pneumonia_subtypes()
    print("\nPneumonia subtypes:")
    for subtype, indices in subtypes.items():
        print(f"  {subtype}: {len(indices)} samples")
else:
    print(f"Chest X-ray dataset not found at {chest_dataset_path}")
    print("Run 'python src/datasets/download_datasets.py' to download the dataset")

In [None]:
# Visualize sample chest X-rays
if 'chest_train' in locals():
    chest_loader = DataLoader(chest_train, batch_size=16, shuffle=True)
    visualize_batch(chest_loader, chest_train.classes, num_images=16)

## 3. Colorectal Cancer Dataset

In [None]:
# Load Colorectal dataset
config_colorectal = get_config('colorectal')
colorectal_dataset_path = config_colorectal['dataset']['data_path']

if colorectal_dataset_path.exists():
    colorectal_train = ColorectalDataset(
        root_dir=colorectal_dataset_path,
        split='train',
        transform=transform
    )
    
    print(f"Dataset: {config_colorectal['dataset']['name']}")
    print(f"Number of classes: {config_colorectal['dataset']['num_classes']}")
    print(f"Classes: {', '.join(colorectal_train.classes)}")
    print(f"Training samples: {len(colorectal_train)}")
    
    # Get class distribution
    distribution = colorectal_train.get_class_distribution()
    print("\nClass distribution:")
    for class_name, count in distribution.items():
        desc = colorectal_train.class_descriptions[class_name]
        print(f"  {class_name} ({desc}): {count}")
    
    # Get tissue type distribution
    tissue_dist = colorectal_train.get_cancer_vs_normal_distribution()
    print("\nTissue type distribution:")
    for tissue_type, count in tissue_dist.items():
        print(f"  {tissue_type}: {count}")
else:
    print(f"Colorectal dataset not found at {colorectal_dataset_path}")
    print("Run 'python src/datasets/download_datasets.py' to download the dataset")

In [None]:
# Visualize colorectal tissue samples
if 'colorectal_train' in locals():
    colorectal_loader = DataLoader(colorectal_train, batch_size=16, shuffle=True)
    visualize_batch(colorectal_loader, colorectal_train.classes, num_images=16)

## 4. Dataset Comparison

In [None]:
# Compare datasets
datasets_info = []

if 'brain_train' in locals():
    datasets_info.append({
        'Dataset': 'Brain Tumor',
        'Classes': len(brain_train.classes),
        'Train Samples': len(brain_train),
        'Image Type': 'MRI',
        'Task': 'Multi-class'
    })

if 'chest_train' in locals():
    datasets_info.append({
        'Dataset': 'Chest X-Ray',
        'Classes': len(chest_train.classes),
        'Train Samples': len(chest_train),
        'Image Type': 'X-Ray',
        'Task': 'Binary'
    })

if 'colorectal_train' in locals():
    datasets_info.append({
        'Dataset': 'Colorectal',
        'Classes': len(colorectal_train.classes),
        'Train Samples': len(colorectal_train),
        'Image Type': 'Microscopy',
        'Task': 'Multi-class'
    })

if datasets_info:
    df = pd.DataFrame(datasets_info)
    print("Dataset Comparison:")
    print(df.to_string(index=False))
    
    # Plot comparison
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot 1: Number of classes
    axes[0].bar(df['Dataset'], df['Classes'], color=['#FF6B6B', '#4ECDC4', '#45B7D1'])
    axes[0].set_ylabel('Number of Classes')
    axes[0].set_title('Classes per Dataset')
    
    # Plot 2: Number of samples
    axes[1].bar(df['Dataset'], df['Train Samples'], color=['#FF6B6B', '#4ECDC4', '#45B7D1'])
    axes[1].set_ylabel('Number of Training Samples')
    axes[1].set_title('Training Samples per Dataset')
    
    plt.tight_layout()
    plt.show()
else:
    print("No datasets loaded. Please download the datasets first.")

## 5. Data Augmentation Examples

In [None]:
# Demonstrate data augmentation
if 'brain_train' in locals():
    # Get a single image
    image, label = brain_train[0]
    
    # Define augmentation transforms
    augmentations = [
        ('Original', transforms.Compose([transforms.Resize((224, 224))])),
        ('Horizontal Flip', transforms.Compose([transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=1.0)])),
        ('Rotation', transforms.Compose([transforms.Resize((224, 224)), transforms.RandomRotation(15)])),
        ('Color Jitter', transforms.Compose([transforms.Resize((224, 224)), transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)])),
        ('Random Affine', transforms.Compose([transforms.Resize((224, 224)), transforms.RandomAffine(10, translate=(0.1, 0.1))])),
        ('Gaussian Blur', transforms.Compose([transforms.Resize((224, 224)), transforms.GaussianBlur(3)])),
    ]
    
    # Apply augmentations
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    axes = axes.flatten()
    
    # Load original image
    orig_img = Image.open(brain_train.image_paths[0]).convert('RGB')
    
    for i, (name, aug) in enumerate(augmentations):
        aug_img = aug(orig_img)
        if isinstance(aug_img, torch.Tensor):
            aug_img = aug_img.permute(1, 2, 0).numpy()
        
        axes[i].imshow(aug_img)
        axes[i].set_title(name)
        axes[i].axis('off')
    
    plt.suptitle(f'Data Augmentation Examples - {brain_train.classes[label]}', fontsize=14)
    plt.tight_layout()
    plt.show()

## 6. Save Dataset Statistics

In [None]:
# Save dataset statistics to file
import json

stats = {}

if 'brain_train' in locals():
    stats['brain_tumor'] = {
        'num_classes': len(brain_train.classes),
        'classes': brain_train.classes,
        'train_samples': len(brain_train),
        'distribution': brain_train.get_class_distribution()
    }

if 'chest_train' in locals():
    stats['chest_xray'] = {
        'num_classes': len(chest_train.classes),
        'classes': chest_train.classes,
        'train_samples': len(chest_train),
        'distribution': chest_train.get_class_distribution()
    }

if 'colorectal_train' in locals():
    stats['colorectal'] = {
        'num_classes': len(colorectal_train.classes),
        'classes': colorectal_train.classes,
        'train_samples': len(colorectal_train),
        'distribution': colorectal_train.get_class_distribution()
    }

if stats:
    stats_path = Path.cwd().parent / 'results' / 'dataset_statistics.json'
    stats_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(stats_path, 'w') as f:
        json.dump(stats, f, indent=2)
    
    print(f"Dataset statistics saved to {stats_path}")
else:
    print("No datasets loaded to save statistics.")