In [1]:
import numpy as np
import h5py
import torch
import torchvision
import torchvision.transforms as transforms
from sklearn.utils import resample
import os

In [2]:
def load_and_preprocess_mnist(target_samples=200000, train_split=150000, valid_split=50000):
    """
    Load MNIST dataset, perform bootstrap resampling, and save to HDF5 format.
    
    Parameters:
    - target_samples: Total number of samples to generate (default: 200,000)
    - train_split: Number of training samples (default: 150,000)
    - valid_split: Number of validation samples (default: 50,000)
    """
    
    print("Loading MNIST dataset...")
    # Define transform to convert PIL Image to tensor and then to numpy
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    # Load the original MNIST dataset using PyTorch
    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, 
        download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        root='./data', train=False, 
        download=True, transform=transform
    )
    
    # Convert to numpy arrays
    x_train_orig = train_dataset.data.numpy()
    y_train_orig = train_dataset.targets.numpy()
    x_test_orig = test_dataset.data.numpy()
    y_test_orig = test_dataset.targets.numpy()
    
    # Combine training and test sets to have all 70,000 samples
    x_all = np.concatenate([x_train_orig, x_test_orig], axis=0)
    y_all = np.concatenate([y_train_orig, y_test_orig], axis=0)
    
    print(f"Original MNIST dataset size: {x_all.shape[0]} samples")
    print(f"Image shape: {x_all.shape[1:]} pixels")
    
    # Check if we have enough samples or need bootstrap resampling
    if x_all.shape[0] >= target_samples:
        print(f"Sufficient samples available. Randomly selecting {target_samples} samples...")
        indices = np.random.choice(x_all.shape[0], size=target_samples, replace=False)
        x_resampled = x_all[indices]
        y_resampled = y_all[indices]
    else:
        print(f"Bootstrap resampling {target_samples} samples from {x_all.shape[0]} original samples...")
        # Bootstrap resampling with replacement
        x_resampled, y_resampled = resample(
            x_all, y_all, 
            n_samples=target_samples, 
            replace=True, 
            random_state=42
        )
    print(f"Resampled dataset size: {x_resampled.shape[0]} samples")
    
    # Convert to float32 and normalize to [0, 1] range
    x_resampled = x_resampled.astype(np.float32) / 255.0
    
    # Randomly shuffle the resampled data
    indices = np.random.permutation(target_samples)
    x_resampled = x_resampled[indices]
    y_resampled = y_resampled[indices]
    
    # Split into training and validation sets
    x_train = x_resampled[:train_split]
    y_train = y_resampled[:train_split]
    x_valid = x_resampled[train_split:train_split + valid_split]
    y_valid = y_resampled[train_split:train_split + valid_split]
    
    print(f"Training set: {x_train.shape[0]} samples")
    print(f"Validation set: {x_valid.shape[0]} samples")
    
    # Save to HDF5 file
    output_filename = 'mnist_resampled.h5'
    print(f"Saving to HDF5 file: {output_filename}")
    
    with h5py.File(output_filename, 'w') as f:
        # Create datasets with proper shapes: N_images, N_pix, N_pix
        f.create_dataset('train', data=x_train, dtype=np.float32, compression='gzip')
        f.create_dataset('valid', data=x_valid, dtype=np.float32, compression='gzip')
        
        # Also save labels (optional, but often useful)
        f.create_dataset('train_labels', data=y_train, dtype=np.int32)
        f.create_dataset('valid_labels', data=y_valid, dtype=np.int32)
        
        # Add metadata
        f.attrs['description'] = 'MNIST dataset with bootstrap resampling'
        f.attrs['total_samples'] = target_samples
        f.attrs['train_samples'] = train_split
        f.attrs['valid_samples'] = valid_split
        f.attrs['image_shape'] = x_train.shape[1:]
        f.attrs['data_type'] = 'float32'
        f.attrs['normalized'] = 'True (0-1 range)'
    
    print(f"Successfully saved HDF5 file with shape:")
    print(f"  - train: {x_train.shape}")
    print(f"  - valid: {x_valid.shape}")
    
    return output_filename

In [3]:
def verify_hdf5_file(filename):
    """
    Verify the contents of the created HDF5 file.
    """
    print(f"\nVerifying HDF5 file: {filename}")
    
    with h5py.File(filename, 'r') as f:
        print("Available keys:", list(f.keys()))
        
        # Check datasets
        for key in ['train', 'valid']:
            if key in f:
                data = f[key]
                print(f"{key} dataset:")
                print(f"  Shape: {data.shape}")
                print(f"  Dtype: {data.dtype}")
                print(f"  Min value: {data[:].min():.4f}")
                print(f"  Max value: {data[:].max():.4f}")
                print(f"  Mean value: {data[:].mean():.4f}")
        
        # Print metadata
        print("\nMetadata:")
        for key, value in f.attrs.items():
            print(f"  {key}: {value}")

def load_from_hdf5(filename):
    """
    Example function to load data from the created HDF5 file.
    """
    print(f"\nExample: Loading data from {filename}")
    
    with h5py.File(filename, 'r') as f:
        x_train = f['train'][:]
        x_valid = f['valid'][:]
        y_train = f['train_labels'][:] if 'train_labels' in f else None
        y_valid = f['valid_labels'][:] if 'valid_labels' in f else None
    
    print(f"Loaded training data: {x_train.shape}")
    print(f"Loaded validation data: {x_valid.shape}")
    
    return x_train, x_valid, y_train, y_valid

In [4]:
if __name__ == "__main__":
    # Set random seed for reproducibility
    np.random.seed(42)
    
    # Main preprocessing
    output_file = load_and_preprocess_mnist(
        target_samples=200000,
        train_split=150000,
        valid_split=50000
    )
    
    # Verify the created file
    verify_hdf5_file(output_file)
    
    # Example of loading the data back
    x_train, x_valid, y_train, y_valid = load_from_hdf5(output_file)
    
    print(f"\nPreprocessing complete!")
    print(f"HDF5 file saved as: {output_file}")
    print(f"File size: {os.path.getsize(output_file) / (1024**2):.2f} MB")

Loading MNIST dataset...
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 67.1MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 2.23MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 19.9MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 24.0MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Original MNIST dataset size: 70000 samples
Image shape: (28, 28) pixels
Bootstrap resampling 200000 samples from 70000 original samples...
Resampled dataset size: 200000 samples
Training set: 150000 samples
Validation set: 50000 samples
Saving to HDF5 file: mnist_resampled.h5
Successfully saved HDF5 file with shape:
  - train: (150000, 28, 28)
  - valid: (50000, 28, 28)

Verifying HDF5 file: mnist_resampled.h5
Available keys: ['train', 'train_labels', 'valid', 'valid_labels']
train dataset:
  Shape: (150000, 28, 28)
  Dtype: float32
  Min value: 0.0000
  Max value: 1.0000
  Mean value: 0.1310
valid dataset:
  Shape: (50000, 28, 28)
  Dtype: float32
  Min value: 0.0000
  Max value: 1.0000
  Mean value: 0.1309

Metadata:
  data_type: float32
  description: MNIST dataset with bootstrap resampling
  image_shape: [28 28]
  normalized: True (0-1 range)
  total_samples: 200000
  train_samples: 150000
  valid_samples: 5