In [None]:
# app/main.py
import os
import time
import json
import bisect
import tempfile
from datetime import datetime
from typing import List, Optional, Dict, Any

import pandas as pd
from fastapi import FastAPI, HTTPException, BackgroundTasks, Query
from pydantic import BaseModel, Field

# External provider functions from nsepython
# make sure you have `nsepython` installed in your env:
# pip install nsepython
from nsepython import option_chain, nse_quote

# ---------------------------
# Models
# ---------------------------

class IndexPriceResponse(BaseModel):
    symbol: str
    lastPrice: float
    pChange: float
    change: float
    timestamp: str

class StockPriceResponse(BaseModel):
    symbol: str
    companyName: Optional[str] = None
    lastPrice: Optional[float] = None
    pChange: Optional[float] = None
    change: Optional[float] = None
    timestamp: str

class FetchOptionsRequest(BaseModel):
    index: str = Field(..., description="Index symbol, e.g. NIFTY or BANKNIFTY")
    num_strikes: int = Field(25, gt=0, le=500)

class FetchExpiryRequest(BaseModel):
    index: str
    expiry: str
    num_strikes: int = Field(25, gt=0, le=500)

class FetchResultMeta(BaseModel):
    createdAtUTC: str
    indexName: str
    nearestExpiry: Optional[str] = None
    selectedExpiry: Optional[str] = None
    underlyingValue: Optional[float] = None
    atmStrike: Optional[int] = None
    selectedStrikesRange: Optional[List[int]] = None
    totalStrikesFetched: Optional[int] = None

class AnalyticsResponse(BaseModel):
    meta: FetchResultMeta
    pcr: Dict[str, float]
    top_oi: Dict[str, List[Dict[str, Any]]]
    max_pain: Dict[str, Any]

# ---------------------------
# App & config
# ---------------------------

app = FastAPI(title="Option Chain API", version="1.0",
              description="Fetch option chains from NSE and return analytics (PCR / MaxPain / Top OI).")

OUTPUT_DIR = os.environ.get("OPTION_OUTPUT_DIR", "option_chain_data")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ---------------------------
# Helpers (adapted from your script)
# ---------------------------

def _expand_side(df: pd.DataFrame, side: str) -> pd.DataFrame:
    valid_rows = df[df[side].apply(lambda x: isinstance(x, dict))]
    if valid_rows.empty:
        return pd.DataFrame()
    side_data = valid_rows[side].apply(pd.Series)
    side_data = side_data.add_prefix(f'{side}_')
    return side_data

def _atomic_write_csv(df: pd.DataFrame, target_path: str):
    # write to temp file then atomically replace
    dirpath = os.path.dirname(target_path)
    os.makedirs(dirpath, exist_ok=True)
    with tempfile.NamedTemporaryFile(mode="w", dir=dirpath, delete=False, suffix=".csv") as tmp:
        tmp_name = tmp.name
        df.to_csv(tmp_name, index=False)
    os.replace(tmp_name, target_path)

def _atomic_write_json(obj: dict, target_path: str):
    dirpath = os.path.dirname(target_path)
    os.makedirs(dirpath, exist_ok=True)
    with tempfile.NamedTemporaryFile(mode="w", dir=dirpath, delete=False, suffix=".json", encoding="utf-8") as tmp:
        tmp_name = tmp.name
        json.dump(obj, tmp, indent=2)
    os.replace(tmp_name, target_path)

def _normalize_index_name(index: str) -> str:
    if not index:
        return ""
    s = index.strip().upper()
    if s in ("NIFTY50", "NIFTY", "NSEI"):
        return "NIFTY"
    if s in ("BANKNIFTY", "NSEBANK"):
        return "BANKNIFTY"
    return s

# ---------------------------
# Analytical helpers (your logic)
# ---------------------------

def calculate_pcr(df: pd.DataFrame) -> dict:
    pcr_data = {'pcr_by_oi': 0.0, 'pcr_by_volume': 0.0}
    if 'PE_openInterest' in df.columns and 'CE_openInterest' in df.columns:
        total_pe_oi = df['PE_openInterest'].fillna(0).sum()
        total_ce_oi = df['CE_openInterest'].fillna(0).sum()
        if total_ce_oi > 0:
            pcr_data['pcr_by_oi'] = round(total_pe_oi / total_ce_oi, 2)
    if 'PE_totalTradedVolume' in df.columns and 'CE_totalTradedVolume' in df.columns:
        total_pe_volume = df['PE_totalTradedVolume'].fillna(0).sum()
        total_ce_volume = df['CE_totalTradedVolume'].fillna(0).sum()
        if total_ce_volume > 0:
            pcr_data['pcr_by_volume'] = round(total_pe_volume / total_ce_volume, 2)
    return pcr_data

def find_high_oi_strikes(df: pd.DataFrame, top_n: int = 5) -> dict:
    results = {'resistance_strikes': [], 'support_strikes': []}
    if 'CE_openInterest' in df.columns:
        top_calls = df.nlargest(top_n, 'CE_openInterest')[['strikePrice', 'CE_openInterest']].fillna(0)
        results['resistance_strikes'] = top_calls.to_dict('records')
    if 'PE_openInterest' in df.columns:
        top_puts = df.nlargest(top_n, 'PE_openInterest')[['strikePrice', 'PE_openInterest']].fillna(0)
        results['support_strikes'] = top_puts.to_dict('records')
    return results

def calculate_max_pain(df: pd.DataFrame) -> dict:
    if 'strikePrice' not in df.columns:
        return {'max_pain_strike': None, 'max_loss_value': 0}
    strikes = sorted(df['strikePrice'].dropna().unique())
    total_loss_at_strike = {}
    for strike_price in strikes:
        loss = 0
        if 'CE_openInterest' in df.columns and 'CE_lastPrice' in df.columns:
            ce_data = df[['strikePrice', 'CE_openInterest', 'CE_lastPrice']].dropna()
            for _, row in ce_data.iterrows():
                if row['strikePrice'] > strike_price:
                    loss += (row['strikePrice'] - strike_price) * row['CE_openInterest']
        if 'PE_openInterest' in df.columns and 'PE_lastPrice' in df.columns:
            pe_data = df[['strikePrice', 'PE_openInterest', 'PE_lastPrice']].dropna()
            for _, row in pe_data.iterrows():
                if row['strikePrice'] < strike_price:
                    loss += (strike_price - row['strikePrice']) * row['PE_openInterest']
        total_loss_at_strike[strike_price] = loss
    if not total_loss_at_strike:
        return {'max_pain_strike': None, 'max_loss_value': 0}
    max_pain_strike = min(total_loss_at_strike, key=total_loss_at_strike.get)
    return {'max_pain_strike': int(max_pain_strike), 'max_loss_value': int(total_loss_at_strike[max_pain_strike])}

# ---------------------------
# Core fetch + save (refactored)
# ---------------------------

def _prepare_option_chain_df(resp: dict, expiry: str) -> pd.DataFrame:
    if not (isinstance(resp, dict) and 'records' in resp and 'data' in resp['records']):
        raise RuntimeError("Invalid response structure from NSE.")
    df_full = pd.DataFrame(resp['records']['data'])
    if df_full.empty:
        raise RuntimeError("No option chain data returned by NSE.")
    if 'strikePrice' not in df_full.columns:
        raise RuntimeError("Column 'strikePrice' missing from NSE response.")
    df_full['strikePrice'] = pd.to_numeric(df_full['strikePrice'], errors='coerce')
    df = df_full[df_full['expiryDate'] == expiry].copy()
    if df.empty:
        raise RuntimeError(f"No data for expiry {expiry}")
    ce_data = _expand_side(df, 'CE')
    pe_data = _expand_side(df, 'PE')
    df_processed = pd.concat([df[['strikePrice', 'expiryDate']].reset_index(drop=True), ce_data.reset_index(drop=True), pe_data.reset_index(drop=True)], axis=1)
    return df_processed

def _select_strikes_and_save(df_processed: pd.DataFrame, resp: dict, index_name: str, expiry: str, num_strikes: int) -> FetchResultMeta:
    underlying_value = float(resp['records'].get('underlyingValue', 0))
    strikes = sorted(df_processed['strikePrice'].dropna().unique())
    if not strikes:
        raise RuntimeError("No strikes found after processing")
    atm_strike_index = bisect.bisect_left(strikes, underlying_value)
    if atm_strike_index > 0 and abs(strikes[atm_strike_index-1] - underlying_value) < abs(strikes[atm_strike_index] - underlying_value):
        atm_strike_index -= 1
    low_index = max(0, atm_strike_index - num_strikes)
    high_index = min(len(strikes) - 1, atm_strike_index + num_strikes)
    selected_strikes = strikes[low_index:high_index+1]
    df_final = df_processed[df_processed['strikePrice'].isin(selected_strikes)].sort_values(['strikePrice']).reset_index(drop=True)
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    safe_expiry = str(expiry).replace(' ', '_').replace('/', '-')
    base_filename = f"{index_name.lower()}_option_chain_{safe_expiry}_{timestamp}"
    csv_path = os.path.join(OUTPUT_DIR, f"{base_filename}.csv")
    meta_path = os.path.join(OUTPUT_DIR, f"{base_filename}.json")
    # atomic write
    _atomic_write_csv(df_final, csv_path)
    metadata = {
        'createdAtUTC': datetime.utcnow().isoformat(),
        'indexName': index_name,
        'nearestExpiry': expiry,
        'underlyingValue': float(underlying_value),
        'atmStrike': int(strikes[atm_strike_index]) if 0 <= atm_strike_index < len(strikes) else None,
        'selectedStrikesRange': [int(selected_strikes[0]), int(selected_strikes[-1])],
        'totalStrikesFetched': int(len(df_final))
    }
    _atomic_write_json(metadata, meta_path)
    return FetchResultMeta(**metadata)

def fetch_and_save_option_chain(index_name: str, num_strikes_around_atm: int = 25) -> FetchResultMeta:
    start_time = time.time()
    resp = option_chain(index_name)
    # find nearest expiry
    expiries = resp['records'].get('expiryDates', [])
    if not expiries:
        raise RuntimeError("No expiries in NSE response.")
    nearest_expiry = expiries[0]
    df_processed = _prepare_option_chain_df(resp, nearest_expiry)
    meta = _select_strikes_and_save(df_processed, resp, index_name, nearest_expiry, num_strikes_around_atm)
    elapsed = time.time() - start_time
    app.logger.info(f"Saved option chain for {index_name} expiry {nearest_expiry} in {elapsed:.2f}s")
    return meta

def fetch_specific_expiry_option_chain(index_name: str, expiry_date: str, num_strikes_around_atm: int = 25) -> FetchResultMeta:
    start_time = time.time()
    resp = option_chain(index_name)
    expiries = resp['records'].get('expiryDates', [])
    if expiry_date not in expiries:
        raise HTTPException(status_code=422, detail=f"Expiry '{expiry_date}' not available. Available: {expiries}")
    df_processed = _prepare_option_chain_df(resp, expiry_date)
    meta = _select_strikes_and_save(df_processed, resp, index_name, expiry_date, num_strikes_around_atm)
    elapsed = time.time() - start_time
    app.logger.info(f"Saved option chain for {index_name} expiry {expiry_date} in {elapsed:.2f}s")
    return meta

# ---------------------------
# Provider small wrappers
# ---------------------------

def get_available_expiries(index_name: str) -> List[str]:
    try:
        resp = option_chain(index_name)
        return resp['records'].get('expiryDates', [])
    except Exception as e:
        app.logger.error("get_available_expiries error: %s", e)
        return []

def fetch_index_price(index_name: str) -> dict:
    try:
        quote = nse_quote(index_name)
        if not quote or 'lastPrice' not in quote:
            raise HTTPException(status_code=404, detail=f"No data for index {index_name}")
        last_price = float(str(quote['lastPrice']).replace(',', ''))
        return {
            'symbol': index_name,
            'lastPrice': last_price,
            'pChange': float(quote.get('pChange', 0)),
            'change': float(quote.get('change', 0)),
            'timestamp': quote.get('secDate', datetime.now().strftime("%d %b %Y %H:%M:%S"))
        }
    except HTTPException:
        raise
    except Exception as e:
        app.logger.exception("fetch_index_price error")
        raise HTTPException(status_code=500, detail=str(e))

def fetch_stock_price(stock_symbol: str) -> dict:
    try:
        quote = nse_quote(stock_symbol)
        info = quote.get('info', {})
        price_info = quote.get('priceInfo', {})
        if not info or not price_info:
            raise HTTPException(status_code=404, detail=f"No data for stock {stock_symbol}")
        last_price = price_info.get('lastPrice')
        try:
            last_price = float(last_price) if last_price is not None else None
        except Exception:
            last_price = None
        return {
            'symbol': info.get('symbol'),
            'companyName': info.get('companyName'),
            'lastPrice': last_price,
            'pChange': float(price_info.get('pChange', 0)) if price_info.get('pChange') is not None else None,
            'change': float(price_info.get('change', 0)) if price_info.get('change') is not None else None,
            'timestamp': quote.get('metadata', {}).get('lastUpdateTime', datetime.now().strftime("%d-%b-%Y %H:%M:%S"))
        }
    except HTTPException:
        raise
    except Exception as e:
        app.logger.exception("fetch_stock_price error")
        raise HTTPException(status_code=500, detail=str(e))

# ---------------------------
# REST endpoints
# ---------------------------

@app.get("/expiries", response_model=List[str])
def api_get_expiries(index: str = Query(..., description="Index symbol, e.g. NIFTY")):
    idx = _normalize_index_name(index)
    expiries = get_available_expiries(idx)
    if not expiries:
        raise HTTPException(status_code=404, detail=f"No expiries found for {idx}")
    return expiries

@app.get("/index-price", response_model=IndexPriceResponse)
def api_index_price(index: str = Query(..., description="Index symbol, e.g. NIFTY")):
    idx = _normalize_index_name(index)
    data = fetch_index_price(idx)
    return IndexPriceResponse(**data)

@app.get("/stock-price", response_model=StockPriceResponse)
def api_stock_price(symbol: str = Query(..., description="Stock symbol (NSE), e.g. RELIANCE")):
    data = fetch_stock_price(symbol.upper())
    return StockPriceResponse(**data)

@app.post("/fetch", response_model=FetchResultMeta, status_code=201)
def api_fetch_options(request: FetchOptionsRequest, background_tasks: BackgroundTasks):
    idx = _normalize_index_name(request.index)
    # we will fetch in background and return immediately: to keep simple, perform sync fetch
    try:
        meta = fetch_and_save_option_chain(idx, request.num_strikes)
        return meta
    except HTTPException:
        raise
    except Exception as e:
        app.logger.exception("api_fetch_options error")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/fetch/expiry", response_model=FetchResultMeta, status_code=201)
def api_fetch_options_expiry(req: FetchExpiryRequest):
    idx = _normalize_index_name(req.index)
    try:
        meta = fetch_specific_expiry_option_chain(idx, req.expiry, req.num_strikes)
        return meta
    except HTTPException:
        raise
    except Exception as e:
        app.logger.exception("api_fetch_options_expiry error")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/analytics", response_model=AnalyticsResponse)
def api_analytics_for_latest(index: str = Query(...), limit: int = Query(500, gt=0, le=5000)):
    """
    Read the latest saved CSV for the index and compute analytics.
    """
    idx = _normalize_index_name(index)
    files = [f for f in os.listdir(OUTPUT_DIR) if f.startswith(f"{idx.lower()}_") and f.endswith('.csv')]
    if not files:
        raise HTTPException(status_code=404, detail=f"No saved option-chain CSVs found for {idx}")
    latest_file = sorted(files, reverse=True)[0]
    csv_path = os.path.join(OUTPUT_DIR, latest_file)
    try:
        df = pd.read_csv(csv_path)
    except Exception as e:
        app.logger.exception("Failed to read CSV for analytics")
        raise HTTPException(status_code=500, detail="Failed to read saved CSV")
    # apply limit
    if limit:
        df = df.head(limit)
    pcr = calculate_pcr(df)
    top_oi = find_high_oi_strikes(df, top_n=5)
    max_pain = calculate_max_pain(df)
    # load metadata JSON if present
    meta_file = csv_path.replace('.csv', '.json')
    meta_obj = {}
    if os.path.exists(meta_file):
        with open(meta_file, 'r', encoding='utf-8') as f:
            meta_obj = json.load(f)
    meta_obj.setdefault('createdAtUTC', datetime.utcnow().isoformat())
    meta = FetchResultMeta(**meta_obj)
    return AnalyticsResponse(meta=meta, pcr=pcr, top_oi=top_oi, max_pain=max_pain)

# ---------------------------
# Simple health endpoint
# ---------------------------

@app.get("/health")
def health():
    return {"status": "ok", "time": datetime.utcnow().isoformat()}

# ---------------------------
# Run example if run directly
# ---------------------------
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)
