In [None]:
import pandas as pd
import numpy as np
from typing import List
import requests
from utils.match_prediction import RAW_PRO_GAMES_DIR
from pathlib import Path
from tqdm import tqdm

RAW_PRO_GAMES_FILE = Path(RAW_PRO_GAMES_DIR) / "pro_games.parquet"
# Load the dataset
df = pd.read_parquet(RAW_PRO_GAMES_FILE)


def get_model_prediction(champion_ids: List[int], patch: str) -> float:
    """Get model prediction from the API"""
    url = "http://localhost:8000/predict"
    payload = {
        "champion_ids": champion_ids,
        "numerical_elo": 0,  # Pro play elo
        "patch": patch,
    }
    headers = {"X-API-Key": "example_token"}

    try:
        response = requests.post(url, json=payload, headers=headers)
        response.raise_for_status()
        return response.json()["win_probability"]
    except Exception as e:
        print(f"Error getting prediction: {e}")
        return None


# Add or update model prediction columns
def update_model_predictions():
    # Initialize columns if they don't exist
    if "model_prediction" not in df.columns:
        df["model_prediction"] = -1.0
    if "model_error" not in df.columns:
        df["model_error"] = -1.0

    # Get predictions for all games
    print("Getting model predictions...")
    for idx in tqdm(df.index):
        row = df.iloc[idx]
        patch = f"{row.gameVersionMajorPatch}.{str(row.gameVersionMinorPatch).zfill(2)}"
        prediction = get_model_prediction(row.champion_ids.tolist(), patch)

        if prediction is not None:
            df.loc[idx, "model_prediction"] = prediction
            # Calculate error (actual - prediction)
            actual = float(row.team_100_win)
            df.loc[idx, "model_error"] = abs(actual - prediction)

    # Save the updated dataset
    print("Saving updated dataset...")
    df.to_parquet(RAW_PRO_GAMES_FILE)
    print("Done!")


# Run the update
update_model_predictions()

# Show some statistics
print("\nModel Error Statistics:")
print(df["model_error"].describe())
print(f"\nGames with error > 0.8: {len(df[df['model_error'] > 0.8])}")