In [None]:
import json
import os
from typing import List

import polars as pl
from loguru import logger

from soccerai.data.data import load_and_process_soccer_events


def load_chains(chains_path: str) -> List[List[int]]:
    with open(chains_path, "r") as f:
        chains = json.load(f)
        return chains


def load_rosters(rosters_path: str):
    rosters_df = pl.read_csv(rosters_path)
    return rosters_df


def create_dataset(resources_path: str) -> pl.DataFrame:
    event_df, players_df = load_and_process_soccer_events(
        "/home/soccerdata/FIFA_WorldCup_2022/Event Data"
    )
    pos_chains = load_chains(os.path.join(resources_path, "accepted_pos_chains.json"))
    neg_chains = load_chains(os.path.join(resources_path, "accepted_neg_chains.json"))
    pos_indices = [idx for chain in pos_chains for idx in chain]
    neg_indices = [idx for chain in neg_chains for idx in chain]

    logger.debug(
        "Num pos indices: {}, Num neg indices: {}", len(pos_indices), len(neg_indices)
    )

    labeled_event_df = event_df.with_columns(
        pl.when(pl.col("index").is_in(pos_indices))
        .then(1)
        .when(pl.col("index").is_in(neg_indices))
        .then(0)
        .otherwise(None)
        .alias("label")
    ).filter(pl.col("label").is_not_null())

    labeled_events_with_players_df = labeled_event_df.join(
        players_df, on=["gameEventId", "possessionEventId"], how="inner", coalesce=True
    )
    return labeled_events_with_players_df


create_dataset("../soccerai/data/resources")
# rosters_df = load_rosters(os.path.join(resources_path, "rosters.csv"))
# accepted_pos_chains.extend(acce)

# _, players_df = load_and_process_soccer_events(
#     "/home/soccerdata/FIFA_WorldCup_2022/Event Data"
# )
# players = players_df.filter(pl.col("index").is_in([0, 1]))

# enricher = PlayerVelocityEnricher(
#     "/home/soccerdata/FIFA_WorldCup_2022/Tracking Data"
# )
# augmented_players = enricher.add_velocity_per_player(players)
# print(augmented_players)