# FDS Challenge Notebook

## 0. Version 2 summary

This version changes the features, and changes the model up. 

Specifically, we wanted to move avway from the calculation of raw damage, as it's not guaranteed to be a fitting estimate of the actual damage, which depends on the drop in hp_pct after attacks.

For example, it could be the case that a move with accuracy = 1.0 may miss because the attacker is paralyzed, or under the effect of status conditions.

Moreover, we refined how we treat the [effects] in player_pokemon_state, by counting the occurrences of each available effect, just like how we did in the previous versions with status conditions. 

We wanted to compare scores on stacked models with different final estimator, so we added a stacked and calibrated classifier with GradientBoost. 
Moreover, we wanted to compare raw predictions with blended predicted probabilities, and indeed, the created files have some lines that differ.

Let's see how this modification changes the score (hopefully, increasing it)

## 1. Loading the Data

In [None]:
import json
import pandas as pd
import os

# --- Define the path to our data ---
COMPETITION_NAME = 'fds-pokemon-battles-prediction-2025'
DATA_PATH = os.path.join('/kaggle/input', COMPETITION_NAME)
train_file_path = os.path.join(DATA_PATH, 'train.jsonl')
test_file_path = os.path.join(DATA_PATH, 'test.jsonl')

def load_data(file_path):
    data = []
    print(f"Loading data from '{file_path}'...")
    try:
        with open(file_path, 'r') as f:
            for line in f:
                data.append(json.loads(line))
        print(f"Successfully loaded {len(data)} battles.")
    except FileNotFoundError:
        print(f"ERROR: Could not find the file at '{file_path}'.")
    return data

train_data = load_data(train_file_path)
test_data = load_data(test_file_path) # Load the test data as well


Loading data from 'train.jsonl'...
Successfully loaded 10000 battles.
Loading data from 'test.jsonl'...
Successfully loaded 5000 battles.


## 2. Complete Pokèmon Dataframe

In [2]:
import pandas as pd

# Gen 1 types in lowercase
types = [
    "normal", "fire", "water", "electric", "grass", "ice", "fighting", "poison",
    "ground", "flying", "psychic", "bug", "rock", "ghost", "dragon", "notype",
]

# Type effectiveness values
type_chart = {
    "normal":   {"rock": 0.5, "ghost": 0.0},
    "fire":     {"grass": 2.0, "ice": 2.0, "bug": 2.0, "rock": 0.5, "fire": 0.5, "water": 0.5},
    "water":    {"fire": 2.0, "ground": 2.0, "rock": 2.0, "water": 0.5, "grass": 0.5},
    "electric": {"water": 2.0, "flying": 2.0, "electric": 0.5, "grass": 0.5, "ground": 0.0},
    "grass":    {"water": 2.0, "ground": 2.0, "rock": 2.0, "fire": 0.5, "grass": 0.5, "flying": 0.5, "bug": 0.5},
    "ice":      {"grass": 2.0, "ground": 2.0, "flying": 2.0, "dragon": 2.0, "fire": 0.5, "water": 0.5},
    "fighting": {"normal": 2.0, "rock": 2.0, "ice": 2.0, "bug": 0.5, "psychic": 0.5, "ghost": 0.0},
    "poison":   {"grass": 2.0, "bug": 2.0, "poison": 0.5, "ground": 0.5, "rock": 0.5, "ghost": 0.5},
    "ground":   {"fire": 2.0, "electric": 2.0, "poison": 2.0, "rock": 2.0, "bug": 0.5, "flying": 0.0},
    "flying":   {"grass": 2.0, "fighting": 2.0, "bug": 2.0, "electric": 0.5, "rock": 0.5},
    "psychic":  {"fighting": 2.0, "poison": 2.0, "psychic": 0.5},
    "bug":      {"grass": 2.0, "poison": 2.0, "psychic": 2.0, "fire": 0.5, "fighting": 0.5, "flying": 0.5, "ghost": 0.5},
    "rock":     {"fire": 2.0, "ice": 2.0, "flying": 2.0, "bug": 2.0, "fighting": 0.5, "ground": 0.5},
    "ghost":    {"psychic": 0.0, "ghost": 2.0, "normal": 0.0},
    "dragon":   {"dragon": 2.0},
    "notype":   {}
}

# Create full chart with default 1.0 (neutral)
df_typechart = pd.DataFrame(index=types, columns=types).fillna(1.0)

# Apply effectiveness values
for attacker, defenders in type_chart.items():
    for defender, value in defenders.items():
        df_typechart.loc[attacker, defender] = value

  df_typechart = pd.DataFrame(index=types, columns=types).fillna(1.0)


In [3]:
print(df_typechart.head())

          normal  fire  water  electric  grass  ice  fighting  poison  ground  \
normal       1.0   1.0    1.0       1.0    1.0  1.0       1.0     1.0     1.0   
fire         1.0   0.5    0.5       1.0    2.0  2.0       1.0     1.0     1.0   
water        1.0   2.0    0.5       1.0    0.5  1.0       1.0     1.0     2.0   
electric     1.0   1.0    2.0       0.5    0.5  1.0       1.0     1.0     0.0   
grass        1.0   0.5    2.0       1.0    0.5  1.0       1.0     1.0     2.0   

          flying  psychic  bug  rock  ghost  dragon  notype  
normal       1.0      1.0  1.0   0.5    0.0     1.0     1.0  
fire         1.0      1.0  2.0   0.5    1.0     1.0     1.0  
water        1.0      1.0  1.0   2.0    1.0     1.0     1.0  
electric     2.0      1.0  1.0   1.0    1.0     1.0     1.0  
grass        0.5      1.0  0.5   2.0    1.0     1.0     1.0  


In [4]:
def compute_stat(base: int, lvl: int, hp_bool: bool = False) -> int:
    max_DVs = 15
    max_EVs = 65535
    square_term = np.sqrt(max_EVs) / 4

    if hp_bool:
        return int(((base + max_DVs) * 2 + square_term) * lvl // 100 + lvl + 10)
    return int(((base + max_DVs) * 2 + square_term) * lvl // 100 + 5)

In [5]:
import json
import pandas as pd

import numpy as np

def compute_stat(base: int, lvl: int, hp_bool: bool = False) -> int:
    """
    Computes the stat value for a Pokémon given its base stat, level, and whether it's HP.
    (formulas: https://www.pokemaniablog.com/2017/11/11/CalculatingHP.html)
    """
    max_DVs = 15
    max_EVs = 65535
    square_term = np.sqrt(max_EVs) / 4

    if hp_bool:
        return int(((base + max_DVs) * 2 + square_term) * lvl // 100 + lvl + 10)
    return int(((base + max_DVs) * 2 + square_term) * lvl // 100 + 5)


def extract_unique_pokemon_no_ids(jsonl_path: str) -> pd.DataFrame:
    """
    Extracts a clean list of unique Pokémon with full base stats from the dataset.
    - Includes Pokémon from p1_team_details, p2_lead_details, and p2_pokemon_state.
    - Removes rows with all-zero stats if the Pokémon appears elsewhere with valid stats.
    - Removes duplicates across battles: only one row per Pokémon name.
    - Drops battle_id column.
    """
    rows = []

    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            battle = json.loads(line)

            # --- p1 team Pokémon ---
            for p in battle.get("p1_team_details", []):
                rows.append({
                    "name": p.get("name", "unknown"),
                    "base_hp": p.get("base_hp", 0),
                    "base_atk": p.get("base_atk", 0),
                    "base_def": p.get("base_def", 0),
                    "base_spa": p.get("base_spa", 0),
                    "base_spd": p.get("base_spd", 0),
                    "base_spe": p.get("base_spe", 0),
                    "type_1": p.get("types", "notype")[0],
                    "type_2": p.get("types", "notype")[1],
                    "lvl": p.get("level", 0),
                    "hp": compute_stat(p.get("base_hp", 0), p.get("level", 0), hp_bool=True),
                    "atk": compute_stat(p.get("base_atk", 0), p.get("level", 0)),
                    "def": compute_stat(p.get("base_def", 0), p.get("level", 0)),
                    "spa": compute_stat(p.get("base_spa", 0), p.get("level", 0)),
                    "spd": compute_stat(p.get("base_spd", 0), p.get("level", 0)),
                    "spe": compute_stat(p.get("base_spe", 0), p.get("level", 0)),
                })

            # --- p2 lead details ---
            lead_details = battle.get("p2_lead_details")
            if lead_details:
                rows.append({
                    "name": lead_details.get("name", "unknown"),
                    "base_hp": lead_details.get("base_hp", 0),
                    "base_atk": lead_details.get("base_atk", 0),
                    "base_def": lead_details.get("base_def", 0),
                    "base_spa": lead_details.get("base_spa", 0),
                    "base_spd": lead_details.get("base_spd", 0),
                    "base_spe": lead_details.get("base_spe", 0),
                    "type_1": lead_details.get("types", "notype")[0],
                    "type_2": lead_details.get("types", "notype")[1],
                    "lvl": lead_details.get("level", 0),
                    "hp": compute_stat(lead_details.get("base_hp", 0), lead_details.get("level", 0), hp_bool=True),
                    "atk": compute_stat(lead_details.get("base_atk", 0), lead_details.get("level", 0)),
                    "def": compute_stat(lead_details.get("base_def", 0), lead_details.get("level", 0)),
                    "spa": compute_stat(lead_details.get("base_spa", 0), lead_details.get("level", 0)),
                    "spd": compute_stat(lead_details.get("base_spd", 0), lead_details.get("level", 0)),
                    "spe": compute_stat(lead_details.get("base_spe", 0), lead_details.get("level", 0)),
                })

            # --- p2 team Pokémon from timeline (unique per battle) ---
            seen = set()
            for turn in battle.get("battle_timeline", []):
                p2 = turn.get("p2_pokemon_state")
                if not p2:
                    continue
                name = p2.get("name", "unknown")
                if name in seen:
                    continue
                seen.add(name)
                rows.append({
                    "name": name,
                    "base_hp": p2.get("base_hp", 0),
                    "base_atk": p2.get("base_atk", 0),
                    "base_def": p2.get("base_def", 0),
                    "base_spa": p2.get("base_spa", 0),
                    "base_spd": p2.get("base_spd", 0),
                    "base_spe": p2.get("base_spe", 0),
                    "type_1": p2.get("types", "notype")[0],
                    "type_2": p2.get("types", "notype")[1],
                    "lvl": p2.get("level", 0),
                    "hp": compute_stat(p.get("base_hp", 0), p.get("level", 0), hp_bool=True),
                    "atk": compute_stat(p.get("base_atk", 0), p.get("level", 0)),
                    "def": compute_stat(p.get("base_def", 0), p.get("level", 0)),
                    "spa": compute_stat(p.get("base_spa", 0), p.get("level", 0)),
                    "spd": compute_stat(p.get("base_spd", 0), p.get("level", 0)),
                    "spe": compute_stat(p.get("base_spe", 0), p.get("level", 0)),
                })

    df = pd.DataFrame(rows)
    

    # --- Remove zero-stat rows if name appears elsewhere with valid stats ---
    stat_cols = ["base_hp", "base_atk", "base_def", "base_spa", "base_spd", "base_spe", "lvl"]
    zero_mask = (df[stat_cols] == 0).all(axis=1)
    valid_names = set(df.loc[~zero_mask, "name"])
    df = df.loc[~(zero_mask & df["name"].isin(valid_names))]

    # --- Drop duplicates: keep only one row per Pokémon name with level 100 ---
    df = df.sort_values(by=["name", "lvl"], ascending=[True, False])
    df = df.drop_duplicates(subset=["name"], keep="first")
    df = df.reset_index(drop=True)

    return df


In [6]:
pokemon_df_train = extract_unique_pokemon_no_ids(train_file_path)

In [7]:
print("\n All Pokémon entries:")
display(pokemon_df_train)


 All Pokémon entries:


Unnamed: 0,name,base_hp,base_atk,base_def,base_spa,base_spd,base_spe,type_1,type_2,lvl,hp,atk,def,spa,spd,spe
0,alakazam,55,50,45,135,135,120,notype,psychic,100,313,198,188,368,368,338
1,articuno,90,85,100,125,125,85,flying,ice,100,383,268,298,348,348,268
2,chansey,250,5,5,105,105,50,normal,notype,100,703,108,108,308,308,198
3,charizard,78,84,78,85,85,100,fire,flying,100,359,266,254,268,268,298
4,cloyster,50,95,180,85,85,70,ice,water,100,303,288,458,268,268,238
5,dragonite,91,134,95,100,100,80,dragon,flying,100,385,366,288,298,298,258
6,exeggutor,95,95,85,125,125,55,grass,psychic,100,393,288,268,348,348,208
7,gengar,60,65,60,130,130,110,ghost,poison,100,323,228,218,358,358,318
8,golem,80,110,130,55,55,45,ground,rock,100,363,318,358,208,208,188
9,jolteon,65,65,60,110,110,130,electric,notype,100,333,228,218,318,318,358


In [8]:
import json
import pandas as pd

def extract_pokemon_in_play(jsonl_path: str) -> pd.DataFrame:
    """
    Extracts the list of Pokémon seen in each battle for both players.
    """
    rows = []

    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            battle = json.loads(line)
            battle_id = battle.get("battle_id", "unknown")

            # --- p1 seen Pokémon ---
            p1_seen = []
            p1_seen_set = set()
            for turn in battle.get("battle_timeline", []):
                p1_state = turn.get("p1_pokemon_state")
                if not p1_state:
                    continue
                name = p1_state.get("name", "unknown")
                if name not in p1_seen_set:
                    p1_seen_set.add(name)
                    p1_seen.append(name)

            # --- p2 seen Pokémon ---
            p2_seen = []
            p2_seen_set = set()
            for turn in battle.get("battle_timeline", []):
                p2_state = turn.get("p2_pokemon_state")
                if not p2_state:
                    continue
                name = p2_state.get("name", "unknown")
                if name not in p2_seen_set:
                    p2_seen_set.add(name)
                    p2_seen.append(name)

            # --- Add p1 row ---
            p1_row = {
                "battle_id": battle_id,
                "player": "p1",
                "num_seen_pokemon": len(p1_seen)
            }
            for i in range(6):
                p1_row[f"pokemon_{i+1}"] = p1_seen[i] if i < len(p1_seen) else ""
            rows.append(p1_row)

            # --- Add p2 row ---
            p2_row = {
                "battle_id": battle_id,
                "player": "p2",
                "num_seen_pokemon": len(p2_seen)
            }
            for i in range(6):
                p2_row[f"pokemon_{i+1}"] = p2_seen[i] if i < len(p2_seen) else ""
            rows.append(p2_row)

    return pd.DataFrame(rows)


In [9]:
seen_pokemons = extract_pokemon_in_play(train_file_path)
display(seen_pokemons.head())

Unnamed: 0,battle_id,player,num_seen_pokemon,pokemon_1,pokemon_2,pokemon_3,pokemon_4,pokemon_5,pokemon_6
0,0,p1,4,starmie,exeggutor,chansey,snorlax,,
1,0,p2,4,exeggutor,starmie,snorlax,chansey,,
2,1,p1,6,jynx,snorlax,exeggutor,tauros,chansey,slowbro
3,1,p2,6,alakazam,chansey,snorlax,exeggutor,starmie,tauros
4,2,p1,3,exeggutor,snorlax,chansey,,,


In [10]:
#create a function that, for every turn, checks the 'status' field of p1 and p2 pokemon_state and extracts any status conditions (like 'par', 'slp', etc.) that are present.
def extract_triggered_statuses(jsonl_path: str) -> pd.DataFrame:
    """
    Extracts status conditions triggered during battles from the dataset.
    - For each turn, checks the 'status' field of p1 and p2 pokemon_state.
    - Records and extracts any status conditions.
    """
    data = set()

    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            battle = json.loads(line)
            timeline = battle.get("battle_timeline", [])

            for turn in timeline:
                for player_key in ["p1", "p2"]:
                    pokemon_state = turn.get(f"{player_key}_pokemon_state", {})
                    status = pokemon_state.get("status")

                    if status is not None:
                        data.add(status)

    return list(data)


In [11]:
extract_triggered_statuses(train_file_path)

['psn', 'nostatus', 'par', 'fnt', 'frz', 'tox', 'slp', 'brn']

In [12]:
def extract_triggered_effects(jsonl_path: str) -> pd.DataFrame:
    """
    Extracts unique triggered effects during battles from the dataset.
    - For each turn, checks the 'effects' field of p1 and p2 pokemon_state.
    - Records unique effect names.
    """
    data = set()

    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            battle = json.loads(line)
            timeline = battle.get("battle_timeline", [])

            for turn in timeline:
                for player_key in ["p1", "p2"]:
                    pokemon_state = turn.get(f"{player_key}_pokemon_state", {})
                    effects_triggered = pokemon_state.get("effects", [])

                    for effect in effects_triggered:
                        data.add(effect)

    return list(data)

In [13]:
extract_triggered_effects(train_file_path)

['noeffect',
 'substitute',
 'clamp',
 'firespin',
 'confusion',
 'wrap',
 'typechange',
 'reflect']

In [14]:
import math


def make_moves_df(jsonl_path: str, pokemon_df: pd.DataFrame, typechart: pd.DataFrame, verbose: bool) -> pd.DataFrame:
    import json
    import pandas as pd

    move_rows = []

    #Boosts and relative multipliers available for atk, def, spa, spd, spe
    boost_multipliers = {
        -6: 0.25,
        -5: 0.28,
        -4: 0.33,
        -3: 0.4,
        -2: 0.5,
        -1: 0.66,
        0: 1.0,
        1: 1.5,
        2: 2.0,
        3: 2.5,
        4: 3.0,
        5: 3.5,
        6: 4.0,
    }

    # Normalize Pokémon data
    pokemon_df_copy = pokemon_df.copy()
    pokemon_df_copy.columns = pokemon_df_copy.columns.str.lower().str.strip()

    if "name" in pokemon_df_copy.columns:
        pokemon_df_copy["name"] = pokemon_df_copy["name"].str.lower().str.strip()
        pokemon_df_copy.set_index("name", inplace=True)
    else:
        raise ValueError(f"'name' column missing. Available columns: {pokemon_df_copy.columns.tolist()}")

    # Parse battles
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            battle = json.loads(line)
            battle_id = battle.get("battle_id")
            timeline = battle.get("battle_timeline", [])

            for turn_index, turn in enumerate(timeline):
                turn_data = {}
                turn_moves = []

                for side in ["p1", "p2"]:
                    opponent = "p2" if side == "p1" else "p1"
                    move = turn.get(f"{side}_move_details")
                    if not move:
                        continue

                    atk_boosts = turn.get(f"{side}_pokemon_state").get("boosts", {
                    "atk": 0,
                    "def": 0,
                    "spa": 0,
                    "spd": 0,
                    "spe": 0
                })

                    def_boosts = turn.get(f"{opponent}_pokemon_state").get("boosts", {
                    "atk": 0,
                    "def": 0,
                    "spa": 0,
                    "spd": 0,
                    "spe": 0
                })

                    atk_pk_name = turn.get(f"{side}_pokemon_state").get("name").lower().strip()
                    def_pk_name = turn.get(f"{opponent}_pokemon_state").get("name").lower().strip()

                    atk_spe = pokemon_df_copy.loc[atk_pk_name, "spe"]
                    def_spe = pokemon_df_copy.loc[def_pk_name, "spe"]

                    def_hp_pct = turn.get(f"{opponent}_pokemon_state").get("hp_pct")

                    def_t1 = pokemon_df_copy.loc[def_pk_name, "type_1"]
                    def_t2 = pokemon_df_copy.loc[def_pk_name, "type_2"]
                    def_stat = pokemon_df_copy.loc[def_pk_name, "def"] if move.get("category").lower() == "physical" else pokemon_df_copy.loc[def_pk_name, "spd"]

                    atk_t1 = pokemon_df_copy.loc[atk_pk_name, "type_1"]
                    atk_t2 = pokemon_df_copy.loc[atk_pk_name, "type_2"]
                    atk_stat = pokemon_df_copy.loc[atk_pk_name, "atk"] if move.get("category").lower() == "physical" else pokemon_df_copy.loc[atk_pk_name, "spa"]

                    
                    if move.get("category").lower() == "physical":
                        def_stat = def_stat * boost_multipliers.get(def_boosts.get("def"), 1.0)
                        atk_stat = atk_stat * boost_multipliers.get(atk_boosts.get("atk"), 1.0)
                    elif move.get("category").lower() == "special":
                        def_stat = def_stat * boost_multipliers.get(def_boosts.get("spd"), 1.0)
                        atk_stat = atk_stat * boost_multipliers.get(atk_boosts.get("spa"), 1.0)

                    atk_spe = atk_spe * boost_multipliers.get(atk_boosts.get("spe"), 1.0)
                    def_spe = def_spe * boost_multipliers.get(def_boosts.get("spe"), 1.0)

                    if move.get("category").lower() == "physical" and move.get("name").lower() == "reflect":
                        def_stat *= 2
                    elif move.get("category").lower() == "special" and move.get("name").lower() == "lightscreen":
                        def_stat *= 2

                    # --- Step 2: Explosion/Selfdestruct halving ---
                    if move.get("name").lower() in ["explosion", "selfdestruct"]:
                        def_stat = max(1, def_stat // 2)

                    # --- Step 3: Clamp if A or D > 255 ---
                    if atk_stat > 255 or def_stat > 255:
                        atk_stat = math.floor(atk_stat / 4)
                        def_stat = max(1, math.floor(def_stat / 4))
                    
                    move_mul = (typechart.loc[move.get("type").lower(), def_t1] * 
                                typechart.loc[move.get("type").lower(), def_t2]) if move.get("category").lower() != "status" else 1.0

                    turn_data[side] = {
                        "priority": move.get("priority", 0),
                        "speed": atk_spe
                    }

                    turn_moves.append({
                        "battle_id": battle_id,
                        "turn": turn_index,
                        "attacker": side,
                        "atk_pk": atk_pk_name,
                        "atk_t1": atk_t1,
                        "atk_t2": atk_t2,
                        "name": move.get("name"),
                        "move_type": move.get("type").lower(),
                        "category": move.get("category").lower(),
                        "base_power": move.get("base_power"),
                        "accuracy": move.get("accuracy"),
                        "priority": move.get("priority"),
                        "defender": opponent,
                        "def_pk": def_pk_name,
                        "def_t1": def_t1,
                        "def_t2": def_t2,
                        "stab": 1 if move.get("type").lower() in [atk_t1, atk_t2] else 0,

                        "se_move": 1 if move_mul == 2.0 else 0,
                        "pe_move": 1 if move_mul == 0.5 else 0,
                        "ne_move": 1 if move_mul == 0.0 else 0,
                        "ko": 1 if def_hp_pct == 0.0 else 0,

                        # ['fnt', 'brn', 'par', 'psn', 'slp', 'frz', 'tox']
                        "def_nostatus": 1 if turn.get(f"{opponent}_pokemon_state").get("status") == "nostatus" else 0,
                        "def_fnt": 1 if turn.get(f"{opponent}_pokemon_state").get("status") == "fnt" else 0,
                        "def_par": 1 if turn.get(f"{opponent}_pokemon_state").get("status") == "par" else 0,
                        "def_slp": 1 if turn.get(f"{opponent}_pokemon_state").get("status") == "slp" else 0,
                        "def_frz": 1 if turn.get(f"{opponent}_pokemon_state").get("status") == "frz" else 0,
                        "def_brn": 1 if turn.get(f"{opponent}_pokemon_state").get("status") == "brn" else 0,
                        "def_psn": 1 if turn.get(f"{opponent}_pokemon_state").get("status") == "psn" else 0,
                        "def_tox": 1 if turn.get(f"{opponent}_pokemon_state").get("status") == "tox" else 0,

                        "atk_nostatus": 1 if turn.get(f"{side}_pokemon_state").get("status") == "nostatus" else 0,
                        "atk_fnt": 1 if turn.get(f"{side}_pokemon_state").get("status") == "fnt" else 0,
                        "atk_par": 1 if turn.get(f"{side}_pokemon_state").get("status") == "par" else 0,
                        "atk_slp": 1 if turn.get(f"{side}_pokemon_state").get("status") == "slp" else 0,
                        "atk_frz": 1 if turn.get(f"{side}_pokemon_state").get("status") == "frz" else 0,
                        "atk_brn": 1 if turn.get(f"{side}_pokemon_state").get("status") == "brn" else 0,
                        "atk_psn": 1 if turn.get(f"{side}_pokemon_state").get("status") == "psn" else 0,
                        "atk_tox": 1 if turn.get(f"{side}_pokemon_state").get("status") == "tox" else 0,

                        "atk_advantage": 1 if ((
                            ((typechart.loc[atk_t1, def_t1] * typechart.loc[atk_t1, def_t2] >= 2.0) or (typechart.loc[atk_t2, def_t1] * typechart.loc[atk_t2, def_t2] >= 2.0)) or
                            ((0.0 <= typechart.loc[def_t1, atk_t1] * typechart.loc[def_t1, atk_t2] <= 0.5) or 
                             (0.0 <= typechart.loc[def_t2, atk_t1] * typechart.loc[def_t2, atk_t2] <= 0.5))) and move.get("category").lower() != "status"
                        ) else 0,
                        

                        # ['clamp', 'typechange', 'confusion', 'wrap', 'substitute', 'firespin', 'noeffect', 'reflect']
                        "atk_clamp": 1 if any(x in ["clamp"] for x in turn.get(f"{side}_pokemon_state").get("effects", [])) else 0,
                        "atk_tc": 1 if any(x in ["typechange"] for x in turn.get(f"{side}_pokemon_state").get("effects", [])) else 0,
                        "atk_confusion": 1 if any(x in ["confusion"] for x in turn.get(f"{side}_pokemon_state").get("effects", [])) else 0,
                        "atk_wrap": 1 if any(x in ["wrap"] for x in turn.get(f"{side}_pokemon_state").get("effects", [])) else 0,
                        "atk_reflect": 1 if any(x in ["reflect"] for x in turn.get(f"{side}_pokemon_state").get("effects", [])) else 0,
                        "atk_substitute": 1 if any(x in ["substitute"] for x in turn.get(f"{side}_pokemon_state").get("effects", [])) else 0,
                        "atk_firespin": 1 if any(x in ["firespin"] for x in turn.get(f"{side}_pokemon_state").get("effects", [])) else 0,
                        "atk_noeffect": 1 if any(x in ["noeffect"] for x in turn.get(f"{side}_pokemon_state").get("effects", [])) else 0,

                        "def_clamp": 1 if any(x in ["clamp"] for x in turn.get(f"{opponent}_pokemon_state").get("effects", [])) else 0,
                        "def_tc": 1 if any(x in ["typechange"] for x in turn.get(f"{opponent}_pokemon_state").get("effects", [])) else 0,
                        "def_confusion": 1 if any(x in ["confusion"] for x in turn.get(f"{opponent}_pokemon_state").get("effects", [])) else 0,
                        "def_wrap": 1 if any(x in ["wrap"] for x in turn.get(f"{opponent}_pokemon_state").get("effects", [])) else 0,
                        "def_reflect": 1 if any(x in ["reflect"] for x in turn.get(f"{opponent}_pokemon_state").get("effects", [])) else 0,
                        "def_substitute": 1 if any(x in ["substitute"] for x in turn.get(f"{opponent}_pokemon_state").get("effects", [])) else 0,
                        "def_firespin": 1 if any(x in ["firespin"] for x in turn.get(f"{opponent}_pokemon_state").get("effects", [])) else 0,
                        "def_noeffect": 1 if any(x in ["noeffect"] for x in turn.get(f"{opponent}_pokemon_state").get("effects", [])) else 0,
                        
                        f"boosted_{side}_atk": atk_boosts.get("atk", 0),
                        f"boosted_{side}_def": atk_boosts.get("def", 0),
                        f"boosted_{side}_spa": atk_boosts.get("spa", 0),
                        f"boosted_{side}_spd": atk_boosts.get("spd", 0),
                        f"boosted_{side}_spe": atk_boosts.get("spe", 0),
                        f"boosted_{opponent}_atk": def_boosts.get("atk", 0),
                        f"boosted_{opponent}_def": def_boosts.get("def", 0),
                        f"boosted_{opponent}_spa": def_boosts.get("spa", 0),
                        f"boosted_{opponent}_spd": def_boosts.get("spd", 0),
                        f"boosted_{opponent}_spe": def_boosts.get("spe", 0),
                        "atk_hp_pct": turn.get(f"{side}_pokemon_state").get("hp_pct"),
                        "def_hp_pct": turn.get(f"{opponent}_pokemon_state").get("hp_pct"),
                    })                                                     

                # Decide who attacks first
                p1 = turn_data.get("p1", {"priority": 0, "speed": 0})
                p2 = turn_data.get("p2", {"priority": 0, "speed": 0})

                if p1["priority"] > p2["priority"]:
                    first = "p1"
                elif p2["priority"] > p1["priority"]:
                    first = "p2"
                else:
                    first = "p1" if p1["speed"] > p2["speed"] else "p2" if p2["speed"] > p1["speed"] else "tie"

                # Assign first attacker to all moves in this turn
                for row in turn_moves:
                    row["first_attacker"] = first
                    move_rows.append(row)

    # Create DataFrame
    moves_df = pd.DataFrame(move_rows)
    moves_df["name"] = moves_df["name"].str.lower().str.strip()

    # --- Check for duplicates and NaN values ---

    if verbose:
        print("Checking for duplicates and NaN values...")

        if moves_df.columns.duplicated().any():
            print("Duplicate columns found:")
            print(moves_df[moves_df.columns[moves_df.columns.duplicated(keep=False)]])

        if moves_df.isnull().values.any():
            print("NaN values found:")
            print(moves_df[moves_df.isnull().any(axis=1)])

    return moves_df


In [15]:
moves_df_train = make_moves_df(train_file_path, pokemon_df_train, df_typechart, verbose=True)

Checking for duplicates and NaN values...


In [16]:
display(moves_df_train[[def_col for def_col in moves_df_train.columns if def_col.startswith("atk_")]].head())

Unnamed: 0,atk_pk,atk_t1,atk_t2,atk_nostatus,atk_fnt,atk_par,atk_slp,atk_frz,atk_brn,atk_psn,...,atk_advantage,atk_clamp,atk_tc,atk_confusion,atk_wrap,atk_reflect,atk_substitute,atk_firespin,atk_noeffect,atk_hp_pct
0,starmie,psychic,water,1,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,1,1.0
1,exeggutor,grass,psychic,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0.221374
2,starmie,psychic,water,1,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,1,1.0
3,starmie,psychic,water,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,1.0
4,chansey,normal,notype,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0.876245


In [17]:
def compute_actual_damage(moves_df: pd.DataFrame, pokemon_df: pd.DataFrame) -> pd.DataFrame:
    """
    Computes actual damage dealt for each move by comparing defender's HP before and after the move.
    Tracks switch-ins and fills in first_attacker using priority and speed comparisons.
    Uses 'spe' from pokemon_df for speed.
    Adds:
        - hp_pct_drop: actual drop in HP percentage from previous turn
        - p1_switched: 1 if p1 switched in this turn
        - p2_switched: 1 if p2 switched in this turn
        - first_attacker: player who attacked first in this turn
    """
    moves_df_c = moves_df.copy()
    moves_df_c["hp_pct_drop"] = 0.0
    moves_df_c["p1_switched"] = 0
    moves_df_c["p2_switched"] = 0
    moves_df_c["first_attacker"] = ""

    moves_df_c.sort_values(by=["battle_id", "turn"], inplace=True)
    last_hp = {}

    for (battle_id, turn), turn_df in moves_df_c.groupby(["battle_id", "turn"]):
        p1_switched = 0
        p2_switched = 0
        first_attacker = ""

        if len(turn_df) == 1:
            attacker = turn_df.iloc[0]["attacker"]
            switched_player = "p2" if attacker == "p1" else "p1"
            if switched_player == "p1":
                p1_switched = 1
            else:
                p2_switched = 1
            first_attacker = attacker
        else:
            # Both players attacked — compare priority first
            p1_row = turn_df[turn_df["attacker"] == "p1"].iloc[0]
            p2_row = turn_df[turn_df["attacker"] == "p2"].iloc[0]

            p1_priority = p1_row["priority"]
            p2_priority = p2_row["priority"]

            if p1_priority > p2_priority:
                first_attacker = "p1"
            elif p2_priority > p1_priority:
                first_attacker = "p2"
            else:
                # Equal priority — compare speed using 'spe'
                p1_speed = pokemon_df.loc[
                    (pokemon_df["name"] == p1_row["atk_pk"]),
                    "spe"
                ].values
                p2_speed = pokemon_df.loc[
                    (pokemon_df["name"] == p2_row["atk_pk"]),
                    "spe"
                ].values

                p1_speed = p1_speed[0] if len(p1_speed) > 0 else 0
                p2_speed = p2_speed[0] if len(p2_speed) > 0 else 0

                if p1_speed > p2_speed:
                    first_attacker = "p1"
                elif p2_speed > p1_speed:
                    first_attacker = "p2"
                else:
                    first_attacker = "p1"  # default fallback

        for idx, row in turn_df.iterrows():
            player_key = (row["battle_id"], row["defender"], row["def_pk"])
            hp_now = row["def_hp_pct"]
            hp_prev = last_hp.get(player_key, 1.0)
            damage = max(0.0, hp_prev - hp_now)

            moves_df_c.at[idx, "hp_pct_drop"] = damage
            moves_df_c.at[idx, "p1_switched"] = p1_switched
            moves_df_c.at[idx, "p2_switched"] = p2_switched
            moves_df_c.at[idx, "first_attacker"] = first_attacker

            last_hp[player_key] = hp_now

    # Final check for missing values
    if moves_df_c.isnull().values.any():
        print("NaN values found:")
        print(moves_df_c[moves_df_c.isnull().any(axis=1)])

    if moves_df_c.duplicated().any():
        print("Duplicate rows found:")
        print(moves_df_c[moves_df_c.duplicated(keep=False)])

    return moves_df_c


In [18]:
moves_df_train = compute_actual_damage(moves_df_train, pokemon_df_train)

## 3. Feature engineering (finally)

In [19]:
def compute_category_counts(moves_df: pd.DataFrame) -> pd.DataFrame:
    counts = (
        moves_df.groupby(["battle_id", "attacker", "category"])
        .size()
        .unstack(fill_value=0)
        .rename(columns={
            "status": "num_status_moves",
            "physical": "num_physical_moves",
            "special": "num_special_moves"
        })
        .reset_index()
    )

    pivoted = (
        counts.pivot(index="battle_id", columns="attacker")
        .sort_index(axis=1)
        .reset_index()
    )

    # Fill missing values with 0 (it means no moves of that category were used)
    pivoted = pivoted.fillna(0)

    pivot_cols = pivoted.columns.drop("battle_id")
    pivoted.columns = ["battle_id"] + [f"{player}_{feature}" for feature, player in pivot_cols]

    # Add null moves as 30 - (phys + special + status)
    pivoted["p1_null_moves"] = (
        30 - (
            pivoted.get("p1_num_physical_moves", 0) +
            pivoted.get("p1_num_special_moves", 0) +
            pivoted.get("p1_num_status_moves", 0)
        )
    )
    pivoted["p2_null_moves"] = (
        30 - (
            pivoted.get("p2_num_physical_moves", 0) +
            pivoted.get("p2_num_special_moves", 0) +
            pivoted.get("p2_num_status_moves", 0)
        )
    )

    # Check nan values
    if pivoted.isnull().values.any():
        print("NaN values found in category counts:")
        print(pivoted[pivoted.isnull().any(axis=1)])

    return pivoted


In [20]:
def compute_boosts(moves_df: pd.DataFrame) -> pd.DataFrame:
    df = moves_df.copy()
    df["player"] = df["attacker"]

    for stat in ["atk", "def", "spa", "spd", "spe"]:
        df[f"boosted_{stat}"] = df.apply(
            lambda row: row[f"boosted_p1_{stat}"] if row["player"] == "p1" else row[f"boosted_p2_{stat}"],
            axis=1
        )

    # Group and sum boosts per player
    boosts_summary = (
        df.groupby(["battle_id", "player"])[[f"boosted_{stat}" for stat in ["atk", "def", "spa", "spd", "spe"]]]
        .sum()
        .reset_index()
        .rename(columns={f"boosted_{stat}": f"total_boosts_{stat}" for stat in ["atk", "def", "spa", "spd", "spe"]})
    )

    pivoted = (
        boosts_summary.pivot(index="battle_id", columns="player")
        .sort_index(axis=1)
        .reset_index()
    )

    pivoted = pivoted.fillna(0)

    pivoted.columns = ["battle_id"] + [f"{player}_{feature}" for feature, player in pivoted.columns[1:]]

    # Check nan values
    if pivoted.isnull().values.any():
        print("NaN values found in boosts summary:")
        print(pivoted[pivoted.isnull().any(axis=1)])

    return pivoted


In [21]:
def compute_first_attacker_counts(moves_df: pd.DataFrame) -> pd.DataFrame:
    df = moves_df.copy()
    df["player"] = df["attacker"]

    first_attacks = (
        df[df["first_attacker"] == df["player"]]
        .groupby(["battle_id", "player"])
        .size()
        .reset_index()
        .rename(columns={0: "num_first_attacks"})
    )

    pivoted = (
        first_attacks.pivot(index="battle_id", columns="player", values="num_first_attacks")
        .reset_index()
        .fillna(0)
        .rename(columns={
            "p1": "p1_num_first_attacks",
            "p2": "p2_num_first_attacks"
        })
    )

    pivoted = pivoted.fillna(0) # Fill missing values with 0, would mean no (first) attacks

    # Check nan values
    if pivoted.isnull().values.any():
        print("NaN values found in first attacker counts:")
        print(pivoted[pivoted.isnull().any(axis=1)])

    return pivoted


In [22]:
def compute_hit_miss_stats(moves_df: pd.DataFrame) -> pd.DataFrame:
    df = moves_df.copy()
    df["player"] = df["attacker"]

    # --- Hits: hp_pct_drop > 0 ---
    hits = (
        df[df["hp_pct_drop"] > 0]
        .groupby(["battle_id", "player"])
        .size()
        .reset_index()
        .rename(columns={0: "num_hits"})
    )

    # --- Misses: hp_pct_drop == 0 and category is physical/special ---
    misses = (
        df[(df["hp_pct_drop"] == 0) & (df["category"].isin(["physical", "special"]))]
        .groupby(["battle_id", "player"])
        .size()
        .reset_index()
        .rename(columns={0: "num_misses"})
    )

    # --- Accurate misses: accuracy == 1.0 and still missed ---
    acc_misses = (
        df[(df["hp_pct_drop"] == 0) & (df["accuracy"] == 1.0) & (df["category"].isin(["physical", "special"]))]
        .groupby(["battle_id", "player"])
        .size()
        .reset_index()
        .rename(columns={0: "num_misses_on_accurate_moves"})
    )

    merged = (
        hits
        .merge(misses, on=["battle_id", "player"], how="outer")
        .merge(acc_misses, on=["battle_id", "player"], how="outer")
        .fillna(0)
    )

    pivoted = (
        merged.pivot(index="battle_id", columns="player")
        .sort_index(axis=1)
        .reset_index()
    )

    pivoted = pivoted.fillna(0) # Fill missing values with 0, would mean no hits/misses (screening for battles where a player had no moves)

    pivoted.columns = ["battle_id"] + [f"{player}_{feature}" for feature, player in pivoted.columns[1:]]

    # Check NaN values
    if pivoted.isnull().values.any():
        print("NaN values found in hit/miss stats:")
        display(pivoted[pivoted.isnull().any(axis=1)])

    return pivoted


In [23]:
def compute_in_play_mean_var(pokemon_df: pd.DataFrame, teams_df: pd.DataFrame) -> pd.DataFrame:
    stat_cols = ["hp", "atk", "def", "spa", "spd", "spe"]

    long_teams = teams_df.melt(
        id_vars=["battle_id", "player", "num_seen_pokemon"],
        value_vars=[f"pokemon_{i+1}" for i in range(6)],
        var_name="slot",
        value_name="name"
    )
    long_teams["name"] = long_teams["name"].str.lower().str.strip()
    long_teams = long_teams[long_teams["name"] != ""]

    # Ensure 'name' is the index in pokemon_df
    if "name" in pokemon_df.columns:
        pokemon_df["name"] = pokemon_df["name"].astype(str).str.lower().str.strip()
        pokemon_df = pokemon_df.set_index("name")

    # Normalize team names
    long_teams["name"] = long_teams["name"].astype(str).str.lower().str.strip()

    long_teams = long_teams.merge(pokemon_df[stat_cols], left_on="name", right_index=True, how="left")
    long_teams[stat_cols] = long_teams[stat_cols].fillna(0)

    # Compute mean and variance per player per battle
    mean_df = (
        long_teams.groupby(["battle_id", "player"])[stat_cols]
        .mean()
        .reset_index()
        .rename(columns={col: f"mean_{col}" for col in stat_cols})
    )

    var_df = (
        long_teams.groupby(["battle_id", "player"])[stat_cols]
        .var()
        .reset_index()
        .rename(columns={col: f"var_{col}" for col in stat_cols})
    )

    merged = mean_df.merge(var_df, on=["battle_id", "player"], how="outer").fillna(0)

    pivoted = (
        merged.pivot(index="battle_id", columns="player")
        .sort_index(axis=1)
        .reset_index()
    )

    pivoted.columns = ["battle_id"] + [f"{player}_{feature}" for feature, player in pivoted.columns[1:]]

    # Check NaN values
    if pivoted.isnull().values.any():
        print("NaN values found in in-play mean/var stats:")
        display(pivoted[pivoted.isnull().any(axis=1)])

    return pivoted


In [24]:
def compute_effectiveness_counts(moves_df: pd.DataFrame) -> pd.DataFrame:
    df = moves_df.copy()
    df["player"] = df["attacker"]

    df["super_effective"] = df["se_move"]
    df["not_effective"] = df["pe_move"]
    df["neutral"] = 1 - (df["se_move"] + df["pe_move"] + df["ne_move"])

    eff_counts = (
        df.groupby(["battle_id", "player"])[["super_effective", "neutral", "not_effective"]]
        .sum()
        .reset_index()
        .rename(columns={
            "super_effective": "num_super_effective",
            "neutral": "num_neutral",
            "not_effective": "num_not_effective"
        })
    )

    pivoted = (
        eff_counts.pivot(index="battle_id", columns="player")
        .sort_index(axis=1)
        .reset_index()
    )

    pivoted = pivoted.fillna(0) # Fill missing values with 0, would mean no moves from any player

    pivoted.columns = ["battle_id"] + [f"{player}_{feature}" for feature, player in pivoted.columns[1:]]

    # Check NaN values
    if pivoted.isnull().values.any():
        print("NaN values found in effectiveness counts:")
        print(pivoted[pivoted.isnull().any(axis=1)])

    return pivoted


In [25]:
def compute_switch_and_regen_counts(moves_df: pd.DataFrame) -> pd.DataFrame:
    df = moves_df.copy()
    df["player"] = df["attacker"]

    # Assign switch flag from p1_switched / p2_switched
    df["is_switch"] = df.apply(
        lambda row: row["p1_switched"] if row["player"] == "p1" else row["p2_switched"],
        axis=1
    )

    # Sort and track previous HP percentage
    df = df.sort_values(by=["battle_id", "turn"])
    df["prev_atk_hp_pct"] = df.groupby(["battle_id", "player"])["atk_hp_pct"].shift(1)

    # Regeneration = HP increased without switching
    df["has_regenerated"] = (
        (df["atk_hp_pct"] > df["prev_atk_hp_pct"]) &
        (df["prev_atk_hp_pct"].notna()) &
        (~df["is_switch"])
    )

    switch_counts = (
        df.groupby("battle_id")[["p1_switched", "p2_switched"]]
        .sum()
        .reset_index()
        .rename(columns={
            "p1_switched": "p1_num_switches",
            "p2_switched": "p2_num_switches"
        })
    )

    regen_counts = (
        df.groupby(["battle_id", "player"])["has_regenerated"]
        .sum()
        .reset_index()
        .rename(columns={"has_regenerated": "num_regenerations"})
    )

    regen_pivot = (
        regen_counts.pivot(index="battle_id", columns="player", values="num_regenerations")
        .reset_index()
        .fillna(0)
        .rename(columns={
            "p1": "p1_num_regenerations",
            "p2": "p2_num_regenerations"
        })
    )

    final = switch_counts.merge(regen_pivot, on="battle_id", how="outer")

    # check NaN values
    if final.isnull().values.any():
        print("NaN values found in switch and regeneration counts:")
        print(final[final.isnull().any(axis=1)])

    return final

In [26]:
def compute_status_infliction_counts(moves_df: pd.DataFrame) -> pd.DataFrame:
    df = moves_df.copy()
    df["player"] = df["attacker"]

    status_types = ["frz", "par", "slp", "brn", "psn", "fnt", "tox", "nostatus"]
    status_frames = []

    for status in status_types:
        status_counts = (
            df.groupby(["battle_id", "player"])[f"def_{status}"]
            .sum()
            .reset_index()
            .rename(columns={f"def_{status}": f"num_{status}"})
        )
        status_frames.append(status_counts)

    merged = status_frames[0]
    for frame in status_frames[1:]:
        merged = merged.merge(frame, on=["battle_id", "player"], how="outer")

    pivoted = (
        merged.pivot(index="battle_id", columns="player")
        .sort_index(axis=1)
        .reset_index()
    )

    pivoted = pivoted.fillna(0) # Fill missing values with 0, would mean no status infllicted or all null moves

    pivoted.columns = ["battle_id"] + [f"{player}_{feature}" for feature, player in pivoted.columns[1:]]

    # Check NaN values
    if pivoted.isnull().values.any():
        print("NaN values found in status infliction counts:")
        print(pivoted[pivoted.isnull().any(axis=1)])

    return pivoted

In [27]:
def compute_effect_application_counts(moves_df: pd.DataFrame) -> pd.DataFrame:
    df = moves_df.copy()
    df["player"] = df["attacker"]

    inflicted_effects = ["clamp", "tc", "confusion", "wrap", "firespin"]
    self_effects = ["reflect", "substitute", "noeffect"]

    effect_frames = []

    # --- Inflicted effects (have effect on opponent) ---
    for effect in inflicted_effects:
        effect_counts = (
            df.groupby(["battle_id", "player"])[f"def_{effect}"]
            .sum()
            .reset_index()
            .rename(columns={f"def_{effect}": f"num_{effect}_inflicted"})
        )
        effect_frames.append(effect_counts)

    # --- Self effects (have effect on attacker) ---
    for effect in self_effects:
        effect_counts = (
            df.groupby(["battle_id", "player"])[f"atk_{effect}"]
            .sum()
            .reset_index()
            .rename(columns={f"atk_{effect}": f"num_{effect}_applied"})
        )
        effect_frames.append(effect_counts)

    merged = effect_frames[0]
    for frame in effect_frames[1:]:
        merged = merged.merge(frame, on=["battle_id", "player"], how="outer")

    pivoted = (
        merged.pivot(index="battle_id", columns="player")
        .sort_index(axis=1)
        .reset_index()
    )

    pivoted = pivoted.fillna(0) # Fill missing values with 0, would mean no move that applied effect

    pivoted.columns = ["battle_id"] + [f"{player}_{feature}" for feature, player in pivoted.columns[1:]]

    # Check NaN values
    if pivoted.isnull().values.any():
        print("NaN values found in effect application counts:")
        display(pivoted[pivoted.isnull().any(axis=1)])

    return pivoted

In [28]:
def compute_ko_advantage_stab_counts(moves_df: pd.DataFrame) -> pd.DataFrame:
    df = moves_df.copy()
    df["player"] = df["attacker"]

    # --- Total KOs inflicted ---
    ko_df = (
        df.groupby(["battle_id", "player"])["ko"]
        .sum()
        .reset_index()
        .rename(columns={"ko": "ko_count"})
    )

    # --- Type advantage moves used ---
    adv_df = (
        df.groupby(["battle_id", "player"])["atk_advantage"]
        .sum()
        .reset_index()
        .rename(columns={"atk_advantage": "num_advantage_moves"})
    )

    # --- STAB moves used ---
    stab_df = (
        df.groupby(["battle_id", "player"])["stab"]
        .sum()
        .reset_index()
        .rename(columns={"stab": "num_stab_moves"})
    )

    merged = (
        ko_df
        .merge(adv_df, on=["battle_id", "player"], how="outer")
        .merge(stab_df, on=["battle_id", "player"], how="outer")
        .fillna(0)
    )

    pivoted = (
        merged.pivot(index="battle_id", columns="player")
        .sort_index(axis=1)
        .reset_index()
    )

    pivoted = pivoted.fillna(0) # Fill missing values with 0, would mean no KOs, advantage moves, or STAB moves (because no moves used)

    pivoted.columns = ["battle_id"] + [f"{player}_{feature}" for feature, player in pivoted.columns[1:]]

    # Check NaN values
    if pivoted.isnull().values.any():
        print("NaN values found in KO, advantage, and STAB counts:")
        display(pivoted[pivoted.isnull().any(axis=1)])

    return pivoted

In [29]:
def add_stat_and_behavior_diffs(battle_df: pd.DataFrame) -> pd.DataFrame:
    df = battle_df.copy()

    # --- Mean stat differences ---
    df["mean_spe_diff"] = df["p1_mean_spe"] - df["p2_mean_spe"]
    df["mean_atk_diff"] = df["p1_mean_atk"] - df["p2_mean_atk"]
    df["mean_def_diff"] = df["p1_mean_def"] - df["p2_mean_def"]
    df["mean_spa_diff"] = df["p1_mean_spa"] - df["p2_mean_spa"]
    df["mean_spd_diff"] = df["p1_mean_spd"] - df["p2_mean_spd"]
    df["mean_hp_diff"] = df["p1_mean_hp"] - df["p2_mean_hp"]

    # --- Variance stat differences ---
    df["var_atk_diff"] = df["p1_var_atk"] - df["p2_var_atk"]
    df["var_def_diff"] = df["p1_var_def"] - df["p2_var_def"]
    df["var_spa_diff"] = df["p1_var_spa"] - df["p2_var_spa"]
    df["var_spd_diff"] = df["p1_var_spd"] - df["p2_var_spd"]
    df["var_spe_diff"] = df["p1_var_spe"] - df["p2_var_spe"]
    df["var_hp_diff"] = df["p1_var_hp"] - df["p2_var_hp"]

    # --- Behavioral differences ---
    df["diff_num_switches"] = df["p1_num_switches"] - df["p2_num_switches"]

    ## Check NaN values
    if df.isnull().values.any():
        print("NaN values found in stat and behavior differences:")
        display(df[df.isnull().any(axis=1)])

    return df

In [30]:
def generate_battle_features(jsonl_path: str, moves_df: pd.DataFrame, teams_df: pd.DataFrame, pokemon_df: pd.DataFrame) -> pd.DataFrame:
    # Call each feature function
    category_counts = compute_category_counts(moves_df)
    boosts_df = compute_boosts(moves_df)
    hits_misses_df = compute_hit_miss_stats(moves_df)
    stats_df = compute_in_play_mean_var(pokemon_df, teams_df)
    effectiveness_df = compute_effectiveness_counts(moves_df)
    switch_regen_df = compute_switch_and_regen_counts(moves_df)
    status_df = compute_status_infliction_counts(moves_df)
    effects_df = compute_effect_application_counts(moves_df)
    ko_adv_stab_df = compute_ko_advantage_stab_counts(moves_df)
    first_attacker_df = compute_first_attacker_counts(moves_df)

    # Merge all features
    battle_df = (
        category_counts
        .merge(boosts_df, on="battle_id", how="outer")
        .merge(hits_misses_df, on="battle_id", how="outer")
        .merge(stats_df, on="battle_id", how="outer")
        .merge(effectiveness_df, on="battle_id", how="outer")
        .merge(switch_regen_df, on="battle_id", how="outer")
        .merge(status_df, on="battle_id", how="outer")
        .merge(effects_df, on="battle_id", how="outer")
        .merge(ko_adv_stab_df, on="battle_id", how="outer")
        .merge(first_attacker_df, on="battle_id", how="outer")
    )

    #Add stat and behavior diffs
    battle_df = add_stat_and_behavior_diffs(battle_df)

    # Step 4: Load labels from JSONL
    labels = []
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            battle = json.loads(line)
            labels.append({
                "battle_id": battle.get("battle_id"),
                "player_won": battle.get("player_won", None)
            })
    labels_df = pd.DataFrame(labels)

    # Step 5: Merge labels
    battle_df = battle_df.merge(labels_df, on="battle_id", how="left")

    # Check NaN values
    if battle_df.isnull().values.any():
        print("NaN values found in final battle features:")
        display(battle_df[battle_df.isnull().any(axis=1)])

    # Check duplicates
    if battle_df.duplicated().any():
        print("Duplicate rows found in final battle features:")
        display(battle_df[battle_df.duplicated(keep=False)])

    return battle_df


train_df = generate_battle_features(train_file_path, moves_df_train, seen_pokemons, pokemon_df_train)

  pivot_cols = pivoted.columns.drop("battle_id")


In [31]:
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 180)
display(train_df.head())
print(train_df.columns.tolist())

# see dimensionality of final training dataframe
print(f"Training DataFrame shape: {train_df.shape}")


Unnamed: 0,battle_id,p1_num_physical_moves,p2_num_physical_moves,p1_num_special_moves,p2_num_special_moves,p1_num_status_moves,p2_num_status_moves,p1_null_moves,p2_null_moves,p1_total_boosts_atk,p2_total_boosts_atk,p1_total_boosts_def,p2_total_boosts_def,p1_total_boosts_spa,p2_total_boosts_spa,p1_total_boosts_spd,p2_total_boosts_spd,p1_total_boosts_spe,p2_total_boosts_spe,p1_num_hits,p2_num_hits,p1_num_misses,p2_num_misses,p1_num_misses_on_accurate_moves,p2_num_misses_on_accurate_moves,p1_mean_atk,p2_mean_atk,p1_mean_def,p2_mean_def,p1_mean_hp,p2_mean_hp,p1_mean_spa,p2_mean_spa,p1_mean_spd,p2_mean_spd,p1_mean_spe,p2_mean_spe,p1_var_atk,p2_var_atk,p1_var_def,p2_var_def,p1_var_hp,p2_var_hp,p1_var_spa,p2_var_spa,p1_var_spd,p2_var_spd,p1_var_spe,p2_var_spe,p1_num_neutral,p2_num_neutral,p1_num_not_effective,p2_num_not_effective,p1_num_super_effective,p2_num_super_effective,p1_num_switches,p2_num_switches,p1_num_regenerations,p2_num_regenerations,p1_num_brn,p2_num_brn,p1_num_fnt,p2_num_fnt,p1_num_frz,p2_num_frz,p1_num_nostatus,p2_num_nostatus,p1_num_par,p2_num_par,p1_num_psn,p2_num_psn,p1_num_slp,p2_num_slp,p1_num_tox,p2_num_tox,p1_num_clamp_inflicted,p2_num_clamp_inflicted,p1_num_confusion_inflicted,p2_num_confusion_inflicted,p1_num_firespin_inflicted,p2_num_firespin_inflicted,p1_num_noeffect_applied,p2_num_noeffect_applied,p1_num_reflect_applied,p2_num_reflect_applied,p1_num_substitute_applied,p2_num_substitute_applied,p1_num_tc_inflicted,p2_num_tc_inflicted,p1_num_wrap_inflicted,p2_num_wrap_inflicted,p1_ko_count,p2_ko_count,p1_num_advantage_moves,p2_num_advantage_moves,p1_num_stab_moves,p2_num_stab_moves,p1_num_first_attacks,p2_num_first_attacks,mean_spe_diff,mean_atk_diff,mean_def_diff,mean_spa_diff,mean_spd_diff,mean_hp_diff,var_atk_diff,var_def_diff,var_spa_diff,var_spd_diff,var_spe_diff,var_hp_diff,diff_num_switches,player_won
0,0,1.0,8.0,15.0,4.0,11.0,4.0,3.0,14.0,0.0,0.0,0.0,0.0,0.0,-1.0,0.0,-1.0,0.0,0.0,12.0,7.0,4.0,5.0,4.0,5.0,240.5,240.5,218.0,218.0,485.5,485.5,295.5,295.5,295.5,295.5,223.0,223.0,8625.0,8625.0,5733.333333,5733.333333,27891.666667,27891.666667,2491.666667,2491.666667,2491.666667,2491.666667,5366.666667,5366.666667,17.0,14.0,5.0,1.0,5.0,1.0,2,13,3.0,4.0,0.0,0.0,1.0,0.0,11.0,0.0,10.0,11.0,4.0,4.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,25.0,16.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,6.0,2.0,16.0,10.0,20.0,9.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-9.094947e-13,0.0,0.0,0.0,0.0,-11,True
1,1,10.0,19.0,11.0,1.0,2.0,3.0,7.0,7.0,0.0,0.0,0.0,0.0,0.0,-2.0,0.0,-2.0,0.0,0.0,13.0,13.0,8.0,7.0,4.0,6.0,243.0,243.0,229.666667,224.666667,449.666667,434.666667,278.0,298.0,278.0,298.0,221.333333,258.0,6190.0,6190.0,6256.666667,4546.666667,19786.666667,23096.666667,2080.0,3200.0,2080.0,3200.0,4506.666667,6200.0,16.0,19.0,5.0,2.0,2.0,2.0,6,6,5.0,4.0,0.0,0.0,0.0,2.0,0.0,0.0,18.0,15.0,2.0,4.0,0.0,0.0,3.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,23.0,23.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,6.0,6.0,20.0,14.0,18.0,11.0,-36.666667,0.0,5.0,-20.0,-20.0,15.0,0.0,1710.0,-1120.0,-1120.0,-1693.333333,-3310.0,0,True
2,2,6.0,0.0,5.0,12.0,16.0,10.0,3.0,8.0,0.0,0.0,0.0,0.0,5.0,0.0,5.0,0.0,0.0,0.0,8.0,10.0,5.0,4.0,5.0,3.0,238.0,208.0,201.333333,200.5,539.666667,423.0,294.666667,318.0,294.666667,318.0,188.0,293.0,12900.0,6200.0,6933.333333,5558.333333,24233.333333,35133.333333,3733.333333,3533.333333,3733.333333,3533.333333,700.0,4100.0,25.0,20.0,1.0,0.0,0.0,2.0,3,8,2.0,2.0,0.0,0.0,0.0,1.0,0.0,0.0,13.0,8.0,4.0,7.0,0.0,0.0,10.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,18.0,10.0,9.0,0.0,0.0,12.0,0.0,0.0,0.0,0.0,0.0,1.0,6.0,9.0,11.0,0.0,8.0,22.0,-105.0,30.0,0.833333,-23.333333,-23.333333,116.666667,6700.0,1375.0,200.0,200.0,-3400.0,-10900.0,-5,True
3,3,4.0,16.0,11.0,4.0,8.0,5.0,7.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.0,10.0,5.0,10.0,5.0,5.0,244.0,300.5,218.0,253.0,465.0,405.5,318.0,278.0,318.0,278.0,236.0,253.0,6830.0,291.666667,4300.0,900.0,23020.0,6291.666667,2900.0,3133.333333,2900.0,3133.333333,4720.0,5166.666667,19.0,23.0,1.0,0.0,3.0,2.0,5,3,5.0,3.0,0.0,0.0,0.0,1.0,0.0,0.0,18.0,8.0,0.0,16.0,0.0,0.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,23.0,21.0,0.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,3.0,4.0,9.0,11.0,9.0,19.0,-17.0,-56.5,-35.0,40.0,40.0,59.5,6538.333333,3400.0,-233.333333,-233.333333,-446.666667,16728.333333,2,True
4,4,13.0,2.0,6.0,11.0,7.0,13.0,4.0,4.0,0.0,0.0,0.0,0.0,-1.0,0.0,-1.0,0.0,0.0,0.0,14.0,11.0,5.0,3.0,5.0,2.0,242.0,252.0,216.0,232.0,457.0,459.0,298.0,284.0,298.0,284.0,244.0,242.0,7730.0,7130.0,5120.0,5280.0,25130.0,24430.0,4000.0,2530.0,4000.0,2530.0,6280.0,5830.0,15.0,26.0,9.0,0.0,2.0,0.0,4,4,8.0,9.0,0.0,0.0,0.0,1.0,0.0,0.0,7.0,18.0,18.0,6.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,26.0,24.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,9.0,5.0,7.0,6.0,23.0,7.0,2.0,-10.0,-16.0,14.0,14.0,-2.0,600.0,-160.0,1470.0,1470.0,450.0,700.0,0,True


['battle_id', 'p1_num_physical_moves', 'p2_num_physical_moves', 'p1_num_special_moves', 'p2_num_special_moves', 'p1_num_status_moves', 'p2_num_status_moves', 'p1_null_moves', 'p2_null_moves', 'p1_total_boosts_atk', 'p2_total_boosts_atk', 'p1_total_boosts_def', 'p2_total_boosts_def', 'p1_total_boosts_spa', 'p2_total_boosts_spa', 'p1_total_boosts_spd', 'p2_total_boosts_spd', 'p1_total_boosts_spe', 'p2_total_boosts_spe', 'p1_num_hits', 'p2_num_hits', 'p1_num_misses', 'p2_num_misses', 'p1_num_misses_on_accurate_moves', 'p2_num_misses_on_accurate_moves', 'p1_mean_atk', 'p2_mean_atk', 'p1_mean_def', 'p2_mean_def', 'p1_mean_hp', 'p2_mean_hp', 'p1_mean_spa', 'p2_mean_spa', 'p1_mean_spd', 'p2_mean_spd', 'p1_mean_spe', 'p2_mean_spe', 'p1_var_atk', 'p2_var_atk', 'p1_var_def', 'p2_var_def', 'p1_var_hp', 'p2_var_hp', 'p1_var_spa', 'p2_var_spa', 'p1_var_spd', 'p2_var_spd', 'p1_var_spe', 'p2_var_spe', 'p1_num_neutral', 'p2_num_neutral', 'p1_num_not_effective', 'p2_num_not_effective', 'p1_num_super_ef

## 4. Training model

In [32]:
from sklearn.experimental import enable_halving_search_cv

from sklearn.model_selection import train_test_split, cross_val_score, cross_validate, RandomizedSearchCV, GridSearchCV, StratifiedKFold, HalvingGridSearchCV
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from xgboost.sklearn import XGBClassifier
import numpy as np
from catboost import CatBoostClassifier
from sklearn.ensemble import StackingClassifier

In [33]:
X = train_df.drop(columns=["battle_id", "player_won"])
y = train_df["player_won"]

# Split
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

In [34]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

In [35]:
from sklearn.linear_model import LogisticRegression
from sklearn.experimental import enable_halving_search_cv  # must be imported before
from sklearn.model_selection import HalvingGridSearchCV

# Define parameter grid
logreg_param_grid = {
    "C": [0.01, 0.1, 1, 10],
    "penalty": ["l2"],
    "solver": ["liblinear", "lbfgs"]
}

# Model
logreg = LogisticRegression(random_state=42, max_iter=1000)

# Halving Grid Search
logreg_search = HalvingGridSearchCV(
    estimator=logreg,
    param_grid=logreg_param_grid,
    scoring="accuracy",
    cv=5,
    verbose=1,
    n_jobs=-1,
    random_state=42
)

from sklearn.ensemble import RandomForestClassifier

# Define parameter grid
rf_param_grid = {
    "n_estimators": [100, 200],
    "max_depth": [None, 5, 10],
    "min_samples_split": [2, 5],
    "min_samples_leaf": [1, 2],
    "max_features": ["sqrt", "log2"]
}

# Model
rf = RandomForestClassifier(random_state=42)

# Halving Grid Search
rf_search = HalvingGridSearchCV(
    estimator=rf,
    param_grid=rf_param_grid,
    scoring="accuracy",
    cv=5,
    verbose=0,
    n_jobs=-1,
    random_state=42
)


logreg_search.fit(X_train, y_train)
rf_search.fit(X_train, y_train)

print("Logistic Regression best score:", logreg_search.best_score_)
print("Best params:", logreg_search.best_params_)

print("Random Forest best score:", rf_search.best_score_)
print("Best params:", rf_search.best_params_)


n_iterations: 2
n_required_iterations: 2
n_possible_iterations: 2
min_resources_: 2666
max_resources_: 8000
aggressive_elimination: False
factor: 3
----------
iter: 0
n_candidates: 8
n_resources: 2666
Fitting 5 folds for each of 8 candidates, totalling 40 fits
----------
iter: 1
n_candidates: 3
n_resources: 7998
Fitting 5 folds for each of 3 candidates, totalling 15 fits
Logistic Regression best score: 0.8145090681676047
Best params: {'C': 0.01, 'penalty': 'l2', 'solver': 'liblinear'}
Random Forest best score: 0.8001251564455568
Best params: {'max_depth': 10, 'max_features': 'log2', 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 200}


In [36]:
# Define parameter distributions
param_grid = {
    "n_estimators": [100, 200],
    "max_depth": [3, 4, 5],
    "learning_rate": [0.05, 0.1],
    "subsample": [0.8, 1.0],
    "colsample_bytree": [0.8, 1.0],
    "min_child_weight": [1, 3],
    "gamma": [0, 0.1, 0.2],
    "reg_alpha": [0, 0.1, 0.5],
    "reg_lambda": [1, 1.5, 2]
}

# Set up the model
xgb = XGBClassifier(use_label_encoder=False, 
                    eval_metric="logloss", 
                    random_state=42, 
                    objective="binary:logistic"
                    )

random_search = HalvingGridSearchCV(
    estimator=xgb,
    param_grid=param_grid,
    scoring="accuracy",
    cv=5,
    verbose=0,
    n_jobs=-1,
    random_state=42
)

random_search.fit(X_train, y_train)

# Check train dataset dimensions
print(f"X_train shape: {X_train.shape}")
print(f"y_train shape: {y_train.shape}")

# 4. Results
print("Best parameters:", random_search.best_params_)
print("Best CV score:", random_search.best_score_)

Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)


X_train shape: (8000, 111)
y_train shape: (8000,)
Best parameters: {'colsample_bytree': 1.0, 'gamma': 0.2, 'learning_rate': 0.05, 'max_depth': 4, 'min_child_weight': 3, 'n_estimators': 200, 'reg_alpha': 0.5, 'reg_lambda': 1, 'subsample': 0.8}
Best CV score: 0.8084362139917696


In [37]:
best_xgb = random_search.best_estimator_
best_xgb.fit(X_train, y_train)
y_pred = best_xgb.predict(X_val)
print(classification_report(y_val, y_pred))

model=best_xgb

Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)


              precision    recall  f1-score   support

       False       0.82      0.84      0.83      1000
        True       0.83      0.81      0.82      1000

    accuracy                           0.82      2000
   macro avg       0.82      0.82      0.82      2000
weighted avg       0.82      0.82      0.82      2000



In [39]:
# --- StackingClassfier with XGB, RandomForest, LogisticRegression as final estimator ---
from catboost import CatBoostClassifier
from sklearn.ensemble import StackingClassifier
from sklearn.svm import SVC
from sklearn.ensemble import GradientBoostingClassifier

# --- Tune a GradientBoost model for final metamodel ---
gb_metamodel = GradientBoostingClassifier(
    n_estimators=300,
    learning_rate=0.05,
    max_depth=4,
    random_state=42
)
param_grid_gb = {
    "n_estimators": [200, 300, 400],
    "learning_rate": [0.01, 0.05, 0.1],
    "max_depth": [3, 4, 5],
    "subsample": [0.8, 1.0],
    "min_samples_split": [2, 5],
    "min_samples_leaf": [1, 2]
}
gb_search = HalvingGridSearchCV(
    estimator=gb_metamodel,
    param_grid=param_grid_gb,
    scoring="accuracy",
    cv=5,
    verbose=1,
    n_jobs=-1,
    random_state=42
)
gb_search.fit(X_train, y_train)
print("Gradient Boosting best score:", gb_search.best_score_)
print("Best params:", gb_search.best_params_)
gb_metamodel = gb_search.best_estimator_

# --- Define CatBoost model ---
catboost_model = CatBoostClassifier(
    iterations=500,
    learning_rate=0.05,
    depth=6,
    verbose=0,  # suppress training output
    loss_function='Logloss',
    random_seed=42
)
param_grid_cb = {
    "iterations": [300, 500, 700],
    "learning_rate": [0.01, 0.05, 0.1],
    "depth": [4, 6, 8]
}
cb_search = HalvingGridSearchCV(
    estimator=catboost_model,
    param_grid=param_grid_cb,
    scoring="accuracy",
    cv=5,
    verbose=1,
    n_jobs=-1,
    random_state=42
)
cb_search.fit(X_train, y_train)
print("CatBoost best score:", cb_search.best_score_)
print("Best params:", cb_search.best_params_)
catboost_model = cb_search.best_estimator_

n_iterations: 5
n_required_iterations: 5
n_possible_iterations: 5
min_resources_: 98
max_resources_: 8000
aggressive_elimination: False
factor: 3
----------
iter: 0
n_candidates: 216
n_resources: 98
Fitting 5 folds for each of 216 candidates, totalling 1080 fits
----------
iter: 1
n_candidates: 72
n_resources: 294
Fitting 5 folds for each of 72 candidates, totalling 360 fits
----------
iter: 2
n_candidates: 24
n_resources: 882
Fitting 5 folds for each of 24 candidates, totalling 120 fits
----------
iter: 3
n_candidates: 8
n_resources: 2646
Fitting 5 folds for each of 8 candidates, totalling 40 fits
----------
iter: 4
n_candidates: 3
n_resources: 7938
Fitting 5 folds for each of 3 candidates, totalling 15 fits
Gradient Boosting best score: 0.8189035916824198
Best params: {'learning_rate': 0.05, 'max_depth': 4, 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 400, 'subsample': 0.8}
n_iterations: 4
n_required_iterations: 4
n_possible_iterations: 4
min_resources_: 296
max_res

In [40]:
# --- Stacking Classifier with XGB, RandomForest, CatBoost as base estimators and LogisticRegression as final estimator ---
stacked_model2 = StackingClassifier(
    estimators=[
        ("xgb", best_xgb),
        ("rf", rf_search.best_estimator_),
        ("catboost", catboost_model),
    ],
    final_estimator=LogisticRegression(**logreg_search.best_params_),
    cv=5,
    passthrough=True,
    n_jobs=-1
)
stacked_model2.fit(X_train, y_train)

# Evaluate
y_pred = stacked_model2.predict(X_val)

print("Validation Accuracy with Logistic Regression as final estimator:", accuracy_score(y_val, y_pred))

# --- Same stacking classiier, but with tuned GradientBoosting as final estimator ---
stacked_model3 = StackingClassifier(
    estimators=[
        ("xgb", best_xgb),
        ("rf", rf_search.best_estimator_),
        ("catboost", catboost_model),
    ],
    final_estimator=gb_metamodel,
    cv=5,
    passthrough=True,
    n_jobs=-1
)
stacked_model3.fit(X_train, y_train)

# Evaluate
y_pred = stacked_model3.predict(X_val)

print("Validation Accuracy with Gradient Boosting as final estimator:", accuracy_score(y_val, y_pred))

Validation Accuracy with Logistic Regression as final estimator: 0.8215
Validation Accuracy with Gradient Boosting as final estimator: 0.825


## Checking feature importance

In [41]:
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import f1_score, roc_auc_score
import numpy as np

# Calibrate the stacking classifier
calibrated_stack = CalibratedClassifierCV(estimator=stacked_model2, 
                                          method="sigmoid", 
                                          cv=5,
                                          #ensemble="auto",
                                        )
calibrated_stack.fit(X_train, y_train)

# Predict calibrated probabilities
calibrated_probs = calibrated_stack.predict_proba(X_val)[:, 1]

# Tune decision threshold for F1 score
best_thresh = 0.5
best_score = 0

for t in np.linspace(0.3, 0.7, 100):
    preds = (calibrated_probs > t).astype(int)
    score = f1_score(y_val, preds)
    if score > best_score:
        best_score = score
        best_thresh = t

print("Best threshold:", best_thresh)
print("Best F1 score:", best_score)

# Include accuracy at best threshold
final_preds = (calibrated_probs > best_thresh).astype(int)
final_accuracy = accuracy_score(y_val, final_preds)
print("Final Accuracy at best threshold (CatBoost):", final_accuracy)

# ROC AUC score
roc_auc = roc_auc_score(y_val, calibrated_probs)
print("ROC AUC Score (CatBoost):", roc_auc)


Best threshold: 0.3686868686868687
Best F1 score: 0.8323040380047506
Final Accuracy at best threshold (CatBoost): 0.8235
ROC AUC Score (CatBoost): 0.890923


In [42]:
# Same steps for stacked_model3 with Gradient Boosting as final estimator
calibrated_stack_gb = CalibratedClassifierCV(estimator=stacked_model3, 
                                          method="sigmoid", 
                                          cv=5,
                                          #ensemble="auto",
                                        )
calibrated_stack_gb.fit(X_train, y_train)

calibrated_probs_gb = calibrated_stack_gb.predict_proba(X_val)[:, 1]

best_thresh_gb = 0.5
best_score_gb = 0
for t in np.linspace(0.3, 0.7, 100):
    preds = (calibrated_probs_gb > t).astype(int)
    score = f1_score(y_val, preds)
    if score > best_score_gb:
        best_score_gb = score
        best_thresh_gb = t
print("Best threshold (GB):", best_thresh_gb)
print("Best F1 score (GB):", best_score_gb)

final_preds_gb = (calibrated_probs_gb > best_thresh_gb).astype(int)
final_accuracy_gb = accuracy_score(y_val, final_preds_gb)
print("Final Accuracy at best threshold (GB):", final_accuracy_gb)

roc_auc_gb = roc_auc_score(y_val, calibrated_probs_gb)
print("ROC AUC Score (GB):", roc_auc_gb)

Best threshold (GB): 0.44141414141414137
Best F1 score (GB): 0.8292682926829268
Final Accuracy at best threshold (GB): 0.8285
ROC AUC Score (GB): 0.896761


## 5. Creating the Submission File

In [45]:
# Prepare test features
pokemon_df_test = extract_unique_pokemon_no_ids(test_file_path)
seen_pokemons_test = extract_pokemon_in_play(test_file_path)
moves_df_test = make_moves_df(test_file_path, pokemon_df_test, df_typechart, verbose=False)
moves_df_test_update = compute_actual_damage(moves_df_test, pokemon_df_test)
test_df = generate_battle_features(test_file_path, moves_df_test_update, seen_pokemons_test, pokemon_df_test)

X_test = test_df.drop(columns=["battle_id", "player_won"])

# Make predictions
print("Generating predictions on the test set...")
test_predictions_cs2 = calibrated_stack_gb.predict(X_test)

# Create submission DataFrame
submission_df_cs2 = pd.DataFrame({
    "battle_id": test_df["battle_id"],
    "player_won": test_predictions_cs2
})

# Save to final CSV
submission_df_cs2.to_csv("submission_version_2.csv", index=False)
print("\n'submission_version_2.csv' file created successfully!")

  pivot_cols = pivoted.columns.drop("battle_id")


NaN values found in final battle features:


Unnamed: 0,battle_id,p1_num_physical_moves,p2_num_physical_moves,p1_num_special_moves,p2_num_special_moves,p1_num_status_moves,p2_num_status_moves,p1_null_moves,p2_null_moves,p1_total_boosts_atk,p2_total_boosts_atk,p1_total_boosts_def,p2_total_boosts_def,p1_total_boosts_spa,p2_total_boosts_spa,p1_total_boosts_spd,p2_total_boosts_spd,p1_total_boosts_spe,p2_total_boosts_spe,p1_num_hits,p2_num_hits,p1_num_misses,p2_num_misses,p1_num_misses_on_accurate_moves,p2_num_misses_on_accurate_moves,p1_mean_atk,p2_mean_atk,p1_mean_def,p2_mean_def,p1_mean_hp,p2_mean_hp,p1_mean_spa,p2_mean_spa,p1_mean_spd,p2_mean_spd,p1_mean_spe,p2_mean_spe,p1_var_atk,p2_var_atk,p1_var_def,p2_var_def,p1_var_hp,p2_var_hp,p1_var_spa,p2_var_spa,p1_var_spd,p2_var_spd,p1_var_spe,p2_var_spe,p1_num_neutral,p2_num_neutral,p1_num_not_effective,p2_num_not_effective,p1_num_super_effective,p2_num_super_effective,p1_num_switches,p2_num_switches,p1_num_regenerations,p2_num_regenerations,p1_num_brn,p2_num_brn,p1_num_fnt,p2_num_fnt,p1_num_frz,p2_num_frz,p1_num_nostatus,p2_num_nostatus,p1_num_par,p2_num_par,p1_num_psn,p2_num_psn,p1_num_slp,p2_num_slp,p1_num_tox,p2_num_tox,p1_num_clamp_inflicted,p2_num_clamp_inflicted,p1_num_confusion_inflicted,p2_num_confusion_inflicted,p1_num_firespin_inflicted,p2_num_firespin_inflicted,p1_num_noeffect_applied,p2_num_noeffect_applied,p1_num_reflect_applied,p2_num_reflect_applied,p1_num_substitute_applied,p2_num_substitute_applied,p1_num_tc_inflicted,p2_num_tc_inflicted,p1_num_wrap_inflicted,p2_num_wrap_inflicted,p1_ko_count,p2_ko_count,p1_num_advantage_moves,p2_num_advantage_moves,p1_num_stab_moves,p2_num_stab_moves,p1_num_first_attacks,p2_num_first_attacks,mean_spe_diff,mean_atk_diff,mean_def_diff,mean_spa_diff,mean_spd_diff,mean_hp_diff,var_atk_diff,var_def_diff,var_spa_diff,var_spd_diff,var_spe_diff,var_hp_diff,diff_num_switches,player_won
0,0,3.0,10.0,15.0,3.0,2.0,11.0,10.0,6.0,0.0,40.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8.0,7.0,10.0,6.0,10.0,1.0,236.0,246.000000,208.000000,208.000000,455.000000,477.000000,294.000000,288.000000,294.000000,288.000000,242.000000,234.000000,6470.0,7220.000000,3550.000000,3350.000000,25570.000000,21780.000000,3980.000000,1250.000000,3980.000000,1250.000000,5830.000000,5680.000000,9.0,22.0,6.0,0.0,5.0,2.0,9,5,4.0,4.0,0.0,0.0,0.0,4.0,0.0,4.0,9.0,7.0,10.0,4.0,0.0,0.0,1.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,20.0,24.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,10.0,6.0,13.0,7.0,8.0,21.0,8.000000,-10.000000,0.000000,6.000000,6.000000,-22.000000,-750.000000,200.000000,2730.000000,2730.000000,150.000000,3790.000000,4,
1,1,19.0,3.0,3.0,8.0,5.0,1.0,3.0,18.0,82.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,14.0,4.0,8.0,7.0,2.0,7.0,297.5,276.333333,298.000000,249.666667,346.000000,406.333333,295.500000,289.666667,295.500000,289.666667,273.000000,243.000000,3241.0,3416.666667,12333.333333,2696.666667,1276.000000,6226.666667,425.000000,4896.666667,425.000000,4896.666667,3300.000000,7110.000000,24.0,4.0,2.0,0.0,0.0,6.0,2,17,3.0,2.0,0.0,0.0,0.0,1.0,0.0,0.0,14.0,9.0,5.0,2.0,0.0,0.0,8.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,27.0,12.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,13.0,9.0,5.0,12.0,22.0,7.0,30.000000,21.166667,48.333333,5.833333,5.833333,-60.333333,-175.666667,9636.666667,-4471.666667,-4471.666667,-3810.000000,-4950.666667,-15,
2,2,12.0,13.0,5.0,4.0,8.0,8.0,5.0,5.0,0.0,12.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,14.0,7.0,3.0,10.0,0.0,10.0,272.0,253.000000,284.000000,271.333333,427.000000,399.666667,260.000000,281.333333,260.000000,281.333333,234.000000,253.000000,9130.0,6910.000000,16830.000000,15626.666667,25330.000000,22746.666667,2370.000000,3186.666667,2370.000000,3186.666667,2880.000000,3830.000000,22.0,24.0,3.0,0.0,0.0,1.0,4,4,2.0,4.0,0.0,0.0,0.0,1.0,2.0,0.0,10.0,24.0,12.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,23.0,16.0,0.0,0.0,0.0,0.0,0.0,0.0,8.0,2.0,0.0,1.0,5.0,6.0,12.0,16.0,16.0,13.0,-19.000000,19.000000,12.666667,-21.333333,-21.333333,27.333333,2220.000000,1203.333333,-816.666667,-816.666667,-950.000000,2583.333333,0,
3,3,2.0,15.0,10.0,8.0,16.0,1.0,2.0,6.0,0.0,0.0,0.0,0.0,-2.0,-1.0,-2.0,-1.0,0.0,0.0,7.0,13.0,5.0,10.0,4.0,9.0,198.0,272.000000,181.333333,282.000000,476.333333,367.000000,314.666667,282.000000,314.666667,282.000000,231.333333,278.000000,8100.0,1730.000000,6533.333333,10530.000000,39433.333333,7930.000000,933.333333,3030.000000,933.333333,3030.000000,2433.333333,6400.000000,26.0,15.0,0.0,0.0,2.0,9.0,2,6,12.0,3.0,0.0,18.0,0.0,0.0,0.0,0.0,12.0,6.0,14.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0,25.0,24.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,5.0,18.0,5.0,8.0,22.0,-46.666667,-74.000000,-100.666667,32.666667,32.666667,109.333333,6370.000000,-3996.666667,-2096.666667,-2096.666667,-3966.666667,31503.333333,-4,
4,4,11.0,11.0,8.0,10.0,5.0,3.0,6.0,6.0,0.0,0.0,0.0,0.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0,10.0,13.0,9.0,8.0,7.0,4.0,252.0,254.666667,232.000000,261.333333,459.000000,433.000000,284.000000,291.333333,284.000000,291.333333,242.000000,239.666667,7130.0,6066.666667,5280.000000,13186.666667,24430.000000,23600.000000,2530.000000,3066.666667,2530.000000,3066.666667,5830.000000,4336.666667,20.0,14.0,0.0,4.0,4.0,6.0,4,4,6.0,6.0,0.0,0.0,1.0,3.0,0.0,0.0,16.0,15.0,5.0,6.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,6.0,0.0,0.0,0.0,0.0,18.0,24.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,3.0,2.0,4.0,13.0,14.0,14.0,14.0,2.333333,-2.666667,-29.333333,-7.333333,-7.333333,26.000000,1063.333333,-7906.666667,-536.666667,-536.666667,1493.333333,830.000000,0,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4995,4995,1.0,5.0,19.0,7.0,5.0,12.0,5.0,6.0,0.0,0.0,0.0,0.0,0.0,-3.0,0.0,-3.0,0.0,0.0,9.0,7.0,11.0,5.0,10.0,2.0,218.0,283.000000,216.000000,270.500000,423.000000,398.000000,346.000000,305.500000,346.000000,305.500000,266.000000,260.500000,5000.0,166.666667,5470.000000,158.333333,25750.000000,2166.666667,520.000000,2825.000000,520.000000,2825.000000,3970.000000,3091.666667,17.0,21.0,6.0,1.0,2.0,2.0,4,5,4.0,3.0,0.0,0.0,0.0,2.0,0.0,0.0,13.0,13.0,12.0,4.0,0.0,0.0,0.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,25.0,24.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,9.0,4.0,15.0,21.0,18.0,11.0,5.500000,-65.000000,-54.500000,40.500000,40.500000,25.000000,4833.333333,5311.666667,-2305.000000,-2305.000000,878.333333,23583.333333,-1,
4996,4996,9.0,9.0,13.0,1.0,4.0,15.0,4.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,11.0,8.0,11.0,2.0,9.0,2.0,240.0,230.500000,222.000000,220.500000,447.000000,443.000000,278.000000,313.000000,278.000000,313.000000,272.000000,260.500000,6770.0,7625.000000,4880.000000,6491.666667,27130.000000,30866.666667,1750.000000,2966.666667,1750.000000,2966.666667,7780.000000,4425.000000,23.0,24.0,2.0,0.0,1.0,1.0,4,5,3.0,5.0,0.0,0.0,0.0,2.0,0.0,0.0,18.0,8.0,8.0,11.0,0.0,0.0,0.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,26.0,25.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,20.0,19.0,19.0,11.0,11.500000,9.500000,1.500000,-35.000000,-35.000000,4.000000,-855.000000,-1611.666667,-1216.666667,-1216.666667,3355.000000,-3736.666667,-1,
4997,4997,8.0,9.0,4.0,9.0,5.0,5.0,13.0,7.0,0.0,0.0,0.0,0.0,-3.0,0.0,-3.0,0.0,0.0,0.0,7.0,13.0,5.0,5.0,4.0,4.0,248.0,242.000000,229.666667,216.000000,436.333333,457.000000,296.333333,298.000000,296.333333,298.000000,254.666667,244.000000,5800.0,7730.000000,4256.666667,5120.000000,22626.666667,25130.000000,2936.666667,4000.000000,2936.666667,4000.000000,5626.666667,6280.000000,13.0,13.0,0.0,5.0,0.0,3.0,11,5,5.0,5.0,0.0,0.0,0.0,1.0,0.0,0.0,15.0,10.0,0.0,12.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,17.0,23.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,3.0,9.0,9.0,9.0,15.0,13.0,10.666667,6.000000,13.666667,-1.666667,-1.666667,-20.666667,-1930.000000,-863.333333,-1063.333333,-1063.333333,-653.333333,-2503.333333,6,
4998,4998,7.0,9.0,7.0,6.0,11.0,9.0,5.0,6.0,0.0,0.0,0.0,0.0,0.0,-1.0,0.0,-1.0,0.0,0.0,7.0,8.0,7.0,7.0,6.0,7.0,243.0,228.000000,233.000000,224.000000,446.333333,417.000000,291.333333,312.000000,291.333333,312.000000,229.666667,278.000000,6190.0,6050.000000,5830.000000,5680.000000,20786.666667,26530.000000,3466.666667,2530.000000,3466.666667,2530.000000,6256.666667,4750.000000,20.0,14.0,5.0,5.0,0.0,5.0,3,4,6.0,6.0,0.0,0.0,0.0,2.0,0.0,0.0,13.0,14.0,7.0,5.0,0.0,0.0,5.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,25.0,24.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,4.0,7.0,16.0,8.0,6.0,22.0,-48.333333,15.000000,9.000000,-20.666667,-20.666667,29.333333,140.000000,150.000000,936.666667,936.666667,1506.666667,-5743.333333,-1,


Generating predictions on the test set...

'submission_version_2.csv' file created successfully!
