# LARUN Continuous Training (Auto-Resume)

**Automatically resumes training when credits reset!**

This notebook:
- Saves checkpoints every 10 epochs
- Caches fetched data (survives restarts)
- Auto-resumes from last checkpoint
- Uploads best model to HuggingFace Hub

---

## Setup for Auto-Restart

### Kaggle:
1. **Settings** → **Accelerator** → **GPU T4 x2**
2. **Save & Run All** (creates a version)
3. **Schedule** → Set to run **Weekly** or **Monthly**

### Manual restart:
Just run the notebook again - it will resume from checkpoint!

---
*Larun Engineering*

In [None]:
# ============================================================
# CONFIGURATION - Adjust these for your needs
# ============================================================

# Training targets (increase for better model)
TARGET_PLANETS = 300          # Total planet samples to collect
TARGET_NON_PLANETS = 300      # Total non-planet samples

# Training settings
TOTAL_EPOCHS = 200            # Total epochs to train
BATCH_SIZE = 64               # Batch size (64-128 for T4)
INPUT_SIZE = 1024             # Light curve length
MAX_WORKERS = 12              # Parallel data fetching

# Checkpoint settings
CHECKPOINT_EVERY = 10         # Save checkpoint every N epochs
UPLOAD_TO_HUB = False         # Upload to HuggingFace Hub?
HF_REPO = "your-username/larun-model"  # HuggingFace repo (if uploading)

# Paths (Kaggle persistent storage)
import os
KAGGLE_MODE = os.path.exists('/kaggle')
if KAGGLE_MODE:
    BASE_DIR = '/kaggle/working'
else:
    BASE_DIR = './larun_output'
    os.makedirs(BASE_DIR, exist_ok=True)

CHECKPOINT_DIR = f'{BASE_DIR}/checkpoints'
DATA_CACHE = f'{BASE_DIR}/data_cache.npz'
MODEL_DIR = f'{BASE_DIR}/models'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

print(f"Mode: {'Kaggle' if KAGGLE_MODE else 'Local'}")
print(f"Base dir: {BASE_DIR}")
print(f"Target: {TARGET_PLANETS} planets + {TARGET_NON_PLANETS} non-planets")

In [None]:
# ============================================================
# STEP 1: Check GPU and install dependencies
# ============================================================
!nvidia-smi

!pip install -q lightkurve astroquery tqdm huggingface_hub

import numpy as np
import tensorflow as tf
import json
from datetime import datetime
from pathlib import Path

print(f"\nTensorFlow: {tf.__version__}")
print(f"GPUs: {tf.config.list_physical_devices('GPU')}")

In [None]:
# ============================================================
# STEP 2: Check for existing checkpoint
# ============================================================

def load_checkpoint():
    """Load training checkpoint if exists."""
    checkpoint_file = f'{CHECKPOINT_DIR}/training_state.json'
    if os.path.exists(checkpoint_file):
        with open(checkpoint_file) as f:
            state = json.load(f)
        print(f"✓ Found checkpoint from {state.get('timestamp', 'unknown')}")
        print(f"  Epochs completed: {state.get('epochs_completed', 0)}")
        print(f"  Best accuracy: {state.get('best_accuracy', 0)*100:.2f}%")
        return state
    print("No checkpoint found - starting fresh")
    return None

def save_checkpoint(state):
    """Save training checkpoint."""
    state['timestamp'] = datetime.now().isoformat()
    checkpoint_file = f'{CHECKPOINT_DIR}/training_state.json'
    with open(checkpoint_file, 'w') as f:
        json.dump(state, f, indent=2)
    print(f"✓ Checkpoint saved: epoch {state['epochs_completed']}")

checkpoint = load_checkpoint()
START_EPOCH = checkpoint['epochs_completed'] if checkpoint else 0
BEST_ACCURACY = checkpoint['best_accuracy'] if checkpoint else 0

print(f"\nWill train from epoch {START_EPOCH} to {TOTAL_EPOCHS}")

In [None]:
# ============================================================
# STEP 3: Load or fetch training data
# ============================================================

def load_cached_data():
    """Load cached training data if exists."""
    if os.path.exists(DATA_CACHE):
        data = np.load(DATA_CACHE)
        return {
            'planet_flux': data['planet_flux'],
            'non_planet_flux': data['non_planet_flux'],
            'planet_targets': data['planet_targets'].tolist(),
            'non_planet_targets': data['non_planet_targets'].tolist()
        }
    return None

def save_cached_data(planet_flux, non_planet_flux, planet_targets, non_planet_targets):
    """Save data cache."""
    np.savez(DATA_CACHE,
             planet_flux=planet_flux,
             non_planet_flux=non_planet_flux,
             planet_targets=np.array(planet_targets, dtype=object),
             non_planet_targets=np.array(non_planet_targets, dtype=object))
    print(f"✓ Data cached: {len(planet_flux)} planets, {len(non_planet_flux)} non-planets")

cached_data = load_cached_data()
if cached_data:
    print(f"✓ Loaded cached data:")
    print(f"  Planets: {len(cached_data['planet_flux'])}")
    print(f"  Non-planets: {len(cached_data['non_planet_flux'])}")
    NEED_MORE_PLANETS = len(cached_data['planet_flux']) < TARGET_PLANETS
    NEED_MORE_NON_PLANETS = len(cached_data['non_planet_flux']) < TARGET_NON_PLANETS
else:
    print("No cached data - will fetch from NASA")
    NEED_MORE_PLANETS = True
    NEED_MORE_NON_PLANETS = True

In [None]:
# ============================================================
# STEP 4: Fetch more data if needed
# ============================================================

if NEED_MORE_PLANETS or NEED_MORE_NON_PLANETS:
    import lightkurve as lk
    from astroquery.nasa_exoplanet_archive import NasaExoplanetArchive
    from concurrent.futures import ThreadPoolExecutor, as_completed
    from tqdm.notebook import tqdm
    import warnings
    warnings.filterwarnings('ignore')
    
    def fetch_lightcurve(args):
        """Fetch single light curve."""
        target, label = args
        try:
            search = lk.search_lightcurve(target, mission=['TESS', 'Kepler'])
            if len(search) == 0:
                return None
            lc = search[0].download(quality_bitmask='default')
            lc = lc.remove_nans().normalize().remove_outliers(sigma=3)
            flux = lc.flux.value
            
            if len(flux) < INPUT_SIZE:
                flux = np.pad(flux, (0, INPUT_SIZE - len(flux)), mode='median')
            else:
                start = (len(flux) - INPUT_SIZE) // 2
                flux = flux[start:start + INPUT_SIZE]
            
            return {'flux': flux.astype(np.float32), 'label': label, 'target': target}
        except:
            return None
    
    # Initialize from cache or empty
    planet_flux = list(cached_data['planet_flux']) if cached_data else []
    planet_targets = cached_data['planet_targets'] if cached_data else []
    non_planet_flux = list(cached_data['non_planet_flux']) if cached_data else []
    non_planet_targets = cached_data['non_planet_targets'] if cached_data else []
    
    # Fetch more planets if needed
    if len(planet_flux) < TARGET_PLANETS:
        print(f"\nFetching planets ({len(planet_flux)} → {TARGET_PLANETS})...")
        
        planets_table = NasaExoplanetArchive.query_criteria(
            table="pscomppars",
            select="hostname,disc_facility",
            where="disc_facility like '%TESS%' or disc_facility like '%Kepler%'"
        )
        all_hosts = list(set(planets_table['hostname'].data.tolist()))
        
        # Skip already fetched
        new_hosts = [h for h in all_hosts if h not in planet_targets]
        np.random.shuffle(new_hosts)
        needed = TARGET_PLANETS - len(planet_flux)
        
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            futures = [executor.submit(fetch_lightcurve, (h, 1)) for h in new_hosts[:needed*2]]
            for future in tqdm(as_completed(futures), total=len(futures), desc="Planets"):
                if len(planet_flux) >= TARGET_PLANETS:
                    break
                result = future.result()
                if result:
                    planet_flux.append(result['flux'])
                    planet_targets.append(result['target'])
        
        print(f"✓ Now have {len(planet_flux)} planet samples")
    
    # Fetch more non-planets if needed
    if len(non_planet_flux) < TARGET_NON_PLANETS:
        print(f"\nFetching non-planets ({len(non_planet_flux)} → {TARGET_NON_PLANETS})...")
        
        # Generate TIC IDs not yet fetched
        existing_tics = set(non_planet_targets)
        new_tics = [f"TIC {100000000 + i*100}" for i in range(TARGET_NON_PLANETS * 5)
                    if f"TIC {100000000 + i*100}" not in existing_tics]
        
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            futures = [executor.submit(fetch_lightcurve, (t, 0)) for t in new_tics]
            for future in tqdm(as_completed(futures), total=len(futures), desc="Non-planets"):
                if len(non_planet_flux) >= TARGET_NON_PLANETS:
                    break
                result = future.result()
                if result:
                    non_planet_flux.append(result['flux'])
                    non_planet_targets.append(result['target'])
        
        print(f"✓ Now have {len(non_planet_flux)} non-planet samples")
    
    # Save updated cache
    save_cached_data(
        np.array(planet_flux),
        np.array(non_planet_flux),
        planet_targets,
        non_planet_targets
    )
else:
    planet_flux = cached_data['planet_flux']
    non_planet_flux = cached_data['non_planet_flux']
    print("Using cached data (no fetch needed)")

In [None]:
# ============================================================
# STEP 5: Prepare training data
# ============================================================
from sklearn.model_selection import train_test_split

# Combine data
X_planets = np.array(planet_flux)[:TARGET_PLANETS]
X_non_planets = np.array(non_planet_flux)[:TARGET_NON_PLANETS]

X = np.concatenate([X_planets, X_non_planets], axis=0)
y = np.concatenate([np.ones(len(X_planets)), np.zeros(len(X_non_planets))]).astype(np.int32)

# Normalize
X = (X - X.mean(axis=1, keepdims=True)) / (X.std(axis=1, keepdims=True) + 1e-8)
X = X.reshape(-1, INPUT_SIZE, 1).astype(np.float32)

# Split
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"\nTraining Data:")
print(f"  Total: {len(X)} samples")
print(f"  Train: {len(X_train)} samples")
print(f"  Val: {len(X_val)} samples")
print(f"  Class balance: {np.bincount(y_train)}")

In [None]:
# ============================================================
# STEP 6: Build or load model
# ============================================================
from tensorflow import keras
from tensorflow.keras import layers

# Check for existing model checkpoint
model_checkpoint = f'{CHECKPOINT_DIR}/model_checkpoint.h5'

if os.path.exists(model_checkpoint) and START_EPOCH > 0:
    print(f"Loading model from checkpoint (epoch {START_EPOCH})...")
    model = keras.models.load_model(model_checkpoint)
else:
    print("Building new model...")
    
    # Multi-GPU strategy
    strategy = tf.distribute.MirroredStrategy()
    print(f"Devices: {strategy.num_replicas_in_sync}")
    
    with strategy.scope():
        model = keras.Sequential([
            keras.Input(shape=(INPUT_SIZE, 1)),
            
            layers.Conv1D(32, 7, padding='same'),
            layers.BatchNormalization(),
            layers.Activation('relu'),
            layers.MaxPooling1D(4),
            layers.Dropout(0.25),
            
            layers.Conv1D(64, 5, padding='same'),
            layers.BatchNormalization(),
            layers.Activation('relu'),
            layers.MaxPooling1D(4),
            layers.Dropout(0.25),
            
            layers.Conv1D(128, 3, padding='same'),
            layers.BatchNormalization(),
            layers.Activation('relu'),
            layers.GlobalAveragePooling1D(),
            layers.Dropout(0.5),
            
            layers.Dense(64, activation='relu'),
            layers.Dropout(0.3),
            layers.Dense(2, activation='softmax')
        ], name='larun_continuous')
        
        model.compile(
            optimizer=keras.optimizers.Adam(0.001),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

model.summary()

In [None]:
# ============================================================
# STEP 7: Custom callback for checkpointing
# ============================================================

class ContinuousTrainingCallback(keras.callbacks.Callback):
    """Save checkpoint every N epochs for resume capability."""
    
    def __init__(self, checkpoint_every=10, start_epoch=0):
        super().__init__()
        self.checkpoint_every = checkpoint_every
        self.start_epoch = start_epoch
        self.best_accuracy = BEST_ACCURACY
        self.history = {'accuracy': [], 'val_accuracy': [], 'loss': [], 'val_loss': []}
    
    def on_epoch_end(self, epoch, logs=None):
        actual_epoch = self.start_epoch + epoch + 1
        
        # Track history
        for key in self.history:
            if key in logs:
                self.history[key].append(float(logs[key]))
        
        # Check for best model
        val_acc = logs.get('val_accuracy', 0)
        if val_acc > self.best_accuracy:
            self.best_accuracy = val_acc
            self.model.save(f'{MODEL_DIR}/best_model.h5')
            print(f"  ★ New best: {val_acc*100:.2f}%")
        
        # Checkpoint every N epochs
        if actual_epoch % self.checkpoint_every == 0:
            # Save model
            self.model.save(f'{CHECKPOINT_DIR}/model_checkpoint.h5')
            
            # Save state
            state = {
                'epochs_completed': actual_epoch,
                'best_accuracy': self.best_accuracy,
                'last_val_accuracy': val_acc,
                'history': self.history
            }
            save_checkpoint(state)

continuous_callback = ContinuousTrainingCallback(
    checkpoint_every=CHECKPOINT_EVERY,
    start_epoch=START_EPOCH
)

In [None]:
# ============================================================
# STEP 8: Train (with auto-resume)
# ============================================================

remaining_epochs = TOTAL_EPOCHS - START_EPOCH

if remaining_epochs <= 0:
    print(f"Training already complete! ({START_EPOCH}/{TOTAL_EPOCHS} epochs)")
    print(f"Best accuracy: {BEST_ACCURACY*100:.2f}%")
else:
    print(f"\n{'='*60}")
    print(f"TRAINING: Epochs {START_EPOCH + 1} to {TOTAL_EPOCHS}")
    print(f"{'='*60}\n")
    
    callbacks = [
        continuous_callback,
        keras.callbacks.EarlyStopping(
            patience=20,
            restore_best_weights=True,
            monitor='val_accuracy'
        ),
        keras.callbacks.ReduceLROnPlateau(
            factor=0.5,
            patience=10,
            min_lr=1e-6
        )
    ]
    
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=remaining_epochs,
        batch_size=BATCH_SIZE,
        callbacks=callbacks,
        verbose=1
    )
    
    # Final checkpoint
    final_epoch = START_EPOCH + len(history.history['accuracy'])
    save_checkpoint({
        'epochs_completed': final_epoch,
        'best_accuracy': continuous_callback.best_accuracy,
        'status': 'completed' if final_epoch >= TOTAL_EPOCHS else 'stopped_early'
    })

In [None]:
# ============================================================
# STEP 9: Evaluate and export
# ============================================================
import matplotlib.pyplot as plt

# Load best model
best_model_path = f'{MODEL_DIR}/best_model.h5'
if os.path.exists(best_model_path):
    best_model = keras.models.load_model(best_model_path)
    val_loss, val_acc = best_model.evaluate(X_val, y_val, verbose=0)
else:
    val_loss, val_acc = model.evaluate(X_val, y_val, verbose=0)
    best_model = model

print(f"\n{'='*60}")
print(f"RESULTS")
print(f"{'='*60}")
print(f"Best Validation Accuracy: {val_acc*100:.2f}%")
print(f"Validation Loss: {val_loss:.4f}")

In [None]:
# ============================================================
# STEP 10: Export TFLite models
# ============================================================

# Save Keras
best_model.save(f'{MODEL_DIR}/larun_model.h5')

# TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(best_model)
tflite = converter.convert()
with open(f'{MODEL_DIR}/larun_model.tflite', 'wb') as f:
    f.write(tflite)

# INT8 quantized
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant = converter.convert()
with open(f'{MODEL_DIR}/larun_model_int8.tflite', 'wb') as f:
    f.write(tflite_quant)

print(f"\nModels exported to {MODEL_DIR}/")
print(f"  TFLite: {len(tflite)/1024:.1f} KB")
print(f"  INT8: {len(tflite_quant)/1024:.1f} KB")

In [None]:
# ============================================================
# STEP 11: Upload to HuggingFace Hub (optional)
# ============================================================

if UPLOAD_TO_HUB:
    from huggingface_hub import HfApi, login
    
    # Login (set HF_TOKEN in Kaggle secrets)
    hf_token = os.environ.get('HF_TOKEN')
    if hf_token:
        login(token=hf_token)
        
        api = HfApi()
        api.upload_folder(
            folder_path=MODEL_DIR,
            repo_id=HF_REPO,
            repo_type="model",
            commit_message=f"Update model - accuracy {val_acc*100:.2f}%"
        )
        print(f"✓ Uploaded to HuggingFace: {HF_REPO}")
    else:
        print("Set HF_TOKEN in Kaggle secrets to enable upload")
else:
    print("HuggingFace upload disabled (set UPLOAD_TO_HUB=True to enable)")

In [None]:
# ============================================================
# STEP 12: Create download package
# ============================================================

!cd {MODEL_DIR} && zip -r ../larun_trained_final.zip *.h5 *.tflite 2>/dev/null || true

# List all outputs
print(f"\n{'='*60}")
print("OUTPUT FILES")
print(f"{'='*60}")
for root, dirs, files in os.walk(BASE_DIR):
    for f in files:
        path = os.path.join(root, f)
        size = os.path.getsize(path) / 1024
        print(f"  {path}: {size:.1f} KB")

print(f"\n{'='*60}")
print("TRAINING COMPLETE!")
print(f"{'='*60}")
print(f"Accuracy: {val_acc*100:.2f}%")
print(f"\nTo resume training later, just run this notebook again!")
print(f"Checkpoints are saved and will be automatically loaded.")

---

## Auto-Restart Setup (Kaggle)

To automatically restart training when credits reset:

1. **Save this notebook** as a Kaggle Version
2. Go to **Scheduling** (in notebook settings)
3. Enable **Scheduled Running**
4. Set frequency: **Weekly** or **Monthly**

The notebook will:
- Load cached data (no re-fetching)
- Resume from last checkpoint
- Continue training where it left off
- Save new checkpoint when done

---
*Larun Engineering - Continuous Learning*