# Light Curve Data Exploration and Preprocessing

This notebook demonstrates how to:
1. Load and explore Kepler/TESS light curve data
2. Apply preprocessing steps
3. Prepare data for model training
4. Visualize light curves and preprocessing results

In [None]:
import sys
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from preprocessing import LightCurvePreprocessor, DataAugmenter
from utils import (
    plot_light_curve, 
    plot_multiple_light_curves,
    get_kepler_file_info,
    create_data_splits
)

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

## 1. Load and Inspect Raw Data

In [None]:
# Example FITS file path
fits_file = '../data/raw/kplr001234567-2009131105131_llc.fits'

# Get file metadata
try:
    info = get_kepler_file_info(fits_file)
    print("File Metadata:")
    for key, value in info.items():
        print(f"  {key}: {value}")
except FileNotFoundError:
    print("Example file not found. Please add your FITS files to data/raw/")

In [None]:
# Load raw light curve
preprocessor = LightCurvePreprocessor()

try:
    df = preprocessor.load_fits(fits_file)
    print(f"Loaded {len(df)} data points")
    print(f"Time range: {df['time'].min():.2f} to {df['time'].max():.2f} BTJD")
    print(f"Flux range: {df['flux'].min():.2f} to {df['flux'].max():.2f}")
    
    # Plot raw light curve
    plot_light_curve(
        df['time'].values, 
        df['flux'].values,
        title="Raw Light Curve"
    )
except FileNotFoundError:
    print("Please add FITS files to continue")

## 2. Apply Preprocessing Steps

In [None]:
# Initialize preprocessor with custom parameters
preprocessor = LightCurvePreprocessor(
    sigma_threshold=3.0,
    rolling_window=50,
    savgol_window=101,
    savgol_poly=3,
    max_gap_days=2.0,
    segment_duration_days=90.0,
    cadence_minutes=30.0
)

# Process the light curve
try:
    segments, timestamps = preprocessor.preprocess(
        fits_file,
        segment=True,
        overlap=0.0
    )
    
    print(f"Created {len(segments)} segments")
    print(f"Segment lengths: {[len(seg) for seg in segments[:5]]}...")
except FileNotFoundError:
    print("Please add FITS files to continue")
    segments = []

## 3. Visualize Preprocessing Results

In [None]:
# Plot first few segments
if len(segments) > 0:
    n_display = min(6, len(segments))
    
    fig, axes = plt.subplots(n_display, 1, figsize=(14, 3*n_display))
    if n_display == 1:
        axes = [axes]
    
    for i in range(n_display):
        axes[i].plot(segments[i], linewidth=0.5, color='black')
        axes[i].set_title(f'Segment {i+1} ({len(segments[i])} points)')
        axes[i].set_xlabel('Time Step')
        axes[i].set_ylabel('Normalized Flux')
        axes[i].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 4. Data Augmentation Examples

In [None]:
if len(segments) > 0:
    augmenter = DataAugmenter()
    original = segments[0]
    
    # Apply different augmentations
    aug_noise = augmenter.add_noise(original)
    aug_shift = augmenter.time_shift(original)
    aug_scale = augmenter.amplitude_scale(original)
    aug_warp = augmenter.time_warp(original)
    
    # Plot comparisons
    fig, axes = plt.subplots(5, 1, figsize=(14, 12))
    
    axes[0].plot(original, linewidth=0.5, color='black')
    axes[0].set_title('Original')
    
    axes[1].plot(aug_noise, linewidth=0.5, color='blue')
    axes[1].set_title('With Added Noise')
    
    axes[2].plot(aug_shift, linewidth=0.5, color='green')
    axes[2].set_title('Time Shifted')
    
    axes[3].plot(aug_scale, linewidth=0.5, color='red')
    axes[3].set_title('Amplitude Scaled')
    
    axes[4].plot(aug_warp, linewidth=0.5, color='purple')
    axes[4].set_title('Time Warped')
    
    for ax in axes:
        ax.set_xlabel('Time Step')
        ax.set_ylabel('Flux')
        ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 5. Batch Processing and Data Preparation

In [None]:
# Process multiple files
raw_data_dir = Path('../data/raw')
processed_dir = Path('../data/processed')
processed_dir.mkdir(parents=True, exist_ok=True)

# Get all FITS and CSV files
fits_files = list(raw_data_dir.glob('*.fits'))
csv_files = list(raw_data_dir.glob('*.csv'))
all_files = fits_files + csv_files

print(f"Found {len(all_files)} files to process")
print(f"  FITS: {len(fits_files)}")
print(f"  CSV: {len(csv_files)}")

In [None]:
# Process all files (example with mock labels)
all_segments = []
all_labels = []  # You need to provide actual labels
all_filenames = []

for file_path in all_files[:10]:  # Process first 10 as example
    try:
        print(f"Processing {file_path.name}...")
        segments, _ = preprocessor.preprocess(str(file_path), segment=True)
        
        for segment in segments:
            all_segments.append(segment)
            # TODO: Add actual label logic here
            # Example: check filename for 'transit' keyword
            if 'transit' in file_path.name.lower():
                all_labels.append(1)  # Transit
            else:
                all_labels.append(0)  # Non-transit
            all_filenames.append(file_path.name)
        
        print(f"  Created {len(segments)} segments")
    except Exception as e:
        print(f"  Error: {e}")

print(f"\nTotal segments: {len(all_segments)}")

## 6. Create Train/Val/Test Splits

In [None]:
if len(all_segments) > 0:
    # Convert to numpy arrays
    # Pad segments to same length
    max_len = 4320
    
    def pad_segment(seg, length):
        if len(seg) >= length:
            return seg[:length]
        else:
            return np.concatenate([seg, np.zeros(length - len(seg))])
    
    flux_array = np.array([pad_segment(seg, max_len) for seg in all_segments])
    labels_array = np.array(all_labels)
    
    # Shuffle and split
    np.random.seed(42)
    indices = np.arange(len(flux_array))
    np.random.shuffle(indices)
    
    n_train = int(0.7 * len(indices))
    n_val = int(0.15 * len(indices))
    
    train_idx = indices[:n_train]
    val_idx = indices[n_train:n_train+n_val]
    test_idx = indices[n_train+n_val:]
    
    # Save processed data
    np.savez(
        processed_dir / 'train_data.npz',
        flux=flux_array[train_idx],
        labels=labels_array[train_idx]
    )
    
    np.savez(
        processed_dir / 'val_data.npz',
        flux=flux_array[val_idx],
        labels=labels_array[val_idx]
    )
    
    np.savez(
        processed_dir / 'test_data.npz',
        flux=flux_array[test_idx],
        labels=labels_array[test_idx]
    )
    
    print(f"Saved processed data to {processed_dir}")
    print(f"  Train: {len(train_idx)} samples")
    print(f"  Val: {len(val_idx)} samples")
    print(f"  Test: {len(test_idx)} samples")

## 7. Data Statistics

In [None]:
if len(all_segments) > 0:
    # Compute statistics
    segment_lengths = [len(seg) for seg in all_segments]
    
    print("Segment Length Statistics:")
    print(f"  Mean: {np.mean(segment_lengths):.1f}")
    print(f"  Median: {np.median(segment_lengths):.1f}")
    print(f"  Min: {np.min(segment_lengths)}")
    print(f"  Max: {np.max(segment_lengths)}")
    
    # Plot histogram
    plt.figure(figsize=(10, 5))
    plt.hist(segment_lengths, bins=50, edgecolor='black')
    plt.xlabel('Segment Length')
    plt.ylabel('Count')
    plt.title('Distribution of Segment Lengths')
    plt.axvline(max_len, color='r', linestyle='--', label=f'Target Length ({max_len})')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.show()
    
    # Label distribution
    unique, counts = np.unique(labels_array, return_counts=True)
    print("\nLabel Distribution:")
    for label, count in zip(unique, counts):
        print(f"  Class {label}: {count} ({100*count/len(labels_array):.1f}%)")
    
    # Plot label distribution
    plt.figure(figsize=(8, 5))
    plt.bar(['Non-Transit', 'Transit'], counts)
    plt.ylabel('Count')
    plt.title('Class Distribution')
    plt.grid(alpha=0.3, axis='y')
    plt.show()

## Next Steps

1. Add your FITS/CSV files to `data/raw/`
2. Update the label assignment logic in Section 5 based on your dataset
3. Run this notebook to create processed data
4. Run `python src/main.py --config configs/config.yaml` to train the model
5. Use `src/evaluate.py` to assess model performance

## Training Command

After preprocessing, train your model:

```bash
cd /path/to/project
python src/main.py --config configs/config.yaml
```