In [None]:
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy as np
from pathlib import Path

In [None]:
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X = mnist.data
y = mnist.target.astype(int)

In [None]:
RANDOM_SEED = 42  # For reproducibility
N_SAMPLES = 5000

# Extract subsets
subsets = [
    (0, 1, 'digits_0v1'),
    (3, 8, 'digits_3v8'),
]

In [None]:
for d0, d1, name in subsets:
    # Filter for these digits
    mask = (y == d0) | (y == d1)
    X_filtered = X[mask]
    y_filtered = y[mask]
    
    # Random sample with stratification
    X_subset, _, y_subset, _ = train_test_split(
        X_filtered, 
        y_filtered,
        train_size=N_SAMPLES,
        stratify=y_filtered,  # Keeps class balance
        random_state=RANDOM_SEED  
    )
    
    # Save raw
    folder = Path('../../data/raw') / name
    folder.mkdir(parents=True, exist_ok=True)
    
    np.save(folder / 'X_raw.npy', X_subset)
    np.save(folder / 'y_raw.npy', y_subset)
    
    print(f"Saved {name}: {len(X_subset)} samples")
    print(f"  Class distribution: {np.bincount(y_subset)}")

    folder = Path('data/raw') / name
    folder.mkdir(parents=True, exist_ok=True)
    np.save(folder / 'X_raw.npy', X_subset)
    np.save(folder / 'y_raw.npy', y_subset)
    
    # ========== VALIDATION ==========
    print(f"\n{'='*60}")
    print(f"Validating {name}")
    print(f"{'='*60}")
    
    # 1. Check shapes
    print(f"✓ X shape: {X_subset.shape}")
    print(f"✓ y shape: {y_subset.shape}")
    assert X_subset.shape[0] == y_subset.shape[0], "X and y have different lengths!"
    assert X_subset.shape[0] == N_SAMPLES, f"Expected {N_SAMPLES} samples!"
    assert X_subset.shape[1] == 784, "Expected 784 features (28×28)!"
    
    # 2. Check for NaN or invalid values
    assert not np.isnan(X_subset).any(), "X contains NaN values!"
    assert not np.isnan(y_subset).any(), "y contains NaN values!"
    print(f"✓ No NaN values")
    
    # 3. Check pixel value range
    assert X_subset.min() >= 0, f"Pixel values below 0! Min: {X_subset.min()}"
    assert X_subset.max() <= 255, f"Pixel values above 255! Max: {X_subset.max()}"
    print(f"✓ Pixel range: [{X_subset.min()}, {X_subset.max()}]")
    
    # 4. Check labels
    unique_labels = np.unique(y_subset)
    expected_labels = np.array([d0, d1])
    assert np.array_equal(unique_labels, expected_labels), \
        f"Unexpected labels! Got {unique_labels}, expected {expected_labels}"
    print(f"✓ Labels: {unique_labels}")
    
    # 5. Check class balance
    class_counts = np.bincount(y_subset)
    print(f"✓ Class distribution:")
    print(f"    Digit {d0}: {class_counts[d0]} samples ({class_counts[d0]/N_SAMPLES*100:.1f}%)")
    print(f"    Digit {d1}: {class_counts[d1]} samples ({class_counts[d1]/N_SAMPLES*100:.1f}%)")
    
    # Warn if imbalanced (should be ~50/50 with stratification)
    balance_ratio = min(class_counts[d0], class_counts[d1]) / max(class_counts[d0], class_counts[d1])
    if balance_ratio < 0.9:
        print(f"  ⚠ Warning: Classes imbalanced (ratio: {balance_ratio:.2f})")
    else:
        print(f"  ✓ Well balanced (ratio: {balance_ratio:.2f})")
    
    # 6. Check data types
    print(f"✓ Data types: X={X_subset.dtype}, y={y_subset.dtype}")

    # X can be int or float (both valid)
    assert X_subset.dtype in [np.int64, np.int32, np.float64, np.float32], \
    f"Unexpected X dtype: {X_subset.dtype}"

    # y should be int
    assert y_subset.dtype in [np.int64, np.int32], \
    f"Unexpected y dtype: {y_subset.dtype}"
    
    print(f"{'='*60}")
    print(f"✓ All validation checks passed for {name}!")
    print(f"{'='*60}\n")