In [2]:
TRAIN_DATA_DIR = "chest_xray/train"
VAL_DATA_DIR = "chest_xray/val"
TEST_DATA_DIR = "chest_xray/test"

In [3]:
# Define data directories for all sets
data_dirs = {
    'train': TRAIN_DATA_DIR,
    'val': VAL_DATA_DIR,
    'test': TEST_DATA_DIR
}

In [4]:
# Function to analyze dataset distribution
def print_dataset_stats(data_dirs):
    total_normal = 0
    total_pneumonia = 0
    
    print("Dataset Distribution:")
    print("-" * 50)
    
    for dataset_type, dir_path in data_dirs.items():
        normal_path = os.path.join(dir_path, 'NORMAL')
        pneumonia_path = os.path.join(dir_path, 'PNEUMONIA')
        
        normal_count = len(os.listdir(normal_path))
        pneumonia_count = len(os.listdir(pneumonia_path))
        
        total_normal += normal_count
        total_pneumonia += pneumonia_count
        
        print(f"{dataset_type.capitalize()} set:")
        print(f"  Normal X-Rays: {normal_count}")
        print(f"  Pneumonia X-Rays: {pneumonia_count}")
        print(f"  Ratio (Pneumonia/Normal): {pneumonia_count/normal_count:.2f}")
        print()
    
    print("Overall Statistics:")
    print(f"Total Normal cases: {total_normal}")
    print(f"Total Pneumonia cases: {total_pneumonia}")
    print(f"Total images: {total_normal + total_pneumonia}")

In [5]:
def plot_sample_images(data_dirs, samples_per_class=5):
    """
    Plot sample images from each dataset and class
    """
    plt.style.use('seaborn')
    
    for dataset_type, dir_path in data_dirs.items():
        normal_path = os.path.join(dir_path, 'NORMAL')
        pneumonia_path = os.path.join(dir_path, 'PNEUMONIA')
        
        normal_samples = np.random.choice(os.listdir(normal_path), samples_per_class, replace=False)
        pneumonia_samples = np.random.choice(os.listdir(pneumonia_path), samples_per_class, replace=False)
        
        fig, axes = plt.subplots(2, samples_per_class, figsize=(20, 8))
        fig.suptitle(f'Sample Images from {dataset_type.capitalize()} Set', fontsize=16)
        
        for idx, (norm_img, pneu_img) in enumerate(zip(normal_samples, pneumonia_samples)):
            # Plot normal images
            norm_img_path = os.path.join(normal_path, norm_img)
            img = Image.open(norm_img_path)
            axes[0, idx].imshow(img, cmap='gray')
            axes[0, idx].axis('off')
            axes[0, idx].set_title('Normal')
            
            # Plot pneumonia images
            pneu_img_path = os.path.join(pneumonia_path, pneu_img)
            img = Image.open(pneu_img_path)
            axes[1, idx].imshow(img, cmap='gray')
            axes[1, idx].axis('off')
            axes[1, idx].set_title('Pneumonia')
        
        plt.tight_layout()
        plt.show()

In [None]:
print_dataset_stats(data_dirs)
plot_sample_images(data_dirs)