# Expected Goals (xG) Demo Notebook

This notebook provides an **interactive demo** for visualizing football shots and comparing **expected goals (xG) predictions** from multiple models against the true value provided by **Statsbomb Open-Data**. 

Users can select a **competition, season, match, and specific shot** through dropdown menus, after which the pitch visualization is displayed, showing the shooter’s location, the shot trajectory, and the freeze frame with the positions of teammates, opponents, and the goalkeeper. 

The system then computes and compares xG values generated by the different models previously trained, including a baseline Linear Regression trained on DS3 as well as Random Forest, XGBoost, and a Neural Network trained on DS4. Alongside these predictions, the true xG value from the dataset is also shown. 

The workflow consists of loading the pre-trained models and feature datasets, navigating interactively through the dropdowns to select the desired shot, visualizing the shot in its tactical context, and finally running all available models to generate predictions and compare them with the ground truth.

## Imports and Global Setting

This section below prepares the competitions dataset by cleaning the local file and merging it with StatsBomb data to associate each competition and season with the correct official IDs. Only the necessary columns are retained for consistency.

In [3]:
import pandas as pd
import matplotlib.pyplot as plt
from mplsoccer import Pitch
from joblib import load 
from statsbombpy import sb
import numpy as np
import ast
import os
from IPython.display import display, HTML

from ipywidgets import Dropdown, VBox, HBox, Output, Label
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

# Load competitions dataset
df_competitions = pd.read_csv("../task1_xg/data/df_competitions.csv")
df_competitions = df_competitions.iloc[:-1]   # remove last row (TOTAL)

# StatsBomb competitions with IDs
df_sb_competitions = sb.competitions()

# Merge to add IDs of StatsBomb competitions and seasons
df_competitions = df_competitions.merge(
    df_sb_competitions,
    left_on=["Competition", "Season"],
    right_on=["competition_name", "season_name"],
    how="left"
)

# Keep only relevant columns and rename
df_competitions = df_competitions[
    ["Competition", "Season", "competition_id", "season_id"]
].rename(columns={
    "competition_id": "Competition_ID",
    "season_id": "Season_ID"
})

# Quick checks
print("Competitions dataset shape:", df_competitions.shape)
print("\nColumns:", df_competitions.columns.tolist())
print("\nCompetitions available:", df_competitions['Competition'].unique())

Competitions dataset shape: (75, 4)

Columns: ['Competition', 'Season', 'Competition_ID', 'Season_ID']

Competitions available: ['FIFA World Cup' 'Champions League' 'La Liga' 'Copa del Rey'
 'North American League' 'FIFA U20 World Cup' 'Liga Profesional' 'Serie A'
 'UEFA Europa League' 'Premier League' '1. Bundesliga' 'Ligue 1'
 "FA Women's Super League" 'NWSL' "Women's World Cup" 'UEFA Euro'
 'Indian Super league' "UEFA Women's Euro" 'African Cup of Nations'
 'Major League Soccer' 'Copa America']


The cells below integrates additional shot information from the raw dataset into DS4 in order to make the data interpretable. 

While DS4 contains normalized features for modeling, the original shot and end-shot locations in `shots_df` are preserved in their unnormalized form. These values are unpacked into explicit numerical coordinates, ensuring that **visualizations** are based on the actual dimensions rather than normalized values. Unnecessary raw fields are then removed to keep the dataset concise.

In [4]:
# Load DS4 dataset
df_ds4 = pd.read_csv("../task1_xg/data/DS4.csv")

print("DS4 shape:", df_ds4.shape)
df_ds4.columns

DS4 shape: (86833, 29)


Index(['event_id', 'match_id', 'player_id', 'minute', 'second', 'period',
       'shot_type', 'shot_technique', 'shot_body_part', 'play_pattern',
       'under_pressure', 'shot_first_time', 'shot_one_on_one', 'target_xg',
       'loc_x', 'loc_y', 'end_shot_x', 'end_shot_y', 'end_shot_z',
       'end_shot_z_available', 'shot_from_set_piece', 'distance_to_goal',
       'angle_to_goal', 'gender', 'role', 'num_players_between',
       'closest_defender_dist', 'goalkeeper_positioning', 'free_proj_goal'],
      dtype='object')

In [5]:
# Load extra shot info
shots_extra = pd.read_csv("../task1_xg/data/shots_df.csv", low_memory=False)

# Rename 'id' in shots_extra to 'event_id' to match DS4
shots_extra = shots_extra.rename(columns={"id": "event_id"})

# Keep only the interpretable columns from shots_df
# Location and Shot_end_location are taken from here since are not normalized in shots_df
shots_extra = shots_extra[[
    "event_id",
    "location", "shot_end_location", "shot_freeze_frame", "shot_outcome",
    "player", "team",
    "minute", "second"
]]

# Reduce DS4 to just what we need for the analysis
df_ds4 = df_ds4[["event_id", "match_id", "target_xg"]]

# Merge with interpretable info
df_ds4 = df_ds4.merge(shots_extra, on="event_id", how="left")

# Unpack location safely
df_ds4["loc_x"] = df_ds4["location"].apply(
    lambda loc: float(ast.literal_eval(loc)[0])
)
df_ds4["loc_y"] = df_ds4["location"].apply(
    lambda loc: float(ast.literal_eval(loc)[1])
)

# Unpack shot_end_location safely
df_ds4["end_shot_x"] = df_ds4["shot_end_location"].apply(
    lambda loc: float(ast.literal_eval(loc)[0])
)
df_ds4["end_shot_y"] = df_ds4["shot_end_location"].apply(
    lambda loc: float(ast.literal_eval(loc)[1])
)

# Drop raw location lists (if not needed anymore)
df_ds4 = df_ds4.drop(columns=["location", "shot_end_location"])

print("DS4 final shape:", df_ds4.shape)
print("DS4 final columns:", df_ds4.columns.tolist())

DS4 final shape: (86833, 13)
DS4 final columns: ['event_id', 'match_id', 'target_xg', 'shot_freeze_frame', 'shot_outcome', 'player', 'team', 'minute', 'second', 'loc_x', 'loc_y', 'end_shot_x', 'end_shot_y']


## Load all info about competitions, seasons, matches and shots

Here the helper functions for loading competitions, seasons, matches, and shots are defined. These are the information shown in the Dropdown menu in the Demo.

Competitions and seasons are retrieved from the prepared dataset, while matches are queried through the StatsBomb API using the correct identifiers. Shots are filtered directly from DS4, which already contains enriched and interpretable shot information.

In [6]:
def load_competitions():
    """
    Return unique competition names from df_competitions
    """
    return sorted(df_competitions["Competition"].unique())


def load_seasons(competition: str):
    """
    Return available seasons for a given competition
    """
    seasons = df_competitions.loc[df_competitions["Competition"] == competition, "Season"].unique()

    return sorted(seasons)

def load_matches(competition: str, season: str):
    """
    Load matches for a given competition and season using statsbombpy.
    If no matches are found, return None.
    """

    # Retrieve all rows for the given competition and season
    rows = df_competitions.loc[
        (df_competitions["Competition"] == competition)
        & (df_competitions["Season"] == season)
    ]

    # Check if any rows were found
    if rows.empty:
        print(f"No matches found for {competition} - {season}.")
        return None

    # Get competition and season IDs
    comp_id = rows["Competition_ID"].iloc[0]
    season_id = rows["Season_ID"].iloc[0]

    # Retrieve matches from StatsBomb
    matches = sb.matches(competition_id=comp_id, season_id=season_id)

    # Check if any matches were found
    if matches.empty:
        print(f"No matches returned for {competition} - {season}")
        return None

    return matches

def load_shots(match_id: int):
    """
    Filter shots from DS4 dataset for the given match_id.
    If no shots are found, return None and a message.
    """

    # Filter shots from DS4 dataset for the given match_id
    shots = df_ds4[df_ds4["match_id"] == match_id].copy()

    # Check if any shots were found
    if shots.empty:
        print(f"No shots found for match_id={match_id}"
              "Try selecting another match or competition")
        return None
    
    return shots


## Loading the available models for xG prediction

This section loads the trained models and prepares the datasets for prediction.  

Models trained on **DS3** (Linear Regression) and **DS4** (Random Forest, XGBoost, Neural Network) are loaded from disk if available.  

The corresponding datasets are also reloaded to extract the feature columns necessary for making predictions.  

A dedicated function, `compute_xg_models`, is defined to compute the **xG value** of a given shot by generating predictions from all available models.  

The function handles different cases, including missing models and feature mismatches.

Results are returned as a dictionary containing the true value and the predictions from each model.


In [7]:
# Models trained on DS3
models_ds3 = {}
if os.path.exists("../task1_xg/models/model_linear_regression.pkl"):
    models_ds3["Linear Regression"] = load("../task1_xg/models/model_linear_regression.pkl")

# Models trained on DS4
models_ds4 = {}
if os.path.exists("../task1_xg/models/model_rf.pkl"):
    models_ds4["Random Forest"] = load("../task1_xg/models/model_rf.pkl")
if os.path.exists("../task1_xg/models/model_xgboost.pkl"):
    models_ds4["XGBoost"] = load("../task1_xg/models/model_xgboost.pkl")
if os.path.exists("../task1_xg/models/model_nn.pkl"):
    models_ds4["Neural Net"] = load("../task1_xg/models/model_nn.pkl")

print("Loaded DS3 models:", list(models_ds3.keys()))
print("Loaded DS4 models:", list(models_ds4.keys()))

Loaded DS3 models: ['Linear Regression']
Loaded DS4 models: ['Random Forest', 'XGBoost', 'Neural Net']


In [8]:
# Load datasets to get feature columns
ds3_for_prediction = pd.read_csv("../task1_xg/data/DS3.csv")
ds4_for_prediction = pd.read_csv("../task1_xg/data/DS4.csv")

# Get training columns (event_id will drop later but now is useful in compute_xg_models to retrieve the correct shot to predict)
train_columns_ds3 = [c for c in ds3_for_prediction.columns if c not in ["target_xg", "match_id", "player_id"]]
train_columns_ds4 = [c for c in ds4_for_prediction.columns if c not in ["target_xg", "match_id", "player_id"]]

# Keep only relevant columns for prediction + "event_id"
ds3_for_prediction = ds3_for_prediction[train_columns_ds3]
ds4_for_prediction = ds4_for_prediction[train_columns_ds4]

In [9]:
def compute_xg_models(shot_row):
    results = {}
    event_id = shot_row["event_id"]

    # True value
    true_val = shot_row.get("target_xg", "N/A")
    results["True_xG"] = round(float(true_val), 2) if pd.notna(true_val) else "N/A"

    # DS3 (Linear Regression)
    if "Linear Regression" in models_ds3:
        try:
            # Get the shot and its features for prediction
            row3 = ds3_for_prediction.loc[ds3_for_prediction["event_id"] == event_id].copy()

            # Check if the row is empty
            if not row3.empty:

                # Drop the event_id column and prepare the feature matrix
                row3 = row3.drop(columns=["event_id"])
                X3 = row3[[c for c in train_columns_ds3 if c in row3.columns]].astype(float)

                # Get the model
                model = models_ds3["Linear Regression"]

                # Get prediction
                results["Linear Regression"] = round(float(np.clip(model.predict(X3)[0], 0, 1)), 2)
            else:
                results["Linear Regression"] = "Shot not found in DS3"
        except Exception as e:
            results["Linear Regression"] = f"Feature mismatch ({e})"
    else:
        results["Linear Regression"] = "Model not available, try to train it and save in /task1_xg/models/"

    # DS4 (RF, XGB, NN)
    for name in ["Random Forest", "XGBoost", "Neural Net"]:
        if name in models_ds4:
            try:
                # Get the shot and its features for prediction
                row4 = ds4_for_prediction.loc[ds4_for_prediction["event_id"] == event_id].copy()

                # Check if the row is empty
                if not row4.empty:

                    # Drop the event_id column and prepare the feature matrix
                    row4 = row4.drop(columns=["event_id"])
                    X4 = row4[[c for c in train_columns_ds4 if c in row4.columns]].astype(float)

                    # Get the model
                    model = models_ds4[name]

                    # Get prediction (manage different output formats)
                    if hasattr(model, "predict_proba"):
                        results[name] = round(float(model.predict_proba(X4)[0, 1]), 2)
                    else:
                        results[name] = round(float(model.predict(X4)[0]), 2)
                else:
                    results[name] = "Shot not found in DS4"
            except Exception as e:
                results[name] = f"Feature mismatch ({e})"
        else:
            results[name] = "Model not available, try to train it and save in /task1_xg/models/"

    return results


## Shot Visualization with Freeze-Frame Context

This cell defines the function `plot_shot()`, which is used to **create a StatsBomb-style pitch and visualize a selected shot** together with the corresponding freeze-frame context.  

The visualization includes:  

- **Shooter position** (highlighted in lime green)  

- **Freeze-frame players**: 

  - Goalkeeper (yellow)  

  - Teammates (sky blue)  

  - Opponents (orange-red) 

- **Shot trajectory** (dashed red line from shooter to end location)  

- **Legend** automatically cleaned to avoid duplicates  


In [10]:
# Plotting function
def plot_shot(x, y, end_x=None, end_y=None, players=None):
    """
    Create a new pitch and plot the selected shot with freeze-frame context.
    """
    pitch = Pitch(pitch_type='statsbomb', pitch_length=120, pitch_width=80,
                  pitch_color="grass", line_color="white")
    fig, ax = pitch.draw(figsize=(8, 5))

    # Freeze-frame players
    if players:

        # Plot each player from the freeze-frame context
        for p in players:

            # Get player location
            px, py = p["location"]

            # Goalkeeper (recognize by position)
            if p.get("position", {}).get("id") == 1 or \
            p.get("position", {}).get("name", "").lower() == "goalkeeper":
                facecolor, edgecolor, marker, label = "yellow", "black", "o", "Goalkeeper"

            # Teammates
            elif p.get("teammate"):
                facecolor, edgecolor, marker, label = "skyblue", "black", "o", "Teammate"

            # Opponents
            else:
                facecolor, edgecolor, marker, label = "orangered", "black", "o", "Opponent"

            # Plot the players
            pitch.scatter(
                px, py, s=100, marker=marker,
                facecolor=facecolor, edgecolor=edgecolor,
                linewidth=1.2, ax=ax, alpha=1, label=label
            )

    # Shooter
    pitch.scatter(x, y, s=100, marker='o',
                  facecolor="lime", edgecolor="black", linewidth=1.2,
                  ax=ax, label="Shooter")

    # Shot trajectory
    if end_x is not None and end_y is not None:
        ax.plot([x, end_x], [y, end_y], color="red", linestyle="--",
                linewidth=2, label="Shot line")

    # Legend (clean duplicates for each player)
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc="upper left")

    return fig


## Shot Information and Model Predictions

This cell defines a helper function `show_predictions()` used to **display the details of a selected shot in a clean and styled format**.  
It shows:  
- **Player and team information**  

- **True xG** value (from StatsBomb data)  

- **Predictions from all available models** in a well-formatted table 
 
- **Shot outcome**, highlighted in **green if it is a goal** and in **red otherwise**

In [11]:
def show_predictions(shot_row, preds):
    """Pretty print predictions and info with styled HTML."""

    # Player info
    player_info = f"<span style='font-weight:bold; color:#222;'>PLAYER:</span> {shot_row['player']} ({shot_row['team']})"

    # Outcome styling
    outcome_text = shot_row["shot_outcome"].upper()
    if outcome_text == "GOAL":
        outcome = f"<span style='font-weight:bold; color:#222;'>SHOT OUTCOME:</span> <span style='color:green; font-weight:bold;'>{outcome_text}</span>"
    else:
        outcome = f"<span style='font-weight:bold; color:#222;'>SHOT OUTCOME:</span> <span style='color:#B22222; font-weight:bold;'>{outcome_text}</span>"

    # True_xG
    true_xg_val = preds.get("True_xG", None)
    if isinstance(true_xg_val, (int, float)):
        true_xg_str = f"<span style='color:#006400;'>{true_xg_val:.2f}</span>"
    else:
        true_xg_str = f"<span style='color:#B22222; font-weight:bold;'>NOT AVAILABLE</span>"

    # Table rows
    rows = ""
    for model, value in preds.items():
        if model == "True_xG":
            continue  # skip, shown above
        if isinstance(value, str):  # missing model
            rows += f"""
                <tr>
                    <td style='padding:6px 12px; border:1px solid #ddd; text-align:center;'>{model}</td>
                    <td style='padding:6px 12px; border:1px solid #ddd; color:#B22222; font-weight:bold; text-align:center;'>{value.upper()}</td>
                </tr>
            """
        else:
            rows += f"""
                <tr>
                    <td style='padding:6px 12px; border:1px solid #ddd; text-align:center;'>{model}</td>
                    <td style='padding:6px 12px; border:1px solid #ddd; text-align:center;'>{value:.2f}</td>
                </tr>
            """

    # Final HTML
    html = f"""
    <div style="font-family:Arial, sans-serif; font-size:14px; line-height:1.6; color:#333;">
        <p style="margin:4px 0;">{player_info}</p>
        <p style="margin:4px 0; font-weight:bold; color:#444;">TRUE_XG: {true_xg_str}</p>
        
        <table style="border-collapse:collapse; margin:10px 0; width:100%; max-width:400px;">
            <thead>
                <tr style="background:#f2f2f2;">
                    <th style="padding:6px 12px; border:1px solid #ddd; text-align:center;">Model</th>
                    <th style="padding:6px 12px; border:1px solid #ddd; text-align:center;">Prediction</th>
                </tr>
            </thead>
            <tbody>
                {rows}
            </tbody>
        </table>
        
        <p style="margin:4px 0;">{outcome}</p>
    </div>
    """
    display(HTML(html))

## DEMO

This section defines the interactive logic for navigating through competitions, seasons, matches, and individual shots.  
Dropdown menus are dynamically updated based on the user’s selections, ensuring a step-by-step flow from:

competition → season → match → shot

When a shot is selected, the corresponding freeze-frame and shot trajectory are visualized on a pitch, and predictions from the trained models are displayed.  

The update functions manage:  

- **update_seasons**: populates the season dropdown after selecting a competition.  

- **update_matches**: loads matches for the chosen season and updates the match dropdown.  

- **update_shots**: lists all available shots for the selected match.  

- **update_demo**: displays the shot context on the pitch and outputs the model predictions alongside the actual shot outcome.  

In [None]:
# Output containers
output = Output()         # text
pitch_output = Output()   # field

# Dropdowns
comp_dd = Dropdown(options=["Select competition..."] + load_competitions(), description="Competition:")
season_dd = Dropdown(options=["Select season..."], description="Season:", disabled=True)            # Disabled before choosing competition
match_dd = Dropdown(options=["Select match..."], description="Match:", disabled=True)               # Disabled before choosing season
shot_dd = Dropdown(options=["Select shot..."], description="Shot:", disabled=True)                  # Disabled before choosing match    


# Update logic
def update_seasons(change):
    """Update seasons when a competition is selected"""

    # Load seasons for the selected competition
    # change["new"] is the selected competition because it comes from the dropdown
    if change["new"] and change["new"] != "Select competition...":

        # Load seasons for the selected competition
        seasons = load_seasons(change["new"])

        # Update the season dropdown and enable it
        season_dd.options = ["Select season..."] + list(seasons)
        season_dd.disabled = False

        # Update match and shot dropdowns
        match_dd.options, shot_dd.options = ["Select match..."], ["Select shot..."]
        match_dd.disabled, shot_dd.disabled = True, True
    else:
        # Reset all dropdowns
        season_dd.options, match_dd.options, shot_dd.options = ["Select season..."], ["Select match..."], ["Select shot..."]
        season_dd.disabled, match_dd.disabled, shot_dd.disabled = True, True, True


def update_matches(change):
    """Update matches when a season is selected"""

    # Load matches for the selected season
    # change["new"] is the selected season because it comes from the dropdown
    if change["new"] and change["new"] != "Select season...":

        # Load matches for the selected season
        matches = load_matches(comp_dd.value, change["new"])

        # Update the match dropdown and enable it if matches are found
        if matches is not None and not matches.empty:
            matches.sort_values(by=["match_week", "kick_off"], inplace=True)    # Sort matches by week and kick-off time
            
            match_dd.options = ["Select match..."] + [
                f"MW {row.get('match_week','?')} - {row['home_team']} vs {row['away_team']} (ID {row['match_id']})"
                for _, row in matches.iterrows()
            ]
        else:
            # No matches found
            match_dd.options = ["No matches available"]
        match_dd.disabled, shot_dd.disabled = False, True
        shot_dd.options = ["Select shot..."]
    else:

        # Reset all dropdowns
        match_dd.options, shot_dd.options = ["Select match..."], ["Select shot..."]
        match_dd.disabled, shot_dd.disabled = True, True


def update_shots(change):
    """Update shot options when a match is selected"""

    # Load shots for the selected match
    # change["new"] is the selected match because it comes from the dropdown
    if change["new"] and "ID" in change["new"]:

        # Get the match ID from the selected match string
        match_id = int(change["new"].split("ID")[-1].strip(") "))

        # Load shots for the selected match by match ID
        shots_df = df_ds4[df_ds4["match_id"] == match_id].copy()

        # Check if shots_df is not None and not empty
        if shots_df is not None and not shots_df.empty:

            # Update the shot dropdown and enable it
            shot_dd.options = [("Select shot...", None)] + [
                (f"Minute {row['minute']}:{row['second']} - {row['player']} ({row['team']})", str(row["event_id"]))
                for _, row in shots_df.iterrows()
            ]
        else:
            # No shots found
            shot_dd.options = [("This match has no shots", None)]
        shot_dd.disabled = False
    else:
        # Reset shot dropdown
        shot_dd.options = [("Select shot...", None)]
        shot_dd.disabled = True


def update_demo(change):
    """When a shot is selected, update pitch and predictions"""

    # Clear previous outputs
    with output:
        output.clear_output()
    with pitch_output:
        pitch_output.clear_output()

        # Get the event ID from the selected shot
        event_id = change["new"]
        if event_id is None:
            return

        # Retrieve shot
        match_id = int(match_dd.value.split("ID")[-1].strip(")"))
        shots = df_ds4[df_ds4["match_id"] == match_id].copy()
        if shots.empty:
            with output:
                print("No shots found for this match in DS4")
            return

        # Get the shot row, the position of the shot and the end position
        shot_row = shots.loc[shots["event_id"] == event_id].iloc[0]
        x, y = shot_row["loc_x"], shot_row["loc_y"]
        end_x, end_y = shot_row["end_shot_x"], shot_row["end_shot_y"]

        # Parse Freeze frame
        freeze = None
        if pd.notna(shot_row["shot_freeze_frame"]):
            try:
                freeze = ast.literal_eval(shot_row["shot_freeze_frame"])
            except Exception as e:
                with output:
                    print("Could not parse freeze_frame:", e)

        # Draw new pitch for this shot
        fig = plot_shot(x, y, end_x=end_x, end_y=end_y, players=freeze)
        with pitch_output:
            display(fig)
            plt.close(fig)   # <-- avoids double printing

        # Compute predictions
        preds = compute_xg_models(shot_row)

        # Show predictions
        with output:
            output.clear_output()
            show_predictions(shot_row, preds)

In [13]:
output = Output()
pitch_output = Output()

# Attach observers
comp_dd.observe(update_seasons, names="value")
season_dd.observe(update_matches, names="value")
match_dd.observe(update_shots, names="value")
shot_dd.observe(update_demo, names="value")

# Layout
ui = HBox(
    [VBox([comp_dd, season_dd, match_dd, shot_dd, output], layout={'padding': '10px'}),
     pitch_output]
)
ui


HBox(children=(VBox(children=(Dropdown(description='Competition:', options=('Select competition...', '1. Bunde…