# üåä XGBoost Training API Notebook - Heatwave/Flood Prediction

This notebook exposes XGBoost model training as a **REST API** that can be triggered by Airflow or other orchestration tools.

**Features:**
- Flask API with `/train` endpoint for triggering training
- Public URL via ngrok for remote access from Airflow
- Loads training data from HDFS via Cloudflare tunnels
- Saves trained models back to HDFS

**API Endpoints:**
- `GET /health` - Health check
- `POST /train` - Trigger training (params: `target`, `test_days`)
- `GET /status` - Get training status

**Workflow:**
1. Run cells 1-7 to set up dependencies, config, and functions
2. Run the API cell to start the server
3. Copy the ngrok URL and configure in Airflow
4. Airflow triggers training via API call

## 1. Install Required Dependencies

In [None]:
!pip install -q xgboost pandas numpy scikit-learn pyarrow requests joblib matplotlib seaborn flask pyngrok

## 2. Configuration

In [None]:
# ============================================
# CONFIGURATION - CLOUDFLARE TUNNEL URLs
# ============================================

# HDFS via Cloudflare Tunnels
# Get URLs from: docker logs cloudflared-hdfs-namenode / cloudflared-hdfs-datanode
HDFS_NAMENODE_URL = "https://lab-jurisdiction-arrives-sys.trycloudflare.com"      # NameNode tunnel
HDFS_DATANODE_URL = "https://scholar-march-certification-tests.trycloudflare.com" # DataNode tunnel
HDFS_USER = "root"

# HDFS Paths
HDFS_FEATURES_PATH = "/data/processed/features.parquet"  # Input features parquet
HDFS_MODELS_DIR = "/models"                               # Output directory for models

# Training Configuration (defaults, can be overridden via API)
DEFAULT_TARGET = "heatwave"        # Options: "heatwave" or "flood_proxy"
DEFAULT_TEST_DAYS = 365            # Number of days for test split

# XGBoost Parameters (matching train_xgb.py)
XGB_PARAMS = {
    'objective': 'binary:logistic',
    'eval_metric': 'aucpr',
    'eta': 0.05,
    'max_depth': 6,
    'subsample': 0.8,
}
NUM_BOOST_ROUNDS = 1000
EARLY_STOPPING_ROUNDS = 50

# API Configuration
API_PORT = 5000
NGROK_AUTH_TOKEN = ""  # Optional: Set your ngrok auth token for stable URLs

## 3. Import Libraries

In [None]:
import os
import io
import pickle
import joblib
import requests
import threading
import time
from datetime import datetime
from urllib.parse import urlparse

import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.metrics import (
    average_precision_score, 
    roc_auc_score, 
    precision_recall_fscore_support,
)

from flask import Flask, request, jsonify
from pyngrok import ngrok

print("‚úÖ Libraries imported successfully!")
print(f"   XGBoost version: {xgb.__version__}")

## 4. HDFS Client Setup

In [None]:
class HDFSManager:
    """Manages HDFS operations via WebHDFS through Cloudflare tunnels.
    
    Uses separate tunnels for NameNode and DataNode to handle WebHDFS redirects.
    """
    
    def __init__(self, namenode_url, datanode_url, user):
        self.namenode_url = namenode_url.rstrip('/')
        self.datanode_url = datanode_url.rstrip('/')
        self.user = user
        self.session = requests.Session()
        print(f"‚úÖ HDFS Manager initialized")
        print(f"   NameNode: {self.namenode_url}")
        print(f"   DataNode: {self.datanode_url}")
    
    def _namenode_api(self, path, op, **params):
        """Build NameNode WebHDFS URL."""
        url = f"{self.namenode_url}/webhdfs/v1{path}?op={op}&user.name={self.user}"
        for k, v in params.items():
            url += f"&{k}={v}"
        return url
    
    def _redirect_to_datanode(self, redirect_url):
        """Convert NameNode redirect URL to use DataNode tunnel."""
        parsed = urlparse(redirect_url)
        return f"{self.datanode_url}{parsed.path}?{parsed.query}" if parsed.query else f"{self.datanode_url}{parsed.path}"
    
    def read_parquet(self, hdfs_path):
        """Read parquet file from HDFS into pandas DataFrame."""
        print(f"üìñ Reading parquet from: {hdfs_path}")
        
        # Step 1: Request OPEN from NameNode (returns 307 redirect to DataNode)
        open_url = self._namenode_api(hdfs_path, "OPEN")
        response = self.session.get(open_url, allow_redirects=False, timeout=60)
        
        if response.status_code == 307:
            # Step 2: Get redirect and convert to DataNode tunnel URL
            redirect_url = response.headers.get('Location')
            datanode_url = self._redirect_to_datanode(redirect_url)
            print(f"   Fetching from DataNode tunnel...")
            
            # Step 3: Download from DataNode tunnel
            response = self.session.get(datanode_url, timeout=300)
            response.raise_for_status()
        elif response.status_code == 200:
            pass  # Direct response
        else:
            raise Exception(f"Failed to open file: {response.status_code} - {response.text[:500]}")
        
        # Parse parquet
        df = pd.read_parquet(io.BytesIO(response.content))
        print(f"‚úÖ Loaded DataFrame with shape: {df.shape}")
        return df
    
    def save_model_joblib(self, model, hdfs_path, metadata=None):
        """Save model using joblib to HDFS (matching train_xgb.py format)."""
        print(f"üíæ Saving model to: {hdfs_path}")
        
        # Create model package with metadata
        model_package = {
            'model': model,
            'metadata': metadata or {},
        }
        
        # Serialize using joblib (same as train_xgb.py)
        buffer = io.BytesIO()
        joblib.dump(model_package, buffer)
        model_bytes = buffer.getvalue()
        
        # Step 1: CREATE request to NameNode (returns 307 redirect)
        create_url = self._namenode_api(hdfs_path, "CREATE", overwrite="true")
        response = self.session.put(create_url, allow_redirects=False, timeout=60)
        
        if response.status_code == 307:
            # Step 2: Upload to DataNode tunnel
            redirect_url = response.headers.get('Location')
            datanode_url = self._redirect_to_datanode(redirect_url)
            
            response = self.session.put(
                datanode_url,
                data=model_bytes,
                headers={'Content-Type': 'application/octet-stream'},
                timeout=300
            )
            response.raise_for_status()
        else:
            raise Exception(f"Unexpected response: {response.status_code} - {response.text[:500]}")
        
        print(f"‚úÖ Model saved! Size: {len(model_bytes) / 1024:.2f} KB")
        return hdfs_path
    
    def load_model_joblib(self, hdfs_path):
        """Load joblib model from HDFS."""
        print(f"üìñ Loading model from: {hdfs_path}")
        
        open_url = self._namenode_api(hdfs_path, "OPEN")
        response = self.session.get(open_url, allow_redirects=False, timeout=60)
        
        if response.status_code == 307:
            redirect_url = response.headers.get('Location')
            datanode_url = self._redirect_to_datanode(redirect_url)
            response = self.session.get(datanode_url, timeout=300)
            response.raise_for_status()
        
        model_package = joblib.load(io.BytesIO(response.content))
        print(f"‚úÖ Model loaded!")
        return model_package
    
    def list_files(self, hdfs_path):
        """List files in HDFS directory."""
        try:
            url = self._namenode_api(hdfs_path, "LISTSTATUS")
            response = self.session.get(url, timeout=30)
            response.raise_for_status()
            data = response.json()
            files = [f['pathSuffix'] for f in data.get('FileStatuses', {}).get('FileStatus', [])]
            return files
        except Exception as e:
            print(f"Error listing {hdfs_path}: {e}")
            return []
    
    def mkdir(self, hdfs_path):
        """Create directory in HDFS."""
        url = self._namenode_api(hdfs_path, "MKDIRS")
        response = self.session.put(url, timeout=30)
        return response.json().get('boolean', False)

# Initialize HDFS Manager with both tunnel URLs
hdfs_manager = HDFSManager(HDFS_NAMENODE_URL, HDFS_DATANODE_URL, HDFS_USER)

## 5. Training Functions

In [None]:
def prepare_data(df, target='heatwave'):
    """Prepare features and target for training (matching train_xgb.py)."""
    df = df.sort_values('Date').copy()
    df = df[~df[target].isna()].copy()
    
    drop_cols = ['Date', 'District', 'heatwave', 'flood_proxy', 'Latitude', 'Longitude']
    feature_cols = [c for c in df.columns if c not in drop_cols]
    
    X = df[feature_cols].copy()
    y = df[target].astype(int)
    X = X.fillna(-999)
    
    return X, y, feature_cols


def time_train_test_split(df, test_size_days=365):
    """Split data by time - last N days as test (matching train_xgb.py)."""
    df = df.sort_values('Date').copy()
    max_date = df['Date'].max()
    cutoff = max_date - pd.Timedelta(days=test_size_days)
    
    train = df[df['Date'] < cutoff]
    test = df[df['Date'] >= cutoff]
    
    return train, test


def train_xgb_model(hdfs_manager, target='heatwave', test_days=365):
    """Full XGBoost training pipeline (matching train_xgb.py logic).
    
    Args:
        hdfs_manager: HDFSManager instance
        target: 'heatwave' or 'flood_proxy'
        test_days: Number of days for test set
    
    Returns:
        dict with model, metrics, and metadata
    """
    print(f"\n{'='*60}")
    print(f"üéØ TRAINING XGBoost FOR: {target.upper()}")
    print(f"{'='*60}")
    
    # Load data from HDFS
    print("üì• Loading data from HDFS...")
    df = hdfs_manager.read_parquet(HDFS_FEATURES_PATH)
    
    # Time-based split
    train_df, test_df = time_train_test_split(df, test_size_days=test_days)
    print(f"üìÖ Train: {len(train_df)} samples, Test: {len(test_df)} samples")
    
    # Prepare data
    X_train, y_train, feature_cols = prepare_data(train_df, target=target)
    X_test, y_test, _ = prepare_data(test_df, target=target)
    
    # Create DMatrix
    dtrain = xgb.DMatrix(X_train, label=y_train, feature_names=feature_cols)
    dtest = xgb.DMatrix(X_test, label=y_test, feature_names=feature_cols)
    
    # Calculate scale_pos_weight (matching train_xgb.py)
    scale_pos_weight = max(1, (len(y_train) - y_train.sum()) / max(1, y_train.sum()))
    
    # Parameters
    params = XGB_PARAMS.copy()
    params['scale_pos_weight'] = scale_pos_weight
    
    # Train
    print("üöÄ Training model...")
    evals = [(dtrain, 'train'), (dtest, 'test')]
    model = xgb.train(
        params, dtrain,
        num_boost_round=NUM_BOOST_ROUNDS,
        evals=evals,
        early_stopping_rounds=EARLY_STOPPING_ROUNDS,
        verbose_eval=50
    )
    
    # Evaluate (matching train_xgb.py output)
    y_pred_proba = model.predict(dtest)
    y_pred = (y_pred_proba >= 0.5).astype(int)
    
    pr_auc = average_precision_score(y_test, y_pred_proba)
    roc_auc = roc_auc_score(y_test, y_pred_proba)
    prec, rec, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='binary', zero_division=0)
    
    print(f"\nPR AUC: {pr_auc:.4f}, ROC AUC: {roc_auc:.4f}")
    print(f"Precision {prec:.3f}, Recall {rec:.3f}, F1 {f1:.3f}")
    
    metrics = {
        'pr_auc': float(pr_auc),
        'roc_auc': float(roc_auc),
        'precision': float(prec),
        'recall': float(rec),
        'f1': float(f1)
    }
    
    metadata = {
        'target': target,
        'feature_columns': feature_cols,
        'training_params': params,
        'best_iteration': model.best_iteration,
        'metrics': metrics,
        'train_samples': len(X_train),
        'test_samples': len(X_test),
        'saved_at': datetime.now().isoformat(),
        'xgboost_version': xgb.__version__
    }
    
    # Save to HDFS
    hdfs_manager.mkdir(HDFS_MODELS_DIR)
    hdfs_path = f"{HDFS_MODELS_DIR}/xgb_{target}_model.joblib"
    hdfs_manager.save_model_joblib(model, hdfs_path, metadata=metadata)
    print(f"üíæ Saved to: {hdfs_path}")
    
    return {
        'model': model,
        'metrics': metrics,
        'metadata': metadata,
        'hdfs_path': hdfs_path
    }

print("‚úÖ Training functions defined!")

## 6. Flask API Setup

In [None]:
# Global state for tracking training
training_state = {
    'is_training': False,
    'current_target': None,
    'last_result': None,
    'last_error': None,
    'started_at': None,
    'completed_at': None
}

# Initialize Flask app
app = Flask(__name__)

# Initialize HDFS Manager
hdfs_manager = HDFSManager(HDFS_NAMENODE_URL, HDFS_DATANODE_URL, HDFS_USER)


@app.route('/health', methods=['GET'])
def health():
    """Health check endpoint."""
    return jsonify({
        'status': 'healthy',
        'service': 'xgb-training-api',
        'timestamp': datetime.now().isoformat()
    })


@app.route('/status', methods=['GET'])
def status():
    """Get current training status."""
    return jsonify({
        'is_training': training_state['is_training'],
        'current_target': training_state['current_target'],
        'started_at': training_state['started_at'],
        'completed_at': training_state['completed_at'],
        'last_result': training_state['last_result'],
        'last_error': training_state['last_error']
    })


def run_training(target, test_days):
    """Background training function."""
    global training_state
    
    try:
        training_state['is_training'] = True
        training_state['current_target'] = target
        training_state['started_at'] = datetime.now().isoformat()
        training_state['last_error'] = None
        
        # Run training
        result = train_xgb_model(hdfs_manager, target=target, test_days=test_days)
        
        training_state['last_result'] = {
            'target': target,
            'metrics': result['metrics'],
            'hdfs_path': result['hdfs_path'],
            'best_iteration': result['metadata']['best_iteration']
        }
        training_state['completed_at'] = datetime.now().isoformat()
        
    except Exception as e:
        training_state['last_error'] = str(e)
        print(f"‚ùå Training error: {e}")
    
    finally:
        training_state['is_training'] = False
        training_state['current_target'] = None


@app.route('/train', methods=['POST'])
def train():
    """Trigger model training.
    
    Request body (JSON):
        - target: 'heatwave' or 'flood_proxy' (default: 'heatwave')
        - test_days: Number of days for test set (default: 365)
        - async: If true, return immediately (default: false)
    """
    if training_state['is_training']:
        return jsonify({
            'status': 'busy',
            'message': f"Training already in progress for {training_state['current_target']}",
            'started_at': training_state['started_at']
        }), 409
    
    # Parse request
    data = request.get_json() or {}
    target = data.get('target', DEFAULT_TARGET)
    test_days = data.get('test_days', DEFAULT_TEST_DAYS)
    async_mode = data.get('async', False)
    
    # Validate target
    if target not in ['heatwave', 'flood_proxy']:
        return jsonify({
            'status': 'error',
            'message': f"Invalid target: {target}. Must be 'heatwave' or 'flood_proxy'"
        }), 400
    
    if async_mode:
        # Start training in background thread
        thread = threading.Thread(target=run_training, args=(target, test_days))
        thread.start()
        
        return jsonify({
            'status': 'started',
            'message': f'Training started for {target}',
            'target': target,
            'test_days': test_days,
            'started_at': datetime.now().isoformat()
        })
    else:
        # Synchronous training
        run_training(target, test_days)
        
        if training_state['last_error']:
            return jsonify({
                'status': 'error',
                'message': training_state['last_error']
            }), 500
        
        return jsonify({
            'status': 'completed',
            'result': training_state['last_result']
        })


@app.route('/train/all', methods=['POST'])
def train_all():
    """Train models for all targets (heatwave and flood_proxy)."""
    if training_state['is_training']:
        return jsonify({
            'status': 'busy',
            'message': f"Training already in progress for {training_state['current_target']}"
        }), 409
    
    data = request.get_json() or {}
    test_days = data.get('test_days', DEFAULT_TEST_DAYS)
    
    results = {}
    
    for target in ['heatwave', 'flood_proxy']:
        run_training(target, test_days)
        
        if training_state['last_error']:
            results[target] = {'status': 'error', 'error': training_state['last_error']}
        else:
            results[target] = {'status': 'completed', 'result': training_state['last_result']}
    
    return jsonify({
        'status': 'completed',
        'results': results
    })


print("‚úÖ Flask API configured!")
print("   Endpoints:")
print("   - GET  /health  - Health check")
print("   - GET  /status  - Training status")
print("   - POST /train   - Train single model")
print("   - POST /train/all - Train all models")

In [None]:
## 7. Start API Server with ngrok

In [None]:
# Set ngrok auth token if provided
if NGROK_AUTH_TOKEN:
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Start ngrok tunnel
public_url = ngrok.connect(API_PORT)
print("=" * 60)
print("üåê XGBoost TRAINING API SERVER")
print("=" * 60)
print(f"\nüì° Public URL: {public_url}")
print(f"   Local URL:  http://localhost:{API_PORT}")
print("\nüìã API Endpoints:")
print(f"   GET  {public_url}/health")
print(f"   GET  {public_url}/status")
print(f"   POST {public_url}/train")
print(f"   POST {public_url}/train/all")
print("\n‚ö†Ô∏è  Copy the public URL and configure in Airflow DAG!")
print("=" * 60)

# Store URL for reference
XGB_API_URL = str(public_url)

In [None]:
# Run Flask server (this cell blocks - run this last!)
print("üöÄ Starting Flask server...")
print("   Press Ctrl+C or restart runtime to stop")
print("-" * 60)

# Run Flask app
app.run(port=API_PORT, threaded=True)

## 8. API Usage Examples

### From Airflow (Python):
```python
import requests

XGB_API_URL = "https://xxxx.ngrok.io"  # Replace with actual ngrok URL

# Train heatwave model
response = requests.post(f"{XGB_API_URL}/train", json={
    "target": "heatwave",
    "test_days": 365
})
print(response.json())

# Train all models
response = requests.post(f"{XGB_API_URL}/train/all")
print(response.json())
```

### From curl:
```bash
# Health check
curl https://xxxx.ngrok.io/health

# Train heatwave model
curl -X POST https://xxxx.ngrok.io/train \
  -H "Content-Type: application/json" \
  -d '{"target": "heatwave", "test_days": 365}'

# Train all models
curl -X POST https://xxxx.ngrok.io/train/all
```