# LARUN TinyML - Lightning.ai Training Notebook

**Train the LARUN exoplanet detection model using Lightning.ai FREE GPU**

Created by: Padmanaban Veeraragavalu (Larun Engineering)

---

## Lightning.ai Setup:
1. Go to https://lightning.ai
2. Create free account
3. **Studios** → **New Studio** → Select **GPU**
4. Upload this notebook or clone from GitHub
5. Run all cells

**Lightning Benefits:**
- 22 hours/month FREE GPU
- Persistent storage
- Pre-configured ML environment
- Easy sharing & collaboration

In [None]:
# Step 1: Check environment
!nvidia-smi

import torch
import tensorflow as tf

print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"TensorFlow: {tf.__version__}")
print(f"TF GPUs: {tf.config.list_physical_devices('GPU')}")

In [None]:
# Step 2: Install astronomy packages
!pip install -q lightkurve astroquery tqdm

import lightkurve as lk
print(f"Lightkurve: {lk.__version__}")

In [None]:
# Step 3: Configuration
import os

# Training parameters
NUM_PLANETS = 150
NUM_NON_PLANETS = 150
EPOCHS = 100
BATCH_SIZE = 64
INPUT_SIZE = 1024
MAX_WORKERS = 10

# Output directory (Lightning.ai persistent storage)
OUTPUT_DIR = os.path.expanduser('~/larun_output')
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Output directory: {OUTPUT_DIR}")
print(f"Config: {NUM_PLANETS} planets, {NUM_NON_PLANETS} non-planets, {EPOCHS} epochs")

In [None]:
# Step 4: Fetch exoplanet hosts from NASA
import numpy as np
from astroquery.nasa_exoplanet_archive import NasaExoplanetArchive
import warnings
warnings.filterwarnings('ignore')

print("Querying NASA Exoplanet Archive...")

try:
    planets_table = NasaExoplanetArchive.query_criteria(
        table="pscomppars",
        select="hostname,disc_facility",
        where="disc_facility like '%TESS%' or disc_facility like '%Kepler%'"
    )
    planet_hosts = list(set(planets_table['hostname'].data.tolist()))
except Exception as e:
    print(f"Archive query failed: {e}")
    # Fallback to known hosts
    planet_hosts = [
        "TOI-700", "TRAPPIST-1", "Kepler-186", "Kepler-442", "GJ 357",
        "LHS 1140", "K2-18", "Kepler-22", "Kepler-452", "Kepler-62"
    ]

np.random.shuffle(planet_hosts)
planet_hosts = planet_hosts[:NUM_PLANETS]
print(f"Found {len(planet_hosts)} exoplanet hosts")

In [None]:
# Step 5: Parallel data fetching
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.notebook import tqdm
import lightkurve as lk

def fetch_lightcurve(args):
    """Fetch and process light curve."""
    target, label = args
    try:
        search = lk.search_lightcurve(target, mission=['TESS', 'Kepler'])
        if len(search) == 0:
            return None
        
        lc = search[0].download(quality_bitmask='default')
        lc = lc.remove_nans().normalize().remove_outliers(sigma=3)
        flux = lc.flux.value
        
        # Resample to fixed size
        if len(flux) < INPUT_SIZE:
            flux = np.pad(flux, (0, INPUT_SIZE - len(flux)), mode='median')
        else:
            start = (len(flux) - INPUT_SIZE) // 2
            flux = flux[start:start + INPUT_SIZE]
        
        return {'flux': flux.astype(np.float32), 'label': label, 'target': target}
    except:
        return None

# Fetch planet hosts
print(f"Fetching {len(planet_hosts)} exoplanet hosts...")
planet_data = []

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    tasks = [(host, 1) for host in planet_hosts]
    futures = [executor.submit(fetch_lightcurve, t) for t in tasks]
    
    for future in tqdm(as_completed(futures), total=len(futures)):
        result = future.result()
        if result:
            planet_data.append(result)

print(f"✓ Got {len(planet_data)} planet host light curves")

In [None]:
# Step 6: Fetch non-planet stars
print(f"Fetching {NUM_NON_PLANETS} non-planet stars...")

non_planet_tics = [f"TIC {100000000 + i*100}" for i in range(NUM_NON_PLANETS * 5)]
non_planet_data = []

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    tasks = [(tic, 0) for tic in non_planet_tics]
    futures = [executor.submit(fetch_lightcurve, t) for t in tasks]
    
    for future in tqdm(as_completed(futures), total=len(futures)):
        if len(non_planet_data) >= NUM_NON_PLANETS:
            break
        result = future.result()
        if result:
            non_planet_data.append(result)

non_planet_data = non_planet_data[:NUM_NON_PLANETS]
print(f"✓ Got {len(non_planet_data)} non-planet light curves")

In [None]:
# Step 7: Prepare training data
from sklearn.model_selection import train_test_split

all_data = planet_data + non_planet_data
print(f"Total samples: {len(all_data)}")

X = np.array([d['flux'] for d in all_data])
y = np.array([d['label'] for d in all_data])

# Normalize
X = (X - X.mean(axis=1, keepdims=True)) / (X.std(axis=1, keepdims=True) + 1e-8)
X = X.reshape(-1, INPUT_SIZE, 1).astype(np.float32)

# Split
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Train: {len(X_train)}, Val: {len(X_val)}")
print(f"Class balance: {np.bincount(y_train)}")

# Save data
np.savez(f'{OUTPUT_DIR}/training_data.npz',
         X_train=X_train, y_train=y_train,
         X_val=X_val, y_val=y_val)
print(f"✓ Data saved to {OUTPUT_DIR}/training_data.npz")

In [None]:
# Step 8: Build and train model
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Build model
model = keras.Sequential([
    keras.Input(shape=(INPUT_SIZE, 1)),
    
    layers.Conv1D(32, 7, padding='same', activation='relu'),
    layers.BatchNormalization(),
    layers.MaxPooling1D(4),
    layers.Dropout(0.25),
    
    layers.Conv1D(64, 5, padding='same', activation='relu'),
    layers.BatchNormalization(),
    layers.MaxPooling1D(4),
    layers.Dropout(0.25),
    
    layers.Conv1D(128, 3, padding='same', activation='relu'),
    layers.BatchNormalization(),
    layers.GlobalAveragePooling1D(),
    layers.Dropout(0.5),
    
    layers.Dense(64, activation='relu'),
    layers.Dropout(0.3),
    layers.Dense(2, activation='softmax')
], name='larun_lightning')

model.compile(
    optimizer=keras.optimizers.Adam(0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

In [None]:
# Step 9: Train
callbacks = [
    keras.callbacks.EarlyStopping(patience=15, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=7, min_lr=1e-6),
    keras.callbacks.ModelCheckpoint(f'{OUTPUT_DIR}/larun_best.h5', save_best_only=True)
]

print("Training...")
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1
)

In [None]:
# Step 10: Evaluate and visualize
import matplotlib.pyplot as plt

val_loss, val_acc = model.evaluate(X_val, y_val, verbose=0)

print(f"\nFinal Accuracy: {val_acc*100:.2f}%")
print(f"Best Accuracy: {max(history.history['val_accuracy'])*100:.2f}%")

# Plot
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history.history['accuracy'], label='Train')
axes[0].plot(history.history['val_accuracy'], label='Val')
axes[0].set_title('Accuracy')
axes[0].legend()

axes[1].plot(history.history['loss'], label='Train')
axes[1].plot(history.history['val_loss'], label='Val')
axes[1].set_title('Loss')
axes[1].legend()

plt.savefig(f'{OUTPUT_DIR}/training_history.png')
plt.show()

In [None]:
# Step 11: Export models
import tensorflow as tf

# Save Keras
model.save(f'{OUTPUT_DIR}/larun_model.h5')

# TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite = converter.convert()
with open(f'{OUTPUT_DIR}/larun_model.tflite', 'wb') as f:
    f.write(tflite)

# Quantized
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant = converter.convert()
with open(f'{OUTPUT_DIR}/larun_model_int8.tflite', 'wb') as f:
    f.write(tflite_quant)

print(f"Models saved to {OUTPUT_DIR}/")
print(f"  TFLite: {len(tflite)/1024:.1f} KB")
print(f"  INT8: {len(tflite_quant)/1024:.1f} KB")

In [None]:
# Step 12: Create download package
import shutil

# Zip everything
!cd {OUTPUT_DIR} && zip -r larun_trained.zip *.h5 *.tflite *.npz *.png

print(f"\n{'='*50}")
print("TRAINING COMPLETE!")
print(f"{'='*50}")
print(f"Accuracy: {val_acc*100:.2f}%")
print(f"\nDownload: {OUTPUT_DIR}/larun_trained.zip")
print("\nIn Lightning Studio: File Browser → Navigate to ~/larun_output/")