In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import pickle
import os
from PIL import Image
import cv2
import gc
import psutil

class DatasetPreprocessor:
    def __init__(self, data_dir='./data', img_size=(128, 128), batch_size=500):
        """
        Initialize dataset preprocessor for robustness study
        """
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size
        os.makedirs(data_dir, exist_ok=True)
        
        self.processed_dir = os.path.join(data_dir, 'processed')
        os.makedirs(self.processed_dir, exist_ok=True)
        
        # Create samples directory
        self.samples_dir = os.path.join(self.processed_dir, 'samples')
        os.makedirs(self.samples_dir, exist_ok=True)
        
        # Auto-adjust batch size based on available memory
        self._auto_adjust_batch_size()
    
    def _auto_adjust_batch_size(self):
        """Automatically adjust batch size based on available memory"""
        try:
            available_memory_gb = psutil.virtual_memory().available / (1024 ** 3)
            
            if available_memory_gb < 2:
                self.batch_size = min(self.batch_size, 50)
                print(f"[WARNING] Low memory detected ({available_memory_gb:.1f}GB). Batch size reduced to {self.batch_size}")
            elif available_memory_gb < 4:
                self.batch_size = min(self.batch_size, 100)
                print(f"[INFO] Moderate memory ({available_memory_gb:.1f}GB). Batch size set to {self.batch_size}")
            else:
                print(f"[OK] Sufficient memory ({available_memory_gb:.1f}GB). Batch size: {self.batch_size}")
        except:
            print("[INFO] Could not detect memory. Using default batch size.")
    
    def _check_memory_and_collect(self):
        """Check memory usage and collect garbage if needed"""
        try:
            memory_percent = psutil.virtual_memory().percent
            if memory_percent > 85:
                gc.collect()
                if memory_percent > 90:
                    print(f"[WARNING] High memory usage ({memory_percent:.1f}%). Forcing garbage collection...")
        except:
            gc.collect()
    
    def save_sample_images(self, images, labels, class_names, aug_type, dataset_name, n_samples=3):
        """Save sample images for a given augmentation type"""
        # Create subdirectory for this dataset
        dataset_samples_dir = os.path.join(self.samples_dir, dataset_name)
        os.makedirs(dataset_samples_dir, exist_ok=True)
        
        # Select random samples
        n_samples = min(n_samples, len(images))
        sample_indices = np.random.choice(len(images), n_samples, replace=False)
        
        # Create a figure with samples
        fig, axes = plt.subplots(1, n_samples, figsize=(4*n_samples, 4))
        if n_samples == 1:
            axes = [axes]
        
        for idx, sample_idx in enumerate(sample_indices):
            img = images[sample_idx]
            label = labels[sample_idx]
            class_name = class_names[label]
            
            # Display image
            if img.shape[-1] == 1:
                axes[idx].imshow(img.squeeze(), cmap='gray')
            else:
                axes[idx].imshow(img)
            
            axes[idx].set_title(f'Class: {class_name}', fontsize=10)
            axes[idx].axis('off')
        
        plt.suptitle(f'{dataset_name} - {aug_type}', fontsize=12, y=0.98)
        plt.tight_layout()
        
        # Save figure
        save_path = os.path.join(dataset_samples_dir, f'{aug_type}.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        
        print(f"  [OK] Saved {n_samples} samples: {save_path}")
        gc.collect()
    
    def resize_images_batch(self, images, target_size):
        """Resize images in batches to avoid memory issues"""
        total = len(images)
        resized = []
        
        # Use smaller sub-batches for resizing
        sub_batch_size = min(self.batch_size, 100)
        
        for i in range(0, total, sub_batch_size):
            batch_end = min(i + sub_batch_size, total)
            if i % (sub_batch_size * 10) == 0:
                print(f"  Resizing: {i}/{total} images processed")
                self._check_memory_and_collect()
            
            batch = images[i:batch_end]
            batch_resized = [cv2.resize(img, target_size) for img in batch]
            resized.extend(batch_resized)
            
            del batch, batch_resized
        
        gc.collect()
        return np.array(resized)
    
    def apply_rotation(self, images, angle):
        """Apply rotation augmentation in batches"""
        rotated = []
        sub_batch_size = min(self.batch_size, 100)
        
        for i in range(0, len(images), sub_batch_size):
            batch_end = min(i + sub_batch_size, len(images))
            batch = images[i:batch_end]
            
            batch_rotated = []
            for img in batch:
                if len(img.shape) == 3 and img.shape[-1] == 1:
                    img_2d = img.squeeze(axis=-1)
                    img_pil = Image.fromarray(img_2d.astype('uint8'))
                    rotated_pil = img_pil.rotate(angle, fillcolor=0)
                    rotated_np = np.expand_dims(np.array(rotated_pil), axis=-1)
                else:
                    img_pil = Image.fromarray(img.astype('uint8'))
                    rotated_pil = img_pil.rotate(angle, fillcolor=0)
                    rotated_np = np.array(rotated_pil)
                batch_rotated.append(rotated_np)
            
            rotated.extend(batch_rotated)
            del batch, batch_rotated
            
            if i % (sub_batch_size * 5) == 0:
                self._check_memory_and_collect()
        
        return np.array(rotated)
    
    def apply_noise(self, images):
        """Apply Gaussian noise augmentation in batches"""
        noisy = []
        sub_batch_size = min(self.batch_size, 100)
        
        for i in range(0, len(images), sub_batch_size):
            batch_end = min(i + sub_batch_size, len(images))
            batch = images[i:batch_end]
            
            # Generate noise for entire batch at once (more efficient)
            noise = np.random.normal(0, 25, batch.shape).astype(np.float32)
            batch_noisy = np.clip(batch.astype(np.float32) + noise, 0, 255).astype(np.uint8)
            
            noisy.extend(batch_noisy)
            del batch, noise, batch_noisy
            
            if i % (sub_batch_size * 5) == 0:
                self._check_memory_and_collect()
        
        return np.array(noisy)
    
    def apply_scaling(self, images, scale):
        """Apply scaling augmentation in batches"""
        scaled = []
        sub_batch_size = min(self.batch_size, 100)
        
        for i in range(0, len(images), sub_batch_size):
            batch_end = min(i + sub_batch_size, len(images))
            batch = images[i:batch_end]
            
            batch_scaled = []
            for img in batch:
                h, w = img.shape[:2]
                new_h, new_w = int(h * scale), int(w * scale)
                
                if len(img.shape) == 3 and img.shape[-1] == 1:
                    img_2d = img.squeeze(axis=-1)
                    img_pil = Image.fromarray(img_2d.astype('uint8'))
                    scaled_pil = img_pil.resize((new_w, new_h))
                    final_pil = Image.new('L', (w, h), color=0)
                    paste_x = max(0, (w - new_w) // 2)
                    paste_y = max(0, (h - new_h) // 2)
                    final_pil.paste(scaled_pil, (paste_x, paste_y))
                    scaled_np = np.expand_dims(np.array(final_pil), axis=-1)
                else:
                    img_pil = Image.fromarray(img.astype('uint8'))
                    scaled_pil = img_pil.resize((new_w, new_h))
                    final_pil = Image.new('RGB', (w, h), color=0)
                    paste_x = max(0, (w - new_w) // 2)
                    paste_y = max(0, (h - new_h) // 2)
                    final_pil.paste(scaled_pil, (paste_x, paste_y))
                    scaled_np = np.array(final_pil)
                
                batch_scaled.append(scaled_np)
            
            scaled.extend(batch_scaled)
            del batch, batch_scaled
            
            if i % (sub_batch_size * 5) == 0:
                self._check_memory_and_collect()
        
        return np.array(scaled)
    
    def apply_occlusion(self, images, occlusion_ratio=0.25):
        """Apply occlusion augmentation in batches"""
        occluded = []
        occlusion_size = int(min(self.img_size) * occlusion_ratio)
        sub_batch_size = min(self.batch_size, 100)
        
        for i in range(0, len(images), sub_batch_size):
            batch_end = min(i + sub_batch_size, len(images))
            batch = images[i:batch_end].copy()
            
            # Apply occlusion to batch
            for j in range(len(batch)):
                h, w = batch[j].shape[:2]
                x = np.random.randint(0, max(1, w - occlusion_size))
                y = np.random.randint(0, max(1, h - occlusion_size))
                batch[j][y:y+occlusion_size, x:x+occlusion_size] = 0
            
            occluded.extend(batch)
            del batch
            
            if i % (sub_batch_size * 5) == 0:
                self._check_memory_and_collect()
        
        return np.array(occluded)
    
    def create_augmented_test_sets(self, images, labels, class_names, dataset_name):
        """
        Create all augmented test datasets using entire test set
        Process and save each augmentation immediately to minimize memory
        """
        print(f"  Creating augmented test sets with {len(images)} samples...")
        
        # Convert to uint8 for augmentation
        images_uint8 = (images * 255).astype('uint8')
        
        # 1. Original dataset - save immediately
        print("    Creating and saving original test set...")
        test_dir = os.path.join(self.processed_dir, 'temp_test')
        os.makedirs(test_dir, exist_ok=True)
        
        original_data = {'images': images.copy(), 'labels': labels.copy()}
        with open(os.path.join(test_dir, 'original.pkl'), 'wb') as f:
            pickle.dump(original_data, f)
        
        # Save samples for original
        self.save_sample_images(images, labels, class_names, 'original_test', dataset_name)
        
        del original_data
        gc.collect()
        
        # Process each augmentation separately and save immediately
        augmentations = [
            ('rotation_15', lambda x: self.apply_rotation(x, 15)),
            ('noise', self.apply_noise),
            ('scaling_0.8', lambda x: self.apply_scaling(x, 0.8)),
            ('occlusion_25', lambda x: self.apply_occlusion(x, 0.25))
        ]
        
        for aug_name, aug_func in augmentations:
            print(f"    Creating and saving {aug_name} test set...")
            augmented = aug_func(images_uint8)
            augmented_normalized = augmented.astype('float32') / 255.0
            augmented_data = {
                'images': augmented_normalized,
                'labels': labels.copy()
            }
            with open(os.path.join(test_dir, f'{aug_name}.pkl'), 'wb') as f:
                pickle.dump(augmented_data, f)
            
            # Save samples for this augmentation
            self.save_sample_images(augmented_normalized, labels, class_names, f'{aug_name}_test', dataset_name)
            
            del augmented, augmented_normalized, augmented_data
            gc.collect()
        
        # 6. All combined - process in smaller chunks
        print("    Creating and saving combined augmentation test set...")
        combined = images_uint8.copy()
        combined = self.apply_rotation(combined, 15)
        gc.collect()
        combined = self.apply_noise(combined)
        gc.collect()
        combined = self.apply_scaling(combined, 0.8)
        gc.collect()
        combined = self.apply_occlusion(combined, 0.25)
        
        combined_normalized = combined.astype('float32') / 255.0
        combined_data = {
            'images': combined_normalized,
            'labels': labels.copy()
        }
        with open(os.path.join(test_dir, 'all_combined.pkl'), 'wb') as f:
            pickle.dump(combined_data, f)
        
        # Save samples for combined
        self.save_sample_images(combined_normalized, labels, class_names, 'all_combined_test', dataset_name)
        
        del combined, combined_normalized, combined_data, images_uint8
        gc.collect()
        
        # Load all datasets back
        augmented_datasets = {}
        for name in ['original', 'rotation_15', 'noise', 'scaling_0.8', 'occlusion_25', 'all_combined']:
            with open(os.path.join(test_dir, f'{name}.pkl'), 'rb') as f:
                augmented_datasets[name] = pickle.load(f)
        
        return augmented_datasets
    
    def create_augmented_train_set(self, images, labels, class_names, dataset_name):
        """
        Create augmented training set with mixed augmentations
        Process and save incrementally to minimize memory
        """
        print(f"  Creating augmented training set with {len(images)} samples...")
        
        # Convert to uint8 for augmentation
        images_uint8 = (images * 255).astype('uint8')
        
        # Split into 5 equal parts
        n_samples = len(images)
        split_size = n_samples // 5
        indices = np.random.permutation(n_samples)
        
        # Create temp directory for intermediate storage
        train_dir = os.path.join(self.processed_dir, 'temp_train')
        os.makedirs(train_dir, exist_ok=True)
        
        # Process mixed augmentation splits
        augmented_images = []
        augmented_labels = []
        
        splits_data = []
        for i in range(5):
            start_idx = i * split_size
            end_idx = (i + 1) * split_size if i < 4 else n_samples
            split_indices = indices[start_idx:end_idx]
            splits_data.append((images_uint8[split_indices], labels[split_indices]))
        
        # Split 0: Original
        print("    Processing original split...")
        augmented_images.extend(splits_data[0][0])
        augmented_labels.extend(splits_data[0][1])
        
        # Split 1: Rotation
        print("    Processing rotation split...")
        rotated = self.apply_rotation(splits_data[1][0], 15)
        augmented_images.extend(rotated)
        augmented_labels.extend(splits_data[1][1])
        del rotated
        gc.collect()
        
        # Split 2: Noise
        print("    Processing noise split...")
        noisy = self.apply_noise(splits_data[2][0])
        augmented_images.extend(noisy)
        augmented_labels.extend(splits_data[2][1])
        del noisy
        gc.collect()
        
        # Split 3: Scaling
        print("    Processing scaling split...")
        scaled = self.apply_scaling(splits_data[3][0], 0.8)
        augmented_images.extend(scaled)
        augmented_labels.extend(splits_data[3][1])
        del scaled
        gc.collect()
        
        # Split 4: Occlusion
        print("    Processing occlusion split...")
        occluded = self.apply_occlusion(splits_data[4][0], 0.25)
        augmented_images.extend(occluded)
        augmented_labels.extend(splits_data[4][1])
        del occluded
        gc.collect()
        
        # Convert to arrays and normalize
        augmented_images = np.array(augmented_images).astype('float32') / 255.0
        augmented_labels = np.array(augmented_labels)
        
        # Save mixed augmented immediately
        mixed_data = {'images': augmented_images, 'labels': augmented_labels}
        with open(os.path.join(train_dir, 'mixed_augmented.pkl'), 'wb') as f:
            pickle.dump(mixed_data, f)
        
        # Save samples for mixed augmented
        self.save_sample_images(augmented_images, augmented_labels, class_names, 'mixed_augmented_train', dataset_name)
        
        del augmented_images, augmented_labels, mixed_data
        gc.collect()
        
        # Create combined augmentation version in chunks
        print("    Creating combined augmentation training set...")
        combined_images = []
        chunk_size = 5000  # Process in smaller chunks
        
        for i in range(0, len(images_uint8), chunk_size):
            chunk_end = min(i + chunk_size, len(images_uint8))
            print(f"      Processing chunk {i}-{chunk_end}...")
            
            chunk = images_uint8[i:chunk_end].copy()
            chunk = self.apply_rotation(chunk, 15)
            chunk = self.apply_noise(chunk)
            chunk = self.apply_scaling(chunk, 0.8)
            chunk = self.apply_occlusion(chunk, 0.25)
            chunk = chunk.astype('float32') / 255.0
            
            combined_images.append(chunk)
            del chunk
            gc.collect()
        
        combined = np.concatenate(combined_images, axis=0)
        del combined_images
        gc.collect()
        
        # Save combined augmented
        combined_data = {'images': combined, 'labels': labels.copy()}
        with open(os.path.join(train_dir, 'combined_augmented.pkl'), 'wb') as f:
            pickle.dump(combined_data, f)
        
        # Save samples for combined augmented
        self.save_sample_images(combined, labels, class_names, 'combined_augmented_train', dataset_name)
        
        del combined, combined_data
        gc.collect()
        
        # Save original
        original_data = {'images': images.copy(), 'labels': labels.copy()}
        with open(os.path.join(train_dir, 'original.pkl'), 'wb') as f:
            pickle.dump(original_data, f)
        
        # Save samples for original train
        self.save_sample_images(images, labels, class_names, 'original_train', dataset_name)
        
        del original_data
        gc.collect()
        
        # Load all datasets back
        train_datasets = {}
        for name in ['original', 'mixed_augmented', 'combined_augmented']:
            with open(os.path.join(train_dir, f'{name}.pkl'), 'rb') as f:
                train_datasets[name] = pickle.load(f)
        
        return train_datasets
    
    def load_mnist(self):
        """Load and preprocess MNIST (grayscale dataset)"""
        print("\n" + "="*60)
        print("Loading MNIST Dataset")
        print("="*60)
        
        # Download dataset
        trainset = torchvision.datasets.MNIST(
            root=self.data_dir, train=True, download=True
        )
        testset = torchvision.datasets.MNIST(
            root=self.data_dir, train=False, download=True
        )
        
        # Convert to numpy
        print("Converting to numpy arrays...")
        train_images = trainset.data.numpy()
        train_labels = trainset.targets.numpy()
        test_images = testset.data.numpy()
        test_labels = testset.targets.numpy()
        
        # Resize images
        print(f"Resizing to {self.img_size}...")
        train_images_resized = self.resize_images_batch(train_images, self.img_size)
        gc.collect()
        test_images_resized = self.resize_images_batch(test_images, self.img_size)
        gc.collect()
        
        # Add channel dimension for consistency
        train_images_resized = np.expand_dims(train_images_resized, axis=-1)
        test_images_resized = np.expand_dims(test_images_resized, axis=-1)
        
        # Normalize to [0, 1] range
        train_images_resized = train_images_resized.astype('float32') / 255.0
        test_images_resized = test_images_resized.astype('float32') / 255.0
        
        mnist_data = {
            'train_images': train_images_resized,
            'train_labels': train_labels,
            'test_images': test_images_resized,
            'test_labels': test_labels,
            'class_names': [str(i) for i in range(10)],
            'dataset_type': 'grayscale'
        }
        
        del train_images, test_images
        gc.collect()
        
        print(f"[OK] MNIST loaded: {len(train_images_resized)} train, {len(test_images_resized)} test")
        return mnist_data
    
    def load_cifar10(self):
        """Load and preprocess CIFAR-10 (color dataset)"""
        print("\n" + "="*60)
        print("Loading CIFAR-10 Dataset")
        print("="*60)
        
        # Download dataset
        trainset = torchvision.datasets.CIFAR10(
            root=self.data_dir, train=True, download=True
        )
        testset = torchvision.datasets.CIFAR10(
            root=self.data_dir, train=False, download=True
        )
        
        # Convert to numpy in smaller batches
        print("Converting training data...")
        train_images = []
        train_labels = []
        sub_batch_size = min(self.batch_size, 500)
        
        for i in range(0, len(trainset), sub_batch_size):
            batch_end = min(i + sub_batch_size, len(trainset))
            batch_imgs = [np.array(trainset[j][0]) for j in range(i, batch_end)]
            batch_labels = [trainset[j][1] for j in range(i, batch_end)]
            train_images.extend(batch_imgs)
            train_labels.extend(batch_labels)
            if i % (sub_batch_size * 10) == 0:
                print(f"  Processed {batch_end}/{len(trainset)} samples")
                self._check_memory_and_collect()
        
        print("Converting test data...")
        test_images = []
        test_labels = []
        for i in range(0, len(testset), sub_batch_size):
            batch_end = min(i + sub_batch_size, len(testset))
            batch_imgs = [np.array(testset[j][0]) for j in range(i, batch_end)]
            batch_labels = [testset[j][1] for j in range(i, batch_end)]
            test_images.extend(batch_imgs)
            test_labels.extend(batch_labels)
            if i % (sub_batch_size * 5) == 0:
                self._check_memory_and_collect()
        
        train_images = np.array(train_images)
        train_labels = np.array(train_labels)
        test_images = np.array(test_images)
        test_labels = np.array(test_labels)
        
        # Resize images
        print(f"Resizing to {self.img_size}...")
        train_images_resized = self.resize_images_batch(train_images, self.img_size)
        del train_images
        gc.collect()
        
        test_images_resized = self.resize_images_batch(test_images, self.img_size)
        del test_images
        gc.collect()
        
        # Normalize to [0, 1] range
        train_images_resized = train_images_resized.astype('float32') / 255.0
        test_images_resized = test_images_resized.astype('float32') / 255.0
        
        cifar_class_names = [
            'airplane', 'automobile', 'bird', 'cat', 'deer',
            'dog', 'frog', 'horse', 'ship', 'truck'
        ]
        
        cifar_data = {
            'train_images': train_images_resized,
            'train_labels': train_labels,
            'test_images': test_images_resized,
            'test_labels': test_labels,
            'class_names': cifar_class_names,
            'dataset_type': 'color'
        }
        
        gc.collect()
        
        print(f"[OK] CIFAR-10 loaded: {len(train_images_resized)} train, {len(test_images_resized)} test")
        return cifar_data
    
    def save_datasets_separately(self, dataset_name, original_data, train_datasets, test_datasets):
        """Save datasets separately to avoid memory issues"""
        print(f"\nSaving {dataset_name} datasets separately...")
        
        # Save original data
        original_path = os.path.join(self.processed_dir, f'{dataset_name}_original.pkl')
        with open(original_path, 'wb') as f:
            pickle.dump(original_data, f)
        print(f"[OK] Saved original data: {os.path.getsize(original_path) / (1024 * 1024):.2f} MB")
        
        # Save training datasets
        train_dir = os.path.join(self.processed_dir, f'{dataset_name}_train')
        os.makedirs(train_dir, exist_ok=True)
        
        for train_type, data in train_datasets.items():
            train_path = os.path.join(train_dir, f'{train_type}.pkl')
            with open(train_path, 'wb') as f:
                pickle.dump(data, f)
            print(f"[OK] Saved {train_type} train: {os.path.getsize(train_path) / (1024 * 1024):.2f} MB")
            gc.collect()
        
        # Save test datasets
        test_dir = os.path.join(self.processed_dir, f'{dataset_name}_test')
        os.makedirs(test_dir, exist_ok=True)
        
        for test_type, data in test_datasets.items():
            test_path = os.path.join(test_dir, f'{test_type}.pkl')
            with open(test_path, 'wb') as f:
                pickle.dump(data, f)
            print(f"[OK] Saved {test_type} test: {os.path.getsize(test_path) / (1024 * 1024):.2f} MB")
            gc.collect()
        
        # Save metadata
        metadata = {
            'dataset_name': dataset_name,
            'img_size': self.img_size,
            'train_types': list(train_datasets.keys()),
            'test_types': list(test_datasets.keys()),
            'original_shape': original_data['train_images'].shape[1:]
        }
        
        metadata_path = os.path.join(self.processed_dir, f'{dataset_name}_metadata.pkl')
        with open(metadata_path, 'wb') as f:
            pickle.dump(metadata, f)
        
        print(f"[OK] Saved metadata")
        
        # Clean up temp directories
        import shutil
        temp_train_dir = os.path.join(self.processed_dir, 'temp_train')
        temp_test_dir = os.path.join(self.processed_dir, 'temp_test')
        if os.path.exists(temp_train_dir):
            shutil.rmtree(temp_train_dir)
        if os.path.exists(temp_test_dir):
            shutil.rmtree(temp_test_dir)
    
    def load_dataset(self, dataset_name, dataset_type, aug_type):
        """Load specific dataset"""
        if dataset_type == 'original':
            path = os.path.join(self.processed_dir, f'{dataset_name}_original.pkl')
        elif dataset_type == 'train':
            path = os.path.join(self.processed_dir, f'{dataset_name}_train', f'{aug_type}.pkl')
        elif dataset_type == 'test':
            path = os.path.join(self.processed_dir, f'{dataset_name}_test', f'{aug_type}.pkl')
        
        with open(path, 'rb') as f:
            data = pickle.load(f)
        return data
    
    def visualize_augmentation_comparison(self, train_datasets, test_datasets, dataset_name, sample_idx=0):
        """Visualize training and test augmentations side by side"""
        fig, axes = plt.subplots(2, 6, figsize=(18, 6))
        
        # Training augmentations
        train_types = ['original', 'mixed_augmented', 'combined_augmented']
        train_titles = ['Original Train', 'Mixed Aug Train', 'Combined Aug Train']
        
        for i, (aug_type, title) in enumerate(zip(train_types, train_titles)):
            img = train_datasets[aug_type]['images'][sample_idx]
            if img.shape[-1] == 1:
                axes[0, i].imshow(img.squeeze(), cmap='gray')
            else:
                axes[0, i].imshow(img)
            axes[0, i].set_title(title, fontsize=10)
            axes[0, i].axis('off')
        
        # Clear unused training plots
        for i in range(3, 6):
            axes[0, i].axis('off')
        
        # Test augmentations
        test_types = ['original', 'rotation_15', 'noise', 'scaling_0.8', 'occlusion_25', 'all_combined']
        test_titles = ['Original Test', 'Rotation 15°', 'Noise', 'Scaling 0.8x', 'Occlusion 25%', 'All Combined']
        
        for i, (aug_type, title) in enumerate(zip(test_types, test_titles)):
            img = test_datasets[aug_type]['images'][sample_idx]
            if img.shape[-1] == 1:
                axes[1, i].imshow(img.squeeze(), cmap='gray')
            else:
                axes[1, i].imshow(img)
            axes[1, i].set_title(title, fontsize=10)
            axes[1, i].axis('off')
        
        plt.suptitle(f'{dataset_name} - Training and Test Augmentations', fontsize=14, y=1.02)
        plt.tight_layout()
        
        save_path = os.path.join(self.processed_dir, f'{dataset_name}_augmentation_comparison.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"[OK] Saved comparison to {save_path}")
        plt.close()
        gc.collect()
    
    def print_dataset_statistics(self, train_datasets, test_datasets, dataset_name):
        """Print statistics for all datasets"""
        print(f"\n{'='*60}")
        print(f"{dataset_name.upper()} - Dataset Statistics")
        print(f"{'='*60}")
        
        print(f"\nTRAINING SETS:")
        for aug_type in train_datasets.keys():
            images = train_datasets[aug_type]['images']
            labels = train_datasets[aug_type]['labels']
            print(f"  {aug_type.upper():<20}: {len(images):>6} samples, shape: {images[0].shape}")
        
        print(f"\nTEST SETS:")
        for aug_type in test_datasets.keys():
            images = test_datasets[aug_type]['images']
            labels = test_datasets[aug_type]['labels']
            print(f"  {aug_type.upper():<20}: {len(images):>6} samples, shape: {images[0].shape}")
        
        print(f"\n{'='*60}")


def main():
    """
    Main preprocessing pipeline for robustness study
    """
    print("\n" + "="*60)
    print("ROBUSTNESS STUDY - DATASET PREPROCESSING")
    print("="*60)
    
    # Initialize preprocessor with adaptive memory management
    preprocessor = DatasetPreprocessor(
        data_dir='./data',
        img_size=(128, 128),
        batch_size=200
    )
    
    try:
        # ==========================================
        # STEP 1: Collect Datasets
        # ==========================================
        
        # Load MNIST (grayscale)
        mnist_data = preprocessor.load_mnist()
        
        # ==========================================
        # STEP 2: Create Augmented Datasets - MNIST
        # ==========================================
        
        print("\n" + "="*60)
        print("Creating Augmented Datasets for MNIST")
        print("="*60)
        
        mnist_train_datasets = preprocessor.create_augmented_train_set(
            mnist_data['train_images'], 
            mnist_data['train_labels'],
            mnist_data['class_names'],
            'mnist'
        )
        mnist_test_datasets = preprocessor.create_augmented_test_sets(
            mnist_data['test_images'], 
            mnist_data['test_labels'],
            mnist_data['class_names'],
            'mnist'
        )
        
        # Save MNIST immediately to free memory
        print("\nSaving MNIST datasets...")
        preprocessor.save_datasets_separately('mnist', mnist_data, mnist_train_datasets, mnist_test_datasets)
        
        # Print MNIST statistics before freeing
        preprocessor.print_dataset_statistics(mnist_train_datasets, mnist_test_datasets, 'MNIST')
        
        # Free MNIST memory
        del mnist_data, mnist_train_datasets, mnist_test_datasets
        gc.collect()
        
        
        # ==========================================
        # Visualizations (load small samples)
        # ==========================================
        
        print("\n" + "="*60)
        print("Creating Visualizations")
        print("="*60)
        
        # Load small samples for MNIST visualization
        print("\nCreating MNIST visualizations...")
        mnist_train_viz = {}
        for train_type in ['original', 'mixed_augmented', 'combined_augmented']:
            mnist_train_viz[train_type] = preprocessor.load_dataset('mnist', 'train', train_type)
        
        mnist_test_viz = {}
        for test_type in ['original', 'rotation_15', 'noise', 'scaling_0.8', 'occlusion_25', 'all_combined']:
            mnist_test_viz[test_type] = preprocessor.load_dataset('mnist', 'test', test_type)
        
        preprocessor.visualize_augmentation_comparison(mnist_train_viz, mnist_test_viz, 'MNIST')
        
        # Free MNIST visualization memory
        del mnist_train_viz, mnist_test_viz
        gc.collect()
        
        # ==========================================
        # COMPLETION MESSAGE
        # ==========================================
        
        print("\n" + "="*60)
        print("[OK] PREPROCESSING COMPLETE")
        print("="*60)
        print("\nGenerated datasets for robustness study:")
        print("\nTRAINING SETS (3 versions per dataset):")
        print("  1. original - Clean training data")
        print("  2. mixed_augmented - Equal parts: original + rotation + noise + scaling + occlusion")
        print("  3. combined_augmented - All augmentations applied to every sample")
        
        print("\nTEST SETS (6 versions per dataset):")
        print("  1. original - Clean test data")
        print("  2. rotation_15 - 15° rotation")
        print("  3. noise - Gaussian noise")
        print("  4. scaling_0.8 - 0.8x scaling")
        print("  5. occlusion_25 - 25% occlusion")
        print("  6. all_combined - All augmentations combined")
        
        print(f"\nProcessed data saved in: {preprocessor.processed_dir}")
        print(f"Sample images saved in: {preprocessor.samples_dir}")
        print("\n" + "="*60)
        
        return preprocessor
        
    except MemoryError as e:
        print(f"\n[ERROR] Memory error occurred: {e}")
        print("Suggestions:")
        print("  1. Reduce img_size (e.g., to (96, 96) or (64, 64))")
        print("  2. Close other applications to free memory")
        print("  3. Process datasets one at a time by commenting out one dataset")
        print("  4. Increase system swap/virtual memory")
        return None
    except Exception as e:
        print(f"\n[ERROR] Error occurred: {e}")
        import traceback
        traceback.print_exc()
        return None
    finally:
        # Final cleanup
        gc.collect()


if __name__ == "__main__":
    main()


ROBUSTNESS STUDY - DATASET PREPROCESSING
[OK] Sufficient memory (15.6GB). Batch size: 200

Loading MNIST Dataset
Converting to numpy arrays...
Resizing to (128, 128)...
  Resizing: 0/60000 images processed
  Resizing: 1000/60000 images processed
  Resizing: 2000/60000 images processed
  Resizing: 3000/60000 images processed
  Resizing: 4000/60000 images processed
  Resizing: 5000/60000 images processed
  Resizing: 6000/60000 images processed
  Resizing: 7000/60000 images processed
  Resizing: 8000/60000 images processed
  Resizing: 9000/60000 images processed
  Resizing: 10000/60000 images processed
  Resizing: 11000/60000 images processed
  Resizing: 12000/60000 images processed
  Resizing: 13000/60000 images processed
  Resizing: 14000/60000 images processed
  Resizing: 15000/60000 images processed
  Resizing: 16000/60000 images processed
  Resizing: 17000/60000 images processed
  Resizing: 18000/60000 images processed
  Resizing: 19000/60000 images processed
  Resizing: 20000/60000