#### Imports and Global Setting

In [15]:
import pandas as pd
import matplotlib.pyplot as plt
from mplsoccer import Pitch
from joblib import load 
from statsbombpy import sb

from ipywidgets import Dropdown, VBox, Output
import matplotlib.pyplot as plt

import ast
from tqdm import tqdm


import warnings
warnings.filterwarnings("ignore")

# Load competitions dataset
df_competitions = pd.read_csv("../task1_xg/data/df_competitions.csv")
df_competitions = df_competitions.iloc[:-1]   # remove last row (TOTAL)

# StatsBomb competitions with IDs
df_sb_competitions = sb.competitions()

# Merge to add IDs
df_competitions = df_competitions.merge(
    df_sb_competitions,
    left_on=["Competition", "Season"],
    right_on=["competition_name", "season_name"],
    how="left"
)

# Keep only relevant columns and rename
df_competitions = df_competitions[
    ["Competition", "Season", "competition_id", "season_id"]
].rename(columns={
    "competition_id": "Competition_ID",
    "season_id": "Season_ID"
})

# Load DS4 dataset
df_ds4 = pd.read_csv("../task1_xg/data/DS4.csv")

# Quick checks
print("Competitions dataset shape:", df_competitions.shape)
print("\nColumns:", df_competitions.columns.tolist())
print("\nCompetitions available:", df_competitions['Competition'].unique())
print("\nDS4 shape:", df_ds4.shape)


Competitions dataset shape: (75, 4)

Columns: ['Competition', 'Season', 'Competition_ID', 'Season_ID']

Competitions available: ['FIFA World Cup' 'Champions League' 'La Liga' 'Copa del Rey'
 'North American League' 'FIFA U20 World Cup' 'Liga Profesional' 'Serie A'
 'UEFA Europa League' 'Premier League' '1. Bundesliga' 'Ligue 1'
 "FA Women's Super League" 'NWSL' "Women's World Cup" 'UEFA Euro'
 'Indian Super league' "UEFA Women's Euro" 'African Cup of Nations'
 'Major League Soccer' 'Copa America']

DS4 shape: (86833, 29)


In [16]:
# Load extra shot info (interpretable metadata)
shots_extra = pd.read_csv("../task1_xg/data/shots_df.csv", low_memory=False)

# Rename 'id' in shots_extra to 'event_id' to match DS4
shots_extra = shots_extra.rename(columns={"id": "event_id"})

# Keep only the interpretable columns from shots_df
shots_extra = shots_extra[[
    "event_id",
    "location", "shot_end_location", "shot_freeze_frame", "shot_outcome",
    "player", "team",
    "minute", "second"
]]

# Reduce DS4 to just event_id and match_id
df_ds4 = df_ds4[["event_id", "match_id", "target_xg"]]

# Merge with interpretable info
df_ds4 = df_ds4.merge(shots_extra, on="event_id", how="left")

# Optional: unpack location and shot_end_location into x,y
df_ds4["loc_x"] = df_ds4["location"].apply(lambda loc: loc[0] if isinstance(loc, list) else None)
df_ds4["loc_y"] = df_ds4["location"].apply(lambda loc: loc[1] if isinstance(loc, list) else None)

df_ds4["end_shot_x"] = df_ds4["shot_end_location"].apply(lambda loc: loc[0] if isinstance(loc, list) else None)
df_ds4["end_shot_y"] = df_ds4["shot_end_location"].apply(lambda loc: loc[1] if isinstance(loc, list) else None)

# Drop raw location lists (if not needed anymore)
df_ds4 = df_ds4.drop(columns=["location", "shot_end_location"])

print("DS4 final shape:", df_ds4.shape)
print("DS4 final columns:", df_ds4.columns.tolist())

DS4 final shape: (86833, 13)
DS4 final columns: ['event_id', 'match_id', 'target_xg', 'shot_freeze_frame', 'shot_outcome', 'player', 'team', 'minute', 'second', 'loc_x', 'loc_y', 'end_shot_x', 'end_shot_y']


#### Load all info about competitions, seasons, matches and shots

In [17]:
def load_competitions():
    """
    Return unique competition names from df_competitions
    """
    return sorted(df_competitions["Competition"].unique())


def load_seasons(competition: str):
    """
    Return available seasons for a given competition
    """
    return sorted(
        df_competitions.loc[
            df_competitions["Competition"] == competition, "Season"
        ].unique()
    )

def load_matches(competition: str, season: str):
    """
    Load matches for a given competition and season using statsbombpy.
    If no matches are found, return None.
    """
    rows = df_competitions.loc[
        (df_competitions["Competition"] == competition)
        & (df_competitions["Season"] == season)
    ]
    
    if rows.empty:
        print(f"No matches found for {competition} - {season}.")
        return None
    
    comp_id = rows["Competition_ID"].iloc[0]
    season_id = rows["Season_ID"].iloc[0]

    matches = sb.matches(competition_id=comp_id, season_id=season_id)
    if matches.empty:
        print(f"No matches returned by StatsBomb for {competition} - {season}.")
        return None

    return matches

def load_shots(match_id: int):
    """
    Filter shots from DS4 dataset for the given match_id.
    If no shots are found, return None and a message.
    """
    shots = df_ds4[df_ds4["match_id"] == match_id].copy()
    
    if shots.empty:
        print(f"No shots found for match_id={match_id}. "
              "Try selecting another match or competition.")
        return None
    
    return shots


#### Plot the field

In [18]:
def plot_shot(x, y, players=None):
    pitch = Pitch(pitch_type='statsbomb', pitch_length=120, pitch_width=80, pitch_color="grass", line_color="white")
    fig, ax = pitch.draw(figsize=(8, 5))

    # Shooter
    pitch.scatter(x, y, s=200, marker='*', color='red', ax=ax)

    # Example goalkeeper
    pitch.scatter(120, 40, s=200, marker='^', color='blue', ax=ax)

    plt.show()

#### Loading the models for xG prediction

In [19]:
# Load all models
models_ds3 = {"Linear Regression": load("../task1_xg/models/model_linear_regression.pkl")}
models_ds4 = {
    "Random Forest": load("../task1_xg/models/model_rf.pkl"),
    "XGBoost": load("../task1_xg/models/model_xgboost.pkl"),
    "Neural Net": load("../task1_xg/models/model_nn.pkl")
}

# Load datasets to get feature columns
ds3_for_prediction = pd.read_csv("../task1_xg/data/DS3.csv")
ds4_for_prediction = pd.read_csv("../task1_xg/data/DS4.csv")

train_columns_ds3 = [c for c in ds3_for_prediction.columns if c not in ["target_xg", "shot_id", "match_id", "event_id"]]
train_columns_ds4 = [c for c in ds4_for_prediction.columns if c not in ["target_xg", "shot_id", "match_id", "event_id"]]

def compute_xg_models(shot_row):
    results = {}
    
    # Features for DS3 (Linear Regression)
    X3 = pd.DataFrame([shot_row[train_columns_ds3]], columns=train_columns_ds3)
    for name, model in models_ds3.items():
        results[name] = float(model.predict(X3)[0])

    # Features for DS4 (RF, XGB, NN)
    X4 = pd.DataFrame([shot_row[train_columns_ds4]], columns=train_columns_ds4)
    for name, model in models_ds4.items():
        if hasattr(model, "predict_proba"):
            results[name] = float(model.predict_proba(X4)[0, 1])
        else:
            results[name] = float(model.predict(X4)[0])

    # True value
    results["True_xG"] = shot_row["target_xg"]
    return results


## DEMO

In [29]:
sb.matches(9, 281).columns

Index(['match_id', 'match_date', 'kick_off', 'competition', 'season',
       'home_team', 'away_team', 'home_score', 'away_score', 'match_status',
       'match_status_360', 'last_updated', 'last_updated_360', 'match_week',
       'competition_stage', 'stadium', 'referee', 'home_managers',
       'away_managers', 'data_version', 'shot_fidelity_version',
       'xy_fidelity_version'],
      dtype='object')

In [31]:
# Output area for plot + text
output = Output()

# --- Dropdowns iniziali ---
comp_dd = Dropdown(options=["Select competition..."] + load_competitions(),
                   description="Competition:")

season_dd = Dropdown(options=["Select season..."], description="Season:", disabled=True)
match_dd = Dropdown(options=["Select match..."], description="Match:", disabled=True)
shot_dd = Dropdown(options=["Select shot..."], description="Shot:", disabled=True)

# --- Update logic ---
def update_seasons(change):
    if change["new"] and change["new"] != "Select competition...":
        seasons = load_seasons(change["new"])
        season_dd.options = ["Select season..."] + list(seasons)
        season_dd.disabled = False
        # reset cascata
        match_dd.options = ["Select match..."]
        match_dd.disabled = True
        shot_dd.options = ["Select shot..."]
        shot_dd.disabled = True
    else:
        season_dd.options = ["Select season..."]
        season_dd.disabled = True
        match_dd.options = ["Select match..."]
        match_dd.disabled = True
        shot_dd.options = ["Select shot..."]
        shot_dd.disabled = True

def update_matches(change):
    if change["new"] and change["new"] != "Select season...":
        matches = load_matches(comp_dd.value, change["new"])
        matches.sort_values(by=["match_week", "kick_off"], inplace=True)
        if matches is not None and not matches.empty:
            # Aggiungo matchweek se disponibile
            match_dd.options = ["Select match..."] + [
                f"MW {row.get('match_week','?')} - {row['home_team']} vs {row['away_team']} (ID {row['match_id']})"
                for _, row in matches.iterrows()
            ]
        else:
            match_dd.options = ["No matches available"]
        match_dd.disabled = False
        # reset cascata
        shot_dd.options = ["Select shot..."]
        shot_dd.disabled = True
    else:
        match_dd.options = ["Select match..."]
        match_dd.disabled = True
        shot_dd.options = ["Select shot..."]
        shot_dd.disabled = True

def update_shots(change):
    if change["new"] and "vs" in change["new"]:
        # estraggo match_id dal testo → "... (ID xxx)"
        match_id = int(change["new"].split("ID")[-1].strip(") "))
        shots_df = load_shots(match_id)
        if shots_df is not None and not shots_df.empty:
            # Aggiungo anche il second e l'event_id
            shot_dd.options = ["Select shot..."] + [
                f"Minute {row['minute']}:{row['second']} - {row['player']}"
                for i, row in shots_df.iterrows()
            ]
        else:
            shot_dd.options = ["No shots available"]
        shot_dd.disabled = False
    else:
        shot_dd.options = ["Select shot..."]
        shot_dd.disabled = True

def update_demo(change):
    with output:
        output.clear_output()
        if change["new"] and "Select" not in change["new"] and "No " not in change["new"]:
            # --- Retrieve shot ---
            shot_idx = int(change["new"].split(":")[0])
            match_id = int(match_dd.value.split("ID")[-1].strip(")"))
            shots_df = load_shots(match_id)
            if shots_df is None or shots_df.empty:
                print("⚠️ No shots found for this match in DS4")
                return

            shot_row = shots_df.iloc[shot_idx]
            x, y = shot_row["loc_x"], shot_row["loc_y"]

            # --- Plot pitch ---
            fig, ax = plt.subplots(figsize=(8, 6))
            pitch = Pitch(pitch_type="statsbomb", line_color="black")
            pitch.draw(ax=ax)
            ax.scatter(x, y, c="red", s=120, marker="*", label="Shot")

            # --- Parse and plot freeze frame for this shot only ---
            if pd.notna(shot_row["shot_freeze_frame"]):
                try:
                    freeze = ast.literal_eval(shot_row["shot_freeze_frame"])
                    print(freeze)
                    if isinstance(freeze, list):
                        for p in freeze:
                            px, py = p["location"]
                            if p.get("teammate", False):
                                color = "blue"
                            elif p.get("keeper", False):
                                color = "green"
                            else:
                                color = "black"
                            ax.scatter(px, py, c=color, s=80, alpha=0.8)
                except Exception as e:
                    print("⚠️ Could not parse freeze_frame:", e)

            plt.show()

            # --- Compute model predictions ---
            preds = {}

            # Linear Regression (DS3) – only if model available
            if "Linear Regression" in models_ds3:
                try:
                    X3 = pd.DataFrame([shot_row[train_columns_ds3]], columns=train_columns_ds3)
                    preds["Linear Regression"] = float(models_ds3["Linear Regression"].predict(X3)[0])
                except Exception as e:
                    preds["Linear Regression"] = f"⚠️ Feature mismatch ({e})"
            else:
                preds["Linear Regression"] = "⚠️ Model not available"

            # RF, XGB, NN (DS4) – only if models available
            for name in ["Random Forest", "XGBoost", "Neural Net"]:
                if name in models_ds4:
                    try:
                        X4 = pd.DataFrame([shot_row[train_columns_ds4]], columns=train_columns_ds4)
                        model = models_ds4[name]
                        if hasattr(model, "predict_proba"):
                            preds[name] = float(model.predict_proba(X4)[0, 1])
                        else:
                            preds[name] = float(model.predict(X4)[0])
                    except Exception as e:
                        preds[name] = f"⚠️ Feature mismatch ({e})"
                else:
                    preds[name] = "⚠️ Model not available"

            # True xG (from dataset)
            preds["True xG"] = shot_row.get("target_xg", "N/A")

            # --- Print info ---
            print(f"Player: {shot_row['player']} ({shot_row['team']})")
            for k, v in preds.items():
                print(f"{k}: {v}")
            print("Shot outcome:", shot_row["shot_outcome"])




# --- Attach observers ---
comp_dd.observe(update_seasons, names="value")
season_dd.observe(update_matches, names="value")
match_dd.observe(update_shots, names="value")
shot_dd.observe(update_demo, names="value")

# --- Display all dropdowns together + output ---
VBox([comp_dd, season_dd, match_dd, shot_dd, output])


VBox(children=(Dropdown(description='Competition:', options=('Select competition...', '1. Bundesliga', 'Africa…