# Data Exploration - Diabetic Retinopathy Dataset

This notebook explores the Diabetic Retinopathy dataset structure, class distribution, and sample images.


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
import os

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)


In [None]:
# Define paths
base_path = Path('../data/raw/DiabeticRetinopathyDataset')
images_path = base_path / 'gaussian_filtered_images' / 'gaussian_filtered_images'
train_csv = base_path / 'train.csv'

# Load CSV
df = pd.read_csv(train_csv)
print(f"Total samples in CSV: {len(df)}")
print(f"\nFirst few rows:")
print(df.head())
print(f"\nDiagnosis distribution:")
print(df['diagnosis'].value_counts().sort_index())


In [None]:
# Class mapping
class_mapping = {
    0: 'No_DR',
    1: 'Mild',
    2: 'Moderate',
    3: 'Severe',
    4: 'Proliferate_DR'
}

# Count images per class directory
class_counts = {}
for class_name in class_mapping.values():
    class_dir = images_path / class_name
    if class_dir.exists():
        count = len(list(class_dir.glob('*.png')))
        class_counts[class_name] = count
        print(f"{class_name}: {count} images")

total_images = sum(class_counts.values())
print(f"\nTotal images: {total_images}")


In [None]:
# Visualize class distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Bar plot
classes = list(class_counts.keys())
counts = list(class_counts.values())
ax1.bar(classes, counts, color=['green', 'yellow', 'orange', 'red', 'darkred'])
ax1.set_title('Class Distribution', fontsize=14, fontweight='bold')
ax1.set_xlabel('Class', fontsize=12)
ax1.set_ylabel('Number of Images', fontsize=12)
ax1.tick_params(axis='x', rotation=45)

# Pie chart
ax2.pie(counts, labels=classes, autopct='%1.1f%%', startangle=90)
ax2.set_title('Class Distribution (Percentage)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig('../docs/class_distribution.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# Display sample images from each class
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, (class_name, class_label) in enumerate(class_mapping.items()):
    class_dir = images_path / class_name
    if class_dir.exists():
        images = list(class_dir.glob('*.png'))
        if images:
            img_path = images[0]
            img = Image.open(img_path)
            axes[idx].imshow(img, cmap='gray')
            axes[idx].set_title(f"{class_name} (Label: {class_label})", fontsize=12, fontweight='bold')
            axes[idx].axis('off')

# Remove extra subplot
axes[5].axis('off')

plt.suptitle('Sample Images from Each Class', fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig('../docs/sample_images.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# Check image dimensions
sample_img_path = list((images_path / 'No_DR').glob('*.png'))[0]
sample_img = Image.open(sample_img_path)
print(f"Image size: {sample_img.size}")
print(f"Image mode: {sample_img.mode}")

# Check a few more images
for class_name in list(class_mapping.values())[:3]:
    class_dir = images_path / class_name
    if class_dir.exists():
        images = list(class_dir.glob('*.png'))
        if images:
            img = Image.open(images[0])
            print(f"{class_name}: {img.size}, {img.mode}")


In [None]:
# Class imbalance analysis
print("Class Imbalance Analysis:")
print("=" * 50)
for class_name, count in class_counts.items():
    percentage = (count / total_images) * 100
    print(f"{class_name:15s}: {count:4d} images ({percentage:5.2f}%)")

print(f"\n{'='*50}")
print(f"Most common class: {max(class_counts, key=class_counts.get)} ({max(class_counts.values())} images)")
print(f"Least common class: {min(class_counts, key=class_counts.get)} ({min(class_counts.values())} images)")
print(f"Imbalance ratio: {max(class_counts.values()) / min(class_counts.values()):.2f}:1")


## Summary

- Dataset contains 5 classes with significant class imbalance
- Images are 224x224 pixels, grayscale
- No_DR class is the most common, Severe is the least common
- Data augmentation will be crucial to handle class imbalance
