# Transient Classification from Light Curves

This notebook demonstrates the classification of astronomical transients and variable stars using deep learning on time-series photometry (light curves).

## The Science of Transient Astronomy

**Transient astronomy** studies objects that change brightness over time. These include:

### Explosive Transients (Single Events)

| Type | Physical Origin | Timescale | Scientific Importance |
|------|----------------|-----------|----------------------|
| **Type Ia Supernova** | White dwarf thermonuclear explosion | ~30 days rise, months decline | Cosmological distance ladder, dark energy discovery |
| **Type II Supernova** | Massive star core collapse | ~100 days plateau | Nucleosynthesis, neutron star formation |
| **Kilonova** | Neutron star merger | ~1-2 days | Heavy element production (gold, platinum), gravitational waves |
| **TDE** | Star destroyed by black hole | Weeks to months | Black hole demographics, accretion physics |
| **SLSN** | Unknown (magnetar?) | Months | Extreme stellar explosions, early universe |

### Variable Stars (Periodic/Recurring)

| Type | Physical Origin | Period | Scientific Importance |
|------|----------------|--------|----------------------|
| **RR Lyrae** | Radial pulsation | 0.2-1 day | Distance indicators, old stellar populations |
| **Cepheids** | Radial pulsation | 1-100 days | Primary distance ladder rung |
| **Eclipsing Binaries** | Stellar eclipses | Hours to days | Stellar masses and radii |
| **AGN** | Black hole accretion | Stochastic | Supermassive black holes, galaxy evolution |

---

In [None]:
# Core imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

# Add src to path
sys.path.insert(0, str(Path('.').resolve().parent))

# Deep learning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder

# Our modules
from src.preprocessing import LightCurvePreprocessor
from src.models import TransientCNN1D, TransientLSTM, HybridCNNLSTM, create_classical_pipeline
from src.visualization import plot_light_curve, plot_confusion_matrix, plot_training_history

# Settings
plt.style.use('seaborn-v0_8-whitegrid')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

# Passband colors (LSST ugrizy)
PASSBAND_COLORS = {
    0: '#8b5cf6',  # u - ultraviolet
    1: '#3b82f6',  # g - blue
    2: '#22c55e',  # r - green
    3: '#f97316',  # i - orange
    4: '#ef4444',  # z - red
    5: '#7f1d1d',  # y - dark red
}
PASSBAND_NAMES = {0: 'u', 1: 'g', 2: 'r', 3: 'i', 4: 'z', 5: 'y'}

## 1. Understanding Light Curves

A **light curve** is a time series of brightness measurements. In modern surveys:

- **Multi-band photometry**: Measurements in multiple filters (colors)
- **Non-uniform sampling**: Observations depend on weather, moon phase, telescope scheduling
- **Varying uncertainties**: Brighter objects have lower relative errors

### Why Multi-Band Matters

Different transient types have characteristic **colors** (flux ratios between bands):

- **Kilonovae**: Very red (heavy element emission)
- **Type Ia SNe**: Blue at peak, reddens over time
- **AGN**: Relatively constant color with varying brightness

In [None]:
# Generate synthetic light curve data
np.random.seed(42)

def generate_light_curve(class_id, object_id):
    """Generate a synthetic light curve for a given transient class."""
    n_obs = np.random.randint(60, 150)
    mjd = np.sort(np.random.uniform(59000, 60000, n_obs))
    passbands = np.random.choice([0, 1, 2, 3, 4, 5], n_obs)
    
    # Class-specific light curve shapes
    if class_id == 0:  # SNIa
        peak_mjd = np.random.uniform(59300, 59700)
        t = mjd - peak_mjd
        peak_flux = np.random.uniform(500, 2000)
        flux = peak_flux * np.exp(-0.5 * (t/15)**2) * (t < 100)
        flux += peak_flux * 0.3 * np.exp(-t/40) * (t > 0) * (t < 100)
        
    elif class_id == 1:  # SNII
        peak_mjd = np.random.uniform(59300, 59700)
        t = mjd - peak_mjd
        peak_flux = np.random.uniform(300, 1500)
        flux = peak_flux * np.exp(-0.5 * (t/20)**2) * (t < 50)
        flux += peak_flux * 0.5 * (t >= 50) * (t < 100)  # Plateau
        flux *= np.exp(-t/100) * (t > 0)
        
    elif class_id == 2:  # Kilonova
        peak_mjd = np.random.uniform(59300, 59700)
        t = mjd - peak_mjd
        peak_flux = np.random.uniform(200, 800)
        flux = peak_flux * np.exp(-t/2) * (t > 0) * (t < 20)  # Very fast!
        
    elif class_id == 3:  # RR Lyrae
        period = np.random.uniform(0.4, 0.9)
        amplitude = np.random.uniform(100, 300)
        flux = 500 + amplitude * np.sin(2 * np.pi * mjd / period)
        
    elif class_id == 4:  # AGN
        base_flux = np.random.uniform(200, 800)
        flux = base_flux + 100 * np.cumsum(np.random.randn(n_obs)) / np.sqrt(n_obs)
        
    else:  # Eclipsing Binary
        period = np.random.uniform(0.5, 5)
        phase = (mjd % period) / period
        flux = 800 - 200 * (np.abs(phase - 0.5) < 0.1).astype(float)  # Primary eclipse
        flux -= 100 * (np.abs(phase) < 0.05).astype(float)  # Secondary eclipse
    
    flux = np.maximum(flux, 0)
    flux_err = np.sqrt(flux + 100) * np.random.uniform(0.8, 1.2, n_obs)
    flux += np.random.randn(n_obs) * flux_err * 0.5
    
    # Color dependence
    color_offset = (passbands - 2) * 50 * np.random.uniform(0.5, 1.5)
    flux += color_offset
    flux = np.maximum(flux, 1)
    
    return pd.DataFrame({
        'object_id': object_id,
        'mjd': mjd,
        'passband': passbands,
        'flux': flux,
        'flux_err': flux_err
    })

# Generate dataset
class_names = ['SNIa', 'SNII', 'Kilonova', 'RRLyrae', 'AGN', 'EclipsingBinary']
n_per_class = 100

light_curves = []
metadata = []

object_id = 0
for class_id, class_name in enumerate(class_names):
    for _ in range(n_per_class):
        lc = generate_light_curve(class_id, object_id)
        light_curves.append(lc)
        metadata.append({
            'object_id': object_id,
            'target': class_id,
            'class_name': class_name
        })
        object_id += 1

lc_df = pd.concat(light_curves, ignore_index=True)
meta_df = pd.DataFrame(metadata)

print(f"Total light curves: {len(meta_df)}")
print(f"Total observations: {len(lc_df)}")
print(f"\nClass distribution:")
print(meta_df['class_name'].value_counts())

In [None]:
# Visualize example light curves for each class
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, class_name in enumerate(class_names):
    ax = axes[i]
    class_objects = meta_df[meta_df['class_name'] == class_name]['object_id'].values[:1]
    
    for obj_id in class_objects:
        obj_lc = lc_df[lc_df['object_id'] == obj_id]
        
        for pb in sorted(obj_lc['passband'].unique()):
            pb_data = obj_lc[obj_lc['passband'] == pb].sort_values('mjd')
            ax.errorbar(
                pb_data['mjd'], pb_data['flux'], 
                yerr=pb_data['flux_err'],
                fmt='o', color=PASSBAND_COLORS[pb], 
                label=PASSBAND_NAMES[pb], alpha=0.7,
                markersize=4, capsize=2
            )
    
    ax.set_xlabel('MJD')
    ax.set_ylabel('Flux')
    ax.set_title(class_name, fontsize=12, fontweight='bold')
    ax.legend(loc='best', fontsize=8)

plt.suptitle('Example Light Curves by Transient Class', fontsize=14)
plt.tight_layout()
plt.savefig('../images/transient_examples.png', dpi=150, bbox_inches='tight')
plt.show()

## 2. Data Preprocessing

Light curves require special preprocessing:

1. **Interpolation to regular grid**: Neural networks expect fixed-size input
2. **Per-band processing**: Each filter has different sensitivity
3. **Normalization**: Make comparison across objects possible

### Why Interpolation?

Real observations are **irregularly sampled**:
- Weather cancellations
- Moon phase (can't observe near bright moon)
- Telescope scheduling (multiple programs share time)

Neural networks need regular grids, so we interpolate to a fixed time grid.

In [None]:
# Preprocess light curves
preprocessor = LightCurvePreprocessor(n_time_bins=100, passbands=[0,1,2,3,4,5])

object_ids = meta_df['object_id'].values
labels = meta_df['target'].values

# Process all light curves
X_interp = []
X_features = []

for obj_id in object_ids:
    result = preprocessor.preprocess(lc_df, object_id=obj_id)
    X_interp.append(result['interpolated'])
    X_features.append(result['features'])

X_interp = np.array(X_interp)  # Shape: (n_objects, n_bands, n_time)
X_features = np.array(X_features)  # Shape: (n_objects, n_features)

print(f"Interpolated shape: {X_interp.shape}")
print(f"Features shape: {X_features.shape}")

In [None]:
# Visualize interpolated light curve
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Original (irregular sampling)
ax = axes[0]
obj_id = 0
obj_lc = lc_df[lc_df['object_id'] == obj_id]
for pb in sorted(obj_lc['passband'].unique()):
    pb_data = obj_lc[obj_lc['passband'] == pb].sort_values('mjd')
    ax.scatter(pb_data['mjd'], pb_data['flux'], 
               color=PASSBAND_COLORS[pb], label=PASSBAND_NAMES[pb], s=20)
ax.set_xlabel('MJD')
ax.set_ylabel('Flux')
ax.set_title('Original Light Curve (Irregular Sampling)')
ax.legend()

# Interpolated (regular grid)
ax = axes[1]
for pb in range(6):
    ax.plot(X_interp[obj_id, pb], color=PASSBAND_COLORS[pb], 
            label=PASSBAND_NAMES[pb], linewidth=1.5)
ax.set_xlabel('Time Bin')
ax.set_ylabel('Normalized Flux')
ax.set_title('Interpolated Light Curve (Regular Grid)')
ax.legend()

plt.tight_layout()
plt.savefig('../images/lc_preprocessing.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Model Architecture Comparison

We compare three approaches for light curve classification:

### 1D-CNN: Local Feature Detection
- Detects **local temporal patterns**: rise shape, peak shape, decline rate
- Translation-invariant: can detect features at any time

### LSTM: Sequential Dependencies
- Models **long-range dependencies**: early behavior predicts late behavior
- Memory cells preserve information across long sequences
- Bidirectional: sees both past and future context

### Hybrid CNN-LSTM: Best of Both
- CNN extracts local features
- LSTM models relationships between features over time

### Why NOT 2D-CNN?

Light curves are **1D time series**, not images!

- 2D-CNNs would require creating artificial images (e.g., spectrograms)
- This loses the natural temporal ordering
- 1D-CNNs directly process the sequential nature of the data

In [None]:
# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X_interp, labels, test_size=0.2, stratify=labels, random_state=42
)

# Convert to tensors
X_train_torch = torch.FloatTensor(X_train)
X_test_torch = torch.FloatTensor(X_test)
y_train_torch = torch.LongTensor(y_train)
y_test_torch = torch.LongTensor(y_test)

# Data loaders
train_dataset = TensorDataset(X_train_torch, y_train_torch)
test_dataset = TensorDataset(X_test_torch, y_test_torch)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Training samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")

In [None]:
# Define training function
def train_model(model, train_loader, test_loader, n_epochs=30, lr=0.001):
    """Train a model and return history."""
    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    
    for epoch in range(n_epochs):
        # Training
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, pred = torch.max(outputs, 1)
            train_total += batch_y.size(0)
            train_correct += (pred == batch_y).sum().item()
        
        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        
        with torch.no_grad():
            for batch_x, batch_y in test_loader:
                batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                
                val_loss += loss.item()
                _, pred = torch.max(outputs, 1)
                val_total += batch_y.size(0)
                val_correct += (pred == batch_y).sum().item()
        
        history['train_loss'].append(train_loss / len(train_loader))
        history['val_loss'].append(val_loss / len(test_loader))
        history['train_acc'].append(train_correct / train_total)
        history['val_acc'].append(val_correct / val_total)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}: Val Acc = {history['val_acc'][-1]:.4f}")
    
    return model, history

In [None]:
# Train 1D-CNN
print("Training 1D-CNN...")
cnn_model = TransientCNN1D(n_classes=6, n_bands=6, n_time=100)
cnn_model, cnn_history = train_model(cnn_model, train_loader, test_loader, n_epochs=30)

print(f"\n1D-CNN Final Validation Accuracy: {cnn_history['val_acc'][-1]:.4f}")

In [None]:
# Train LSTM
print("Training LSTM...")
lstm_model = TransientLSTM(n_classes=6, n_bands=6, hidden_size=128)
lstm_model, lstm_history = train_model(lstm_model, train_loader, test_loader, n_epochs=30)

print(f"\nLSTM Final Validation Accuracy: {lstm_history['val_acc'][-1]:.4f}")

In [None]:
# Train Hybrid CNN-LSTM
print("Training Hybrid CNN-LSTM...")
hybrid_model = HybridCNNLSTM(n_classes=6, n_bands=6)
hybrid_model, hybrid_history = train_model(hybrid_model, train_loader, test_loader, n_epochs=30)

print(f"\nHybrid CNN-LSTM Final Validation Accuracy: {hybrid_history['val_acc'][-1]:.4f}")

In [None]:
# Compare models
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Validation loss
ax = axes[0]
ax.plot(cnn_history['val_loss'], label='1D-CNN', linewidth=2)
ax.plot(lstm_history['val_loss'], label='LSTM', linewidth=2)
ax.plot(hybrid_history['val_loss'], label='Hybrid CNN-LSTM', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Loss')
ax.set_title('Model Comparison: Validation Loss')
ax.legend()

# Validation accuracy
ax = axes[1]
ax.plot(cnn_history['val_acc'], label='1D-CNN', linewidth=2)
ax.plot(lstm_history['val_acc'], label='LSTM', linewidth=2)
ax.plot(hybrid_history['val_acc'], label='Hybrid CNN-LSTM', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Accuracy')
ax.set_title('Model Comparison: Validation Accuracy')
ax.legend()

plt.tight_layout()
plt.savefig('../images/transient_model_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# Summary
print("\nModel Comparison Summary:")
print(f"  1D-CNN:         {max(cnn_history['val_acc']):.4f}")
print(f"  LSTM:           {max(lstm_history['val_acc']):.4f}")
print(f"  Hybrid CNN-LSTM: {max(hybrid_history['val_acc']):.4f}")

## 4. Detailed Evaluation

### Interpretation of Confusion Matrix

Common confusions in transient classification:

- **SNIa ↔ SNII**: Both are supernovae, but different physics
- **RRLyrae ↔ Eclipsing Binary**: Both periodic, but different shapes
- **Kilonova confusion**: Very fast, few observations, hard to classify

In [None]:
# Evaluate best model (Hybrid)
hybrid_model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(DEVICE)
        outputs = hybrid_model(batch_x)
        _, pred = torch.max(outputs, 1)
        all_preds.extend(pred.cpu().numpy())
        all_labels.extend(batch_y.numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Classification report
print("Classification Report (Hybrid CNN-LSTM):")
print(classification_report(all_labels, all_preds, target_names=class_names))

In [None]:
# Confusion matrix
fig = plot_confusion_matrix(all_labels, all_preds, class_names=class_names)
plt.savefig('../images/transient_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Classical ML Comparison

Deep learning isn't always necessary. Classical ML with hand-crafted features often works well:

**Advantages of Classical ML:**
- Faster training
- More interpretable
- Works with less data

**Advantages of Deep Learning:**
- Learns features automatically
- Can model complex patterns
- Scales to large datasets

In [None]:
# Classical ML with extracted features
X_feat_train, X_feat_test, y_feat_train, y_feat_test = train_test_split(
    X_features, labels, test_size=0.2, stratify=labels, random_state=42
)

# Handle NaN values
X_feat_train = np.nan_to_num(X_feat_train, nan=0.0)
X_feat_test = np.nan_to_num(X_feat_test, nan=0.0)

# Train Random Forest
rf_pipeline = create_classical_pipeline('random_forest')
rf_pipeline.fit(X_feat_train, y_feat_train)
rf_score = rf_pipeline.score(X_feat_test, y_feat_test)

print(f"Random Forest Accuracy: {rf_score:.4f}")

# Train Gradient Boosting
gb_pipeline = create_classical_pipeline('gradient_boosting')
gb_pipeline.fit(X_feat_train, y_feat_train)
gb_score = gb_pipeline.score(X_feat_test, y_feat_test)

print(f"Gradient Boosting Accuracy: {gb_score:.4f}")

In [None]:
# Final comparison
results = {
    '1D-CNN': max(cnn_history['val_acc']),
    'LSTM': max(lstm_history['val_acc']),
    'Hybrid CNN-LSTM': max(hybrid_history['val_acc']),
    'Random Forest': rf_score,
    'Gradient Boosting': gb_score
}

fig, ax = plt.subplots(figsize=(10, 6))

models = list(results.keys())
scores = list(results.values())
colors = ['#3498db', '#2ecc71', '#9b59b6', '#e74c3c', '#f39c12']

bars = ax.barh(models, scores, color=colors)

for bar, score in zip(bars, scores):
    ax.text(score + 0.01, bar.get_y() + bar.get_height()/2,
            f'{score:.3f}', va='center', fontsize=11)

ax.set_xlabel('Accuracy', fontsize=12)
ax.set_title('Transient Classification: Model Comparison', fontsize=14)
ax.set_xlim(0, 1.1)

plt.tight_layout()
plt.savefig('../images/transient_final_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Conclusions

### Key Findings

1. **Deep learning effectively classifies transients** from multi-band light curves
2. **Hybrid CNN-LSTM** often performs best by combining local and global features
3. **Classical ML with good features** can be competitive, especially with limited data
4. **Rare transients** (kilonovae) are challenging due to fast evolution and few samples

### Scientific Implications

Automated classification enables:
- **Real-time alerts** for follow-up observations
- **Large-scale surveys** (LSST will produce 10M alerts/night)
- **Discovery of rare events** that would be missed manually

### Limitations

- **Synthetic data**: Real light curves have more complexity
- **Imbalanced classes**: Some transients are very rare
- **Redshift effects**: Distant objects are fainter and time-dilated

### Future Directions

- Use real PLAsTiCC/ZTF data with 14+ classes
- Incorporate host galaxy information
- Attention mechanisms for interpretability
- Active learning for rare transient discovery