In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image

from utils import load_config
from data import MIMICCXRDataset, MIMICCXRDataModule

## 1. Load Configuration and Dataset

In [None]:
# Load configuration
config = load_config(project_root / "configs" / "base_config.yaml")
print(f"Data directory: {config['data']['data_dir']}")

In [None]:
# Initialise data module
# Note: Update data_dir to your actual MIMIC-CXR path
DATA_DIR = config["data"]["data_dir"]

# datamodule = MIMICCXRDataModule(
#     data_dir=DATA_DIR,
#     batch_size=4,
#     max_samples=100,  # Limit for exploration
# )
# datamodule.setup(stage="fit")

## 2. Visualise Sample Images

In [None]:
# Plot sample images
def plot_sample_images(dataset, num_samples=4):
    """Display sample images from the dataset."""
    fig, axes = plt.subplots(1, num_samples, figsize=(4 * num_samples, 4))
    
    for i, ax in enumerate(axes):
        if i >= len(dataset):
            break
        sample = dataset[i]
        image = sample["image"]
        
        # Convert tensor to displayable format
        if hasattr(image, "numpy"):
            image = image.permute(1, 2, 0).numpy()
        
        ax.imshow(image, cmap="gray")
        ax.set_title(f"Sample {i}")
        ax.axis("off")
    
    plt.tight_layout()
    plt.show()

# Uncomment when data is available:
# plot_sample_images(datamodule.train_dataset)

## 3. Analyse Report Statistics

In [None]:
def analyse_reports(dataset):
    """Compute statistics on radiology reports."""
    report_lengths = []
    
    for i in range(min(len(dataset), 1000)):
        sample = dataset[i]
        report = sample.get("report", "")
        report_lengths.append(len(report.split()))
    
    print(f"Report length statistics (words):")
    print(f"  Mean: {np.mean(report_lengths):.1f}")
    print(f"  Std:  {np.std(report_lengths):.1f}")
    print(f"  Min:  {np.min(report_lengths)}")
    print(f"  Max:  {np.max(report_lengths)}")
    
    # Plot histogram
    plt.figure(figsize=(10, 4))
    plt.hist(report_lengths, bins=50, edgecolor="black")
    plt.xlabel("Report Length (words)")
    plt.ylabel("Frequency")
    plt.title("Distribution of Report Lengths")
    plt.show()

# Uncomment when data is available:
# analyse_reports(datamodule.train_dataset)

## 4. Next Steps

Once data is loaded:
- Examine finding distributions
- Check image quality and resolution
- Identify potential data issues