Setup environment

In [None]:
# Ensure src folder is importable
import sys
from pathlib import Path

project_root = Path.cwd().parent
sys.path.append(str(project_root))

# Auto-reload changes in .py files
%load_ext autoreload
%autoreload 2


Imports

In [None]:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import numpy as np
import pandas as pd
from typing import List
import uvicorn

from src.data.data_loader import CryptoDataLoader
from src.data.feature_engineering import FeatureEngineer
from src.models.lstm import LSTMModel
from src.models.transformer import TransformerModel


Initialize API app and globals

In [None]:
app = FastAPI(
    title="Crypto Price Forecasting API",
    description="API for forecasting cryptocurrency prices"
)

# Global variables
model = None
feature_scaler = None
target_scaler = None
feature_cols = None
sequence_length = 60
model_type = 'lstm'
input_size = None

Define request/response schemas

In [None]:
class PredictionRequest(BaseModel):
    ticker: str
    days: int = 1

class PredictionResponse(BaseModel):
    ticker: str
    predictions: List[float]
    dates: List[str]

class ModelInfo(BaseModel):
    model_type: str
    input_size: int
    sequence_length: int


Load model (startup simulation)

In [None]:
# Simulate loading a model (weights not loaded in this notebook)
def load_model():
    global model, feature_scaler, target_scaler, feature_cols, sequence_length, model_type, input_size
    
    model_type = 'lstm'  # or 'transformer'
    sequence_length = 60
    input_size = 20  # Number of features

    if model_type == 'lstm':
        model = LSTMModel(
            input_size=input_size,
            hidden_size=64,
            num_layers=2,
            output_size=1,
            dropout=0.2
        )
    else:
        model = TransformerModel(
            input_size=input_size,
            d_model=64,
            nhead=4,
            num_encoder_layers=2,
            dim_feedforward=128,
            output_size=1
        )
    
    # Set model to eval mode
    model.eval()
    
    # Dummy feature columns
    feature_cols = [f'feature_{i}' for i in range(input_size)]

load_model()
print(f"Loaded {model_type} model with input size {input_size}")

Root endpoint

In [None]:
@app.get("/")
async def root():
    return {"message": "Crypto Price Forecasting API"}


Model info endpoint

In [None]:
@app.get("/model/info", response_model=ModelInfo)
async def get_model_info():
    return ModelInfo(
        model_type=model_type,
        input_size=input_size,
        sequence_length=sequence_length
    )


Prediction endpoint

In [None]:
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    
    # Load crypto data
    loader = CryptoDataLoader()
    data = loader.get_latest_data(request.ticker, days=sequence_length + 10)
    
    if data is None or len(data) < sequence_length:
        raise HTTPException(status_code=404, detail=f"Insufficient data for {request.ticker}")
    
    # Feature engineering
    engineer = FeatureEngineer()
    data_with_features = engineer.add_technical_indicators(data)
    
    # Use last sequence_length rows
    last_sequence = data_with_features.iloc[-sequence_length:]
    features = last_sequence.values
    
    # Create input tensor
    input_tensor = torch.FloatTensor(features).unsqueeze(0)  # batch dimension
    
    # Predict
    with torch.no_grad():
        output = model(input_tensor)
        prediction = output.item()
    
    # Generate future dates
    last_date = data.index[-1]
    future_dates = pd.date_range(start=last_date, periods=request.days + 1, freq='D')[1:]
    predictions = [prediction] * request.days
    
    return PredictionResponse(
        ticker=request.ticker,
        predictions=predictions,
        dates=future_dates.strftime('%Y-%m-%d').tolist()
    )


Run API server from notebook

In [None]:
# Uncomment the line below to run the API in the notebook (blocking)
# uvicorn.run(app, host="0.0.0.0", port=8000)

⚠️ Note: Running uvicorn.run() in a notebook is blocking, so typically you run it in a terminal.

Test prediction interactively

In [None]:
from fastapi.testclient import TestClient

client = TestClient(app)

response = client.post("/predict", json={"ticker": "BTC-USD", "days": 3})
print(response.json())
