# Lab 2.1.2: Dataset Pipeline - Efficient Data Loading

**Module:** 2.1 - Deep Learning with PyTorch  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê (Intermediate)

---

## Learning Objectives

By the end of this notebook, you will:
- [ ] Implement a custom `Dataset` class for local image folders
- [ ] Create data augmentation transforms using `torchvision.transforms`
- [ ] Configure `DataLoader` with optimal worker settings
- [ ] Benchmark and optimize data loading for DGX Spark
- [ ] Understand memory-mapped datasets for large data

---

## Prerequisites

- Completed: Lab 2.1.1 (Custom Module Lab)
- Knowledge of: Python iterators, image file formats

---

## Real-World Context

Data loading is often the **bottleneck** in deep learning training! If your GPU is waiting for data, you're wasting expensive compute.

Real scenarios where data pipelines matter:
- **Medical imaging**: Loading huge DICOM files (hundreds of MB each)
- **Autonomous vehicles**: Processing millions of driving images
- **Recommendation systems**: Streaming user interaction data
- **Video analysis**: Extracting frames efficiently

With DGX Spark's 128GB unified memory, we can load more data, but we need to do it efficiently!

---

## ELI5: What is a DataLoader?

> **Imagine a restaurant kitchen...** üç≥
>
> The **Dataset** is like your ingredients pantry - it knows where everything is and how to get it.
>
> The **DataLoader** is like your kitchen staff:
> - They go to the pantry, grab ingredients
> - They prepare multiple orders at once (batching)
> - Multiple cooks work in parallel (num_workers)
> - They shuffle orders to prevent bias (shuffle=True)
> - They have a prep station ready for the next order (prefetching)
>
> The chef (GPU) never has to wait - there's always a prepared batch ready!
>
> **In AI terms:** Dataset defines HOW to load individual samples. DataLoader handles batching, shuffling, and parallel loading.

---

## Part 1: Environment Setup

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

import os
import time
import tempfile
import numpy as np
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
from typing import Tuple, List, Optional, Callable
from collections import defaultdict
import multiprocessing

# Check environment
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CPU cores: {multiprocessing.cpu_count()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

---

## Part 2: Creating a Custom Dataset

PyTorch's `Dataset` class requires you to implement three methods:
1. `__init__`: Initialize the dataset (load file paths, transforms, etc.)
2. `__len__`: Return the total number of samples
3. `__getitem__`: Return a single sample given an index

Let's start by creating a synthetic dataset for demonstration:

In [None]:
# Create a sample image folder structure for this tutorial
# Using a unique temp directory to avoid conflicts with user data

# Create unique temp directory for this session
SAMPLE_DATA_DIR = tempfile.mkdtemp(prefix='module06_dataset_')
print(f"Using temp directory: {SAMPLE_DATA_DIR}")

def create_sample_dataset(root_dir: str, num_images: int = 100):
    """
    Create a sample dataset with random images for testing.
    
    Structure:
        sample_data/
            class_0/
                image_0.jpg
                image_1.jpg
                ...
            class_1/
                image_0.jpg
                ...
    """
    root = Path(root_dir)
    classes = ['cats', 'dogs', 'birds']
    
    for cls in classes:
        class_dir = root / cls
        class_dir.mkdir(parents=True, exist_ok=True)
        
        for i in range(num_images // len(classes)):
            # Create a random color image
            img_array = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
            img = Image.fromarray(img_array)
            img.save(class_dir / f'image_{i:04d}.jpg')
    
    print(f"Created sample dataset at '{root_dir}' with {num_images} images")
    return root_dir

# Create the sample dataset in the unique temp directory
sample_data_path = create_sample_dataset(SAMPLE_DATA_DIR, num_images=300)

In [None]:
class ImageFolderDataset(Dataset):
    """
    Custom Dataset for loading images from a folder structure.
    
    Expected folder structure:
        root/
            class1/
                img1.jpg
                img2.jpg
            class2/
                img1.jpg
                ...
    
    Args:
        root_dir: Path to the root directory
        transform: Optional transform to apply to images
        extensions: Tuple of valid image extensions
    
    Example:
        >>> dataset = ImageFolderDataset('./data', transform=T.ToTensor())
        >>> image, label = dataset[0]
        >>> print(image.shape, label)
    """
    
    def __init__(
        self,
        root_dir: str,
        transform: Optional[Callable] = None,
        extensions: Tuple[str, ...] = ('.jpg', '.jpeg', '.png', '.bmp')
    ):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.extensions = extensions
        
        # Discover classes (subdirectories)
        self.classes = sorted([
            d.name for d in self.root_dir.iterdir()
            if d.is_dir()
        ])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        # Collect all image paths and labels
        self.samples = []
        for class_name in self.classes:
            class_dir = self.root_dir / class_name
            for img_path in class_dir.iterdir():
                if img_path.suffix.lower() in self.extensions:
                    self.samples.append((
                        str(img_path),
                        self.class_to_idx[class_name]
                    ))
        
        print(f"Found {len(self.samples)} images in {len(self.classes)} classes")
    
    def __len__(self) -> int:
        """Return the total number of samples."""
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        """
        Get a single sample.
        
        Args:
            idx: Index of the sample
            
        Returns:
            Tuple of (image, label)
        """
        img_path, label = self.samples[idx]
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
    def get_class_name(self, label: int) -> str:
        """Get the class name for a given label."""
        return self.classes[label]


# Create and test our dataset
simple_transform = T.Compose([
    T.Resize((32, 32)),
    T.ToTensor(),
])

dataset = ImageFolderDataset(SAMPLE_DATA_DIR, transform=simple_transform)

# Test it
image, label = dataset[0]
print(f"\nSample image shape: {image.shape}")
print(f"Sample label: {label} ({dataset.get_class_name(label)})")
print(f"Image dtype: {image.dtype}")
print(f"Image range: [{image.min():.3f}, {image.max():.3f}]")

### What Just Happened?

1. We scanned the directory to find all classes (subdirectories)
2. We collected all image paths and their corresponding labels
3. When `__getitem__` is called, we load the image on-demand (lazy loading)
4. The transform converts the image to a tensor

**Key insight:** Images are only loaded when requested, saving memory!

---

## Part 3: Data Augmentation Transforms

Data augmentation creates variations of training images to:
1. **Prevent overfitting** - model sees more diverse examples
2. **Improve robustness** - model handles real-world variations
3. **Increase effective dataset size** - especially important for small datasets

In [None]:
# Training transforms with augmentation
train_transform = T.Compose([
    # Resize to slightly larger than target
    T.Resize(40),
    
    # Random crop to target size (32x32 for CIFAR-like)
    T.RandomCrop(32, padding=4),
    
    # Random horizontal flip (50% chance)
    T.RandomHorizontalFlip(p=0.5),
    
    # Color jitter - randomly adjust brightness, contrast, saturation
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    
    # Random rotation (small angles)
    T.RandomRotation(degrees=15),
    
    # Convert to tensor (scales to [0, 1])
    T.ToTensor(),
    
    # Normalize using ImageNet statistics
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Validation/Test transforms - NO augmentation!
val_transform = T.Compose([
    T.Resize(32),
    T.CenterCrop(32),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

print("Training transform:")
print(train_transform)

print("\nValidation transform:")
print(val_transform)

In [None]:
# Visualize augmentations
def show_augmentations(image_path: str, transform: T.Compose, n_samples: int = 6):
    """Show multiple augmented versions of the same image."""
    
    original = Image.open(image_path).convert('RGB')
    
    fig, axes = plt.subplots(2, n_samples // 2 + 1, figsize=(15, 6))
    axes = axes.flatten()
    
    # Show original
    axes[0].imshow(original)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    # Show augmented versions
    # Need a transform without normalization for visualization
    viz_transform = T.Compose([
        T.Resize(40),
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(p=0.5),
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        T.RandomRotation(degrees=15),
        T.ToTensor(),
    ])
    
    for i in range(1, n_samples + 1):
        augmented = viz_transform(original)
        # Convert tensor to displayable format
        img = augmented.permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        axes[i].imshow(img)
        axes[i].set_title(f'Augmented {i}')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Show augmentations on a sample image
sample_path = Path(SAMPLE_DATA_DIR) / 'cats' / 'image_0000.jpg'
show_augmentations(str(sample_path), train_transform)

### Advanced Transform: RandAugment

RandAugment (from Google) is a modern augmentation strategy that randomly applies N transforms from a pool of options.

In [None]:
# RandAugment is available in newer torchvision versions
try:
    from torchvision.transforms import autoaugment
    
    randaugment_transform = T.Compose([
        T.Resize(40),
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        autoaugment.RandAugment(num_ops=2, magnitude=9),  # Apply 2 random transforms
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    print("RandAugment available!")
except ImportError:
    print("RandAugment not available in this torchvision version")
    randaugment_transform = train_transform

---

## Part 4: DataLoader Configuration

The DataLoader handles:
- **Batching**: Combining samples into batches
- **Shuffling**: Randomizing order each epoch
- **Parallel loading**: Using multiple worker processes
- **Prefetching**: Loading next batches while GPU works

### Key Parameters:
- `batch_size`: How many samples per batch
- `num_workers`: Number of parallel data loading processes
- `pin_memory`: Copy data directly to GPU-pinned memory
- `prefetch_factor`: Batches to prefetch per worker

In [None]:
# Create datasets with transforms
train_dataset = ImageFolderDataset(SAMPLE_DATA_DIR, transform=train_transform)

# Split into train/val
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

print(f"Training samples: {len(train_subset)}")
print(f"Validation samples: {len(val_subset)}")

In [None]:
# Basic DataLoader
train_loader = DataLoader(
    train_subset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,  # Faster GPU transfer
    drop_last=True,   # Drop incomplete final batch
)

val_loader = DataLoader(
    val_subset,
    batch_size=64,  # Can use larger batch for validation (no gradients)
    shuffle=False,  # No need to shuffle validation data
    num_workers=4,
    pin_memory=True,
)

# Test the DataLoader
batch = next(iter(train_loader))
images, labels = batch

print(f"Batch images shape: {images.shape}")
print(f"Batch labels shape: {labels.shape}")
print(f"Number of batches per epoch: {len(train_loader)}")

---

## Part 5: Benchmarking Data Loading

Let's find the optimal `num_workers` and `batch_size` for DGX Spark!

In [None]:
def benchmark_dataloader(
    dataset: Dataset,
    batch_size: int,
    num_workers: int,
    num_batches: int = 50
) -> float:
    """
    Benchmark DataLoader throughput.
    
    Returns:
        Average time per batch in seconds
    """
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    
    # Warmup
    for i, batch in enumerate(loader):
        if i >= 5:
            break
    
    # Benchmark
    start = time.time()
    for i, batch in enumerate(loader):
        images, labels = batch
        # Simulate GPU transfer
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        if i >= num_batches:
            break
    
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    elapsed = time.time() - start
    
    return elapsed / num_batches


# Benchmark different configurations
print("Benchmarking DataLoader configurations...")
print("="*60)

worker_counts = [0, 2, 4, 8]
batch_sizes = [32, 64, 128]

results = defaultdict(dict)

for batch_size in batch_sizes:
    for num_workers in worker_counts:
        try:
            avg_time = benchmark_dataloader(train_dataset, batch_size, num_workers)
            throughput = batch_size / avg_time
            results[batch_size][num_workers] = throughput
            print(f"batch_size={batch_size:3d}, workers={num_workers}: "
                  f"{avg_time*1000:.2f}ms/batch, {throughput:.0f} samples/sec")
        except Exception as e:
            print(f"batch_size={batch_size:3d}, workers={num_workers}: Error - {e}")

print("="*60)

In [None]:
# Visualize benchmark results
if results:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    for batch_size in batch_sizes:
        if batch_size in results:
            workers = sorted(results[batch_size].keys())
            throughputs = [results[batch_size][w] for w in workers]
            ax.plot(workers, throughputs, marker='o', label=f'batch_size={batch_size}')
    
    ax.set_xlabel('Number of Workers')
    ax.set_ylabel('Throughput (samples/sec)')
    ax.set_title('DataLoader Throughput Benchmark')
    ax.legend()
    ax.grid(True)
    
    plt.tight_layout()
    plt.show()

### DGX Spark Optimization Tips

1. **`num_workers`**: Start with 4-8 on DGX Spark. Too many workers can cause memory contention.

2. **`batch_size`**: With 128GB unified memory, you can use larger batches! Start at 64-128 and increase.

3. **`pin_memory=True`**: Always use this on GPU systems - it enables faster CPU-to-GPU transfer.

4. **`prefetch_factor`**: Default is 2. Increase to 4 if data loading is the bottleneck.

5. **`persistent_workers=True`**: Keep workers alive between epochs (faster but uses more memory).

In [None]:
# Optimized DataLoader for DGX Spark
def create_optimized_loaders(
    train_dataset: Dataset,
    val_dataset: Dataset,
    batch_size: int = 128,
    num_workers: int = 4
) -> Tuple[DataLoader, DataLoader]:
    """
    Create optimized DataLoaders for DGX Spark.
    
    Args:
        train_dataset: Training dataset
        val_dataset: Validation dataset
        batch_size: Batch size (default: 128, increase for DGX Spark)
        num_workers: Number of data loading workers (default: 4)
    
    Returns:
        Tuple of (train_loader, val_loader)
    """
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
        persistent_workers=True if num_workers > 0 else False,
        prefetch_factor=2 if num_workers > 0 else None,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size * 2,  # Larger batch for validation
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True if num_workers > 0 else False,
    )
    
    return train_loader, val_loader


# Create optimized loaders
train_loader, val_loader = create_optimized_loaders(
    train_subset, val_subset, batch_size=64, num_workers=4
)

print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")

### Understanding the Beta Distribution for Mixup

**Mixup** is a data augmentation technique that creates new training samples by blending two existing samples. The key to Mixup is choosing the mixing ratio Œª (lambda).

We sample Œª from a **Beta distribution**: `Œª ~ Beta(Œ±, Œ±)`

The **Beta distribution** is a probability distribution on [0, 1] - perfect for mixing ratios!
- When Œ± = 1.0: Uniform distribution (Œª can be anything from 0 to 1)
- When Œ± < 1.0: U-shaped (Œª tends toward 0 or 1 - minimal mixing)
- When Œ± > 1.0: Bell-shaped centered at 0.5 (more aggressive mixing)

**Common Mixup settings:**
- Œ± = 0.2: Mild mixing (most samples stay close to original)
- Œ± = 0.4: Moderate mixing
- Œ± = 1.0: Strong mixing (uniform blending)

In [None]:
# Visualize the Beta distribution for different alpha values
import numpy as np
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(12, 3))

alphas = [0.2, 0.5, 1.0]
for ax, alpha in zip(axes, alphas):
    # Sample 1000 values from Beta(alpha, alpha)
    samples = np.random.beta(alpha, alpha, 1000)
    ax.hist(samples, bins=30, density=True, alpha=0.7, edgecolor='black')
    ax.set_xlabel('Œª (mixing ratio)')
    ax.set_ylabel('Density')
    ax.set_title(f'Beta(Œ±={alpha}, Œ±={alpha})')
    ax.set_xlim(0, 1)

plt.suptitle('Beta Distribution: How Mixup Samples Œª', fontsize=12)
plt.tight_layout()
plt.show()

# Quick demo of np.random.beta()
print("Sample Œª values with Œ±=0.2:")
for i in range(5):
    lam = np.random.beta(0.2, 0.2)
    print(f"  Œª = {lam:.3f} ‚Üí {lam*100:.1f}% of image1, {(1-lam)*100:.1f}% of image2")

---

## ‚úã Try It Yourself: Exercise

Create a custom dataset that loads images and applies **Mixup** augmentation.

Mixup creates new training samples by combining two images:
```
mixed_image = lambda * image1 + (1 - lambda) * image2
mixed_label = lambda * label1 + (1 - lambda) * label2
```

**Requirements:**
1. Implement a `MixupDataset` wrapper class
2. The class should wrap any existing dataset
3. Implement mixup in `__getitem__`

<details>
<summary>üí° Hint</summary>

```python
class MixupDataset(Dataset):
    def __init__(self, dataset, alpha=0.2):
        self.dataset = dataset
        self.alpha = alpha
        
    def __getitem__(self, idx):
        # Get first sample
        img1, label1 = self.dataset[idx]
        
        # Get random second sample
        idx2 = np.random.randint(len(self.dataset))
        img2, label2 = self.dataset[idx2]
        
        # Sample lambda from Beta distribution
        lam = np.random.beta(self.alpha, self.alpha)
        
        # Mix!
        mixed_img = lam * img1 + (1 - lam) * img2
        # For labels, return both and lambda for loss computation
        return mixed_img, (label1, label2, lam)
```
</details>

In [None]:
# YOUR CODE HERE: Implement MixupDataset
class MixupDataset(Dataset):
    """
    Dataset wrapper that applies Mixup augmentation.
    
    Mixup creates new training samples by combining two images:
        mixed_image = lambda * image1 + (1 - lambda) * image2
        
    Args:
        dataset: Base dataset to wrap
        alpha: Beta distribution parameter (default: 0.2)
    
    Returns from __getitem__:
        Tuple of (mixed_image, (label1, label2, lambda))
    """
    
    def __init__(self, dataset: Dataset, alpha: float = 0.2) -> None:
        """Initialize with a base dataset and mixup alpha parameter."""
        # TODO: Store the dataset and alpha
        # self.dataset = ...
        # self.alpha = ...
        raise NotImplementedError("Implement __init__: store dataset and alpha")
    
    def __len__(self) -> int:
        """Return the length of the wrapped dataset."""
        # TODO: Return len(self.dataset)
        raise NotImplementedError("Implement __len__: return length of wrapped dataset")
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Tuple[int, int, float]]:
        """
        Get a mixup sample.
        
        Steps:
        1. Get first sample from dataset[idx]
        2. Get random second sample from dataset
        3. Sample lambda from Beta(alpha, alpha) distribution
        4. Mix images: lam * img1 + (1 - lam) * img2
        5. Return mixed_img, (label1, label2, lam)
        
        Returns:
            Tuple of (mixed_image, (label1, label2, lambda))
        """
        # TODO: Implement the mixup logic
        # img1, label1 = self.dataset[idx]
        # idx2 = np.random.randint(len(self.dataset))
        # img2, label2 = self.dataset[idx2]
        # lam = np.random.beta(self.alpha, self.alpha)
        # mixed_img = lam * img1 + (1 - lam) * img2
        # return mixed_img, (label1, label2, lam)
        raise NotImplementedError("Implement __getitem__: apply mixup augmentation")

# Test your implementation (uncomment after implementing)
# mixup_dataset = MixupDataset(train_dataset, alpha=0.2)
# mixed_img, (label1, label2, lam) = mixup_dataset[0]
# print(f"Mixed image shape: {mixed_img.shape}")
# print(f"Labels: {label1}, {label2}, lambda={lam:.3f}")

---

## Common Mistakes

### Mistake 1: Applying augmentation to validation data

```python
# ‚ùå Wrong - augmenting validation data adds noise to metrics
val_transform = T.Compose([
    T.RandomHorizontalFlip(),  # NO!
    T.ToTensor(),
])

# ‚úÖ Right - validation should be deterministic
val_transform = T.Compose([
    T.Resize(32),
    T.CenterCrop(32),  # Deterministic crop
    T.ToTensor(),
])
```

### Mistake 2: Wrong number of workers

```python
# ‚ùå Wrong - too many workers on limited system
loader = DataLoader(dataset, num_workers=32)  # CPU has only 8 cores

# ‚úÖ Right - match to available cores
import multiprocessing
num_workers = min(4, multiprocessing.cpu_count())
loader = DataLoader(dataset, num_workers=num_workers)
```

### Mistake 3: Forgetting `pin_memory` on GPU

```python
# ‚ùå Wrong - slower GPU transfer
loader = DataLoader(dataset, batch_size=64)

# ‚úÖ Right - faster GPU transfer
loader = DataLoader(dataset, batch_size=64, pin_memory=True)
```

### Mistake 4: Not handling normalization correctly

```python
# ‚ùå Wrong - normalizing with wrong statistics
transform = T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Generic

# ‚úÖ Right - use dataset-specific statistics
# For ImageNet pretrained models:
transform = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# For CIFAR-10:
transform = T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
```

---

## Checkpoint

You've learned:
- ‚úÖ How to create custom `Dataset` classes
- ‚úÖ Data augmentation with `torchvision.transforms`
- ‚úÖ Configuring `DataLoader` for optimal performance
- ‚úÖ Benchmarking data loading throughput
- ‚úÖ DGX Spark-specific optimizations

---

## Challenge (Optional)

Implement a **Prefetching DataLoader** that:
1. Runs data loading in a background thread
2. Keeps 3-5 batches ready at all times
3. Transfers batches to GPU asynchronously

This can further improve training speed by overlapping data loading with GPU computation.

---

## Further Reading

- [PyTorch Data Loading Tutorial](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)
- [torchvision.transforms Documentation](https://pytorch.org/vision/stable/transforms.html)
- [RandAugment Paper](https://arxiv.org/abs/1909.13719)
- [Mixup Paper](https://arxiv.org/abs/1710.09412)

In [None]:
# Cleanup
import shutil
import gc

# Remove temp sample data directory
if os.path.exists(SAMPLE_DATA_DIR):
    shutil.rmtree(SAMPLE_DATA_DIR)
    print(f"Removed temp directory: {SAMPLE_DATA_DIR}")

# Clear GPU memory
torch.cuda.empty_cache()
gc.collect()

print("Cleanup complete!")