In [1]:
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy as np
from pathlib import Path

In [2]:
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X = mnist.data
y = mnist.target.astype(int)

In [3]:
RANDOM_SEED = 42  # For reproducibility
N_SAMPLES = 500

# Extract subsets
subsets = [
    (0, 1, 2, 3, "4-way_multiclass")
]

In [4]:
for d0, d1, d2, d3, name in subsets:
    # Filter for these digits
    mask = (y == d0) | (y == d1) | (y == d2) | (y == d3)
    X_filtered = X[mask]
    y_filtered = y[mask]
    
    # Random sample with stratification
    X_subset, _, y_subset, _ = train_test_split(
        X_filtered, 
        y_filtered,
        train_size=N_SAMPLES,
        stratify=y_filtered,  # Keeps class balance
        random_state=RANDOM_SEED  
    )
    
    # Save raw
    folder = Path('../../../data/multiclass/raw') / name
    folder.mkdir(parents=True, exist_ok=True)
    
    np.save(folder / 'X_raw.npy', X_subset)
    np.save(folder / 'y_raw.npy', y_subset)
    
    print(f"Saved {name}: {len(X_subset)} samples")
    print(f"  Class distribution: {np.bincount(y_subset)}")

    #folder = Path('data/multiclass/raw') / name
    folder.mkdir(parents=True, exist_ok=True)
    np.save(folder / 'X_raw.npy', X_subset)
    np.save(folder / 'y_raw.npy', y_subset)
    
    # ========== VALIDATION ==========
    print(f"\n{'='*60}")
    print(f"Validating {name}")
    print(f"{'='*60}")
    
    # 1. Check shapes
    print(f"✓ X shape: {X_subset.shape}")
    print(f"✓ y shape: {y_subset.shape}")
    assert X_subset.shape[0] == y_subset.shape[0], "X and y have different lengths!"
    assert X_subset.shape[0] == N_SAMPLES, f"Expected {N_SAMPLES} samples!"
    assert X_subset.shape[1] == 784, "Expected 784 features (28×28)!"
    
    # 2. Check for NaN or invalid values
    assert not np.isnan(X_subset).any(), "X contains NaN values!"
    assert not np.isnan(y_subset).any(), "y contains NaN values!"
    print(f"✓ No NaN values")
    
    # 3. Check pixel value range
    assert X_subset.min() >= 0, f"Pixel values below 0! Min: {X_subset.min()}"
    assert X_subset.max() <= 255, f"Pixel values above 255! Max: {X_subset.max()}"
    print(f"✓ Pixel range: [{X_subset.min()}, {X_subset.max()}]")
    
  # 4. Check labels
    unique_labels = np.unique(y_subset)
    # Update this to include d2 and d3!
    expected_labels = np.sort(np.array([d0, d1, d2, d3])) 
    assert np.array_equal(unique_labels, expected_labels), \
        f"Unexpected labels! Got {unique_labels}, expected {expected_labels}"
    print(f"✓ Labels: {unique_labels}")

    # 5. Check class balance
    class_counts = np.bincount(y_subset)
    print(f"✓ Class distribution:")
    for d in [d0, d1, d2, d3]:
        print(f"    Digit {d}: {class_counts[d]} samples ({class_counts[d]/N_SAMPLES*100:.1f}%)")

    # Warn if imbalanced (should be ~50/50 with stratification)
    balance_ratio = min(class_counts[d0], class_counts[d1]) / max(class_counts[d0], class_counts[d1])
    if balance_ratio < 0.9:
        print(f"  ⚠ Warning: Classes imbalanced (ratio: {balance_ratio:.2f})")
    else:
        print(f"  ✓ Well balanced (ratio: {balance_ratio:.2f})")
    
    # 6. Check data types
    print(f"✓ Data types: X={X_subset.dtype}, y={y_subset.dtype}")

    # X can be int or float (both valid)
    assert X_subset.dtype in [np.int64, np.int32, np.float64, np.float32], \
    f"Unexpected X dtype: {X_subset.dtype}"

    # y should be int
    assert y_subset.dtype in [np.int64, np.int32], \
    f"Unexpected y dtype: {y_subset.dtype}"
    
    print(f"{'='*60}")
    print(f"✓ All validation checks passed for {name}!")
    print(f"{'='*60}\n")

Saved 4-way_multiclass: 500 samples
  Class distribution: [119 136 121 124]

Validating 4-way_multiclass
✓ X shape: (500, 784)
✓ y shape: (500,)
✓ No NaN values
✓ Pixel range: [0, 255]
✓ Labels: [0 1 2 3]
✓ Class distribution:
    Digit 0: 119 samples (23.8%)
    Digit 1: 136 samples (27.2%)
    Digit 2: 121 samples (24.2%)
    Digit 3: 124 samples (24.8%)
✓ Data types: X=int64, y=int64
✓ All validation checks passed for 4-way_multiclass!



In [5]:
subsets2 = [
    (0, 1, 2, 3, 4, 5, 6, 7, "8-way_multiclass") ]

In [6]:
for d0, d1, d2, d3, d4, d5, d6, d7,name in subsets2:
    # Filter for these digits
    mask = (y == d0) | (y == d1) | (y == d2) | (y == d3) | (y == d4) | (y == d5) | (y == d6) | (y == d7)
    X_filtered = X[mask]
    y_filtered = y[mask]
    
    # Random sample with stratification
    X_subset, _, y_subset, _ = train_test_split(
        X_filtered, 
        y_filtered,
        train_size=N_SAMPLES,
        stratify=y_filtered,  # Keeps class balance
        random_state=RANDOM_SEED  
    )
    
    # Save raw
    folder = Path('../../../data/multiclass/raw') / name
    folder.mkdir(parents=True, exist_ok=True)
    
    np.save(folder / 'X_raw.npy', X_subset)
    np.save(folder / 'y_raw.npy', y_subset)
    
    print(f"Saved {name}: {len(X_subset)} samples")
    print(f"  Class distribution: {np.bincount(y_subset)}")

    #folder = Path('data/multiclass/raw') / name
    folder.mkdir(parents=True, exist_ok=True)
    np.save(folder / 'X_raw.npy', X_subset)
    np.save(folder / 'y_raw.npy', y_subset)
    
    # ========== VALIDATION ==========
    print(f"\n{'='*60}")
    print(f"Validating {name}")
    print(f"{'='*60}")
    
    # 1. Check shapes
    print(f"✓ X shape: {X_subset.shape}")
    print(f"✓ y shape: {y_subset.shape}")
    assert X_subset.shape[0] == y_subset.shape[0], "X and y have different lengths!"
    assert X_subset.shape[0] == N_SAMPLES, f"Expected {N_SAMPLES} samples!"
    assert X_subset.shape[1] == 784, "Expected 784 features (28×28)!"
    
    # 2. Check for NaN or invalid values
    assert not np.isnan(X_subset).any(), "X contains NaN values!"
    assert not np.isnan(y_subset).any(), "y contains NaN values!"
    print(f"✓ No NaN values")
    
    # 3. Check pixel value range
    assert X_subset.min() >= 0, f"Pixel values below 0! Min: {X_subset.min()}"
    assert X_subset.max() <= 255, f"Pixel values above 255! Max: {X_subset.max()}"
    print(f"✓ Pixel range: [{X_subset.min()}, {X_subset.max()}]")
    
  # 4. Check labels
    unique_labels = np.unique(y_subset)
    # Update this to include d2 and d3!
    expected_labels = np.sort(np.array([d0, d1, d2, d3, d4, d5, d6, d7])) 
    assert np.array_equal(unique_labels, expected_labels), \
        f"Unexpected labels! Got {unique_labels}, expected {expected_labels}"
    print(f"✓ Labels: {unique_labels}")

    # 5. Check class balance
    class_counts = np.bincount(y_subset)
    print(f"✓ Class distribution:")
    for d in [d0, d1, d2, d3, d4, d5, d6, d7]:
        print(f"    Digit {d}: {class_counts[d]} samples ({class_counts[d]/N_SAMPLES*100:.1f}%)")

    # Warn if imbalanced (should be ~50/50 with stratification)
    balance_ratio = min(class_counts[d0], class_counts[d1]) / max(class_counts[d0], class_counts[d1])
    if balance_ratio < 0.9:
        print(f"  ⚠ Warning: Classes imbalanced (ratio: {balance_ratio:.2f})")
    else:
        print(f"  ✓ Well balanced (ratio: {balance_ratio:.2f})")
    
    # 6. Check data types
    print(f"✓ Data types: X={X_subset.dtype}, y={y_subset.dtype}")

    # X can be int or float (both valid)
    assert X_subset.dtype in [np.int64, np.int32, np.float64, np.float32], \
    f"Unexpected X dtype: {X_subset.dtype}"

    # y should be int
    assert y_subset.dtype in [np.int64, np.int32], \
    f"Unexpected y dtype: {y_subset.dtype}"
    
    print(f"{'='*60}")
    print(f"✓ All validation checks passed for {name}!")
    print(f"{'='*60}\n")

Saved 8-way_multiclass: 500 samples
  Class distribution: [61 70 62 64 61 56 61 65]

Validating 8-way_multiclass
✓ X shape: (500, 784)
✓ y shape: (500,)
✓ No NaN values
✓ Pixel range: [0, 255]
✓ Labels: [0 1 2 3 4 5 6 7]
✓ Class distribution:
    Digit 0: 61 samples (12.2%)
    Digit 1: 70 samples (14.0%)
    Digit 2: 62 samples (12.4%)
    Digit 3: 64 samples (12.8%)
    Digit 4: 61 samples (12.2%)
    Digit 5: 56 samples (11.2%)
    Digit 6: 61 samples (12.2%)
    Digit 7: 65 samples (13.0%)
✓ Data types: X=int64, y=int64
✓ All validation checks passed for 8-way_multiclass!

