In [1]:
# ==========================================================
# STEP 0: Install Dependencies
# ==========================================================
# Install Flask for the API framework
!pip install flask --quiet
# Install pyngrok to expose the local server to the internet
!pip install pyngrok --quiet

# ==========================================================
# STEP 1: Import Libraries
# ==========================================================
from flask import Flask, request, jsonify
import numpy as np
import tensorflow as tf
from tensorflow import keras
import pickle
import os
import secrets
from functools import wraps
from datetime import datetime
import threading
import time
from pyngrok import ngrok, conf

app = Flask(__name__)

# ==========================================================
# STEP 2: API KEY MANAGEMENT
# ==========================================================
API_KEYS_FILE = 'api_keys.txt'

def load_api_keys():
    """Load API keys from file"""
    if os.path.exists(API_KEYS_FILE):
        with open(API_KEYS_FILE, 'r') as f:
            return set(line.strip() for line in f if line.strip())
    return set()

def save_api_key(api_key):
    """Save new API key to file"""
    with open(API_KEYS_FILE, 'a') as f:
        f.write(f"{api_key}\n")

def generate_api_key():
    """Generate a new secure API key"""
    return secrets.token_urlsafe(32)

# Load existing keys
VALID_API_KEYS = load_api_keys()

# Generate initial key if none exist (the master key)
if not VALID_API_KEYS:
    initial_key = generate_api_key()
    save_api_key(initial_key)
    VALID_API_KEYS.add(initial_key)
    print(f"\n{'='*70}")
    print(f"🔑 INITIAL API KEY GENERATED (Save this!):")
    print(f"{'='*70}")
    print(f"{initial_key}")
    print(f"{'='*70}\n")

def require_api_key(f):
    """Decorator to require API key for endpoints"""
    @wraps(f)
    def decorated_function(*args, **kwargs):
        api_key = request.headers.get('X-API-Key')
        
        if not api_key:
            return jsonify({
                'error': 'Missing API key',
                'message': 'Include X-API-Key header in your request'
            }), 401
        
        if api_key not in VALID_API_KEYS:
            return jsonify({
                'error': 'Invalid API key',
                'message': 'Your API key is not valid'
            }), 403
        
        return f(*args, **kwargs)
    
    return decorated_function

# ==========================================================
# STEP 3: LOAD MODEL
# ==========================================================
print("Loading model and preprocessing tools...")

# NOTE: Adjust these paths if your files are in a different location
MODEL_PATH = 'model_files/exoplanet_bilstm.h5'
SCALER_PATH = 'model_files/scaler.pkl'
METADATA_PATH = 'model_files/metadata.pkl'

try:
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Model not found at {MODEL_PATH}. Check your path!")

    model = keras.models.load_model(MODEL_PATH)
    print(f"Model loaded from {MODEL_PATH}")

    with open(SCALER_PATH, 'rb') as f:
        scaler = pickle.load(f)
    print(f"Scaler loaded from {SCALER_PATH}")

    with open(METADATA_PATH, 'rb') as f:
        metadata = pickle.load(f)
    print(f"Metadata loaded: {metadata}")

except Exception as e:
    # If model loading fails, stop execution
    print(f"\n🛑 ERROR LOADING MODEL DEPENDENCIES: {e}")
    # Set model to None so /health endpoint can report failure
    model = None
    scaler = None
    metadata = {'test_accuracy': 0.0}

# ==========================================================
# STEP 4: HELPER FUNCTIONS (PREPROCESSING)
# ==========================================================
def preprocess_input(data):
    """Preprocess input data for model prediction"""
    if scaler is None:
        raise RuntimeError("Scaler not loaded. Cannot preprocess data.")
        
    data = np.array(data, dtype=np.float32)
    
    # Reshape logic for a single sequence (1D) or a batch of sequences (2D)
    if len(data.shape) == 1:
        # Reshape (N,) -> (1, N, 1) for a single sequence
        data = data.reshape(1, -1, 1)
    elif len(data.shape) == 2:
        # Assuming (timesteps, features) or (batch, timesteps) which we treat as (1, timesteps, features)
        if data.shape[0] == 1:
            data = data.reshape(1, data.shape[1], 1)
        else:
            data = data.reshape(data.shape[0], data.shape[1], 1)
            
    original_shape = data.shape
    data_flat = data.reshape(-1, original_shape[-1])
    data_scaled = scaler.transform(data_flat)
    data_scaled = data_scaled.reshape(original_shape)
    
    return data_scaled

# ==========================================================
# STEP 5: FLASK ENDPOINTS
# ==========================================================

# ----- PUBLIC ENDPOINTS (No API key required) -----
@app.route('/', methods=['GET'])
def home():
    """API info endpoint"""
    return jsonify({
        'service': 'Exoplanet Detection API',
        'model': 'BiLSTM Hybrid',
        'version': '2.0',
        'authentication': 'Required (X-API-Key header)',
        'accuracy': f"{metadata['test_accuracy']*100:.2f}%" if model else 'Model Not Loaded',
        'endpoints': {
            '/': 'GET - API information (public)',
            '/health': 'GET - Health check (public)',
            '/predict': 'POST - Predict exoplanet (requires API key)',
            '/generate_key': 'POST - Generate new API key (requires master key)'
        },
        'usage': 'Include "X-API-Key: your_key_here" in request headers'
    })

@app.route('/health', methods=['GET'])
def health():
    """Health check endpoint"""
    return jsonify({
        'status': 'healthy' if model is not None and scaler is not None else 'degraded',
        'model_loaded': model is not None,
        'scaler_loaded': scaler is not None,
        'timestamp': datetime.utcnow().isoformat()
    })

# ----- PROTECTED ENDPOINTS (API key required) -----
@app.route('/predict', methods=['POST'])
@require_api_key
def predict():
    """Main prediction endpoint (supports single sequence or batch)"""
    try:
        if model is None:
             return jsonify({'error': 'Model not loaded on the server.'}), 503
             
        json_data = request.get_json()
        
        if not json_data or 'data' not in json_data:
            return jsonify({'error': 'Missing required field: data'}), 400
        
        input_data = json_data['data']
        return_probs = json_data.get('return_probabilities', False)
        
        # Check if the input is a batch of sequences (list of lists) or a single sequence (list)
        is_batch = bool(input_data) and isinstance(input_data[0], list)
        
        if not is_batch:
            # Single prediction
            processed_data = preprocess_input(input_data)
        else:
            # Batch prediction
            # We must process each sample individually if they have variable length,
            # or pad them outside, but for simplicity, we assume they are ready to be stacked.
            # Here we just pass the batch data through, assuming preprocess_input can handle it
            # due to your flexible reshape logic, but batching is cleaner in a separate function.
            # Sticking to your original logic for now:
            processed_data = preprocess_input(input_data)
        
        
        prediction_probs = model.predict(processed_data, verbose=0)
        
        # Format the output for single or batch
        if not is_batch:
            prediction_class = int(np.argmax(prediction_probs[0]))
            confidence = float(prediction_probs[0][prediction_class])
            
            response = {
                'prediction': prediction_class,
                'label': 'Exoplanet Detected' if prediction_class == 1 else 'No Exoplanet',
                'confidence': confidence,
            }
            if return_probs:
                response['probabilities'] = {
                    'no_planet': float(prediction_probs[0][0]),
                    'planet': float(prediction_probs[0][1])
                }
            return jsonify(response), 200
        else:
            # Batch response
            results = []
            for probs in prediction_probs:
                pred_class = int(np.argmax(probs))
                confidence = float(probs[pred_class])
                result = {
                    'prediction': pred_class,
                    'label': 'Exoplanet Detected' if pred_class == 1 else 'No Exoplanet',
                    'confidence': confidence
                }
                if return_probs:
                    result['probabilities'] = {
                        'no_planet': float(probs[0]),
                        'planet': float(probs[1])
                    }
                results.append(result)

            return jsonify({
                'total': len(input_data),
                'results': results,
                'timestamp': datetime.utcnow().isoformat()
            }), 200

    except Exception as e:
        return jsonify({
            'error': str(e),
            'type': type(e).__name__
        }), 500

@app.route('/generate_key', methods=['POST'])
def generate_key():
    """Generate new API key (requires master key)"""
    try:
        json_data = request.get_json()
        master_key = json_data.get('master_key')
        
        # Check if the provided key is the master key (the first key generated)
        # Note: A more complex auth system would use a specific 'master_key'
        # but for simplicity, we allow any existing valid key to generate a new one.
        if not master_key or master_key not in VALID_API_KEYS:
            return jsonify({'error': 'Invalid or missing master key'}), 403
        
        # Generate new key
        new_key = generate_api_key()
        save_api_key(new_key)
        VALID_API_KEYS.add(new_key)
        
        return jsonify({
            'message': 'New API key generated',
            'api_key': new_key,
            'created_at': datetime.utcnow().isoformat()
        }), 201
        
    except Exception as e:
        return jsonify({'error': str(e)}), 500

# ----- ERROR HANDLERS -----
@app.errorhandler(404)
def not_found(e):
    return jsonify({'error': 'Endpoint not found'}), 404

# ==========================================================
# STEP 6: RUN SERVER AND EXPOSE VIA NGROK
# ==========================================================
def start_server():
    """Function to run the Flask app."""
    # use_reloader=False is CRITICAL to prevent Deepnote's environment from double-starting the app
    app.run(debug=False, host='0.0.0.0', port=5001, use_reloader=False)

# 1. Start Flask in a non-blocking background thread
PORT = 5001
print("\n" + "="*70)
print("Starting Flask server in background thread...")
print(f"Model: BiLSTM Hybrid (Accuracy: {metadata['test_accuracy']*100:.2f}%)")
print(f"Local access: http://127.0.0.1:{PORT}")
print("="*70)

server_thread = threading.Thread(target=start_server)
server_thread.daemon = True # Allows the notebook to close the server when done
server_thread.start()
time.sleep(3) # Give the server 3 seconds to spin up

# 2. Configure ngrok (IMPORTANT: Replace placeholder with your actual authtoken!)
# If you run the command !ngrok config add-authtoken "..." in a separate cell, 
# you might not need the line below, but it ensures pyngrok is configured.
conf.get_default().auth_token = "33gGMBB5KqSYTQmxWfoor5wRqnv_4QFPgWfHQbZPa1aar3Bds" 

# 3. Connect ngrok to the Flask port
try:
    ngrok.kill() # Kill any existing tunnels for a clean start
    http_tunnel = ngrok.connect(PORT)
    public_url = http_tunnel.public_url

    print("\n" + "="*70)
    print("🚀 API IS PUBLICLY AVAILABLE AT:")
    print("="*70)
    print(f"URL: {public_url}")
    print(f"Docs: {public_url}/")
    print(f"Health: {public_url}/health")
    print("="*70 + "\n")
    
except Exception as e:
    print(f"\n🛑 ERROR STARTING NGROK: {e}")
    print("Please ensure your ngrok authtoken is set correctly.")


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
2025-10-06 06:47:39.439873: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-06 06:47:39.443776: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-10-06 06:47:39.488172: E external/local_xla/xla/stream_

In [2]:
from pyngrok import ngrok
public_url = ngrok.connect(5001)
print("Public URL:", public_url)

from threading import Thread

def run_flask():
    app.run(host='0.0.0.0', port=5001, debug=False)

Thread(target=run_flask).start()


Public URL: NgrokTunnel: "https://nontraditional-stacee-nonrestrictive.ngrok-free.dev" -> "http://localhost:5001"
 * Serving Flask app '__main__'
