### The idea here is

In [1]:
import torch
from torch.utils.data import Dataset

class PossessionDataset(Dataset):
    def __init__(self, sequences, cond_vecs, T):
        """
        sequences: list of dicts, each dict has keys:
            'type', 'sz', 'ez', 'out', 'dt', 'term'  -> each is a list of int IDs (len <= T)
        cond_vecs: torch.FloatTensor [N, C]
        """
        self.seqs = sequences
        self.cond = cond_vecs
        self.T = T

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, i):
        s = self.seqs[i]
        # pad to T
        def pad(x, pad_id=0):
            x = x[:self.T]
            if len(x) < self.T:
                x = x + [pad_id] * (self.T - len(x))
            return torch.tensor(x, dtype=torch.long)

        type_ids = pad(s["type"], 0)
        sz_ids   = pad(s["sz"],   0)
        ez_ids   = pad(s["ez"],   0)
        out_ids  = pad(s["out"],  0)
        dt_ids   = pad(s["dt"],   0)
        term_ids = pad(s["term"], 0)

        # mask: real tokens are non-pad in type stream
        mask = (type_ids != 0).float()

        # teacher forcing: decoder predicts next step
        x = torch.stack([type_ids, sz_ids, ez_ids, out_ids, dt_ids, term_ids], dim=-1)  # [T, 6]
        return x, self.cond[i], mask


def masked_ce(logits, target, mask):
    # logits: [B,T,V], target: [B,T], mask: [B,T]
    B,T,V = logits.shape
    loss = F.cross_entropy(logits.reshape(B*T, V), target.reshape(B*T), reduction="none")
    loss = loss.reshape(B,T) * mask
    return loss.sum() / (mask.sum() + 1e-8)

def kld(mu, logv):
    return -0.5 * torch.mean(1 + logv - mu.pow(2) - logv.exp())


The history saving thread hit an unexpected error (OperationalError('unable to open database file')).History will not be written to the database.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SeqCVAE(nn.Module):
    def __init__(self, n_types, n_zones, n_out, n_dt, n_term,
                 emb=32, hidden=256, zdim=32, cdim=12):
        super().__init__()

        # embeddings for each categorical stream
        self.type_emb = nn.Embedding(n_types, emb, padding_idx=0)
        self.sz_emb   = nn.Embedding(n_zones, emb, padding_idx=0)
        self.ez_emb   = nn.Embedding(n_zones, emb, padding_idx=0)
        self.out_emb  = nn.Embedding(n_out,   emb, padding_idx=0)
        self.dt_emb   = nn.Embedding(n_dt,    emb, padding_idx=0)
        self.term_emb = nn.Embedding(n_term,  emb, padding_idx=0)

        in_dim = emb * 6

        # encoder
        self.enc_rnn = nn.GRU(in_dim + cdim, hidden, batch_first=True)
        self.to_mu   = nn.Linear(hidden, zdim)
        self.to_logv = nn.Linear(hidden, zdim)

        # decoder
        self.dec_rnn = nn.GRU(in_dim + cdim + zdim, hidden, batch_first=True)

        # heads: predict each categorical field
        self.h_type = nn.Linear(hidden, n_types)
        self.h_sz   = nn.Linear(hidden, n_zones)
        self.h_ez   = nn.Linear(hidden, n_zones)
        self.h_out  = nn.Linear(hidden, n_out)
        self.h_dt   = nn.Linear(hidden, n_dt)
        self.h_term = nn.Linear(hidden, n_term)

    def embed_step(self, x):
        # x: [B, T, 6] -> embeddings concat [B, T, 6*emb]
        t, sz, ez, out, dt, term = x.unbind(dim=-1)
        e = torch.cat([
            self.type_emb(t),
            self.sz_emb(sz),
            self.ez_emb(ez),
            self.out_emb(out),
            self.dt_emb(dt),
            self.term_emb(term),
        ], dim=-1)
        return e

    def encode(self, x, c):
        B, T, _ = x.shape
        e = self.embed_step(x)
        c_rep = c.unsqueeze(1).expand(B, T, c.shape[-1])
        inp = torch.cat([e, c_rep], dim=-1)
        _, h = self.enc_rnn(inp)     # h: [1, B, hidden]
        h = h.squeeze(0)
        mu = self.to_mu(h)
        logv = self.to_logv(h)
        return mu, logv

    def reparam(self, mu, logv):
        std = torch.exp(0.5 * logv)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, x_in, z, c):
        # teacher forcing: x_in is shifted right sequence (B,T,6)
        B, T, _ = x_in.shape
        e = self.embed_step(x_in)
        c_rep = c.unsqueeze(1).expand(B, T, c.shape[-1])
        z_rep = z.unsqueeze(1).expand(B, T, z.shape[-1])
        inp = torch.cat([e, c_rep, z_rep], dim=-1)
        out, _ = self.dec_rnn(inp)  # [B,T,hidden]
        return {
            "type": self.h_type(out),
            "sz":   self.h_sz(out),
            "ez":   self.h_ez(out),
            "out":  self.h_out(out),
            "dt":   self.h_dt(out),
            "term": self.h_term(out),
        }

    def forward(self, x, c):
        mu, logv = self.encode(x, c)
        z = self.reparam(mu, logv)

        # shift right for decoder input
        x_in = x.clone()
        x_in[:, 1:] = x[:, :-1]      # teacher forcing
        x_in[:, 0] = 0               # first token PAD (or a START token if you add one)

        logits = self.decode(x_in, z, c)
        return logits, mu, logv


In [3]:
#Even larger loader
import json
from pathlib import Path
import pandas as pd
import numpy as np


def _safe_id(x):
    if isinstance(x, dict) and "id" in x:
        return x["id"]
    return np.nan

def _safe_name(x):
    if isinstance(x, dict) and "name" in x:
        return x["name"]
    return None

def _safe_bool(e: dict, key: str, default=False) -> bool:
    v = e.get(key, default)
    return bool(v) if v is not None else bool(default)

def _parse_timestamp_to_sec(ts: str):
    """
    StatsBomb timestamp usually looks like '00:12:34.123' (HH:MM:SS.sss).
    Returns seconds within the current period; if parsing fails returns np.nan.
    """
    if not isinstance(ts, str):
        return np.nan
    try:
        # split hh:mm:ss(.ms)
        hh, mm, ss = ts.split(":")
        return float(hh) * 3600.0 + float(mm) * 60.0 + float(ss)
    except Exception:
        return np.nan

def _period_offset_seconds(period):
    """
    Simple offsets for regulation time.
    Period 1: 0
    Period 2: 45*60
    Period 3: 90*60 (ET1)
    Period 4: 105*60 (ET2)
    Unknown -> 0
    """
    try:
        p = int(period)
    except Exception:
        return 0.0
    if p == 1:
        return 0.0
    if p == 2:
        return 45.0 * 60.0
    if p == 3:
        return 90.0 * 60.0
    if p == 4:
        return 105.0 * 60.0
    return 0.0

def _generic_outcome_for_event(e: dict):
    """
    Returns (outcome_name, success_bool_or_nan) using StatsBomb conventions.
    - Many nested outcome fields exist; missing often implies 'Complete/Success'.
    - For event types without a notion of outcome, returns (None, np.nan).
    """
    t = _safe_name(e.get("type"))

    # PASS
    if t == "Pass" and isinstance(e.get("pass"), dict):
        out = e["pass"].get("outcome")
        if isinstance(out, dict):
            return out.get("name"), False
        # missing outcome => completed
        return "Complete", True

    # SHOT
    if t == "Shot" and isinstance(e.get("shot"), dict):
        out = e["shot"].get("outcome")
        if isinstance(out, dict):
            # success is ambiguous (goal vs on target etc). keep np.nan for boolean.
            return out.get("name"), np.nan
        return None, np.nan

    # DRIBBLE
    if t == "Dribble" and isinstance(e.get("dribble"), dict):
        out = e["dribble"].get("outcome")
        if isinstance(out, dict):
            name = out.get("name")
            # StatsBomb uses "Complete"/"Incomplete"
            if name is not None:
                return name, (name.lower() == "complete")
        return None, np.nan

    # DUEL
    if t == "Duel" and isinstance(e.get("duel"), dict):
        out = e["duel"].get("outcome")
        if isinstance(out, dict):
            name = out.get("name")
            # Often "Won"/"Lost"/"Success In Play"/etc
            if name is not None:
                low = name.lower()
                if low in ("won", "success", "success in play", "success out"):
                    return name, True
                if low in ("lost", "failure"):
                    return name, False
            return name, np.nan
        return None, np.nan

    # INTERCEPTION
    if t == "Interception" and isinstance(e.get("interception"), dict):
        out = e["interception"].get("outcome")
        if isinstance(out, dict):
            name = out.get("name")
            if name is not None:
                low = name.lower()
                if low in ("won", "success"):
                    return name, True
                if low in ("lost", "failure"):
                    return name, False
            return name, np.nan
        # If no outcome, treat as success-ish
        return "Won", True

    # BALL RECOVERY
    if t == "Ball Recovery" and isinstance(e.get("ball_recovery"), dict):
        fail = e["ball_recovery"].get("recovery_failure")
        if fail is True:
            return "Failure", False
        if fail is False:
            return "Success", True
        return None, np.nan

    # MISCONTROL (always “bad touch” in spirit)
    if t == "Miscontrol":
        return "Miscontrol", False

    # CLEARANCE (no explicit outcome)
    if t == "Clearance":
        return None, np.nan

    # PRESSURE (no explicit success)
    if t == "Pressure":
        return None, np.nan

    # FOUL COMMITTED / WON are separate event types (no outcome field)
    if t in ("Foul Committed", "Foul Won"):
        return None, np.nan

    # DEFAULT
    return None, np.nan

def flatten_events_for_match(sb_data_root: Path, match_row: dict) -> pd.DataFrame:
    match_id = match_row["match_id"]
    p = sb_data_root / "events" / f"{match_id}.json"
    ev = json.loads(p.read_text(encoding="utf-8"))

    rows = []
    for e in ev:
        loc = e.get("location", None)
        x = loc[0] if isinstance(loc, list) and len(loc) >= 2 else np.nan
        y = loc[1] if isinstance(loc, list) and len(loc) >= 2 else np.nan

        # end locations (pass/carry/shot)
        endx = endy = np.nan
        pass_length = np.nan
        pass_subtype = None

        # extra pass fields (lightweight, useful later)
        pass_height = None
        pass_cross = False
        pass_body_part = None
        pass_outcome = None
        pass_recipient_id = np.nan
        pass_recipient_name = None

        # carry distance (computed)
        carry_length = np.nan

        # shot extras
        shot_endx = shot_endy = np.nan
        shot_outcome = None
        shot_xg = np.nan
        shot_body_part = None
        shot_type = None

        # duel subtype
        duel_type = None
        duel_outcome = None

        # generic outcome
        generic_outcome, success = _generic_outcome_for_event(e)

        # time handling
        ts = e.get("timestamp", None)
        t_in_period = _parse_timestamp_to_sec(ts)
        period = e.get("period", np.nan)
        t_abs = (
            t_in_period + _period_offset_seconds(period)
            if not (isinstance(t_in_period, float) and np.isnan(t_in_period))
            else np.nan
        )

        # pass/carry details
        if isinstance(e.get("pass"), dict):
            pe = e["pass"]
            end = pe.get("end_location", None)
            if isinstance(end, list) and len(end) >= 2:
                endx, endy = end[0], end[1]
            pass_length = pe.get("length", np.nan)
            pass_subtype = _safe_name(pe.get("type"))
            pass_height = _safe_name(pe.get("height"))
            pass_cross = bool(pe.get("cross", False))
            pass_body_part = _safe_name(pe.get("body_part"))

            out = pe.get("outcome")
            pass_outcome = _safe_name(out) if isinstance(out, dict) else None

            rec = pe.get("recipient")
            pass_recipient_id = _safe_id(rec)
            pass_recipient_name = _safe_name(rec)

        elif isinstance(e.get("carry"), dict):
            ce = e["carry"]
            end = ce.get("end_location", None)
            if isinstance(end, list) and len(end) >= 2:
                endx, endy = end[0], end[1]
            # compute carry length if we have both points
            if not (np.isnan(x) or np.isnan(y) or np.isnan(endx) or np.isnan(endy)):
                carry_length = float(np.hypot(endx - x, endy - y))

        # shot details
        if isinstance(e.get("shot"), dict):
            se = e["shot"]
            out = se.get("outcome")
            shot_outcome = _safe_name(out) if isinstance(out, dict) else None

            end = se.get("end_location", None)
            if isinstance(end, list) and len(end) >= 2:
                shot_endx, shot_endy = end[0], end[1]

            # StatsBomb xG field in open data
            shot_xg = se.get("statsbomb_xg", np.nan)

            shot_body_part = _safe_name(se.get("body_part"))
            shot_type = _safe_name(se.get("type"))

        # duel details
        if isinstance(e.get("duel"), dict):
            de = e["duel"]
            duel_type = _safe_name(de.get("type"))
            duel_outcome = _safe_name(de.get("outcome")) if isinstance(de.get("outcome"), dict) else None

        rows.append({
            "match_id": match_id,
            "competition_id": match_row["competition"]["competition_id"] if isinstance(match_row.get("competition"), dict) else match_row.get("competition_id"),
            "season_id": match_row["season"]["season_id"] if isinstance(match_row.get("season"), dict) else match_row.get("season_id"),
            "competition_name": match_row.get("competition", {}).get("competition_name", None) if isinstance(match_row.get("competition"), dict) else None,
            "season_name": match_row.get("season", {}).get("season_name", None) if isinstance(match_row.get("season"), dict) else None,

            # NEW: stable event keys
            "event_id": e.get("id", None),
            "event_index": e.get("index", np.nan),

            "type": _safe_name(e.get("type")),
            "play_pattern": _safe_name(e.get("play_pattern")),

            # NEW: player
            "player_id": _safe_id(e.get("player")),
            "player_name": _safe_name(e.get("player")),

            "team_id": _safe_id(e.get("team")),
            "team_name": _safe_name(e.get("team")),
            "possession": e.get("possession", np.nan),
            "possession_team_id": _safe_id(e.get("possession_team")),
            "possession_team_name": _safe_name(e.get("possession_team")),

            "minute": e.get("minute", np.nan),
            "second": e.get("second", np.nan),
            "timestamp": ts,
            "duration": e.get("duration", np.nan),
            "period": period,

            # NEW: convenient absolute time in seconds (for dt bins)
            "t_in_period_sec": t_in_period,
            "t_abs_sec": t_abs,

            # locations
            "x": x, "y": y,
            "endx": endx, "endy": endy,

            # pass
            "pass_length": pass_length,
            "pass_subtype": pass_subtype,
            "pass_height": pass_height,
            "pass_cross": pass_cross,
            "pass_body_part": pass_body_part,
            "pass_outcome": pass_outcome,
            "pass_recipient_id": pass_recipient_id,
            "pass_recipient_name": pass_recipient_name,

            # carry
            "carry_length": carry_length,

            # shot
            "shot_endx": shot_endx,
            "shot_endy": shot_endy,
            "shot_outcome": shot_outcome,
            "shot_xg": shot_xg,
            "shot_body_part": shot_body_part,
            "shot_type": shot_type,

            # duel
            "duel_type": duel_type,
            "duel_outcome": duel_outcome,

            # NEW: pressure flags (event-level)
            "under_pressure": _safe_bool(e, "under_pressure", False),
            "counterpress": _safe_bool(e, "counterpress", False),

            # NEW: generic outcome for CVAE v1
            "outcome": generic_outcome,
            "success": success,
        })

    return pd.DataFrame(rows)


In [4]:
import os


def load_competitions(sb_data_root: Path) -> pd.DataFrame:
    comp_path = sb_data_root / "competitions.json"
    comps = json.loads(comp_path.read_text(encoding="utf-8"))
    return pd.DataFrame(comps)

TARGET = [
    ("England", "Premier League"),
   # ("Spain", "La Liga"),
    #("Italy", "Serie A"),
    #("Germany", "1. Bundesliga"),
]

def pick_competitions_1516(comps):
    selected = []

    for country, comp in TARGET:
        sel = comps[
            (comps["country_name"] == country) &
            (comps["competition_name"] == comp) &
            (comps["season_name"] == "2015/2016")
        ]

        if sel.empty:
            raise ValueError(f"Missing: {country} {comp} 2015/2016")

        selected.append(sel.iloc[0])

    return pd.DataFrame(selected)
def load_matches(sb_data_root: Path, competition_id: int, season_id: int) -> pd.DataFrame:
    p = sb_data_root / "matches" / str(competition_id) / f"{season_id}.json"
    matches = json.loads(p.read_text(encoding="utf-8"))
    return pd.DataFrame(matches)
def load_all_events_1516(sb_data_root: Path) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    comps = load_competitions(sb_data_root)
    picked = pick_competitions_1516(comps)

    all_matches = []
    for _, r in picked.iterrows():
        m = load_matches(sb_data_root, int(r["competition_id"]), int(r["season_id"]))
        # enrich for convenience
        m["competition_name"] = r["competition_name"]
        m["season_name"] = r["season_name"]
        all_matches.append(m)

    matches_df = pd.concat(all_matches, ignore_index=True)

    # Load events
    event_dfs = []
    for _, mr in matches_df.iterrows():
        event_dfs.append(flatten_events_for_match(sb_data_root, mr.to_dict()))

    events_df = pd.concat(event_dfs, ignore_index=True)

    # create a "league" label that matches your normalization bucket
    events_df["league_season"] = events_df["competition_name"].fillna("") + " | " + events_df["season_name"].fillna("")
    matches_df["league_season"] = matches_df["competition_name"].fillna("") + " | " + matches_df["season_name"].fillna("")

    return comps, matches_df, events_df
DATA_ROOT = Path(os.environ["EXJOBB_DATA"])
sb_root = DATA_ROOT / "open-data-master" / "data"
comps = pd.read_json(sb_root / "competitions.json")
comps, matches_df, events_df = load_all_events_1516(sb_root)


In [5]:
print(events_df.shape)
print(events_df.columns.tolist())
print(events_df.head())


(1313783, 49)
['match_id', 'competition_id', 'season_id', 'competition_name', 'season_name', 'event_id', 'event_index', 'type', 'play_pattern', 'player_id', 'player_name', 'team_id', 'team_name', 'possession', 'possession_team_id', 'possession_team_name', 'minute', 'second', 'timestamp', 'duration', 'period', 't_in_period_sec', 't_abs_sec', 'x', 'y', 'endx', 'endy', 'pass_length', 'pass_subtype', 'pass_height', 'pass_cross', 'pass_body_part', 'pass_outcome', 'pass_recipient_id', 'pass_recipient_name', 'carry_length', 'shot_endx', 'shot_endy', 'shot_outcome', 'shot_xg', 'shot_body_part', 'shot_type', 'duel_type', 'duel_outcome', 'under_pressure', 'counterpress', 'outcome', 'success', 'league_season']
   match_id  competition_id  season_id competition_name season_name  \
0   3754058               2         27   Premier League   2015/2016   
1   3754058               2         27   Premier League   2015/2016   
2   3754058               2         27   Premier League   2015/2016   
3   375

In [6]:
print(events_df.isna().mean().sort_values(ascending=False).head(20))


shot_endy              0.992458
shot_endx              0.992458
shot_type              0.992458
shot_body_part         0.992458
shot_xg                0.992458
shot_outcome           0.992458
duel_outcome           0.988244
duel_type              0.975422
pass_subtype           0.937896
pass_outcome           0.934358
carry_length           0.789197
pass_recipient_id      0.740958
pass_recipient_name    0.740958
pass_body_part         0.740791
pass_height            0.719422
pass_length            0.719422
success                0.688529
outcome                0.672208
endx                   0.508619
endy                   0.508619
dtype: float64


In [7]:
mid = events_df["match_id"].sample(1).iloc[0]
m = events_df[events_df.match_id == mid].copy()

m = m.sort_values(["period","t_in_period_sec","event_index"])

print("Non-monotonic within period:",
      (m.groupby("period")["t_in_period_sec"].diff().fillna(0) < 0).sum())

print("Non-monotonic absolute:",
      (m["t_abs_sec"].diff().fillna(0) < 0).sum())


Non-monotonic within period: 0
Non-monotonic absolute: 1


In [8]:
print("Unique possessions in match:", m["possession"].nunique())

# Check that possession doesn’t mix teams
pos_team_check = (
    m.groupby("possession")["possession_team_id"]
     .nunique()
     .max()
)
print("Max teams per possession:", pos_team_check)


Unique possessions in match: 216
Max teams per possession: 1


In [9]:
print(events_df["type"].value_counts().head(15))

print("Pass success rate:",
      events_df[events_df.type=="Pass"]["success"].mean())

print("Dribble success rate:",
      events_df[events_df.type=="Dribble"]["success"].mean())

print("Shot outcomes:")
print(events_df[events_df.type=="Shot"]["shot_outcome"].value_counts())


type
Pass              368619
Ball Receipt*     340324
Carry             276949
Pressure          115402
Ball Recovery      40943
Duel               32290
Clearance          21645
Block              14839
Dribble            13721
Goal Keeper        11777
Miscontrol         10786
Dispossessed       10520
Shot                9908
Foul Committed      9512
Foul Won            9112
Name: count, dtype: int64
Pass success rate: 0.7660457003030229
Dribble success rate: 0.5866919320749217
Shot outcomes:
shot_outcome
Off T               3197
Blocked             2880
Saved               2209
Goal                 988
Wayward              396
Post                 170
Saved Off Target      45
Saved to Post         23
Name: count, dtype: int64


In [10]:
events_df = events_df.sort_values(["match_id","period","t_in_period_sec","event_index"])

events_df["dt"] = (
    events_df.groupby("match_id")["t_abs_sec"]
    .diff()
)

print(events_df["dt"].describe())


count    1.313403e+06
mean     1.638085e+00
std      5.557648e+00
min     -5.555650e+02
25%      0.000000e+00
50%      7.040000e-01
75%      1.560000e+00
max      4.143830e+02
Name: dt, dtype: float64


In [None]:
df = events_df.sort_values(["match_id","minute","second","event_index"]).copy()

df["t_abs_fix"] = 60*df["minute"].fillna(0) + df["second"].fillna(0)
df["dt_fix"] = df.groupby("match_id")["t_abs_fix"].diff()

bad = df[df["dt_fix"] < 0]
print("Bad rows:", len(bad))

cols = ["match_id","event_index","type","period","minute","second","timestamp","t_abs_sec","t_abs_fix","dt_fix"]
print(bad[cols].head(20))
