# Building Custom DataModules

**File Location:** `notebooks/02_datamodules_and_metrics/04_building_datamodules.ipynb`

## Introduction

This notebook covers building robust LightningDataModules for different data types. Learn to handle vision, NLP, tabular, and time-series data with proper data loading, preprocessing, and validation patterns.

## DataModule Architecture Fundamentals

### Core DataModule Structure

```python
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl
from typing import Optional, Tuple
import numpy as np
from pathlib import Path

class BaseDataModule(pl.LightningDataModule):
    """Template for all DataModules"""
    
    def __init__(
        self,
        batch_size: int = 32,
        num_workers: int = 0,
        pin_memory: bool = True,
        persistent_workers: bool = False
    ):
        super().__init__()
        self.save_hyperparameters()
        
    def prepare_data(self):
        """Download data, tokenize, etc. Only called on 1 GPU/TPU in distributed"""
        # This is where you download datasets, do one-time preprocessing
        pass
    
    def setup(self, stage: Optional[str] = None):
        """Set up datasets. Called on every GPU in distributed training"""
        # This is where you create train/val/test splits
        # stage can be 'fit', 'validate', 'test', or 'predict'
        pass
    
    def train_dataloader(self):
        """Return training DataLoader"""
        pass
    
    def val_dataloader(self):
        """Return validation DataLoader"""
        pass
    
    def test_dataloader(self):
        """Return test DataLoader"""  
        pass
    
    def predict_dataloader(self):
        """Return prediction DataLoader"""
        pass

print("✓ Base DataModule template defined")
```

## Vision DataModule

### Custom Vision Dataset and DataModule

```python
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

class SyntheticVisionDataset(Dataset):
    """Synthetic image dataset for demonstration"""
    
    def __init__(self, num_samples: int, image_size: Tuple[int, int] = (32, 32), 
                 num_classes: int = 10, transform=None):
        self.num_samples = num_samples
        self.image_size = image_size
        self.num_classes = num_classes
        self.transform = transform
        
        # Generate synthetic data once
        torch.manual_seed(42)
        self.data = torch.rand(num_samples, 3, *image_size) * 255
        self.targets = torch.randint(0, num_classes, (num_samples,))
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        image = self.data[idx]
        target = self.targets[idx]
        
        # Convert to PIL for transforms
        image = transforms.ToPILImage()(image.byte())
        
        if self.transform:
            image = self.transform(image)
        
        return image, target

class VisionDataModule(pl.LightningDataModule):
    def __init__(
        self,
        num_train_samples: int = 5000,
        num_val_samples: int = 1000,
        num_test_samples: int = 1000,
        image_size: Tuple[int, int] = (32, 32),
        num_classes: int = 10,
        batch_size: int = 32,
        num_workers: int = 2,
        pin_memory: bool = True,
        augment: bool = True
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Define transforms
        self.train_transforms = transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(p=0.5) if augment else transforms.Lambda(lambda x: x),
            transforms.RandomRotation(10) if augment else transforms.Lambda(lambda x: x),
            transforms.ColorJitter(brightness=0.2, contrast=0.2) if augment else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.val_transforms = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            self.train_dataset = SyntheticVisionDataset(
                num_samples=self.hparams.num_train_samples,
                image_size=self.hparams.image_size,
                num_classes=self.hparams.num_classes,
                transform=self.train_transforms
            )
            self.val_dataset = SyntheticVisionDataset(
                num_samples=self.hparams.num_val_samples,
                image_size=self.hparams.image_size,
                num_classes=self.hparams.num_classes,
                transform=self.val_transforms
            )
        
        if stage == "test" or stage is None:
            self.test_dataset = SyntheticVisionDataset(
                num_samples=self.hparams.num_test_samples,
                image_size=self.hparams.image_size,
                num_classes=self.hparams.num_classes,
                transform=self.val_transforms
            )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            drop_last=True  # For batch norm stability
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory
        )

# Test the vision datamodule
vision_dm = VisionDataModule(
    num_train_samples=1000,
    num_val_samples=200,
    batch_size=16,
    augment=True
)

vision_dm.setup("fit")

# Test a batch
train_loader = vision_dm.train_dataloader()
batch = next(iter(train_loader))
images, labels = batch

print(f"✓ Vision DataModule created")
print(f"Image batch shape: {images.shape}")
print(f"Labels batch shape: {labels.shape}")
print(f"Image range: [{images.min():.3f}, {images.max():.3f}]")
print(f"Unique labels in batch: {torch.unique(labels).tolist()}")
```

## NLP DataModule

### Text Processing and NLP DataModule

```python
from collections import Counter
import string
import random

class TextDataset(Dataset):
    """Simple text classification dataset"""
    
    def __init__(self, texts, labels, vocab, max_length=100):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Tokenize and convert to indices
        tokens = text.lower().translate(str.maketrans('', '', string.punctuation)).split()
        indices = [self.vocab.get(token, self.vocab['<UNK>']) for token in tokens]
        
        # Pad or truncate
        if len(indices) > self.max_length:
            indices = indices[:self.max_length]
        else:
            indices.extend([self.vocab['<PAD>']] * (self.max_length - len(indices)))
        
        return torch.tensor(indices, dtype=torch.long), torch.tensor(label, dtype=torch.long)

class NLPDataModule(pl.LightningDataModule):
    def __init__(
        self,
        num_samples: int = 5000,
        vocab_size: int = 1000,
        max_length: int = 100,
        min_freq: int = 2,
        batch_size: int = 32,
        num_workers: int = 0
    ):
        super().__init__()
        self.save_hyperparameters()
        self.vocab = None
        
    def prepare_data(self):
        """Generate synthetic text data and build vocabulary"""
        # Generate synthetic text data
        random.seed(42)
        
        # Simple vocabulary
        base_words = [
            'the', 'and', 'is', 'in', 'to', 'of', 'a', 'that', 'it', 'with',
            'for', 'as', 'was', 'on', 'are', 'you', 'all', 'not', 'can', 'had',
            'good', 'bad', 'great', 'terrible', 'amazing', 'awful', 'wonderful', 'horrible',
            'love', 'hate', 'like', 'dislike', 'enjoy', 'despise', 'adore', 'loathe',
            'movie', 'film', 'book', 'story', 'show', 'series', 'episode', 'scene'
        ]
        
        # Generate synthetic reviews
        positive_templates = [
            "this {} is really good and amazing",
            "I love this {} it is wonderful",
            "great {} with excellent quality",
            "amazing {} highly recommend"
        ]
        
        negative_templates = [
            "this {} is terrible and awful",
            "I hate this {} it is horrible",
            "bad {} with poor quality", 
            "awful {} do not recommend"
        ]
        
        # Generate texts and labels
        self.texts = []
        self.labels = []
        
        for _ in range(self.hparams.num_samples):
            if random.random() > 0.5:  # Positive
                template = random.choice(positive_templates)
                item = random.choice(['movie', 'book', 'show', 'film'])
                text = template.format(item)
                # Add some random words
                extra_words = random.choices(base_words, k=random.randint(3, 8))
                text += " " + " ".join(extra_words)
                self.texts.append(text)
                self.labels.append(1)
            else:  # Negative
                template = random.choice(negative_templates)
                item = random.choice(['movie', 'book', 'show', 'film'])
                text = template.format(item)
                # Add some random words
                extra_words = random.choices(base_words, k=random.randint(3, 8))
                text += " " + " ".join(extra_words)
                self.texts.append(text)
                self.labels.append(0)
        
        # Build vocabulary
        all_tokens = []
        for text in self.texts:
            tokens = text.lower().translate(str.maketrans('', '', string.punctuation)).split()
            all_tokens.extend(tokens)
        
        token_counts = Counter(all_tokens)
        
        # Create vocab with special tokens
        self.vocab = {'<PAD>': 0, '<UNK>': 1}
        idx = 2
        for token, count in token_counts.most_common():
            if count >= self.hparams.min_freq and idx < self.hparams.vocab_size:
                self.vocab[token] = idx
                idx += 1
        
        print(f"✓ Generated {len(self.texts)} text samples")
        print(f"✓ Built vocabulary with {len(self.vocab)} tokens")
        
    def setup(self, stage: Optional[str] = None):
        if self.vocab is None:
            self.prepare_data()
            
        # Split data
        train_size = int(0.8 * len(self.texts))
        val_size = int(0.1 * len(self.texts))
        test_size = len(self.texts) - train_size - val_size
        
        indices = list(range(len(self.texts)))
        random.seed(42)
        random.shuffle(indices)
        
        train_indices = indices[:train_size]
        val_indices = indices[train_size:train_size + val_size]
        test_indices = indices[train_size + val_size:]
        
        if stage == "fit" or stage is None:
            train_texts = [self.texts[i] for i in train_indices]
            train_labels = [self.labels[i] for i in train_indices]
            self.train_dataset = TextDataset(train_texts, train_labels, self.vocab, self.hparams.max_length)
            
            val_texts = [self.texts[i] for i in val_indices]
            val_labels = [self.labels[i] for i in val_indices]
            self.val_dataset = TextDataset(val_texts, val_labels, self.vocab, self.hparams.max_length)
        
        if stage == "test" or stage is None:
            test_texts = [self.texts[i] for i in test_indices]
            test_labels = [self.labels[i] for i in test_indices]
            self.test_dataset = TextDataset(test_texts, test_labels, self.vocab, self.hparams.max_length)
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers
        )

# Test NLP DataModule
nlp_dm = NLPDataModule(num_samples=1000, vocab_size=500, max_length=50, batch_size=16)
nlp_dm.setup("fit")

# Test a batch
train_loader = nlp_dm.train_dataloader()
batch = next(iter(train_loader))
texts, labels = batch

print(f"✓ NLP DataModule created")
print(f"Text batch shape: {texts.shape}")
print(f"Labels batch shape: {labels.shape}")
print(f"Vocabulary size: {len(nlp_dm.vocab)}")
print(f"Sample text indices: {texts[0][:10].tolist()}")
print(f"Sample label: {labels[0].item()}")
```

## Tabular DataModule

### Structured Data Processing

```python
import pandas as pd
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split

class TabularDataset(Dataset):
    """Dataset for tabular data"""
    
    def __init__(self, features, targets):
        self.features = torch.FloatTensor(features)
        self.targets = torch.LongTensor(targets) if targets.dtype == 'int64' else torch.FloatTensor(targets)
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

class TabularDataModule(pl.LightningDataModule):
    def __init__(
        self,
        num_samples: int = 10000,
        num_features: int = 20,
        num_classes: int = 3,
        task_type: str = "classification",  # "classification" or "regression"
        batch_size: int = 64,
        num_workers: int = 0,
        normalize: bool = True
    ):
        super().__init__()
        self.save_hyperparameters()
        self.scaler = StandardScaler() if normalize else None
        
    def prepare_data(self):
        """Generate synthetic tabular data"""
        np.random.seed(42)
        
        # Generate features with different distributions
        features = []
        
        # Normal features
        normal_features = np.random.normal(0, 1, (self.hparams.num_samples, self.hparams.num_features // 2))
        features.append(normal_features)
        
        # Uniform features
        uniform_features = np.random.uniform(-2, 2, (self.hparams.num_samples, self.hparams.num_features // 4))
        features.append(uniform_features)
        
        # Exponential features
        exp_features = np.random.exponential(1, (self.hparams.num_samples, self.hparams.num_features // 4))
        features.append(exp_features)
        
        self.features = np.concatenate(features, axis=1)
        
        if self.hparams.task_type == "classification":
            # Create targets with some relationship to features
            feature_sum = np.sum(self.features[:, :5], axis=1)
            self.targets = np.digitize(feature_sum, 
                                     bins=np.percentile(feature_sum, 
                                                      [100/self.hparams.num_classes * i 
                                                       for i in range(1, self.hparams.num_classes)]))
            self.targets = np.clip(self.targets, 0, self.hparams.num_classes - 1)
        else:  # regression
            # Create continuous target
            weights = np.random.normal(0, 0.5, self.hparams.num_features)
            self.targets = np.dot(self.features, weights) + np.random.normal(0, 0.1, self.hparams.num_samples)
        
        print(f"✓ Generated tabular data: {self.features.shape}")
        print(f"✓ Task type: {self.hparams.task_type}")
        if self.hparams.task_type == "classification":
            print(f"✓ Class distribution: {np.bincount(self.targets)}")
        else:
            print(f"✓ Target range: [{self.targets.min():.3f}, {self.targets.max():.3f}]")
    
    def setup(self, stage: Optional[str] = None):
        if not hasattr(self, 'features'):
            self.prepare_data()
        
        # Split data
        X_train, X_temp, y_train, y_temp = train_test_split(
            self.features, self.targets, test_size=0.3, random_state=42, 
            stratify=self.targets if self.hparams.task_type == "classification" else None
        )
        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp, test_size=0.5, random_state=42,
            stratify=y_temp if self.hparams.task_type == "classification" else None
        )
        
        # Normalize features
        if self.hparams.normalize:
            if stage == "fit" or stage is None:
                X_train = self.scaler.fit_transform(X_train)
                X_val = self.scaler.transform(X_val)
            if stage == "test" or stage is None:
                if hasattr(self.scaler, 'mean_'):  # Already fitted
                    X_test = self.scaler.transform(X_test)
                else:  # Need to fit first
                    self.scaler.fit(X_train)
                    X_test = self.scaler.transform(X_test)
        
        if stage == "fit" or stage is None:
            self.train_dataset = TabularDataset(X_train, y_train)
            self.val_dataset = TabularDataset(X_val, y_val)
        
        if stage == "test" or stage is None:
            self.test_dataset = TabularDataset(X_test, y_test)
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers
        )

# Test Tabular DataModule
tabular_dm = TabularDataModule(
    num_samples=5000,
    num_features=15,
    num_classes=4,
    task_type="classification",
    batch_size=32
)

tabular_dm.setup("fit")

# Test a batch
train_loader = tabular_dm.train_dataloader()
batch = next(iter(train_loader))
features, targets = batch

print(f"✓ Tabular DataModule created")
print(f"Features batch shape: {features.shape}")
print(f"Targets batch shape: {targets.shape}")
print(f"Feature range: [{features.min():.3f}, {features.max():.3f}]")
```

## Time Series DataModule

### Sequential Data Processing

```python
class TimeSeriesDataset(Dataset):
    """Dataset for time series data"""
    
    def __init__(self, sequences, targets, sequence_length):
        self.sequences = sequences
        self.targets = targets
        self.sequence_length = sequence_length
        
    def __len__(self):
        return len(self.sequences) - self.sequence_length
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx:idx + self.sequence_length]
        target = self.targets[idx + self.sequence_length]
        return torch.FloatTensor(sequence), torch.FloatTensor([target])

class TimeSeriesDataModule(pl.LightningDataModule):
    def __init__(
        self,
        num_samples: int = 10000,
        num_features: int = 5,
        sequence_length: int = 50,
        batch_size: int = 32,
        num_workers: int = 0,
        normalize: bool = True
    ):
        super().__init__()
        self.save_hyperparameters()
        self.scaler = StandardScaler() if normalize else None
        
    def prepare_data(self):
        """Generate synthetic time series data"""
        np.random.seed(42)
        
        # Generate time series with trend, seasonality, and noise
        time_steps = np.arange(self.hparams.num_samples)
        
        # Multiple features with different patterns
        series_data = []
        
        for feature_idx in range(self.hparams.num_features):
            # Trend component
            trend = 0.001 * time_steps + np.random.normal(0, 0.1)
            
            # Seasonal component
            seasonal = np.sin(2 * np.pi * time_steps / 365) * (0.5 + feature_idx * 0.2)
            
            # Weekly pattern
            weekly = np.sin(2 * np.pi * time_steps / 7) * 0.3
            
            # Noise
            noise = np.random.normal(0, 0.2, self.hparams.num_samples)
            
            # Combine components
            feature_series = trend + seasonal + weekly + noise
            series_data.append(feature_series)
        
        self.data = np.column_stack(series_data)
        
        # Create targets (next value prediction)
        self.targets = self.data[1:, 0]  # Predict first feature
        self.data = self.data[:-1]  # Remove last sample
        
        print(f"✓ Generated time series data: {self.data.shape}")
        print(f"✓ Sequence length: {self.hparams.sequence_length}")
        print(f"✓ Target range: [{self.targets.min():.3f}, {self.targets.max():.3f}]")
        
    def setup(self, stage: Optional[str] = None):
        if not hasattr(self, 'data'):
            self.prepare_data()
        
        # Time series split (no shuffling to preserve temporal order)
        train_size = int(0.7 * len(self.data))
        val_size = int(0.15 * len(self.data))
        
        train_data = self.data[:train_size]
        train_targets = self.targets[:train_size]
        
        val_data = self.data[train_size:train_size + val_size]
        val_targets = self.targets[train_size:train_size + val_size]
        
        test_data = self.data[train_size + val_size:]
        test_targets = self.targets[train_size + val_size:]
        
        # Normalize data
        if self.hparams.normalize:
            if stage == "fit" or stage is None:
                train_data = self.scaler.fit_transform(train_data)
                val_data = self.scaler.transform(val_data)
            if stage == "test" or stage is None:
                if hasattr(self.scaler, 'mean_'):
                    test_data = self.scaler.transform(test_data)
                else:
                    self.scaler.fit(train_data)
                    test_data = self.scaler.transform(test_data)
        
        if stage == "fit" or stage is None:
            self.train_dataset = TimeSeriesDataset(
                train_data, train_targets, self.hparams.sequence_length
            )
            self.val_dataset = TimeSeriesDataset(
                val_data, val_targets, self.hparams.sequence_length
            )
        
        if stage == "test" or stage is None:
            self.test_dataset = TimeSeriesDataset(
                test_data, test_targets, self.hparams.sequence_length
            )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,  # Can shuffle for time series within sequences
            num_workers=self.hparams.num_workers
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers
        )

# Test Time Series DataModule
ts_dm = TimeSeriesDataModule(
    num_samples=5000,
    num_features=3,
    sequence_length=30,
    batch_size=16
)

ts_dm.setup("fit")

# Test a batch
train_loader = ts_dm.train_dataloader()
batch = next(iter(train_loader))
sequences, targets = batch

print(f"✓ Time Series DataModule created")
print(f"Sequence batch shape: {sequences.shape}")
print(f"Targets batch shape: {targets.shape}")
print(f"Sequence range: [{sequences.min():.3f}, {sequences.max():.3f}]")
```

## DataModule Best Practices

### Data Validation and Health Checks

```python
def validate_datamodule(dm: pl.LightningDataModule, stage: str = "fit"):
    """Comprehensive DataModule validation"""
    print(f"=== Validating DataModule: {type(dm).__name__} ===")
    
    # Setup the datamodule
    dm.setup(stage)
    
    # Check dataloaders exist
    loaders = []
    if hasattr(dm, 'train_dataloader') and stage in ["fit", None]:
        train_loader = dm.train_dataloader()
        loaders.append(("train", train_loader))
        
    if hasattr(dm, 'val_dataloader') and stage in ["fit", "validate", None]:
        val_loader = dm.val_dataloader()
        loaders.append(("val", val_loader))
        
    if hasattr(dm, 'test_dataloader') and stage in ["test", None]:
        test_loader = dm.test_dataloader()
        loaders.append(("test", test_loader))
    
    print(f"✓ Found {len(loaders)} dataloaders")
    
    # Validate each dataloader
    for loader_name, loader in loaders:
        print(f"\n--- {loader_name.upper()} Loader ---")
        
        # Basic properties
        print(f"Dataset size: {len(loader.dataset)}")
        print(f"Batch size: {loader.batch_size}")
        print(f"Number of batches: {len(loader)}")
        
        # Test first batch
        try:
            batch = next(iter(loader))
            if isinstance(batch, (list, tuple)) and len(batch) == 2:
                inputs, targets = batch
                print(f"Input shape: {inputs.shape}")
                print(f"Target shape: {targets.shape}")
                print(f"Input dtype: {inputs.dtype}")
                print(f"Target dtype: {targets.dtype}")
                
                # Check for NaN/Inf
                if torch.isnan(inputs).any():
                    print("⚠ WARNING: Found NaN in inputs")
                if torch.isinf(inputs).any():
                    print("⚠ WARNING: Found Inf in inputs")
                    
                # Check data range
                print(f"Input range: [{inputs.min():.3f}, {inputs.max():.3f}]")
                
                if targets.dtype in [torch.long, torch.int]:
                    print(f"Target classes: {torch.unique(targets).tolist()}")
                else:
                    print(f"Target range: [{targets.min():.3f}, {targets.max():.3f}]")
                    
            else:
                print(f"Batch structure: {type(batch)}")
                
        except Exception as e:
            print(f"❌ Error loading batch: {e}")
            continue
    
    print("\n✓ DataModule validation completed")

# Validate all our datamodules
print("=== DataModule Validation ===")
datamodules = [
    ("Vision", vision_dm),
    ("NLP", nlp_dm), 
    ("Tabular", tabular_dm),
    ("TimeSeries", ts_dm)
]

for name, dm in datamodules:
    print(f"\n{name} DataModule:")
    try:
        validate_datamodule(dm)
        print("✅ PASSED")
    except Exception as e:
        print(f"❌ FAILED: {e}")
```

### Memory and Performance Optimization

```python
def benchmark_datamodule(dm: pl.LightningDataModule, num_batches: int = 10):
    """Benchmark DataModule loading performance"""
    import time
    
    print(f"=== Benchmarking {type(dm).__name__} ===")
    
    dm.setup("fit")
    train_loader = dm.train_dataloader()
    
    # Warmup
    for i, batch in enumerate(train_loader):
        if i >= 2:  # Just warmup
            break
    
    # Benchmark
    start_time = time.time()
    batch_times = []
    
    for i, batch in enumerate(train_loader):
        batch_start = time.time()
        
        # Simulate some processing
        if isinstance(batch, (list, tuple)) and len(batch) == 2:
            inputs, targets = batch
            _ = inputs.mean()  # Simple operation
        
        batch_end = time.time()
        batch_times.append(batch_end - batch_start)
        
        if i >= num_batches - 1:
            break
    
    end_time = time.time()
    
    total_time = end_time - start_time
    avg_batch_time = np.mean(batch_times)
    
    print(f"Total time: {total_time:.3f}s")
    print(f"Average batch time: {avg_batch_time*1000:.1f}ms")
    print(f"Batches per second: {1/avg_batch_time:.1f}")
    print(f"Samples per second: {len(batch[0])/avg_batch_time:.0f}")
    
    return {
        'total_time': total_time,
        'avg_batch_time': avg_batch_time,
        'batches_per_sec': 1/avg_batch_time
    }

# Benchmark our datamodules
print("=== Performance Benchmarking ===")
for name, dm in datamodules:
    print(f"\n{name} DataModule:")
    try:
        results = benchmark_datamodule(dm, num_batches=5)
    except Exception as e:
        print(f"❌ Benchmark failed: {e}")
```

# Advanced DataModule Patterns

## Multi-Modal DataModule

```python
class MultiModalDataset(Dataset):
    """Dataset that combines vision and text data"""
    
    def __init__(self, images, texts, labels, image_transform=None, text_vocab=None, max_text_length=50):
        self.images = images
        self.texts = texts
        self.labels = labels
        self.image_transform = image_transform
        self.text_vocab = text_vocab or {}
        self.max_text_length = max_text_length
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Process image
        image = self.images[idx]
        if self.image_transform:
            image = self.image_transform(image)
        
        # Process text
        text = self.texts[idx]
        tokens = text.lower().split()
        text_indices = [self.text_vocab.get(token, self.text_vocab.get('<UNK>', 0)) for token in tokens]
        
        # Pad or truncate text
        if len(text_indices) > self.max_text_length:
            text_indices = text_indices[:self.max_text_length]
        else:
            text_indices.extend([self.text_vocab.get('<PAD>', 0)] * (self.max_text_length - len(text_indices)))
        
        text_tensor = torch.tensor(text_indices, dtype=torch.long)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return {
            'image': image,
            'text': text_tensor,
            'label': label
        }

class MultiModalDataModule(pl.LightningDataModule):
    """DataModule for multi-modal vision + text data"""
    
    def __init__(
        self,
        num_samples: int = 1000,
        image_size: Tuple[int, int] = (64, 64),
        vocab_size: int = 500,
        max_text_length: int = 50,
        num_classes: int = 5,
        batch_size: int = 16,
        num_workers: int = 0
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Image transforms
        self.image_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def prepare_data(self):
        """Generate synthetic multi-modal data"""
        torch.manual_seed(42)
        np.random.seed(42)
        
        # Generate synthetic images
        self.images = torch.rand(self.hparams.num_samples, 3, *self.hparams.image_size) * 255
        
        # Generate synthetic texts
        base_words = ['cat', 'dog', 'bird', 'fish', 'car', 'bike', 'tree', 'house', 'red', 'blue', 'big', 'small']
        self.texts = []
        for _ in range(self.hparams.num_samples):
            num_words = np.random.randint(5, 15)
            text = ' '.join(np.random.choice(base_words, num_words))
            self.texts.append(text)
        
        # Generate labels
        self.labels = np.random.randint(0, self.hparams.num_classes, self.hparams.num_samples)
        
        # Build text vocabulary
        all_words = set()
        for text in self.texts:
            all_words.update(text.split())
        
        self.vocab = {'<PAD>': 0, '<UNK>': 1}
        for i, word in enumerate(sorted(all_words)[:self.hparams.vocab_size-2]):
            self.vocab[word] = i + 2
        
        print(f"✓ Generated {len(self.images)} multi-modal samples")
        print(f"✓ Vocabulary size: {len(self.vocab)}")
        
    def setup(self, stage: Optional[str] = None):
        if not hasattr(self, 'images'):
            self.prepare_data()
        
        # Split data
        train_size = int(0.8 * self.hparams.num_samples)
        val_size = int(0.1 * self.hparams.num_samples)
        
        indices = list(range(self.hparams.num_samples))
        np.random.shuffle(indices)
        
        train_indices = indices[:train_size]
        val_indices = indices[train_size:train_size + val_size]
        test_indices = indices[train_size + val_size:]
        
        if stage == "fit" or stage is None:
            train_images = [self.images[i] for i in train_indices]
            train_texts = [self.texts[i] for i in train_indices]
            train_labels = [self.labels[i] for i in train_indices]
            
            self.train_dataset = MultiModalDataset(
                train_images, train_texts, train_labels, 
                self.image_transform, self.vocab, self.hparams.max_text_length
            )
            
            val_images = [self.images[i] for i in val_indices]
            val_texts = [self.texts[i] for i in val_indices]
            val_labels = [self.labels[i] for i in val_indices]
            
            self.val_dataset = MultiModalDataset(
                val_images, val_texts, val_labels,
                self.image_transform, self.vocab, self.hparams.max_text_length
            )
        
        if stage == "test" or stage is None:
            test_images = [self.images[i] for i in test_indices]
            test_texts = [self.texts[i] for i in test_indices]
            test_labels = [self.labels[i] for i in test_indices]
            
            self.test_dataset = MultiModalDataset(
                test_images, test_texts, test_labels,
                self.image_transform, self.vocab, self.hparams.max_text_length
            )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            collate_fn=self._collate_fn
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            collate_fn=self._collate_fn
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            collate_fn=self._collate_fn
        )
    
    def _collate_fn(self, batch):
        """Custom collate function for multi-modal data"""
        images = torch.stack([item['image'] for item in batch])
        texts = torch.stack([item['text'] for item in batch])
        labels = torch.stack([item['label'] for item in batch])
        
        return {
            'image': images,
            'text': texts,
            'label': labels
        }

# Test Multi-Modal DataModule
multimodal_dm = MultiModalDataModule(
    num_samples=500,
    image_size=(32, 32),
    vocab_size=100,
    max_text_length=20,
    num_classes=3,
    batch_size=8
)

multimodal_dm.setup("fit")

# Test a batch
train_loader = multimodal_dm.train_dataloader()
batch = next(iter(train_loader))

print(f"✓ Multi-Modal DataModule created")
print(f"Image batch shape: {batch['image'].shape}")
print(f"Text batch shape: {batch['text'].shape}")
print(f"Labels batch shape: {batch['label'].shape}")
print(f"Sample text: {batch['text'][0][:10].tolist()}")
```

## Data Augmentation and Transforms

```python
class AugmentationDataModule(pl.LightningDataModule):
    """DataModule with advanced augmentation strategies"""
    
    def __init__(
        self,
        base_datamodule: pl.LightningDataModule,
        augment_prob: float = 0.8,
        mixup_alpha: float = 0.2,
        cutmix_prob: float = 0.5,
        use_mixup: bool = False,
        use_cutmix: bool = False
    ):
        super().__init__()
        self.base_dm = base_datamodule
        self.save_hyperparameters(ignore=['base_datamodule'])
        
    def setup(self, stage: Optional[str] = None):
        self.base_dm.setup(stage)
        
        if stage == "fit" or stage is None:
            # Wrap datasets with augmentation
            self.train_dataset = AugmentedDataset(
                self.base_dm.train_dataset,
                use_mixup=self.hparams.use_mixup,
                use_cutmix=self.hparams.use_cutmix,
                mixup_alpha=self.hparams.mixup_alpha,
                cutmix_prob=self.hparams.cutmix_prob
            )
            self.val_dataset = self.base_dm.val_dataset
        
        if stage == "test" or stage is None:
            self.test_dataset = self.base_dm.test_dataset
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.base_dm.hparams.batch_size,
            shuffle=True,
            num_workers=self.base_dm.hparams.num_workers,
            collate_fn=self._augmented_collate_fn
        )
    
    def val_dataloader(self):
        return self.base_dm.val_dataloader()
    
    def test_dataloader(self):
        return self.base_dm.test_dataloader()
    
    def _augmented_collate_fn(self, batch):
        """Collate function with batch-level augmentations"""
        # Standard collation
        if isinstance(batch[0], (list, tuple)):
            inputs = torch.stack([item[0] for item in batch])
            targets = torch.stack([item[1] for item in batch])
        else:
            return torch.utils.data.dataloader.default_collate(batch)
        
        # Apply batch augmentations
        if self.training and self.hparams.use_mixup and np.random.rand() < 0.5:
            inputs, targets = self._mixup(inputs, targets)
        elif self.training and self.hparams.use_cutmix and np.random.rand() < self.hparams.cutmix_prob:
            inputs, targets = self._cutmix(inputs, targets)
            
        return inputs, targets
    
    def _mixup(self, x, y):
        """Apply MixUp augmentation"""
        lam = np.random.beta(self.hparams.mixup_alpha, self.hparams.mixup_alpha)
        batch_size = x.size(0)
        index = torch.randperm(batch_size)
        
        mixed_x = lam * x + (1 - lam) * x[index]
        y_a, y_b = y, y[index]
        
        return mixed_x, (y_a, y_b, lam)
    
    def _cutmix(self, x, y):
        """Apply CutMix augmentation"""
        lam = np.random.beta(1.0, 1.0)
        rand_index = torch.randperm(x.size(0))
        
        # Get random box
        W, H = x.size(2), x.size(3)
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        
        x[:, :, bbx1:bbx2, bby1:bby2] = x[rand_index, :, bbx1:bbx2, bby1:bby2]
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
        
        return x, (y, y[rand_index], lam)

class AugmentedDataset(Dataset):
    """Wrapper dataset for applying augmentations"""
    
    def __init__(self, base_dataset, use_mixup=False, use_cutmix=False, 
                 mixup_alpha=0.2, cutmix_prob=0.5):
        self.base_dataset = base_dataset
        self.use_mixup = use_mixup
        self.use_cutmix = use_cutmix
        self.mixup_alpha = mixup_alpha
        self.cutmix_prob = cutmix_prob
        
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        return self.base_dataset[idx]

print("✓ Advanced augmentation DataModule patterns defined")
```

## Error Handling and Recovery

```python
class RobustDataModule(pl.LightningDataModule):
    """DataModule with comprehensive error handling"""
    
    def __init__(self, base_datamodule: pl.LightningDataModule, max_retries: int = 3):
        super().__init__()
        self.base_dm = base_datamodule
        self.max_retries = max_retries
        self.failed_samples = set()
        
    def setup(self, stage: Optional[str] = None):
        try:
            self.base_dm.setup(stage)
            
            # Wrap datasets with error handling
            if hasattr(self.base_dm, 'train_dataset'):
                self.train_dataset = RobustDataset(self.base_dm.train_dataset, self.max_retries)
            if hasattr(self.base_dm, 'val_dataset'):
                self.val_dataset = RobustDataset(self.base_dm.val_dataset, self.max_retries)
            if hasattr(self.base_dm, 'test_dataset'):
                self.test_dataset = RobustDataset(self.base_dm.test_dataset, self.max_retries)
                
        except Exception as e:
            print(f"❌ Setup failed: {e}")
            # Implement fallback strategy
            self._create_fallback_datasets()
    
    def _create_fallback_datasets(self):
        """Create minimal fallback datasets when main setup fails"""
        print("⚠ Creating fallback datasets...")
        fallback_data = torch.randn(100, 3, 32, 32)
        fallback_targets = torch.randint(0, 10, (100,))
        
        fallback_dataset = torch.utils.data.TensorDataset(fallback_data, fallback_targets)
        
        self.train_dataset = fallback_dataset
        self.val_dataset = fallback_dataset
        self.test_dataset = fallback_dataset
        
    def train_dataloader(self):
        return self._create_robust_dataloader(self.train_dataset, shuffle=True)
    
    def val_dataloader(self):
        return self._create_robust_dataloader(self.val_dataset, shuffle=False)
    
    def test_dataloader(self):
        return self._create_robust_dataloader(self.test_dataset, shuffle=False)
    
    def _create_robust_dataloader(self, dataset, shuffle=False):
        """Create DataLoader with error handling"""
        try:
            return DataLoader(
                dataset,
                batch_size=getattr(self.base_dm.hparams, 'batch_size', 32),
                shuffle=shuffle,
                num_workers=0,  # Safer for debugging
                collate_fn=self._robust_collate_fn
            )
        except Exception as e:
            print(f"❌ DataLoader creation failed: {e}")
            # Return minimal dataloader
            return DataLoader(dataset, batch_size=1, shuffle=False)
    
    def _robust_collate_fn(self, batch):
        """Collate function with error recovery"""
        try:
            return torch.utils.data.dataloader.default_collate(batch)
        except Exception as e:
            print(f"⚠ Collate error: {e}, using fallback")
            # Filter out problematic samples
            valid_batch = []
            for item in batch:
                try:
                    # Test if item can be processed
                    _ = torch.utils.data.dataloader.default_collate([item])
                    valid_batch.append(item)
                except:
                    continue
            
            if valid_batch:
                return torch.utils.data.dataloader.default_collate(valid_batch)
            else:
                # Return dummy batch
                return torch.randn(1, 3, 32, 32), torch.tensor([0])

class RobustDataset(Dataset):
    """Dataset wrapper with error handling and recovery"""
    
    def __init__(self, base_dataset, max_retries=3):
        self.base_dataset = base_dataset
        self.max_retries = max_retries
        self.failed_indices = set()
        
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        original_idx = idx
        retries = 0
        
        while retries < self.max_retries:
            try:
                # Skip known failed samples
                while idx in self.failed_indices and retries < len(self.base_dataset):
                    idx = (idx + 1) % len(self.base_dataset)
                
                return self.base_dataset[idx]
                
            except Exception as e:
                print(f"⚠ Sample {idx} failed (attempt {retries + 1}): {e}")
                self.failed_indices.add(idx)
                retries += 1
                idx = (idx + 1) % len(self.base_dataset)
        
        # Return fallback sample
        print(f"❌ All retries failed for sample {original_idx}, returning fallback")
        return self._get_fallback_sample()
    
    def _get_fallback_sample(self):
        """Generate a fallback sample when all else fails"""
        # Create a simple fallback based on expected data structure
        try:
            sample = self.base_dataset[0]
            if isinstance(sample, (list, tuple)) and len(sample) == 2:
                data, target = sample
                fallback_data = torch.zeros_like(data)
                return fallback_data, target
        except:
            pass
        
        # Ultimate fallback
        return torch.zeros(3, 32, 32), torch.tensor(0)

print("✓ Robust error handling patterns defined")
```

## DataModule Testing Framework

```python
class DataModuleTestSuite:
    """Comprehensive testing suite for DataModules"""
    
    def __init__(self, datamodule: pl.LightningDataModule):
        self.dm = datamodule
        self.test_results = {}
    
    def run_all_tests(self):
        """Run comprehensive test suite"""
        print(f"=== Testing {type(self.dm).__name__} ===")
        
        tests = [
            ("Setup", self._test_setup),
            ("DataLoader Creation", self._test_dataloader_creation),
            ("Batch Loading", self._test_batch_loading),
            ("Data Consistency", self._test_data_consistency),
            ("Memory Usage", self._test_memory_usage),
            ("Performance", self._test_performance),
            ("Error Handling", self._test_error_handling)
        ]
        
        for test_name, test_func in tests:
            print(f"\n--- {test_name} ---")
            try:
                result = test_func()
                self.test_results[test_name] = {"status": "PASSED", "result": result}
                print(f"✅ {test_name}: PASSED")
            except Exception as e:
                self.test_results[test_name] = {"status": "FAILED", "error": str(e)}
                print(f"❌ {test_name}: FAILED - {e}")
        
        self._print_summary()
        return self.test_results
    
    def _test_setup(self):
        """Test DataModule setup"""
        stages = ["fit", "validate", "test"]
        results = {}
        
        for stage in stages:
            try:
                self.dm.setup(stage)
                results[stage] = "OK"
            except Exception as e:
                results[stage] = f"Error: {e}"
        
        return results
    
    def _test_dataloader_creation(self):
        """Test DataLoader creation"""
        self.dm.setup("fit")
        
        loaders = {}
        if hasattr(self.dm, 'train_dataloader'):
            loaders['train'] = len(self.dm.train_dataloader())
        if hasattr(self.dm, 'val_dataloader'):
            loaders['val'] = len(self.dm.val_dataloader())
        if hasattr(self.dm, 'test_dataloader'):
            loaders['test'] = len(self.dm.test_dataloader())
            
        return loaders
    
    def _test_batch_loading(self):
        """Test batch loading from all dataloaders"""
        self.dm.setup("fit")
        results = {}
        
        loaders = []
        if hasattr(self.dm, 'train_dataloader'):
            loaders.append(('train', self.dm.train_dataloader()))
        if hasattr(self.dm, 'val_dataloader'):
            loaders.append(('val', self.dm.val_dataloader()))
        
        for name, loader in loaders:
            batch = next(iter(loader))
            if isinstance(batch, (list, tuple)):
                results[name] = {
                    'batch_size': len(batch[0]),
                    'input_shape': batch[0].shape if hasattr(batch[0], 'shape') else str(type(batch[0])),
                    'target_shape': batch[1].shape if hasattr(batch[1], 'shape') else str(type(batch[1]))
                }
            else:
                results[name] = {'type': str(type(batch))}
        
        return results
    
    def _test_data_consistency(self):
        """Test data consistency across epochs"""
        self.dm.setup("fit")
        train_loader = self.dm.train_dataloader()
        
        # Get first batch from two different iterations
        batch1 = next(iter(train_loader))
        batch2 = next(iter(train_loader))
        
        if isinstance(batch1, (list, tuple)) and isinstance(batch2, (list, tuple)):
            return {
                'shapes_consistent': batch1[0].shape == batch2[0].shape,
                'dtypes_consistent': batch1[0].dtype == batch2[0].dtype,
                'batch_sizes_equal': len(batch1[0]) == len(batch2[0])
            }
        return {'status': 'Cannot test - unexpected batch format'}
    
    def _test_memory_usage(self):
        """Test memory usage during data loading"""
        import psutil
        import os
        
        process = psutil.Process(os.getpid())
        initial_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        self.dm.setup("fit")
        train_loader = self.dm.train_dataloader()
        
        # Load several batches
        for i, batch in enumerate(train_loader):
            if i >= 5:  # Load 5 batches
                break
        
        final_memory = process.memory_info().rss / 1024 / 1024  # MB
        memory_increase = final_memory - initial_memory
        
        return {
            'initial_memory_mb': round(initial_memory, 2),
            'final_memory_mb': round(final_memory, 2),
            'memory_increase_mb': round(memory_increase, 2)
        }
    
    def _test_performance(self):
        """Test loading performance"""
        import time
        
        self.dm.setup("fit")
        train_loader = self.dm.train_dataloader()
        
        # Warmup
        next(iter(train_loader))
        
        # Time batch loading
        start_time = time.time()
        batch_count = 0
        
        for batch in train_loader:
            batch_count += 1
            if batch_count >= 10:
                break
        
        end_time = time.time()
        avg_batch_time = (end_time - start_time) / batch_count
        
        return {
            'batches_tested': batch_count,
            'avg_batch_time_ms': round(avg_batch_time * 1000, 2),
            'estimated_batches_per_sec': round(1 / avg_batch_time, 2)
        }
    
    def _test_error_handling(self):
        """Test error handling capabilities"""
        results = {}
        
        # Test with invalid stage
        try:
            self.dm.setup("invalid_stage")
            results['invalid_stage'] = "No error raised"
        except:
            results['invalid_stage'] = "Error properly handled"
        
        # Test multiple setups
        try:
            self.dm.setup("fit")
            self.dm.setup("fit")  # Second setup
            results['multiple_setup'] = "OK"
        except Exception as e:
            results['multiple_setup'] = f"Error: {e}"
        
        return results
    
    def _print_summary(self):
        """Print test summary"""
        print(f"\n=== Test Summary ===")
        passed = sum(1 for result in self.test_results.values() if result['status'] == 'PASSED')
        total = len(self.test_results)
        
        print(f"Tests Passed: {passed}/{total}")
        
        if passed == total:
            print("🎉 All tests passed!")
        else:
            print("⚠ Some tests failed. Check results above.")

# Test all our DataModules
datamodules_to_test = [
    vision_dm,
    nlp_dm,
    tabular_dm,
    ts_dm,
    multimodal_dm
]

print("=== DataModule Testing Framework ===")
for dm in datamodules_to_test:
    test_suite = DataModuleTestSuite(dm)
    results = test_suite.run_all_tests()
    print(f"\n{'='*50}")
```

## Summary

### Key Takeaways

This comprehensive notebook covered building robust LightningDataModules for various data types and scenarios:

**Core DataModule Concepts:**
- **Structure**: `prepare_data()` for one-time setup, `setup()` for dataset creation, and individual dataloader methods
- **Best Practices**: Proper data splits, normalization, and hyperparameter saving
- **Flexibility**: Support for different stages (fit, validate, test, predict)

**Data Type Implementations:**
- **Vision DataModules**: Image preprocessing, augmentation, and batch handling
- **NLP DataModules**: Text tokenization, vocabulary building, and sequence padding
- **Tabular DataModules**: Feature scaling, categorical encoding, and structured data handling
- **Time Series DataModules**: Sequential data processing and temporal splitting
- **Multi-Modal DataModules**: Combined vision and text data processing

**Advanced Patterns:**
- **Augmentation Integration**: MixUp, CutMix, and batch-level augmentations
- **Error Handling**: Robust data loading with fallback mechanisms
- **Performance Optimization**: Memory management and loading speed optimization
- **Testing Framework**: Comprehensive validation of DataModule functionality

**Production Considerations:**
- Memory usage monitoring and optimization
- Error recovery and graceful degradation
- Performance benchmarking and bottleneck identification
- Comprehensive testing and validation

**DataModule Benefits:**
- **Reproducibility**: Consistent data processing across experiments
- **Modularity**: Reusable data loading logic independent of model code
- **Distributed Training**: Automatic handling of multi-GPU data loading
- **Experimentation**: Easy swapping of different data processing strategies

The DataModule pattern is essential for building maintainable, scalable ML pipelines in PyTorch Lightning, providing a clean separation between data processing and model training logic.