# 02 — Data Preparation: Pipeline Validation

**Goal**: Validate the data pipeline end-to-end before training.

This notebook:
1. Downloads and extracts galaxy cutouts (or uses existing images)
2. Loads the stratified splits
3. Instantiates the `EuclidDataset` with transforms
4. Verifies image loading, target shapes, and mask correctness
5. Visualizes sample images with their morphology labels
6. Tests a DataLoader batch for training readiness

## 1. Setup

In [None]:
import sys
from pathlib import Path

# Add project root to path so we can import src/
PROJECT_ROOT = Path('..').resolve()
sys.path.insert(0, str(PROJECT_ROOT))

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

from src.data.dataset import EuclidDataset, MorphologySchema
from src.data.transforms import get_transforms, get_tta_transforms

import warnings
warnings.filterwarnings('ignore')

plt.rcParams['figure.dpi'] = 120

CATALOG_PATH = PROJECT_ROOT / 'data' / 'raw' / 'morphology_catalogue.parquet'
SPLIT_PATH = PROJECT_ROOT / 'data' / 'processed' / 'split_indices.json'
IMAGE_DIR = PROJECT_ROOT / 'data' / 'raw' / 'images'
FIGURES_DIR = PROJECT_ROOT / 'results' / 'figures'

print(f'Project root: {PROJECT_ROOT}')
print(f'Catalog exists: {CATALOG_PATH.exists()}')
print(f'Splits exist: {SPLIT_PATH.exists()}')
print(f'Images exist: {IMAGE_DIR.exists()}')

## 2. Verify Split Indices

In [None]:
with open(SPLIT_PATH) as f:
    split_data = json.load(f)

meta = split_data['meta']
print(f"Total galaxies: {meta['total_galaxies']:,}")
print(f"Seed: {meta['seed']}")
print(f"Training cuts applied: {meta['apply_training_cuts']}")
print(f"\nSplit sizes:")
for split, size in meta['split_sizes'].items():
    print(f"  {split:6s}: {size:>8,} ({size/meta['total_galaxies']*100:.1f}%)")

# Verify no overlap
train_set = set(split_data['train'])
val_set = set(split_data['val'])
test_set = set(split_data['test'])
assert len(train_set & val_set) == 0, 'Train/Val overlap!'
assert len(train_set & test_set) == 0, 'Train/Test overlap!'
assert len(val_set & test_set) == 0, 'Val/Test overlap!'
print('\nNo overlap between splits (verified)')

## 3. Load Catalog Subsets

In [None]:
train_df = EuclidDataset.load_split(CATALOG_PATH, SPLIT_PATH, 'train')
val_df = EuclidDataset.load_split(CATALOG_PATH, SPLIT_PATH, 'val')
test_df = EuclidDataset.load_split(CATALOG_PATH, SPLIT_PATH, 'test')

print(f'Train: {len(train_df):,}')
print(f'Val:   {len(val_df):,}')
print(f'Test:  {len(test_df):,}')

## 4. Morphology Schema

In [None]:
schema = MorphologySchema.default()

print(f'Number of output values: {schema.num_outputs}')
print(f'\nQuestions and index slices:')
for question, (start, end) in schema.question_slices.items():
    answers = schema.questions[question]
    print(f'  [{start:2d}:{end:2d}] {question}: {answers}')

print(f'\nFull column list:')
for i, col in enumerate(schema.columns):
    print(f'  [{i:2d}] {col}')

## 5. Instantiate Dataset

**Note**: This cell requires downloaded images. If images are not yet downloaded,
run `python scripts/download_images.py` first (~3.8 GB download).

If images are not available, we can still validate the target/mask logic
by creating the dataset without transforms and skipping image loading.

In [None]:
# Check if images are available
images_available = IMAGE_DIR.exists() and any(IMAGE_DIR.rglob('*.jpg'))

if images_available:
    n_images = sum(1 for _ in IMAGE_DIR.rglob('*.jpg'))
    print(f'Found {n_images:,} images in {IMAGE_DIR}')
else:
    print(f'No images found in {IMAGE_DIR}')
    print('Run: python scripts/download_images.py')
    print('\nContinuing with target/mask validation only...')

In [None]:
# Create datasets — transforms only if images are available
train_transform = get_transforms('train') if images_available else None
val_transform = get_transforms('val') if images_available else None

train_ds = EuclidDataset(train_df, IMAGE_DIR, schema, train_transform)
val_ds = EuclidDataset(val_df, IMAGE_DIR, schema, val_transform)
test_ds = EuclidDataset(test_df, IMAGE_DIR, schema, val_transform)

print(f'Train dataset: {len(train_ds):,} samples')
print(f'Val dataset:   {len(val_ds):,} samples')
print(f'Test dataset:  {len(test_ds):,} samples')
print(f'Output dim:    {schema.num_outputs}')

## 6. Validate Target & Mask Structure

In [None]:
# Validate the pre-computed target and mask tensors
print(f'Target tensor shape: {train_ds.targets.shape}')
print(f'Mask tensor shape:   {train_ds.masks.shape}')

# Check per-question validity rates
print(f'\nPer-question validity (% of training samples):')
for question, (start, end) in schema.question_slices.items():
    valid_pct = train_ds.masks[:, start].mean().item() * 100
    print(f'  {question:25s}: {valid_pct:5.1f}% valid')

# Verify fractions sum to 1 for valid questions
print(f'\nFraction sum check (should be ~1.0 where valid):')
for question, (start, end) in schema.question_slices.items():
    mask = train_ds.masks[:, start] > 0
    if mask.sum() > 0:
        sums = train_ds.targets[mask, start:end].sum(dim=1)
        print(f'  {question:25s}: mean={sums.mean():.4f}, std={sums.std():.4f}')

## 7. Visualize Sample Images

Load a few images and display them with their morphology labels.
(Skip this section if images are not yet downloaded.)

In [None]:
if images_available:
    # Use eval transforms so we see the original images (no random augmentation)
    viz_ds = EuclidDataset(train_df, IMAGE_DIR, schema, get_transforms('val'))
    
    fig, axes = plt.subplots(3, 5, figsize=(18, 11))
    
    # Show 15 random galaxies
    rng = np.random.default_rng(42)
    indices = rng.choice(len(viz_ds), 15, replace=False)
    
    # ImageNet denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    
    for ax, idx in zip(axes.flat, indices):
        img, targets, mask = viz_ds[idx]
        
        # Denormalize
        img_vis = img * std + mean
        img_vis = img_vis.clamp(0, 1).permute(1, 2, 0).numpy()
        
        ax.imshow(img_vis)
        ax.axis('off')
        
        # Label: dominant morphology + key fractions
        s_frac = targets[0].item()  # smooth fraction
        f_frac = targets[1].item()  # featured fraction
        p_frac = targets[2].item()  # problem fraction
        ax.set_title(f'S={s_frac:.2f} F={f_frac:.2f} P={p_frac:.2f}', fontsize=9)
    
    fig.suptitle('Sample Galaxies (S=Smooth, F=Featured, P=Problem)', fontsize=14)
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'data_sample_galaxies.pdf', bbox_inches='tight')
    plt.show()
else:
    print('Skipping visualization — images not downloaded yet.')

In [None]:
if images_available:
    # Show augmentation effects on a single galaxy
    aug_ds = EuclidDataset(train_df, IMAGE_DIR, schema, get_transforms('train'))
    
    fig, axes = plt.subplots(2, 5, figsize=(18, 7))
    fig.suptitle('Same Galaxy Under 10 Random Augmentations', fontsize=14)
    
    idx = indices[0]  # Pick one galaxy
    for ax in axes.flat:
        img, _, _ = aug_ds[idx]
        img_vis = (img * std + mean).clamp(0, 1).permute(1, 2, 0).numpy()
        ax.imshow(img_vis)
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'data_augmentation_examples.pdf', bbox_inches='tight')
    plt.show()

## 8. Test DataLoader

In [None]:
if images_available:
    loader = DataLoader(
        train_ds, batch_size=32, shuffle=True,
        num_workers=0,  # Use 0 for notebook; 4 in scripts
        pin_memory=True,
    )
    
    batch_img, batch_targets, batch_mask = next(iter(loader))
    
    print(f'Image batch:  {batch_img.shape}  {batch_img.dtype}')
    print(f'Target batch: {batch_targets.shape}  {batch_targets.dtype}')
    print(f'Mask batch:   {batch_mask.shape}  {batch_mask.dtype}')
    print(f'\nPixel range: [{batch_img.min():.3f}, {batch_img.max():.3f}]')
    print(f'Target range: [{batch_targets.min():.3f}, {batch_targets.max():.3f}]')
    print(f'Mask sum per sample (mean): {batch_mask.sum(dim=1).mean():.1f} / {schema.num_outputs}')
    
    # Simulate masked loss computation
    fake_pred = torch.randn_like(batch_targets)
    per_output_loss = (fake_pred - batch_targets) ** 2
    masked_loss = (per_output_loss * batch_mask).sum() / batch_mask.sum()
    print(f'\nSimulated masked MSE loss: {masked_loss.item():.4f}')
    print('DataLoader test PASSED')
else:
    print('Skipping DataLoader test — images not downloaded yet.')
    print('\nTarget/mask tensors are ready. Once images are downloaded,')
    print('the pipeline will be fully operational.')

## 9. TTA Transforms Preview

In [None]:
if images_available:
    from PIL import Image as PILImage
    
    tta_transforms = get_tta_transforms()
    print(f'Number of TTA views: {len(tta_transforms)}')
    
    # Load one raw image
    row = train_df.iloc[indices[0]]
    tile = row['tile_index']
    obj_id = str(row['object_id']).replace('-', 'NEG')
    fname = f'{tile}_{obj_id}_gz_arcsinh_vis_y.jpg'
    img_path = IMAGE_DIR / str(tile) / fname
    raw_img = PILImage.open(img_path).convert('RGB')
    
    fig, axes = plt.subplots(1, 7, figsize=(21, 3))
    tta_labels = ['Original', 'H-Flip', 'V-Flip', '90°', '180°', '270°', 'H-Flip+90°']
    
    for ax, tfm, label in zip(axes, tta_transforms, tta_labels):
        img_t = tfm(raw_img)
        img_vis = (img_t * std + mean).clamp(0, 1).permute(1, 2, 0).numpy()
        ax.imshow(img_vis)
        ax.set_title(label, fontsize=10)
        ax.axis('off')
    
    fig.suptitle('Test-Time Augmentation Views', fontsize=13)
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'data_tta_views.pdf', bbox_inches='tight')
    plt.show()
else:
    tta_transforms = get_tta_transforms()
    print(f'TTA configured with {len(tta_transforms)} views')
    print('Visual preview available after image download.')

## 10. Summary

### Pipeline Status

In [None]:
print('=' * 60)
print('DATA PIPELINE STATUS')
print('=' * 60)
print(f'\n  Catalog:        {"OK" if CATALOG_PATH.exists() else "MISSING"} ({CATALOG_PATH.name})')
print(f'  Splits:         {"OK" if SPLIT_PATH.exists() else "MISSING"} ({SPLIT_PATH.name})')
print(f'  Images:         {"OK" if images_available else "PENDING"} (run download_images.py)')
print(f'  Schema:         {schema.num_outputs} outputs across {len(schema.questions)} questions')
print(f'  Train samples:  {len(train_ds):,}')
print(f'  Val samples:    {len(val_ds):,}')
print(f'  Test samples:   {len(test_ds):,}')
print(f'  TTA views:      {len(get_tta_transforms())}')
print(f'\n  Ready for training: {"YES" if images_available else "NO (need images)"}')
print('=' * 60)

if not images_available:
    print('\nNext step: python scripts/download_images.py')
    print('  This downloads ~3.8 GB of VIS+Y galaxy cutouts from Zenodo.')
    print('  After download, re-run this notebook to validate the full pipeline.')