# Data Augmentation and Cross-Validation Analysis\n\nThis notebook provides comprehensive analysis of data augmentation techniques and cross-validation strategies for exoplanet detection. We'll explore the impact of different augmentation methods and validate our models using robust cross-validation approaches.\n\n## Objectives:\n1. **Augmentation Analysis**: Compare traditional vs physics-informed augmentation\n2. **Cross-Validation**: Implement star-level stratified CV to prevent data leakage\n3. **Ablation Studies**: Evaluate individual augmentation components\n4. **Synthetic Transit Validation**: Verify physics-informed generation quality\n5. **Performance Impact**: Quantify improvements from augmentation strategies

In [None]:
# Setup and imports\nimport sys\nimport os\nfrom pathlib import Path\nimport warnings\nwarnings.filterwarnings('ignore')\n\n# Add src to path\nsys.path.insert(0, str(Path.cwd().parent / 'src'))\n\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nfrom tqdm import tqdm\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader\nfrom sklearn.model_selection import StratifiedKFold, GroupKFold\nfrom sklearn.metrics import classification_report, confusion_matrix\nimport plotly.graph_objects as go\nimport plotly.express as px\nfrom plotly.subplots import make_subplots\n\n# Set style and seed\nplt.style.use('seaborn-v0_8')\nsns.set_palette('husl')\n\nprint(\"Environment setup complete!\")\nprint(f\"PyTorch version: {torch.__version__}\")\nprint(f\"CUDA available: {torch.cuda.is_available()}\")

In [None]:
# Import our modules\nfrom data.augmentation import (\n    create_standard_augmentation_pipeline,\n    create_conservative_augmentation_pipeline,\n    create_physics_aware_augmentation_pipeline,\n    TimeJitter, AmplitudeScaling, GaussianNoise, RandomMasking,\n    FrequencyFiltering, TimeWarping, MixUp, TransitPreservingAugmentation\n)\nfrom data.dataset import LightCurveDataset, AugmentedLightCurveDataset, collate_fn\nfrom models.cnn import ExoplanetCNN\nfrom training.trainer import ExoplanetTrainer, create_optimizer, create_scheduler\nfrom training.metrics import MetricsCalculator\nfrom preprocessing.synthetic_injection import SyntheticTransitInjector\nfrom utils.reproducibility import set_seed\n\nset_seed(42)\nprint(\"All modules imported successfully!\")

## 1. Cross-Validation Strategy\n\nWe implement star-level stratified cross-validation to prevent data leakage and ensure robust evaluation.

In [None]:
class StarLevelCrossValidator:\n    \"\"\"\n    Cross-validator that ensures no star appears in both training and validation sets.\n    \"\"\"\n    \n    def __init__(self, n_splits=5, random_state=42):\n        self.n_splits = n_splits\n        self.random_state = random_state\n    \n    def split(self, X, y, groups):\n        \"\"\"\n        Generate train/validation splits ensuring no group (star) overlap.\n        \"\"\"\n        # Create group-based stratified splits\n        unique_groups = np.unique(groups)\n        group_labels = np.array([y[groups == group][0] for group in unique_groups])\n        \n        # Use stratified k-fold on groups\n        skf = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state)\n        \n        for train_group_idx, val_group_idx in skf.split(unique_groups, group_labels):\n            train_groups = unique_groups[train_group_idx]\n            val_groups = unique_groups[val_group_idx]\n            \n            # Convert group indices to sample indices\n            train_idx = np.where(np.isin(groups, train_groups))[0]\n            val_idx = np.where(np.isin(groups, val_groups))[0]\n            \n            yield train_idx, val_idx\n    \n    def get_n_splits(self, X=None, y=None, groups=None):\n        return self.n_splits

## 2. Augmentation Comparison Framework\n\nLet's create a comprehensive framework to compare different augmentation strategies.

In [None]:
def create_augmentation_strategies():\n    \"\"\"Create different augmentation strategies for comparison.\"\"\"\n    \n    strategies = {\n        'no_augmentation': None,\n        'standard': create_standard_augmentation_pipeline(),\n        'conservative': create_conservative_augmentation_pipeline(),\n        'physics_aware': create_physics_aware_augmentation_pipeline()\n    }\n    \n    return strategies\n\ndef evaluate_augmentation_strategy(\n    strategy_name, \n    augmentation_pipeline,\n    train_data, \n    train_labels, \n    train_groups,\n    val_data, \n    val_labels,\n    device,\n    epochs=10\n):\n    \"\"\"Evaluate a single augmentation strategy.\"\"\"\n    \n    print(f\"\\nEvaluating {strategy_name} augmentation...\")\n    \n    # Create datasets\n    if augmentation_pipeline is None:\n        train_dataset = LightCurveDataset(train_data, train_labels)\n    else:\n        train_dataset = AugmentedLightCurveDataset(\n            train_data, train_labels, \n            augmentation_pipeline=augmentation_pipeline\n        )\n    \n    val_dataset = LightCurveDataset(val_data, val_labels)\n    \n    # Create data loaders\n    train_loader = DataLoader(\n        train_dataset, batch_size=32, shuffle=True, \n        collate_fn=collate_fn, num_workers=0\n    )\n    val_loader = DataLoader(\n        val_dataset, batch_size=64, shuffle=False, \n        collate_fn=collate_fn, num_workers=0\n    )\n    \n    # Create model\n    model = ExoplanetCNN(input_channels=2, sequence_length=2048)\n    \n    # Create training components\n    criterion = nn.BCELoss()\n    optimizer = create_optimizer(model, 'adamw', learning_rate=0.001)\n    scheduler = create_scheduler(optimizer, 'cosine', T_max=epochs)\n    \n    # Create trainer\n    trainer = ExoplanetTrainer(\n        model=model,\n        train_loader=train_loader,\n        val_loader=val_loader,\n        criterion=criterion,\n        optimizer=optimizer,\n        scheduler=scheduler,\n        device=device,\n        experiment_name=f\"aug_{strategy_name}\"\n    )\n    \n    # Train model\n    history = trainer.train(epochs=epochs, patience=epochs, verbose=False)\n    \n    # Get final metrics\n    final_metrics = history['val_metrics'][-1] if history['val_metrics'] else {}\n    \n    return {\n        'strategy': strategy_name,\n        'final_metrics': final_metrics,\n        'history': history,\n        'best_f1': trainer.best_val_f1\n    }