# desisky Quickstart Guide

This notebook demonstrates how to:
1. Load pre-trained models
2. Download and load DESI sky spectra data
3. Run inference
4. Visualize results

## Installation

```bash
pip install desisky[cpu,data]
```

In [None]:
import desisky
from desisky.data import SkySpecVAC
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

## 1. Load Pre-trained Broadband Model

The broadband model predicts surface brightness in V, g, r, and z photometric bands from observational metadata.

In [None]:
# Load the pre-trained broadband model
model, meta = desisky.io.load_model("broadband")

print("Model loaded successfully!")
print(f"Architecture: {meta['arch']}")
print(f"Model: {model}")

## 2. Download DESI Sky Spectra Data

The `SkySpecVAC` class provides a PyTorch-like interface for downloading and loading the DESI DR1 sky spectra Value-Added Catalog.

The dataset includes:
- **9,176 sky spectra** with 7,781 wavelength points each
- **Metadata** including observing conditions (airmass, seeing, sky brightness, moon/sun positions, etc.)

First run will download ~274 MB from DESI public data release.

In [None]:
# Download and load the VAC
# Set download=True to download if not already present
# SHA-256 verification ensures data integrity
vac = SkySpecVAC(version="v1.0", download=True, verify=True)

print(f"Data file location: {vac.filepath()}")

In [None]:
# Load wavelength, flux, and metadata
wavelength, flux, metadata = vac.load()

print(f"Wavelength shape: {wavelength.shape}")  # (7781,)
print(f"Flux shape: {flux.shape}")              # (9176, 7781)
print(f"Metadata shape: {metadata.shape}")      # (9176, 22)
print(f"\nAvailable metadata columns:\n{list(metadata.columns)}")

## 3. Explore the Data

Let's examine the metadata to understand what observing conditions are available.

In [None]:
# Show first few rows of metadata
metadata.head(10)

In [None]:
# Summary statistics
metadata[['AIRMASS', 'SEEING_ETC', 'SKY_MAG_G_SPEC', 'SKY_MAG_R_SPEC', 'MOONFRAC']].describe()

## 4. Visualize Sky Spectra

Plot some example spectra to see what the data looks like.

In [None]:
# Plot a few random spectra
fig, ax = plt.subplots(figsize=(12, 6))

# Select 5 random spectra
np.random.seed(42)
indices = np.random.choice(flux.shape[0], size=5, replace=False)

for idx in indices:
    ax.plot(wavelength, flux[idx], alpha=0.7, linewidth=0.8, 
            label=f"Spectrum {idx} (airmass={metadata.iloc[idx]['AIRMASS']:.2f})")

ax.set_xlabel("Wavelength (Å)", fontsize=12)
ax.set_ylabel("Flux", fontsize=12)
ax.set_title("Example DESI Sky Spectra", fontsize=14, fontweight='bold')
ax.legend(loc='best', fontsize=9)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 5. Zoom into Emission Lines

Sky spectra are dominated by emission lines. Let's zoom into a region to see them clearly.

In [None]:
# Zoom into 5500-6000 Å (includes strong OI lines)
fig, ax = plt.subplots(figsize=(12, 6))

# Select wavelength range
mask = (wavelength > 5500) & (wavelength < 6000)
wave_zoom = wavelength[mask]

for idx in indices:
    flux_zoom = flux[idx, mask]
    ax.plot(wave_zoom, flux_zoom, alpha=0.7, linewidth=1.0,
            label=f"Spectrum {idx}")

ax.set_xlabel("Wavelength (Å)", fontsize=12)
ax.set_ylabel("Flux", fontsize=12)
ax.set_title("Sky Emission Lines (5500-6000 Å)", fontsize=14, fontweight='bold')
ax.legend(loc='best', fontsize=9)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Observing Conditions Distribution

Visualize the distribution of observing conditions in the dataset.

In [None]:
# Plot distributions of key observing parameters
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.ravel()

params = [
    ('AIRMASS', 'Airmass'),
    ('SEEING_ETC', 'Seeing (arcsec)'),
    ('SKY_MAG_R_SPEC', 'Sky Mag (r-band)'),
    ('MOONFRAC', 'Moon Illumination'),
    ('MOONALT', 'Moon Altitude (deg)'),
    ('EBV', 'E(B-V) Extinction')
]

for i, (col, label) in enumerate(params):
    axes[i].hist(metadata[col], bins=30, alpha=0.7, edgecolor='black')
    axes[i].set_xlabel(label, fontsize=10)
    axes[i].set_ylabel('Count', fontsize=10)
    axes[i].grid(True, alpha=0.3)

plt.suptitle('Distribution of Observing Conditions', fontsize=14, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

## 7. Model Inference Example

Run the broadband model on example inputs to predict V, g, r, z magnitudes.

**Note:** You'll need to know the specific input features the model expects. This is a placeholder example.

In [None]:
# Example: Create dummy input (replace with actual features)
# The model expects 6 input features
example_input = jnp.ones((meta['arch']['in_size'],))

# Run inference
predicted_mags = model(example_input)

print(f"Input shape: {example_input.shape}")  # (6,)
print(f"Output shape: {predicted_mags.shape}")  # (4,)
print(f"\nPredicted magnitudes (V, g, r, z): {predicted_mags}")

## 8. Batch Inference

Process multiple inputs at once using JAX's vectorization.

In [None]:
import jax

# Create batch of inputs
batch_size = 100
batch_inputs = jnp.ones((batch_size, meta['arch']['in_size']))

# Vectorize the model for batch processing
batch_model = jax.vmap(model)

# Run batch inference
batch_predictions = batch_model(batch_inputs)

print(f"Batch input shape: {batch_inputs.shape}")      # (100, 6)
print(f"Batch output shape: {batch_predictions.shape}")  # (100, 4)
print(f"\nFirst 5 predictions:\n{batch_predictions[:5]}")

## Summary

This notebook demonstrated:
- ✅ Loading pre-trained models with `desisky.io.load_model()`
- ✅ Downloading DESI data with automatic verification
- ✅ Loading and exploring sky spectra
- ✅ Visualizing emission lines and observing conditions
- ✅ Running model inference (single and batch)

## Next Steps

- Train your own models on the VAC data
- Integrate with SpecSim for survey forecasting
- Explore correlations between observing conditions and sky brightness
- Build generative models for synthetic sky spectra