#### Chest X-ray Pneumonia Detection: Exploratory Data Analysis & Data Preprocessing
#### Dataset: https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia
#### Author: Cholpon Zhakshylykova


In [2]:
#---------------Import Libraries and Setup

import os
import random
import warnings
from pathlib import Path
from typing import Dict
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import kagglehub

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# ---------------- Configuration and Data Download
# Download the dataset using kagglehub
dataset_dir = kagglehub.dataset_download("paultimothymooney/chest-xray-pneumonia")
DATA_ROOT = os.path.join(dataset_dir, "chest_xray")
print("DATA_ROOT:", DATA_ROOT)

# Configuration
warnings.filterwarnings("ignore")
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
os.makedirs("plots", exist_ok=True)


Downloading from https://www.kaggle.com/api/v1/datasets/download/paultimothymooney/chest-xray-pneumonia?dataset_version_number=2...


100%|██████████| 2.29G/2.29G [02:51<00:00, 14.4MB/s]


Extracting files...
DATA_ROOT: /Users/cholponzhakshylykova/.cache/kagglehub/datasets/paultimothymooney/chest-xray-pneumonia/versions/2/chest_xray


In [None]:
# Cell 3: Define ChestXrayEDA Class
class ChestXrayEDA:
    """Comprehensive EDA class for Chest X-ray Pneumonia dataset"""
    
    def __init__(self, data_root: str):
        self.data_root = Path(data_root)
        self.splits = ['train', 'val', 'test']
        self.classes = ['NORMAL', 'PNEUMONIA']
        self.dataset_stats = {}
        self._validate_dataset_structure()
    
    def _validate_dataset_structure(self):
        """Validate that the dataset has the expected structure"""
        if not self.data_root.exists():
            raise FileNotFoundError(f"Dataset root not found: {self.data_root}")
        
        for split in self.splits:
            split_path = self.data_root / split
            if not split_path.exists():
                raise FileNotFoundError(f"Split directory not found: {split_path}")
            
            for cls in self.classes:
                class_path = split_path / cls
                if not class_path.exists():
                    raise FileNotFoundError(f"Class directory not found: {class_path}")
    
    def analyze_dataset_distribution(self) -> Dict:
        """Analyze the distribution of images across splits and classes"""
        stats = {}
        
        for split in self.splits:
            stats[split] = {}
            split_path = self.data_root / split
            
            for cls in self.classes:
                class_path = split_path / cls
                image_files = [f for f in class_path.iterdir() if f.suffix.lower() in ['.jpg', '.jpeg', '.png']]
                stats[split][cls] = len(image_files)
            
            stats[split]['total'] = sum(stats[split].values())
        
        self.dataset_stats = stats

        print("="*60)
        print("DATASET DISTRIBUTION ANALYSIS")
        print("="*60)
        
        for split in self.splits:
            print(f"\n{split.upper()} SET:")
            for cls in self.classes:
                count = stats[split][cls]
                percentage = (count / stats[split]['total']) * 100
                print(f"  {cls:>10}: {count:>5} images ({percentage:.1f}%)")
            print(f"  {'TOTAL':>10}: {stats[split]['total']:>5} images")
        
        # Overall statistics
        total_images = sum(stats[split]['total'] for split in self.splits)
        total_normal = sum(stats[split]['NORMAL'] for split in self.splits)
        total_pneumonia = sum(stats[split]['PNEUMONIA'] for split in self.splits)
        
        print(f"\nOVERALL DATASET:")
        print(f"  {'NORMAL':>10}: {total_normal:>5} images ({(total_normal/total_images)*100:.1f}%)")
        print(f"  {'PNEUMONIA':>10}: {total_pneumonia:>5} images ({(total_pneumonia/total_images)*100:.1f}%)")
        print(f"  {'TOTAL':>10}: {total_images:>5} images")
        
        # Class imbalance analysis
        imbalance_info = []
        imbalance_threshold = 1.2
        
        for split in self.splits:
            n_normal = stats[split]['NORMAL']
            n_pneumonia = stats[split]['PNEUMONIA']
            ratio = max(n_normal, n_pneumonia) / (min(n_normal, n_pneumonia) + 1e-9)
            imbalance_info.append((split, ratio))
        
        overall_ratio = max(total_normal, total_pneumonia) / (min(total_normal, total_pneumonia) + 1e-9)
        
        print("\nCLASS IMBALANCE ANALYSIS & RECOMMENDATION:")
        for split, ratio in imbalance_info:
            print(f"  {split.upper()} set imbalance ratio: {ratio:.2f} (max/min)")
        print(f"  OVERALL imbalance ratio: {overall_ratio:.2f} (max/min)")
        
        if overall_ratio > imbalance_threshold:
            print("\nRecommendation: There is a significant class imbalance.")
            print("It is recommended to use OVERSAMPLING (or class weighting) during model training to address this.")
        else:
            print("\nNo significant class imbalance detected.")
        
        return stats
    
    def visualize_distribution(self):
        """Create comprehensive visualizations of dataset distribution"""
        if not self.dataset_stats:
            self.analyze_dataset_distribution()
        
        splits = list(self.dataset_stats.keys())
        normal_counts = [self.dataset_stats[split]['NORMAL'] for split in splits]
        pneumonia_counts = [self.dataset_stats[split]['PNEUMONIA'] for split in splits]
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Stacked bar chart
        x = np.arange(len(splits))
        width = 0.6
        
        axes[0, 0].bar(x, normal_counts, width, label='NORMAL', alpha=0.8)
        axes[0, 0].bar(x, pneumonia_counts, width, bottom=normal_counts, label='PNEUMONIA', alpha=0.8)
        axes[0, 0].set_xlabel('Dataset Split')
        axes[0, 0].set_ylabel('Number of Images')
        axes[0, 0].set_title('Dataset Distribution by Split (Stacked)')
        axes[0, 0].set_xticks(x)
        axes[0, 0].set_xticklabels([s.capitalize() for s in splits])
        axes[0, 0].legend()
        axes[0, 0].grid(axis='y', alpha=0.3)
        
        # Grouped bar chart
        width = 0.35
        axes[0, 1].bar(x - width/2, normal_counts, width, label='NORMAL', alpha=0.8)
        axes[0, 1].bar(x + width/2, pneumonia_counts, width, label='PNEUMONIA', alpha=0.8)
        axes[0, 1].set_xlabel('Dataset Split')
        axes[0, 1].set_ylabel('Number of Images')
        axes[0, 1].set_title('Class Distribution Comparison')
        axes[0, 1].set_xticks(x)
        axes[0, 1].set_xticklabels([s.capitalize() for s in splits])
        axes[0, 1].legend()
        axes[0, 1].grid(axis='y', alpha=0.3)
        
        # Overall pie chart
        total_normal = sum(normal_counts)
        total_pneumonia = sum(pneumonia_counts)
        
        axes[1, 0].pie([total_normal, total_pneumonia], labels=['NORMAL', 'PNEUMONIA'],
                       autopct='%1.1f%%', startangle=90, colors=['lightblue', 'lightcoral'])
        axes[1, 0].set_title('Overall Class Distribution')
        
        # Imbalance ratio chart
        imbalance_ratios = []
        split_labels = []
        
        for split in splits:
            normal = self.dataset_stats[split]['NORMAL']
            pneumonia = self.dataset_stats[split]['PNEUMONIA']
            ratio = pneumonia / normal if normal > 0 else 0
            imbalance_ratios.append(ratio)
            split_labels.append(f"{split.capitalize()}\n({pneumonia}:{normal})")
        
        bars = axes[1, 1].bar(range(len(splits)), imbalance_ratios, alpha=0.8)
        axes[1, 1].set_xlabel('Dataset Split')
        axes[1, 1].set_ylabel('Pneumonia:Normal Ratio')
        axes[1, 1].set_title('Class Imbalance by Split')
        axes[1, 1].set_xticks(range(len(splits)))
        axes[1, 1].set_xticklabels(split_labels)
        axes[1, 1].axhline(y=1, color='red', linestyle='--', alpha=0.7, label='Balanced')
        axes[1, 1].legend()
        axes[1, 1].grid(axis='y', alpha=0.3)
        
        # Add value labels on bars
        for i, (bar, ratio) in enumerate(zip(bars, imbalance_ratios)):
            axes[1, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                            f'{ratio:.2f}', ha='center', va='bottom')
        
        plt.tight_layout()
        fig.savefig(os.path.join("plots", "dataset_distribution.png"), dpi=300, bbox_inches='tight')
        plt.show()
    
    def sample_images_visualization(self, n_samples: int = 8):
        """Display sample images from each class"""
        fig, axes = plt.subplots(2, n_samples, figsize=(20, 8))
        
        for class_idx, class_name in enumerate(self.classes):
            class_path = self.data_root / 'train' / class_name
            image_files = list(class_path.glob('*.jpeg')) + list(class_path.glob('*.jpg'))
            sampled_files = random.sample(image_files, min(n_samples, len(image_files)))
            
            for img_idx, img_path in enumerate(sampled_files):
                try:
                    img = Image.open(img_path).convert('L')
                    axes[class_idx, img_idx].imshow(img, cmap='gray')
                    axes[class_idx, img_idx].set_title(f'{class_name}\n{img_path.name}', fontsize=10)
                    axes[class_idx, img_idx].axis('off')
                except Exception as e:
                    axes[class_idx, img_idx].text(0.5, 0.5, f'Error loading\n{img_path.name}',
                                                 ha='center', va='center', transform=axes[class_idx, img_idx].transAxes)
                    axes[class_idx, img_idx].axis('off')
        
        plt.suptitle('Sample Images from Each Class', fontsize=16)
        plt.tight_layout()
        fig.savefig(os.path.join("plots", "sample_images.png"), dpi=300, bbox_inches='tight')
        plt.show()
    
    def analyze_image_properties(self, sample_size: int = 100):
        """Analyze image properties like dimensions, file sizes, etc."""
        print("\n" + "="*60)
        print("IMAGE PROPERTIES ANALYSIS")
        print("="*60)
        
        properties = {split: {cls: {'widths': [], 'heights': [], 'file_sizes': []} 
                             for cls in self.classes} for split in self.splits}
        
        for split in self.splits:
            for cls in self.classes:
                class_path = self.data_root / split / cls
                image_files = list(class_path.glob('*.jpeg')) + list(class_path.glob('*.jpg'))
                
                # Sample random images for analysis
                sample_files = random.sample(image_files, min(sample_size, len(image_files)))
                
                for img_path in sample_files:
                    try:
                        with Image.open(img_path) as img:
                            width, height = img.size
                            properties[split][cls]['widths'].append(width)
                            properties[split][cls]['heights'].append(height)
                            properties[split][cls]['file_sizes'].append(img_path.stat().st_size)
                    except Exception as e:
                        continue
        
        # Display statistics
        for split in self.splits:
            print(f"\n{split.upper()} SET:")
            for cls in self.classes:
                widths = properties[split][cls]['widths']
                heights = properties[split][cls]['heights']
                file_sizes = properties[split][cls]['file_sizes']
                
                if widths:  # Only if we have data
                    print(f"  {cls}:")
                    print(f"    Width:  {np.mean(widths):.0f} ± {np.std(widths):.0f} px (range: {np.min(widths)}-{np.max(widths)})")
                    print(f"    Height: {np.mean(heights):.0f} ± {np.std(heights):.0f} px (range: {np.min(heights)}-{np.max(heights)})")
                    print(f"    Size:   {np.mean(file_sizes)/1024:.1f} ± {np.std(file_sizes)/1024:.1f} KB")
        
        return properties
    
    def plot_image_properties(self, properties: Dict):
        """Plot image properties distribution"""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # Collect all data for plotting
        all_widths = {'NORMAL': [], 'PNEUMONIA': []}
        all_heights = {'NORMAL': [], 'PNEUMONIA': []}
        all_file_sizes = {'NORMAL': [], 'PNEUMONIA': []}
        
        for split in self.splits:
            for cls in self.classes:
                all_widths[cls].extend(properties[split][cls]['widths'])
                all_heights[cls].extend(properties[split][cls]['heights'])
                all_file_sizes[cls].extend(properties[split][cls]['file_sizes'])
        
        # Width distribution
        axes[0, 0].hist([all_widths['NORMAL'], all_widths['PNEUMONIA']], 
                       bins=30, alpha=0.7, label=['NORMAL', 'PNEUMONIA'])
        axes[0, 0].set_xlabel('Width (pixels)')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].set_title('Image Width Distribution')
        axes[0, 0].legend()
        axes[0, 0].grid(alpha=0.3)
        
        # Height distribution
        axes[0, 1].hist([all_heights['NORMAL'], all_heights['PNEUMONIA']], 
                       bins=30, alpha=0.7, label=['NORMAL', 'PNEUMONIA'])
        axes[0, 1].set_xlabel('Height (pixels)')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].set_title('Image Height Distribution')
        axes[0, 1].legend()
        axes[0, 1].grid(alpha=0.3)
        
        # File size distribution
        axes[0, 2].hist([np.array(all_file_sizes['NORMAL'])/1024, np.array(all_file_sizes['PNEUMONIA'])/1024], 
                       bins=30, alpha=0.7, label=['NORMAL', 'PNEUMONIA'])
        axes[0, 2].set_xlabel('File Size (KB)')
        axes[0, 2].set_ylabel('Frequency')
        axes[0, 2].set_title('File Size Distribution')
        axes[0, 2].legend()
        axes[0, 2].grid(alpha=0.3)
        
        # Aspect ratio analysis
        aspect_ratios = {'NORMAL': [], 'PNEUMONIA': []}
        for cls in self.classes:
            for w, h in zip(all_widths[cls], all_heights[cls]):
                if h > 0:
                    aspect_ratios[cls].append(w/h)
        
        axes[1, 0].hist([aspect_ratios['NORMAL'], aspect_ratios['PNEUMONIA']], 
                       bins=30, alpha=0.7, label=['NORMAL', 'PNEUMONIA'])
        axes[1, 0].set_xlabel('Aspect Ratio (W/H)')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].set_title('Aspect Ratio Distribution')
        axes[1, 0].legend()
        axes[1, 0].grid(alpha=0.3)
        
        # Box plots for dimensions
        width_data = [all_widths['NORMAL'], all_widths['PNEUMONIA']]
        height_data = [all_heights['NORMAL'], all_heights['PNEUMONIA']]
        
        axes[1, 1].boxplot(width_data, labels=['NORMAL', 'PNEUMONIA'])
        axes[1, 1].set_ylabel('Width (pixels)')
        axes[1, 1].set_title('Width Distribution (Box Plot)')
        axes[1, 1].grid(alpha=0.3)
        
        axes[1, 2].boxplot(height_data, labels=['NORMAL', 'PNEUMONIA'])
        axes[1, 2].set_ylabel('Height (pixels)')
        axes[1, 2].set_title('Height Distribution (Box Plot)')
        axes[1, 2].grid(alpha=0.3)
        
        plt.tight_layout()
        fig.savefig(os.path.join("plots", "image_properties.png"), dpi=300, bbox_inches='tight')
        plt.show()

print("✅ ChestXrayEDA class defined successfully!")

In [None]:
# Cell 4: Initialize EDA Object and Validate Dataset
print("🫁 CHEST X-RAY PNEUMONIA DETECTION - EDA & PREPROCESSING")
print("=" * 70)
# Initialize EDA object
eda = ChestXrayEDA(DATA_ROOT)
print("✅ Dataset structure validated successfully!")

In [None]:
# Cell 5: Analyze Dataset Distribution
print("\n📊 ANALYZING DATASET DISTRIBUTION...")
dataset_stats = eda.analyze_dataset_distribution()

In [None]:
# Cell 6: Visualize Dataset Distribution
print("\n🎨 CREATING DISTRIBUTION VISUALIZATIONS...")
eda.visualize_distribution()


In [None]:
# Cell 7: Display Sample Images
print("\n🖼️  DISPLAYING SAMPLE IMAGES...")
eda.sample_images_visualization(n_samples=6)

In [None]:
# Cell 8: Analyze Image Properties
print("\n📏 ANALYZING IMAGE PROPERTIES...")
properties = eda.analyze_image_properties(sample_size=150)

In [None]:
# Cell 9: Plot Image Properties
print("\n📈 PLOTTING IMAGE PROPERTIES...")
eda.plot_image_properties(properties)

In [None]:
# Cell 10: Data Augmentation Visualization
def visualize_augmentations(data_root: str, n_augmentations: int = 5):
    """Visualize data augmentation transformations"""
    print("\n🔄 DEMONSTRATING DATA AUGMENTATION...")
    
    # Define augmentation transforms
    transform_augment = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
    ])
    
    # Select sample images from both classes
    fig, axes = plt.subplots(2, n_augmentations + 1, figsize=(20, 8))
    
    for class_idx, class_name in enumerate(['NORMAL', 'PNEUMONIA']):
        sample_path = Path(data_root) / 'train' / class_name
        sample_files = list(sample_path.glob('*.jpeg')) + list(sample_path.glob('*.jpg'))
        sample_img_path = random.choice(sample_files)
        
        # Load original image
        original_img = Image.open(sample_img_path).convert('L')
        
        # Show original
        axes[class_idx, 0].imshow(original_img, cmap='gray')
        axes[class_idx, 0].set_title(f'Original\n{class_name}')
        axes[class_idx, 0].axis('off')
        
        # Show augmented versions
        for i in range(n_augmentations):
            augmented = transform_augment(original_img)
            axes[class_idx, i + 1].imshow(augmented.squeeze(), cmap='gray')
            axes[class_idx, i + 1].set_title(f'Augmented {i + 1}')
            axes[class_idx, i + 1].axis('off')
    
    plt.suptitle('Data Augmentation Examples', fontsize=16)
    plt.tight_layout()
    fig.savefig(os.path.join("plots", "data_augmentation.png"), dpi=300, bbox_inches='tight')
    plt.show()

# Apply augmentation visualization
visualize_augmentations(DATA_ROOT, n_augmentations=5)

In [None]:
# Cell 11: Preprocessing Recommendations
def preprocessing_recommendations(dataset_stats: Dict):
    """Provide preprocessing recommendations based on EDA findings"""
    print("\n💡 PREPROCESSING RECOMMENDATIONS")
    print("=" * 60)
    
    # Calculate class imbalance
    total_normal = sum(dataset_stats[split]['NORMAL'] for split in dataset_stats.keys())
    total_pneumonia = sum(dataset_stats[split]['PNEUMONIA'] for split in dataset_stats.keys())
    imbalance_ratio = max(total_normal, total_pneumonia) / min(total_normal, total_pneumonia)
    
    print("1. CLASS IMBALANCE:")
    if imbalance_ratio > 1.5:
        print("   ⚠️  Significant class imbalance detected")
        print("   📋 Recommendations:")
        print("      - Use weighted loss function (WeightedRandomSampler)")
        print("      - Apply SMOTE or oversampling techniques")
        print("      - Consider focal loss for training")
    else:
        print("   ✅ Classes are relatively balanced")
    
    print("\n2. DATA AUGMENTATION:")
    print("   📋 Recommended augmentations:")
    print("      - Random rotation (±15°)")
    print("      - Random horizontal flip")
    print("      - Random brightness/contrast adjustment")
    print("      - Random zoom/scaling")
    print("      - Gaussian noise (medical imaging)")
    
    print("\n3. PREPROCESSING PIPELINE:")
    print("   📋 Recommended steps:")
    print("      - Resize to consistent dimensions (224x224 or 256x256)")
    print("      - Normalize pixel values to [0,1] or [-1,1]")
    print("      - Apply histogram equalization if needed")
    print("      - Consider lung segmentation for better focus")
    
    print("\n4. TRAIN/VALIDATION/TEST SPLIT:")
    train_total = dataset_stats['train']['total']
    val_total = dataset_stats['val']['total']
    test_total = dataset_stats['test']['total']
    total_images = train_total + val_total + test_total
    
    print(f"   Current split: {train_total/total_images:.1%} / {val_total/total_images:.1%} / {test_total/total_images:.1%}")
    
    if val_total/total_images < 0.15:
        print("   ⚠️  Validation set might be too small")
        print("   📋 Consider stratified sampling for better validation")
    
    print("\n5. MODEL CONSIDERATIONS:")
    print("   📋 Recommendations:")
    print("      - Use transfer learning (ResNet, DenseNet, EfficientNet)")
    print("      - Apply gradual unfreezing")
    print("      - Use appropriate metrics (AUC, F1-score, sensitivity, specificity)")
    print("      - Implement early stopping and learning rate scheduling")

# Generate recommendations
preprocessing_recommendations(dataset_stats)


In [None]:
# Cell 12: Summary and Next Steps
print("\n🎯 SUMMARY AND NEXT STEPS")
print("=" * 60)

print("✅ COMPLETED TASKS:")
print("   • Dataset structure validation")
print("   • Distribution analysis across splits and classes")
print("   • Image properties analysis")
print("   • Sample visualization")
print("   • Data augmentation demonstration")
print("   • Preprocessing recommendations")

print("\n📁 GENERATED FILES:")
print("   • plots/dataset_distribution.png")
print("   • plots/sample_images.png")
print("   • plots/image_properties.png")
print("   • plots/data_augmentation.png")

print("\n🚀 NEXT STEPS:")
print("   1. Implement data preprocessing pipeline")
print("   2. Create balanced data loaders")
print("   3. Design CNN architecture or use transfer learning")
print("   4. Implement training loop with proper validation")
print("   5. Evaluate model performance with medical metrics")
print("   6. Deploy model with proper validation and monitoring")

print("\n✅ EDA COMPLETED SUCCESSFULLY!")