# üåä LSTM Training API Notebook - Heatwave/Flood Prediction

This notebook exposes LSTM model training as a **REST API** that can be triggered by Airflow or other orchestration tools.

**Features:**
- Flask API with `/train` endpoint for triggering training
- Public URL via ngrok for remote access from Airflow
- Loads training data from HDFS via Cloudflare tunnels
- Saves trained PyTorch models back to HDFS

**API Endpoints:**
- `GET /health` - Health check
- `POST /train` - Trigger training (params: `target`, `timesteps`, `epochs`)
- `GET /status` - Get training status

**Workflow:**
1. Run cells 1-6 to set up dependencies, config, and functions
2. Run the API cell to start the server
3. Copy the ngrok URL and configure in Airflow
4. Airflow triggers training via API call

## 1. Install Required Dependencies

In [None]:
!pip install -q torch pandas numpy scikit-learn pyarrow requests flask pyngrok

## 2. Configuration

In [None]:
# ============================================
# CONFIGURATION - CLOUDFLARE TUNNEL URLs
# ============================================

# HDFS via Cloudflare Tunnels
HDFS_NAMENODE_URL = "https://lab-jurisdiction-arrives-sys.trycloudflare.com"
HDFS_DATANODE_URL = "https://scholar-march-certification-tests.trycloudflare.com"
HDFS_USER = "root"

# HDFS Paths
HDFS_FEATURES_PATH = "/data/processed/features.parquet"
HDFS_MODELS_DIR = "/models"

# LSTM Training Configuration (defaults, can be overridden via API)
DEFAULT_TARGET = "heatwave"
DEFAULT_TIMESTEPS = 14
DEFAULT_EPOCHS = 20
HIDDEN_SIZE = 64
FC_SIZE = 32
BATCH_SIZE = 128
TEST_SIZE = 0.2

# API Configuration
API_PORT = 5001  # Different port from XGBoost API
NGROK_AUTH_TOKEN = ""

## 3. Import Libraries

In [None]:
import os
import io
import pickle
import requests
import threading
from datetime import datetime
from urllib.parse import urlparse

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score

from flask import Flask, request, jsonify
from pyngrok import ngrok

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è Using device: {device}")
print("‚úÖ Libraries imported successfully!")

## 4. HDFS Client Setup

In [None]:
class HDFSManager:
    """Manages HDFS operations via WebHDFS through Cloudflare tunnels.
    
    Uses separate tunnels for NameNode and DataNode to handle WebHDFS redirects.
    """
    
    def __init__(self, namenode_url, datanode_url, user):
        self.namenode_url = namenode_url.rstrip('/')
        self.datanode_url = datanode_url.rstrip('/')
        self.user = user
        self.session = requests.Session()
        print(f"‚úÖ HDFS Manager initialized")
        print(f"   NameNode: {self.namenode_url}")
        print(f"   DataNode: {self.datanode_url}")
    
    def _namenode_api(self, path, op, **params):
        """Build NameNode WebHDFS URL."""
        url = f"{self.namenode_url}/webhdfs/v1{path}?op={op}&user.name={self.user}"
        for k, v in params.items():
            url += f"&{k}={v}"
        return url
    
    def _redirect_to_datanode(self, redirect_url):
        """Convert NameNode redirect URL to use DataNode tunnel."""
        parsed = urlparse(redirect_url)
        return f"{self.datanode_url}{parsed.path}?{parsed.query}" if parsed.query else f"{self.datanode_url}{parsed.path}"
    
    def read_parquet(self, hdfs_path):
        """Read parquet file from HDFS into pandas DataFrame."""
        print(f"üìñ Reading parquet from: {hdfs_path}")
        
        # Step 1: Request OPEN from NameNode (returns 307 redirect to DataNode)
        open_url = self._namenode_api(hdfs_path, "OPEN")
        response = self.session.get(open_url, allow_redirects=False, timeout=60)
        
        if response.status_code == 307:
            # Step 2: Get redirect and convert to DataNode tunnel URL
            redirect_url = response.headers.get('Location')
            datanode_url = self._redirect_to_datanode(redirect_url)
            print(f"   Fetching from DataNode tunnel...")
            
            # Step 3: Download from DataNode tunnel
            response = self.session.get(datanode_url, timeout=300)
            response.raise_for_status()
        elif response.status_code == 200:
            pass  # Direct response
        else:
            raise Exception(f"Failed to open file: {response.status_code} - {response.text[:500]}")
        
        # Parse parquet
        df = pd.read_parquet(io.BytesIO(response.content))
        print(f"‚úÖ Loaded DataFrame with shape: {df.shape}")
        return df
    
    def save_model_torch(self, model_dict, hdfs_path):
        """Save PyTorch model state dict to HDFS."""
        print(f"üíæ Saving model to: {hdfs_path}")
        
        # Serialize model using torch.save to bytes
        buffer = io.BytesIO()
        torch.save(model_dict, buffer)
        model_bytes = buffer.getvalue()
        
        # Step 1: CREATE request to NameNode (returns 307 redirect)
        create_url = self._namenode_api(hdfs_path, "CREATE", overwrite="true")
        response = self.session.put(create_url, allow_redirects=False, timeout=60)
        
        if response.status_code == 307:
            # Step 2: Upload to DataNode tunnel
            redirect_url = response.headers.get('Location')
            datanode_url = self._redirect_to_datanode(redirect_url)
            
            response = self.session.put(
                datanode_url,
                data=model_bytes,
                headers={'Content-Type': 'application/octet-stream'},
                timeout=300
            )
            response.raise_for_status()
        else:
            raise Exception(f"Unexpected response: {response.status_code} - {response.text[:500]}")
        
        print(f"‚úÖ Model saved! Size: {len(model_bytes) / 1024:.2f} KB")
        return hdfs_path
    
    def load_model_torch(self, hdfs_path):
        """Load PyTorch model from HDFS."""
        print(f"üìñ Loading model from: {hdfs_path}")
        
        open_url = self._namenode_api(hdfs_path, "OPEN")
        response = self.session.get(open_url, allow_redirects=False, timeout=60)
        
        if response.status_code == 307:
            redirect_url = response.headers.get('Location')
            datanode_url = self._redirect_to_datanode(redirect_url)
            response = self.session.get(datanode_url, timeout=300)
            response.raise_for_status()
        
        model_dict = torch.load(io.BytesIO(response.content), map_location='cpu')
        print(f"‚úÖ Model loaded!")
        return model_dict
    
    def list_files(self, hdfs_path):
        """List files in HDFS directory."""
        try:
            url = self._namenode_api(hdfs_path, "LISTSTATUS")
            response = self.session.get(url, timeout=30)
            response.raise_for_status()
            data = response.json()
            files = [f['pathSuffix'] for f in data.get('FileStatuses', {}).get('FileStatus', [])]
            return files
        except Exception as e:
            print(f"Error listing {hdfs_path}: {e}")
            return []
    
    def mkdir(self, hdfs_path):
        """Create directory in HDFS."""
        url = self._namenode_api(hdfs_path, "MKDIRS")
        response = self.session.put(url, timeout=30)
        return response.json().get('boolean', False)

# Initialize HDFS Manager with both tunnel URLs
hdfs_manager = HDFSManager(HDFS_NAMENODE_URL, HDFS_DATANODE_URL, HDFS_USER)

## 5. LSTM Model & Training Functions

In [None]:
class LSTMModel(nn.Module):
    """LSTM model for binary classification (matching train_lstm.py)."""
    
    def __init__(self, input_size, hidden_size=64, fc_size=32):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc1 = nn.Linear(hidden_size, fc_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(fc_size, 1)
    
    def forward(self, x):
        lstm_out, (h_n, c_n) = self.lstm(x)
        out = self.fc1(h_n[-1])
        out = self.relu(out)
        out = self.fc2(out)
        return out.squeeze(-1)


def build_sequences(df, timesteps=14, target='heatwave'):
    """Build sliding window sequences (matching train_lstm.py)."""
    df = df.sort_values(['District', 'Date']).copy()
    exclude_cols = ['Date', 'District', 'heatwave', 'flood_proxy']
    feature_cols = [c for c in df.columns if c not in exclude_cols]
    
    sequences, targets = [], []
    for district, group in df.groupby('District'):
        arr = group[feature_cols].values
        lab = group[target].fillna(False).astype(int).values
        if len(arr) < timesteps + 1:
            continue
        for i in range(timesteps, len(arr)):
            sequences.append(arr[i-timesteps:i])
            targets.append(lab[i])
    
    X = np.array(sequences, dtype=np.float32)
    y = np.array(targets, dtype=np.float32)
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
    
    return X, y, feature_cols


def train_lstm_model(hdfs_manager, target='heatwave', timesteps=14, epochs=20):
    """Full LSTM training pipeline (matching train_lstm.py logic)."""
    print(f"\n{'='*60}")
    print(f"üéØ TRAINING LSTM FOR: {target.upper()}")
    print(f"{'='*60}")
    
    # Load data
    print("üì• Loading data from HDFS...")
    df = hdfs_manager.read_parquet(HDFS_FEATURES_PATH)
    
    # Build sequences
    X, y, feature_cols = build_sequences(df, timesteps=timesteps, target=target)
    print(f"üìä Sequences: {X.shape}, Positive: {y.sum():.0f} ({y.mean()*100:.2f}%)")
    
    # Split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=TEST_SIZE, shuffle=True, random_state=42
    )
    
    n_samples, n_timesteps, n_features = X_train.shape
    
    # Scale
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train.reshape(-1, n_features)).reshape(X_train.shape)
    X_test_scaled = scaler.transform(X_test.reshape(-1, n_features)).reshape(X_test.shape)
    
    # DataLoaders
    train_loader = DataLoader(
        TensorDataset(torch.FloatTensor(X_train_scaled), torch.FloatTensor(y_train)),
        batch_size=BATCH_SIZE, shuffle=True
    )
    test_loader = DataLoader(
        TensorDataset(torch.FloatTensor(X_test_scaled), torch.FloatTensor(y_test)),
        batch_size=BATCH_SIZE, shuffle=False
    )
    
    # Model
    model = LSTMModel(n_features, HIDDEN_SIZE, FC_SIZE).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters())
    
    # Training
    print("üöÄ Training model...")
    best_auc = 0.0
    best_state = None
    
    for epoch in range(epochs):
        model.train()
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            loss = criterion(model(X_batch), y_batch)
            loss.backward()
            optimizer.step()
        
        # Evaluate
        model.eval()
        all_preds, all_targets = [], []
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                X_batch = X_batch.to(device)
                probs = torch.sigmoid(model(X_batch))
                all_preds.extend(probs.cpu().numpy())
                all_targets.extend(y_batch.numpy())
        
        auc = roc_auc_score(all_targets, all_preds) if len(set(all_targets)) > 1 else 0
        if auc > best_auc:
            best_auc = auc
            best_state = model.state_dict().copy()
        
        if (epoch + 1) % 5 == 0:
            print(f"   Epoch {epoch+1}/{epochs} - AUC: {auc:.4f}")
    
    # Load best
    model.load_state_dict(best_state)
    
    # Final metrics
    model.eval()
    all_preds, all_targets = [], []
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            probs = torch.sigmoid(model(X_batch.to(device)))
            all_preds.extend(probs.cpu().numpy())
            all_targets.extend(y_batch.numpy())
    
    val_auc = roc_auc_score(all_targets, all_preds)
    print(f"\n‚úÖ Best Val AUC: {best_auc:.4f}")
    
    metrics = {'val_auc': float(val_auc), 'best_auc': float(best_auc)}
    
    # Save checkpoint
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'input_size': n_features,
        'hidden_size': HIDDEN_SIZE,
        'fc_size': FC_SIZE,
        'timesteps': timesteps,
        'feature_cols': feature_cols,
        'scaler_mean': scaler.mean_.tolist(),
        'scaler_scale': scaler.scale_.tolist(),
        'target': target,
        'metrics': metrics,
        'saved_at': datetime.now().isoformat(),
        'pytorch_version': torch.__version__
    }
    
    hdfs_manager.mkdir(HDFS_MODELS_DIR)
    hdfs_path = f"{HDFS_MODELS_DIR}/lstm_{target}.pt"
    hdfs_manager.save_model_torch(checkpoint, hdfs_path)
    print(f"üíæ Saved to: {hdfs_path}")
    
    return {'model': model, 'metrics': metrics, 'hdfs_path': hdfs_path}

print("‚úÖ LSTM model and training functions defined!")

## 6. Flask API Setup

In [None]:
# Global state for tracking training
training_state = {
    'is_training': False,
    'current_target': None,
    'last_result': None,
    'last_error': None,
    'started_at': None,
    'completed_at': None
}

# Initialize Flask app
app = Flask(__name__)

# Initialize HDFS Manager
hdfs_manager = HDFSManager(HDFS_NAMENODE_URL, HDFS_DATANODE_URL, HDFS_USER)


@app.route('/health', methods=['GET'])
def health():
    """Health check endpoint."""
    return jsonify({
        'status': 'healthy',
        'service': 'lstm-training-api',
        'device': str(device),
        'timestamp': datetime.now().isoformat()
    })


@app.route('/status', methods=['GET'])
def status():
    """Get current training status."""
    return jsonify({
        'is_training': training_state['is_training'],
        'current_target': training_state['current_target'],
        'started_at': training_state['started_at'],
        'completed_at': training_state['completed_at'],
        'last_result': training_state['last_result'],
        'last_error': training_state['last_error']
    })


def run_training(target, timesteps, epochs):
    """Background training function."""
    global training_state
    
    try:
        training_state['is_training'] = True
        training_state['current_target'] = target
        training_state['started_at'] = datetime.now().isoformat()
        training_state['last_error'] = None
        
        result = train_lstm_model(hdfs_manager, target=target, timesteps=timesteps, epochs=epochs)
        
        training_state['last_result'] = {
            'target': target,
            'metrics': result['metrics'],
            'hdfs_path': result['hdfs_path']
        }
        training_state['completed_at'] = datetime.now().isoformat()
        
    except Exception as e:
        training_state['last_error'] = str(e)
        print(f"‚ùå Training error: {e}")
    
    finally:
        training_state['is_training'] = False
        training_state['current_target'] = None


@app.route('/train', methods=['POST'])
def train():
    """Trigger model training."""
    if training_state['is_training']:
        return jsonify({
            'status': 'busy',
            'message': f"Training already in progress for {training_state['current_target']}"
        }), 409
    
    data = request.get_json() or {}
    target = data.get('target', DEFAULT_TARGET)
    timesteps = data.get('timesteps', DEFAULT_TIMESTEPS)
    epochs = data.get('epochs', DEFAULT_EPOCHS)
    async_mode = data.get('async', False)
    
    if target not in ['heatwave', 'flood_proxy']:
        return jsonify({'status': 'error', 'message': f"Invalid target: {target}"}), 400
    
    if async_mode:
        thread = threading.Thread(target=run_training, args=(target, timesteps, epochs))
        thread.start()
        return jsonify({
            'status': 'started',
            'message': f'Training started for {target}',
            'target': target
        })
    else:
        run_training(target, timesteps, epochs)
        
        if training_state['last_error']:
            return jsonify({'status': 'error', 'message': training_state['last_error']}), 500
        
        return jsonify({'status': 'completed', 'result': training_state['last_result']})


@app.route('/train/all', methods=['POST'])
def train_all():
    """Train models for all targets."""
    if training_state['is_training']:
        return jsonify({'status': 'busy'}), 409
    
    data = request.get_json() or {}
    timesteps = data.get('timesteps', DEFAULT_TIMESTEPS)
    epochs = data.get('epochs', DEFAULT_EPOCHS)
    
    results = {}
    for target in ['heatwave', 'flood_proxy']:
        run_training(target, timesteps, epochs)
        if training_state['last_error']:
            results[target] = {'status': 'error', 'error': training_state['last_error']}
        else:
            results[target] = {'status': 'completed', 'result': training_state['last_result']}
    
    return jsonify({'status': 'completed', 'results': results})


print("‚úÖ Flask API configured!")
print("   Endpoints: /health, /status, /train, /train/all")

## 7. Start API Server with ngrok

In [None]:
# Set ngrok auth token if provided
if NGROK_AUTH_TOKEN:
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Start ngrok tunnel
public_url = ngrok.connect(API_PORT)
print("=" * 60)
print("üåê LSTM TRAINING API SERVER")
print("=" * 60)
print(f"\nüì° Public URL: {public_url}")
print(f"   Local URL:  http://localhost:{API_PORT}")
print("\nüìã API Endpoints:")
print(f"   GET  {public_url}/health")
print(f"   GET  {public_url}/status")
print(f"   POST {public_url}/train")
print(f"   POST {public_url}/train/all")
print("\n‚ö†Ô∏è  Copy the public URL and configure in Airflow DAG!")
print("=" * 60)

LSTM_API_URL = str(public_url)

In [None]:
# Run Flask server (this cell blocks - run this last!)
print("üöÄ Starting Flask server...")
print("   Press Ctrl+C or restart runtime to stop")
print("-" * 60)

app.run(port=API_PORT, threaded=True)

## 8. API Usage Examples

### From Airflow (Python):
```python
import requests

LSTM_API_URL = "https://xxxx.ngrok.io"  # Replace with actual ngrok URL

# Train heatwave model
response = requests.post(f"{LSTM_API_URL}/train", json={
    "target": "heatwave",
    "timesteps": 14,
    "epochs": 20
})
print(response.json())
```

### From curl:
```bash
curl -X POST https://xxxx.ngrok.io/train \
  -H "Content-Type: application/json" \
  -d '{"target": "heatwave", "timesteps": 14, "epochs": 20}'
```