## Data visualization

In [1]:
# Open .pkl data

import pickle
import os
import numpy as np
import matplotlib.pyplot as plt

file_path = 'data/trajectories.pkl'
if os.path.exists(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)



In [8]:
## Get diferent types of events
evnt_types = set()
for traj in data:
    for event in traj[:]:
        evnt_types.add(event['action'])

In [25]:

from __future__ import annotations

import argparse
import json
import math
import pickle
from pathlib import Path
from typing import List, Dict, Any

import numpy as np
import torch

# --------------------------------------------------------------------------- #
#                         Field & action constants                            #
# --------------------------------------------------------------------------- #
FIELD_X: int = 104  # metres (length)
FIELD_Y: int = 68   # metres (width)

ACTION_ID = {
    "pass": 0, "cross": 0,
    "dribble": 1, "take_on": 1,
    "shot": 2,
}

N_ACTIONS: int = 3  # one‑hot length

# --------------------------------------------------------------------------- #
#                      Core encoding helpers                                  #
# --------------------------------------------------------------------------- #

def encode_action(evt: Dict[str, Any]) -> np.ndarray:
    """Return 7‑D float32 vector: [one‑hot(3), x_s, y_s, x_e, y_e]."""
    one_hot = np.zeros(N_ACTIONS, dtype=np.float32)
    aid = ACTION_ID.get(evt["action"], None)
    if aid is None:
        raise ValueError(f"Unknown action: {evt['action']}")
    one_hot[aid] = 1.0

    start = np.array(evt["ball_start"], dtype=np.float32)
    end   = np.array(evt.get("ball_end", evt["ball_start"]), dtype=np.float32)
    return np.concatenate([one_hot, start, end])  # shape (7,)


def build_state(evt: Dict[str, Any]) -> np.ndarray:
    """Return 4×104×68 float32 tensor (channels first)."""
    # 1) Sparse occupancy maps ------------------------------------------- #
    tm = np.zeros((FIELD_X, FIELD_Y), dtype=np.float32)  # teammates
    op = np.zeros((FIELD_X, FIELD_Y), dtype=np.float32)  # opponents

    for pl in evt["player_loc"].values[0]:
        x, y = pl["location"]
        xi = int(round(np.clip(x, 0, FIELD_X - 1)))
        yi = int(round(np.clip(y, 0, FIELD_Y - 1)))
        if pl["teammate"]:
            tm[xi, yi] = 1.0
        else:
            op[xi, yi] = 1.0

    # 2) Distance & angle maps ------------------------------------------- #
    bx, by = evt["ball_start"]
    xs = np.arange(FIELD_X).reshape(-1, 1)  # (104,1)
    ys = np.arange(FIELD_Y).reshape(1, -1)  # (1,68)

    dx = xs - bx  # broadcasting
    dy = ys - by

    dist  = np.sqrt(dx * dx + dy * dy, dtype=np.float32)
    angle = np.arctan2(dy, dx, dtype=np.float32) / math.pi  # normalise → [‑1,1]

    state = np.stack([tm, op, dist, angle], axis=0)  # (4,104,68)
    return state.astype(np.float32)


# --------------------------------------------------------------------------- #
#                   Reward & return computation                               #
# --------------------------------------------------------------------------- #

def reward_fn(evt: Dict[str, Any]) -> float:
    """Sparse proxy: +1 if a *shot* succeeds, else 0."""
    return 1.0 if evt["action"] == "shot" and evt.get("outcome", False) else 0.0


def discount_cumsum(rewards: List[float], gamma: float) -> List[float]:
    G, out = 0.0, [0.0] * len(rewards)
    for t in reversed(range(len(rewards))):
        G = rewards[t] + gamma * G
        out[t] = G
    return out


# --------------------------------------------------------------------------- #
#                        Processing pipeline                                  #
# --------------------------------------------------------------------------- #

def process_trajectories(data: List[List[Dict[str, Any]]],
                         gamma: float = 0.99) -> List[List[Dict[str, Any]]]:
    """Convert raw JSON trajectories → training‑ready format."""
    processed: List[List[Dict[str, Any]]] = []

    for traj_id, traj in enumerate(data):
        steps: List[Dict[str, Any]] = []
        # encode states & actions first (next_* placeholders) ------------- #
        for t, evt in enumerate(traj):
            step = {
                "state": torch.tensor(build_state(evt)),
                "action": torch.tensor(encode_action(evt)),
                "reward": 0.0,  # placeholder, filled later
                "next_state": None,
                "next_action": None,
                "done": False,
                "traj_id": traj_id,
                "t": t,
            }
            steps.append(step)

        # pointer to next state/action & terminal flag -------------------- #
        for i in range(len(steps) - 1):
            steps[i]["next_state"] = steps[i + 1]["state"]
            steps[i]["next_action"] = steps[i + 1]["action"]
        steps[-1]["done"] = True

        # assign sparse reward only to terminal step ---------------------- #
        steps[-1]["reward"] = reward_fn(traj[-1])
        
        # Monte‑Carlo returns (λ‑returns added later in training script) --- #
        rewards = [s["reward"] for s in steps]
        Gs = discount_cumsum(rewards, gamma)
        for s, G in zip(steps, Gs):
            s["G"] = G

        processed.append(steps)
    return processed

In [None]:
build_state(data[0][0])

In [32]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

def plot_soccer_state(state: np.ndarray,
                      teammate_thresh: float = 0.5,
                      opponent_thresh: float = 0.5):
    """
    Visualise a 4×104×68 state tensor.
    
    Parameters
    ----------
    state : np.ndarray
        Tensor with channels:
          0 – teammate occupancy   (binary/sparse)
          1 – opponent occupancy   (binary/sparse)
          2 – distance-to-ball map (float, 0 at ball)
          3 – angle-to-ball map    (unused here)
    teammate_thresh, opponent_thresh : float
        Value above which a grid-cell is considered occupied.
    """
    # ------------------------------------------------------------------ #
    # 1. Unpack channels & locate entities
    teammates = state[0] > teammate_thresh
    opponents = state[1] > opponent_thresh

    # player coordinates (x = length, y = width)
    tx, ty = np.where(teammates)
    ox, oy = np.where(opponents)

    # ball = pixel of minimum distance
    bx, by = np.unravel_index(state[2].argmin(), state[2].shape)

    # ------------------------------------------------------------------ #
    # 2. Draw the pitch
    L, W = 104, 68                       # metres
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.add_patch(Rectangle((0, 0), L, W, fill=False, lw=2, color="black"))
    ax.plot([L / 2, L / 2], [0, W], color="black")       # half-way line

    # ------------------------------------------------------------------ #
    # 3. Scatter entities
    ax.scatter(tx, ty, c="blue",  s=40, label="Teammate")
    ax.scatter(ox, oy, c="red",   s=40, label="Opponent")
    ax.scatter(bx, by, c="orange", s=120, edgecolors="black",
               marker="o", label="Ball")

    # Aesthetics -------------------------------------------------------- #
    ax.set_xlim(0, L)
    ax.set_ylim(0, W)
    ax.set_aspect("equal")
    ax.invert_yaxis()                    # origin top-left like tracking data
    ax.set_xlabel("Metres (length)")
    ax.set_ylabel("Metres (width)")
    ax.set_title("Soccer state snapshot")
    ax.legend()
    plt.tight_layout()
    plt.show()

In [None]:
plot_soccer_state(build_state(data[0][6]))