### The idea here is

In [6]:
# --- CVAE v1 (fixed) ---
# Fixes:
# 1) Separate vocab sizes for start_zone vs end_zone (sz != ez)
# 2) term loss only on END timesteps
# 3) Encoder ignores padding via pack_padded_sequence
# 4) Correct import order + minor cleanup

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class PossessionDataset(Dataset):
    def __init__(self, seq_ids, cond_vecs, T: int):
        """
        seq_ids: list of dicts with keys: 'type','sz','ez','out','dt','term' (lists of int ids, variable length)
        cond_vecs: np.ndarray or torch.FloatTensor [N, C]
        """
        self.seqs = seq_ids
        self.cond = torch.as_tensor(cond_vecs, dtype=torch.float32)
        self.T = int(T)

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, i):
        s = self.seqs[i]

        def pad(x, pad_id=0):
            x = x[: self.T]
            if len(x) < self.T:
                x = x + [pad_id] * (self.T - len(x))
            return torch.tensor(x, dtype=torch.long)

        type_ids = pad(s["type"], 0)
        sz_ids   = pad(s["sz"],   0)
        ez_ids   = pad(s["ez"],   0)
        out_ids  = pad(s["out"],  0)
        dt_ids   = pad(s["dt"],   0)
        term_ids = pad(s["term"], 0)

        # mask: real tokens are non-pad in type stream
        mask = (type_ids != 0).float()
        length = int(mask.sum().item())

        x = torch.stack([type_ids, sz_ids, ez_ids, out_ids, dt_ids, term_ids], dim=-1)  # [T, 6]
        return x, self.cond[i], mask, length


def masked_ce(logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """
    logits: [B,T,V], target: [B,T], mask: [B,T] in {0,1}
    """
    B, T, V = logits.shape
    loss = F.cross_entropy(logits.reshape(B * T, V), target.reshape(B * T), reduction="none")
    loss = loss.reshape(B, T) * mask
    return loss.sum() / (mask.sum() + 1e-8)


def kld(mu: torch.Tensor, logv: torch.Tensor) -> torch.Tensor:
    return -0.5 * torch.mean(1 + logv - mu.pow(2) - logv.exp())


class SeqCVAE(nn.Module):
    def __init__(
        self,
        n_types: int,
        n_sz: int,
        n_ez: int,
        n_out: int,
        n_dt: int,
        n_term: int,
        emb: int = 32,
        hidden: int = 256,
        zdim: int = 32,
        cdim: int = 12,
    ):
        super().__init__()
        self.cdim = cdim
        self.zdim = zdim

        # embeddings
        self.type_emb = nn.Embedding(n_types, emb, padding_idx=0)
        self.sz_emb   = nn.Embedding(n_sz,   emb, padding_idx=0)
        self.ez_emb   = nn.Embedding(n_ez,   emb, padding_idx=0)
        self.out_emb  = nn.Embedding(n_out,  emb, padding_idx=0)
        self.dt_emb   = nn.Embedding(n_dt,   emb, padding_idx=0)
        self.term_emb = nn.Embedding(n_term, emb, padding_idx=0)

        in_dim = emb * 6

        # encoder (packed)
        self.enc_rnn = nn.GRU(in_dim + cdim, hidden, batch_first=True)
        self.to_mu   = nn.Linear(hidden, zdim)
        self.to_logv = nn.Linear(hidden, zdim)

        # decoder (teacher forcing, un-packed is fine)
        self.dec_rnn = nn.GRU(in_dim + cdim + zdim, hidden, batch_first=True)

        # heads
        self.h_type = nn.Linear(hidden, n_types)
        self.h_sz   = nn.Linear(hidden, n_sz)
        self.h_ez   = nn.Linear(hidden, n_ez)
        self.h_out  = nn.Linear(hidden, n_out)
        self.h_dt   = nn.Linear(hidden, n_dt)
        self.h_term = nn.Linear(hidden, n_term)

    def embed_step(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B,T,6]
        t, sz, ez, out, dt, term = x.unbind(dim=-1)
        return torch.cat(
            [
                self.type_emb(t),
                self.sz_emb(sz),
                self.ez_emb(ez),
                self.out_emb(out),
                self.dt_emb(dt),
                self.term_emb(term),
            ],
            dim=-1,
        )

    def encode(self, x: torch.Tensor, c: torch.Tensor, lengths: torch.Tensor):
        """
        lengths: [B] int, number of non-pad timesteps (>=1)
        """
        B, T, _ = x.shape
        e = self.embed_step(x)  # [B,T,6*emb]
        c_rep = c.unsqueeze(1).expand(B, T, c.shape[-1])
        inp = torch.cat([e, c_rep], dim=-1)  # [B,T,in+cdim]

        # pack so PAD timesteps don't influence the encoder hidden state
        lengths_cpu = lengths.detach().to("cpu")
        packed = pack_padded_sequence(inp, lengths_cpu, batch_first=True, enforce_sorted=False)
        _, h = self.enc_rnn(packed)  # h: [1,B,H]
        h = h.squeeze(0)

        mu = self.to_mu(h)
        logv = self.to_logv(h)
        return mu, logv

    def reparam(self, mu: torch.Tensor, logv: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logv)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, x_in: torch.Tensor, z: torch.Tensor, c: torch.Tensor):
        B, T, _ = x_in.shape
        e = self.embed_step(x_in)
        c_rep = c.unsqueeze(1).expand(B, T, c.shape[-1])
        z_rep = z.unsqueeze(1).expand(B, T, z.shape[-1])
        inp = torch.cat([e, c_rep, z_rep], dim=-1)
        out, _ = self.dec_rnn(inp)
        return {
            "type": self.h_type(out),
            "sz":   self.h_sz(out),
            "ez":   self.h_ez(out),
            "out":  self.h_out(out),
            "dt":   self.h_dt(out),
            "term": self.h_term(out),
        }

    def forward(self, x: torch.Tensor, c: torch.Tensor, lengths: torch.Tensor):
        mu, logv = self.encode(x, c, lengths)
        z = self.reparam(mu, logv)

        # shift-right for teacher forcing
        x_in = x.clone()
        x_in[:, 1:] = x[:, :-1]
        x_in[:, 0] = 0  # PAD as "start"

        logits = self.decode(x_in, z, c)
        return logits, mu, logv


def compute_loss(
    logits: dict,
    x: torch.Tensor,
    mask: torch.Tensor,
    type_end_id: int,
    mu: torch.Tensor,
    logv: torch.Tensor,
    beta: float = 1.0,
):
    """
    Applies term loss only where target type == END.
    """
    type_t, sz_t, ez_t, out_t, dt_t, term_t = x.unbind(dim=-1)

    loss_main = (
        masked_ce(logits["type"], type_t, mask) +
        masked_ce(logits["sz"],   sz_t,   mask) +
        masked_ce(logits["ez"],   ez_t,   mask) +
        masked_ce(logits["out"],  out_t,  mask) +
        masked_ce(logits["dt"],   dt_t,   mask)
    )

    end_mask = (type_t == type_end_id).float() * mask
    loss_term = masked_ce(logits["term"], term_t, end_mask) if end_mask.sum() > 0 else torch.tensor(0.0, device=x.device)

    loss_kld = kld(mu, logv)

    total = loss_main + loss_term + beta * loss_kld
    return total, {"main": loss_main.detach(), "term": loss_term.detach(), "kld": loss_kld.detach()}


# --- Example training loop skeleton (minimal) ---
from tqdm import tqdm

def train_one_epoch(model, loader, optimizer, device, type2id, beta):
    model.train()
    END_ID = type2id["END"]
    running = 0.0

    pbar = tqdm(loader, desc=f"train beta={beta:.2f}", leave=False)
    for x, c, mask, lengths in pbar:
        x = x.to(device)
        c = c.to(device)
        mask = mask.to(device)
        lengths = lengths.to(device)

        logits, mu, logv = model(x, c, lengths)
        loss, parts = compute_loss(logits, x, mask, END_ID, mu, logv, beta=beta)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running += loss.item()
        pbar.set_postfix(loss=loss.item(), main=float(parts["main"]), kld=float(parts["kld"]))

    return running / max(1, len(loader))



# --- Instantiate correctly ---
# model = SeqCVAE(
#     n_types=len(type2id),
#     n_sz=len(sz2id),
#     n_ez=len(ez2id),
#     n_out=len(out2id),
#     n_dt=len(dt2id),
#     n_term=len(term2id),
#     emb=32, hidden=256, zdim=32, cdim=12
# )


In [7]:
#Even larger loader
import json
from pathlib import Path
import pandas as pd
import numpy as np


def _safe_id(x):
    if isinstance(x, dict) and "id" in x:
        return x["id"]
    return np.nan

def _safe_name(x):
    if isinstance(x, dict) and "name" in x:
        return x["name"]
    return None

def _safe_bool(e: dict, key: str, default=False) -> bool:
    v = e.get(key, default)
    return bool(v) if v is not None else bool(default)


def _generic_outcome_for_event(e: dict):
    """
    Returns (outcome_name, success_bool_or_nan) using StatsBomb conventions.
    - Many nested outcome fields exist; missing often implies 'Complete/Success'.
    - For event types without a notion of outcome, returns (None, np.nan).
    """
    t = _safe_name(e.get("type"))

    # PASS
    if t == "Pass" and isinstance(e.get("pass"), dict):
        out = e["pass"].get("outcome")
        if isinstance(out, dict):
            return out.get("name"), False
        # missing outcome => completed
        return "Complete", True

    # SHOT
    if t == "Shot" and isinstance(e.get("shot"), dict):
        out = e["shot"].get("outcome")
        if isinstance(out, dict):
            # success is ambiguous (goal vs on target etc). keep np.nan for boolean.
            return out.get("name"), np.nan
        return None, np.nan

    # DRIBBLE
    if t == "Dribble" and isinstance(e.get("dribble"), dict):
        out = e["dribble"].get("outcome")
        if isinstance(out, dict):
            name = out.get("name")
            # StatsBomb uses "Complete"/"Incomplete"
            if name is not None:
                return name, (name.lower() == "complete")
        return None, np.nan

    # DUEL
    if t == "Duel" and isinstance(e.get("duel"), dict):
        out = e["duel"].get("outcome")
        if isinstance(out, dict):
            name = out.get("name")
            # Often "Won"/"Lost"/"Success In Play"/etc
            if name is not None:
                low = name.lower()
                if low in ("won", "success", "success in play", "success out"):
                    return name, True
                if low in ("lost", "failure"):
                    return name, False
            return name, np.nan
        return None, np.nan

    # INTERCEPTION
    if t == "Interception" and isinstance(e.get("interception"), dict):
        out = e["interception"].get("outcome")
        if isinstance(out, dict):
            name = out.get("name")
            if name is not None:
                low = name.lower()
                if low in ("won", "success"):
                    return name, True
                if low in ("lost", "failure"):
                    return name, False
            return name, np.nan
        # If no outcome, treat as success-ish
        return "Won", True

    # BALL RECOVERY
    if t == "Ball Recovery" and isinstance(e.get("ball_recovery"), dict):
        fail = e["ball_recovery"].get("recovery_failure")
        if fail is True:
            return "Failure", False
        if fail is False:
            return "Success", True
        return None, np.nan

    # MISCONTROL (always “bad touch” in spirit)
    if t == "Miscontrol":
        return "Miscontrol", False

    # CLEARANCE (no explicit outcome)
    if t == "Clearance":
        return None, np.nan

    # PRESSURE (no explicit success)
    if t == "Pressure":
        return None, np.nan

    # FOUL COMMITTED / WON are separate event types (no outcome field)
    if t in ("Foul Committed", "Foul Won"):
        return None, np.nan

    # DEFAULT
    return None, np.nan

def flatten_events_for_match(sb_data_root: Path, match_row: dict) -> pd.DataFrame:
    match_id = match_row["match_id"]
    p = sb_data_root / "events" / f"{match_id}.json"
    ev = json.loads(p.read_text(encoding="utf-8"))

    rows = []
    for e in ev:
        loc = e.get("location", None)
        x = loc[0] if isinstance(loc, list) and len(loc) >= 2 else np.nan
        y = loc[1] if isinstance(loc, list) and len(loc) >= 2 else np.nan

        # end locations (pass/carry/shot)
        endx = endy = np.nan
        pass_length = np.nan
        pass_subtype = None

        # extra pass fields (lightweight, useful later)
        pass_height = None
        pass_cross = False
        pass_body_part = None
        pass_outcome = None
        pass_recipient_id = np.nan
        pass_recipient_name = None

        # carry distance (computed)
        carry_length = np.nan

        # shot extras
        shot_endx = shot_endy = np.nan
        shot_outcome = None
        shot_xg = np.nan
        shot_body_part = None
        shot_type = None

        # duel subtype
        duel_type = None
        duel_outcome = None

        # generic outcome
        generic_outcome, success = _generic_outcome_for_event(e)

            # time handling (robust: use minute/second)
        ts = e.get("timestamp", None)
        period = e.get("period", np.nan)

        minute = e.get("minute", np.nan)
        second = e.get("second", np.nan)

        if pd.notna(minute) and pd.notna(second):
            t_abs = float(minute) * 60.0 + float(second)
        else:
            t_abs = np.nan

        # keep name for compatibility; it's "match clock seconds"
        t_in_period = t_abs


        # pass/carry details
        if isinstance(e.get("pass"), dict):
            pe = e["pass"]
            end = pe.get("end_location", None)
            if isinstance(end, list) and len(end) >= 2:
                endx, endy = end[0], end[1]
            pass_length = pe.get("length", np.nan)
            pass_subtype = _safe_name(pe.get("type"))
            pass_height = _safe_name(pe.get("height"))
            pass_cross = bool(pe.get("cross", False))
            pass_body_part = _safe_name(pe.get("body_part"))

            out = pe.get("outcome")
            pass_outcome = _safe_name(out) if isinstance(out, dict) else None

            rec = pe.get("recipient")
            pass_recipient_id = _safe_id(rec)
            pass_recipient_name = _safe_name(rec)

        elif isinstance(e.get("carry"), dict):
            ce = e["carry"]
            end = ce.get("end_location", None)
            if isinstance(end, list) and len(end) >= 2:
                endx, endy = end[0], end[1]
            # compute carry length if we have both points
            if not (np.isnan(x) or np.isnan(y) or np.isnan(endx) or np.isnan(endy)):
                carry_length = float(np.hypot(endx - x, endy - y))

        # shot details
        if isinstance(e.get("shot"), dict):
            se = e["shot"]
            out = se.get("outcome")
            shot_outcome = _safe_name(out) if isinstance(out, dict) else None

            end = se.get("end_location", None)
            if isinstance(end, list) and len(end) >= 2:
                shot_endx, shot_endy = end[0], end[1]

            # StatsBomb xG field in open data
            shot_xg = se.get("statsbomb_xg", np.nan)

            shot_body_part = _safe_name(se.get("body_part"))
            shot_type = _safe_name(se.get("type"))

        # duel details
        if isinstance(e.get("duel"), dict):
            de = e["duel"]
            duel_type = _safe_name(de.get("type"))
            duel_outcome = _safe_name(de.get("outcome")) if isinstance(de.get("outcome"), dict) else None

        rows.append({
            "match_id": match_id,
            "competition_id": match_row["competition"]["competition_id"] if isinstance(match_row.get("competition"), dict) else match_row.get("competition_id"),
            "season_id": match_row["season"]["season_id"] if isinstance(match_row.get("season"), dict) else match_row.get("season_id"),
            "competition_name": match_row.get("competition", {}).get("competition_name", None) if isinstance(match_row.get("competition"), dict) else None,
            "season_name": match_row.get("season", {}).get("season_name", None) if isinstance(match_row.get("season"), dict) else None,

            # NEW: stable event keys
            "event_id": e.get("id", None),
            "event_index": e.get("index", np.nan),

            "type": _safe_name(e.get("type")),
            "play_pattern": _safe_name(e.get("play_pattern")),

            # NEW: player
            "player_id": _safe_id(e.get("player")),
            "player_name": _safe_name(e.get("player")),

            "team_id": _safe_id(e.get("team")),
            "team_name": _safe_name(e.get("team")),
            "possession": e.get("possession", np.nan),
            "possession_team_id": _safe_id(e.get("possession_team")),
            "possession_team_name": _safe_name(e.get("possession_team")),

            "minute": e.get("minute", np.nan),
            "second": e.get("second", np.nan),
            "timestamp": ts,
            "duration": e.get("duration", np.nan),
            "period": period,

            # NEW: convenient absolute time in seconds (for dt bins)
            "t_in_period_sec": t_in_period,
            "t_abs_sec": t_abs,

            # locations
            "x": x, "y": y,
            "endx": endx, "endy": endy,

            # pass
            "pass_length": pass_length,
            "pass_subtype": pass_subtype,
            "pass_height": pass_height,
            "pass_cross": pass_cross,
            "pass_body_part": pass_body_part,
            "pass_outcome": pass_outcome,
            "pass_recipient_id": pass_recipient_id,
            "pass_recipient_name": pass_recipient_name,

            # carry
            "carry_length": carry_length,

            # shot
            "shot_endx": shot_endx,
            "shot_endy": shot_endy,
            "shot_outcome": shot_outcome,
            "shot_xg": shot_xg,
            "shot_body_part": shot_body_part,
            "shot_type": shot_type,

            # duel
            "duel_type": duel_type,
            "duel_outcome": duel_outcome,

            # NEW: pressure flags (event-level)
            "under_pressure": _safe_bool(e, "under_pressure", False),
            "counterpress": _safe_bool(e, "counterpress", False),

            # NEW: generic outcome for CVAE v1
            "outcome": generic_outcome,
            "success": success,
        })

    return pd.DataFrame(rows)


In [8]:
import os


def load_competitions(sb_data_root: Path) -> pd.DataFrame:
    comp_path = sb_data_root / "competitions.json"
    comps = json.loads(comp_path.read_text(encoding="utf-8"))
    return pd.DataFrame(comps)

TARGET = [
    ("England", "Premier League"),
   # ("Spain", "La Liga"),
    #("Italy", "Serie A"),
    #("Germany", "1. Bundesliga"),
]

def pick_competitions_1516(comps):
    selected = []

    for country, comp in TARGET:
        sel = comps[
            (comps["country_name"] == country) &
            (comps["competition_name"] == comp) &
            (comps["season_name"] == "2015/2016")
        ]

        if sel.empty:
            raise ValueError(f"Missing: {country} {comp} 2015/2016")

        selected.append(sel.iloc[0])

    return pd.DataFrame(selected)
def load_matches(sb_data_root: Path, competition_id: int, season_id: int) -> pd.DataFrame:
    p = sb_data_root / "matches" / str(competition_id) / f"{season_id}.json"
    matches = json.loads(p.read_text(encoding="utf-8"))
    return pd.DataFrame(matches)
def load_all_events_1516(sb_data_root: Path) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    comps = load_competitions(sb_data_root)
    picked = pick_competitions_1516(comps)

    all_matches = []
    for _, r in picked.iterrows():
        m = load_matches(sb_data_root, int(r["competition_id"]), int(r["season_id"]))
        # enrich for convenience
        m["competition_name"] = r["competition_name"]
        m["season_name"] = r["season_name"]
        all_matches.append(m)

    matches_df = pd.concat(all_matches, ignore_index=True)

    # Load events
    event_dfs = []
    for _, mr in matches_df.iterrows():
        event_dfs.append(flatten_events_for_match(sb_data_root, mr.to_dict()))

    events_df = pd.concat(event_dfs, ignore_index=True)

    # create a "league" label that matches your normalization bucket
    events_df["league_season"] = events_df["competition_name"].fillna("") + " | " + events_df["season_name"].fillna("")
    matches_df["league_season"] = matches_df["competition_name"].fillna("") + " | " + matches_df["season_name"].fillna("")

    return comps, matches_df, events_df
DATA_ROOT = Path(os.environ["EXJOBB_DATA"])
sb_root = DATA_ROOT / "open-data-master" / "data"
comps = pd.read_json(sb_root / "competitions.json")
comps, matches_df, events_df = load_all_events_1516(sb_root)


In [None]:
import functions
NONPLAY = {
    "Starting XI","Half Start","Half End","Tactical Shift",
    "Injury Stoppage","Referee Ball-Drop"
}

df = events_df.copy()
DROP_TYPES = {
    "Starting XI","Half Start","Half End",
    "Ball Receipt*",
    "Substitution","Player On","Player Off",
    "Bad Behaviour","Own Goal For","Own Goal Against",
    "Injury Stoppage","Tactical Shift","Referee Ball-Drop",
    # v1 choice:
    "Pressure",
}
KEEP_TYPES = {"Pass", "Carry", "Dribble", "Shot", "Miscontrol", "Dispossessed"}

KEEP = {"Pass","Carry","Dribble","Shot","Miscontrol","Dispossessed","Foul Won","Foul Committed","Ball Recovery","Interception"}
df = df[df["type"].isin(KEEP_TYPES)].copy()





# sort with your new robust clock
df = df.sort_values(["match_id","minute","second","event_index"]).reset_index(drop=True)

# delta time in seconds within match
df["dt"] = df.groupby("match_id")["t_abs_sec"].diff().fillna(0).clip(lower=0)

# bin it (tune later)
bins = [-1, 1, 3, 7, 15, 1e9]
labels = ["0-1", "1-3", "3-7", "7-15", "15+"]
df["dt_bin"] = pd.cut(df["dt"], bins=bins, labels=labels)
df["dt_bin"] = df["dt_bin"].astype(str)  # for vocab building later

df["start_zone"] = df.apply(lambda r: functions.get_zone(r["x"], r["y"]), axis=1)

df["end_zone"] = df.apply(
    lambda r: functions.get_zone(r["endx"], r["endy"]) if pd.notna(r["endx"]) and pd.notna(r["endy"]) else "NA_END",
    axis=1
)

def v1_outcome(row):
    t = row["type"]
    if t == "Pass":
        if pd.isna(row["success"]): return "NA"
        return "Complete" if bool(row["success"]) else "Incomplete"
    if t == "Dribble":
        if pd.isna(row["success"]): return "NA"
        return "Complete" if bool(row["success"]) else "Incomplete"
    if t == "Shot":
        return row["shot_outcome"] if row["shot_outcome"] is not None else "NA"
    return "NA"


df["outcome_v1"] = df.apply(v1_outcome, axis=1)

def terminal_reason(last_row):
    t = last_row["type"]
    if t == "Shot":
        return "shot"
    if t in ("Foul Committed","Foul Won"):
        return "foul"
    if t in ("Pass","Carry","Dribble") and last_row.get("outcome_v1") in ("Incomplete", "Fail
    ure", "Lost"):
        return "turnover"
    if t in ("Dispossessed","Miscontrol"):
        return "turnover"
    if t in ("Clearance","Interception","Ball Recovery","Duel"):
        return "turnover"
    return "other"

def build_possession_sequences(df, max_T=40):
    sequences = []
    meta = []  # store keys for matching condition vectors later

    grp_cols = ["match_id", "possession", "possession_team_id"]
    for (mid, poss, ptid), g in df.groupby(grp_cols, sort=False):
        g = g.sort_values(["minute","second","event_index"])
        g = g[g["team_id"] == g["possession_team_id"]]
        if g.empty:
            continue

        steps = []
        for _, r in g.iterrows():
            steps.append({
                "type": r["type"],
                "sz": r["start_zone"],
                "ez": r["end_zone"],
                "out": r["outcome_v1"],
                "dt": r["dt_bin"],
                "term": "NA_TERM"  # only used for END token
            })

        # append END

        end_reason = terminal_reason(g.iloc[-1])
        steps.append({
            "type": "END",
            "sz": "PAD",
            "ez": "PAD",
            "out": "PAD",
            "dt": "PAD",
            "term": end_reason
        })

        # truncate/pad later; but keep reasonable max length now
        if len(steps) > max_T:
            steps = steps[:max_T-1] + [steps[-1]]  # keep END
        if len(steps) == 0:
            continue

        sequences.append(steps)
        meta.append({"match_id": mid, "possession": poss, "possession_team_id": ptid})

    return sequences, pd.DataFrame(meta)

sequences, seq_meta = build_possession_sequences(df, max_T=40)

print("Num sequences:", len(sequences))
print("Example sequence length:", len(sequences[0]))
print(sequences[0][:5])
print("END step:", sequences[0][-1])


Num sequences: 70683
Example sequence length: 18
[{'type': 'Pass', 'sz': 'Center_Dead_Att', 'ez': 'Center_Dead_Att', 'out': 'Complete', 'dt': '0-1', 'term': 'NA_TERM'}, {'type': 'Pass', 'sz': 'Center_Dead_Att', 'ez': 'Center_Dead_Def', 'out': 'Complete', 'dt': '0-1', 'term': 'NA_TERM'}, {'type': 'Carry', 'sz': 'Center_Dead_Def', 'ez': 'Center_Dead_Def', 'out': 'NA', 'dt': '0-1', 'term': 'NA_TERM'}, {'type': 'Pass', 'sz': 'Center_Dead_Def', 'ez': 'Def_Pocket_Right', 'out': 'Complete', 'dt': '0-1', 'term': 'NA_TERM'}, {'type': 'Carry', 'sz': 'Def_Pocket_Right', 'ez': 'Def_Pocket_Right', 'out': 'NA', 'dt': '1-3', 'term': 'NA_TERM'}]
END step: {'type': 'END', 'sz': 'PAD', 'ez': 'PAD', 'out': 'PAD', 'dt': 'PAD', 'term': 'turnover'}


In [10]:
from collections import Counter

def build_vocab(values, add_pad=True):
    # PAD must be 0
    uniq = sorted(set(values))
    vocab = {"PAD": 0} if add_pad else {}
    for v in uniq:
        if add_pad and v == "PAD":
            continue
        if v not in vocab:
            vocab[v] = len(vocab)
    return vocab

all_types, all_sz, all_ez, all_out, all_dt, all_term = [], [], [], [], [], []

for seq in sequences:
    for s in seq:
        all_types.append(s["type"])
        all_sz.append(s["sz"])
        all_ez.append(s["ez"])
        all_out.append(s["out"])
        all_dt.append(s["dt"])
        all_term.append(s["term"])

type2id = build_vocab(all_types)
sz2id   = build_vocab(all_sz)
ez2id   = build_vocab(all_ez)     # includes NA_END
out2id  = build_vocab(all_out)
dt2id   = build_vocab(all_dt)
term2id = build_vocab(all_term)   # includes NA_TERM + shot/turnover/etc

print("Vocab sizes:",
      len(type2id), len(sz2id), len(ez2id), len(out2id), len(dt2id), len(term2id))

print("Example ids:", type2id["Pass"], sz2id["Center_Dead_Att"], ez2id["NA_END"])


Vocab sizes: 8 27 28 12 6 5
Example ids: 6 4 9


In [11]:
def seq_to_ids(seq):
    return {
        "type": [type2id[s["type"]] for s in seq],
        "sz":   [sz2id[s["sz"]] for s in seq],
        "ez":   [ez2id[s["ez"]] for s in seq],
        "out":  [out2id[s["out"]] for s in seq],
        "dt":   [dt2id[s["dt"]] for s in seq],
        "term": [term2id[s["term"]] for s in seq],
    }

seq_ids = [seq_to_ids(seq) for seq in sequences]
print(seq_ids[0]["type"][:10])
@torch.no_grad()
def eval_one_epoch(model, loader, device, type2id, beta=1.0):
    model.eval()
    END_ID = type2id["END"]
    total = 0.0
    parts_sum = {"main":0.0, "term":0.0, "kld":0.0}
    n = 0

    for x, c, mask, lengths in loader:
        x = x.to(device); c = c.to(device)
        mask = mask.to(device); lengths = lengths.to(device)

        logits, mu, logv = model(x, c, lengths)
        loss, parts = compute_loss(logits, x, mask, END_ID, mu, logv, beta=beta)

        total += loss.item()
        parts_sum["main"] += float(parts["main"])
        parts_sum["term"] += float(parts["term"])
        parts_sum["kld"]  += float(parts["kld"])
        n += 1

    out = {k: v/max(1,n) for k,v in parts_sum.items()}
    out["total"] = total/max(1,n)
    return out


[6, 6, 1, 6, 1, 6, 1, 6, 1, 6]


In [12]:
from prem1516 import build_team_match_features_1516
team_match = build_team_match_features_1516(sb_root) 


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[c] = df[c].astype("Int64")
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["pass_width"] = abs(df["endy"] - 40)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[c] = df[c].astype("Int64")
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer

In [13]:
tm = team_match.copy()
# build opponent lookup: each match should have exactly 2 teams
opp = (tm[["match_id","team_id"]]
       .merge(tm[["match_id","team_id"]], on="match_id", suffixes=("_own","_opp")))

opp = opp[opp["team_id_own"] != opp["team_id_opp"]].drop_duplicates()

# now join opponent id back onto team rows
tm2 = tm.merge(
    opp.rename(columns={"team_id_own":"team_id", "team_id_opp":"opp_team_id"})[["match_id","team_id","opp_team_id"]],
    on=["match_id","team_id"],
    how="left"
)
feat_cols = ["mean_width", "directness", "tempo", "press_intensity", "press_height_mean_x"
, "mean_pass_length"]

assert tm2["opp_team_id"].notna().all(), "Some teams have no opponent mapped (check matches with != 2 teams)."
own = tm2[["match_id","team_id","opp_team_id"] + feat_cols].copy()
own.columns = ["match_id","team_id","opp_team_id"] + [f"own_{c}" for c in feat_cols]

opp_feats = tm2[["match_id","team_id"] + feat_cols].copy()
opp_feats.columns = ["match_id","opp_team_id"] + [f"opp_{c}" for c in feat_cols]

pair = own.merge(opp_feats, on=["match_id","opp_team_id"], how="left")

# sanity
assert pair[[f"opp_{c}" for c in feat_cols]].notna().all().all(), "Missing opponent feature rows."

seq_meta2 = seq_meta.rename(columns={"possession_team_id":"team_id"}).copy()

seq_with_c = seq_meta2.merge(
    pair,
    on=["match_id","team_id"],
    how="left"
)

missing = seq_with_c[[f"own_{c}" for c in feat_cols] + [f"opp_{c}" for c in feat_cols]].isna().any(axis=1).sum()
print("Sequences missing conditioning:", missing)
cond_cols = [f"own_{c}" for c in feat_cols] + [f"opp_{c}" for c in feat_cols]

C = seq_with_c[cond_cols].astype(float).to_numpy()

mu = np.nanmean(C, axis=0)
sd = np.nanstd(C, axis=0)
sd[sd == 0] = 1.0

Cz = (C - mu) / sd

print("Cond shape:", Cz.shape)  # should be [num_sequences, 12]


Sequences missing conditioning: 0
Cond shape: (70683, 12)


In [14]:
lengths = [len(seq) for seq in sequences]

print("Min length:", min(lengths))
print("Max length:", max(lengths))
print("Mean length:", np.mean(lengths))
print("Median length:", np.median(lengths))

# percentiles
for p in [75, 90, 95, 99]:
    print(f"{p}th percentile:", np.percentile(lengths, p))


Min length: 2
Max length: 40
Mean length: 9.894684719097944
Median length: 7.0
75th percentile: 13.0
90th percentile: 22.0
95th percentile: 29.0
99th percentile: 40.0


In [15]:
from torch.utils.data import DataLoader

T = 40
from torch.utils.data import DataLoader, Subset

N = len(seq_ids)
idx = np.arange(N)
rng = np.random.default_rng(0)
rng.shuffle(idx)

n_train = int(0.80 * N)
n_val   = int(0.10 * N)

train_idx = idx[:n_train]
val_idx   = idx[n_train:n_train+n_val]
test_idx  = idx[n_train+n_val:]

ds_all = PossessionDataset(seq_ids, Cz, T=40)

train_ds = Subset(ds_all, train_idx)
val_ds   = Subset(ds_all, val_idx)
test_ds  = Subset(ds_all, test_idx)

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True,  num_workers=0, drop_last=True)
val_dl   = DataLoader(val_ds,   batch_size=128, shuffle=False, num_workers=0, drop_last=False)
test_dl  = DataLoader(test_ds,  batch_size=128, shuffle=False, num_workers=0, drop_last=False)


device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

model = SeqCVAE(
    n_types=len(type2id),
    n_sz=len(sz2id),
    n_ez=len(ez2id),
    n_out=len(out2id),
    n_dt=len(dt2id),
    n_term=len(term2id),
    emb=32, hidden=256, zdim=32, cdim=12
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

# KL anneal: ramp beta 0->1 over first 5 epochs
for epoch in range(20):
    beta = min(1.0, epoch / 5.0)
    loss = train_one_epoch(model, train_dl, opt, device, type2id, beta=beta)
    print(epoch, "beta", beta, "loss", loss)
    val_stats = eval_one_epoch(model, val_dl, device, type2id, beta=beta)
    print(val_stats)


import torch

checkpoint = {
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": opt.state_dict(),
    "type2id": type2id,
    "sz2id": sz2id,
    "ez2id": ez2id,
    "out2id": out2id,
    "dt2id": dt2id,
    "term2id": term2id,
    "config": {
        "emb": 32,
        "hidden": 256,
        "zdim": 32,
        "cdim": 12,
        "T": 40
    }
}

torch.save(checkpoint, "cvae_v1_prem1516.pt")

print("Model saved.")



device: mps


                                                                                                    

0 beta 0.0 loss 5.730668166327098
{'main': 4.4710031151771545, 'term': 0.023426896410195956, 'kld': 8.283684798649379, 'total': 4.494429997035435}


                                                                                                   

1 beta 0.2 loss 4.440718838957702
{'main': 3.8064785386834825, 'term': 0.014739617904914277, 'kld': 1.0172089274440492, 'total': 4.0246599572045465}


                                                                                                   

2 beta 0.4 loss 4.028829608104126
{'main': 3.5742353115762984, 'term': 0.010338394403723734, 'kld': 0.6481655878680093, 'total': 3.8438399178641185}


                                                                                                   

3 beta 0.6 loss 3.832817726394757
{'main': 3.3469179400375912, 'term': 0.00940993992636712, 'kld': 0.537487411073276, 'total': 3.6788203503404344}


                                                                                                   

4 beta 0.8 loss 3.723835875928537
{'main': 3.229816347360611, 'term': 0.008398077520333962, 'kld': 0.4832943986569132, 'total': 3.6248499623366763}


train beta=1.00:  92%|█████████▏| 405/441 [01:43<00:09,  3.82it/s, kld=0.428, loss=3.68, main=3.24]2026-02-19 22:07:53.363 python[69772:2371302] Error creating directory 
 The volume ‚ÄúMacintosh HD‚Äù is out of space. You can‚Äôt save the file ‚Äúmpsgraph-69772-2026-02-19_22_07_53-4260486711‚Äù because the volume ‚ÄúMacintosh HD‚Äù is out of space.
2026-02-19 22:07:53.413 python[69772:2371302] Error creating directory 
 The volume ‚ÄúMacintosh HD‚Äù is out of space. You can‚Äôt save the file ‚Äúmpsgraph-69772-2026-02-19_22_07_53-244826589‚Äù because the volume ‚ÄúMacintosh HD‚Äù is out of space.
                                                                                                   

5 beta 1.0 loss 3.6569960685003373
{'main': 3.1468958939824785, 'term': 0.006065905719359372, 'kld': 0.4369916926537241, 'total': 3.5899534778935567}


train beta=1.00:   4%|▎         | 16/441 [00:04<02:02,  3.47it/s, kld=0.455, loss=3.86, main=3.39]2026-02-19 22:08:10.854 python[69772:2371302] Error creating directory 
 The volume ‚ÄúMacintosh HD‚Äù is out of space. You can‚Äôt save the file ‚Äúmpsgraph-69772-2026-02-19_22_08_10-164405893‚Äù because the volume ‚ÄúMacintosh HD‚Äù is out of space.
2026-02-19 22:08:10.859 python[69772:2371302] Error creating directory 
 The volume ‚ÄúMacintosh HD‚Äù is out of space. You can‚Äôt save the file ‚Äúmpsgraph-69772-2026-02-19_22_08_10-2540499703‚Äù because the volume ‚ÄúMacintosh HD‚Äù is out of space.
2026-02-19 22:08:10.862 python[69772:2371302] Error creating directory 
 The volume ‚ÄúMacintosh HD‚Äù is out of space. You can‚Äôt save the file ‚Äúmpsgraph-69772-2026-02-19_22_08_10-2563716296‚Äù because the volume ‚ÄúMacintosh HD‚Äù is out of space.
2026-02-19 22:08:10.866 python[69772:2371302] Error creating directory 
 The volume ‚ÄúMacintosh HD‚Äù is out of space. You can‚Äôt save the fil

6 beta 1.0 loss 3.5300354687264717
{'main': 3.266918339899608, 'term': 0.009769203546706453, 'kld': 0.4574246534279415, 'total': 3.734112194606236}


train beta=1.00:  98%|█████████▊| 433/441 [02:15<00:02,  3.49it/s, kld=0.445, loss=3.26, main=2.81]2026-02-19 22:12:50.369 python[69772:2371302] Error creating directory 
 The volume ‚ÄúMacintosh HD‚Äù is out of space. You can‚Äôt save the file ‚Äúmpsgraph-69772-2026-02-19_22_12_50-722991728‚Äù because the volume ‚ÄúMacintosh HD‚Äù is out of space.
2026-02-19 22:12:50.374 python[69772:2371302] Error creating directory 
 The volume ‚ÄúMacintosh HD‚Äù is out of space. You can‚Äôt save the file ‚Äúmpsgraph-69772-2026-02-19_22_12_50-2410884025‚Äù because the volume ‚ÄúMacintosh HD‚Äù is out of space.
2026-02-19 22:12:50.378 python[69772:2371302] Error creating directory 
 The volume ‚ÄúMacintosh HD‚Äù is out of space. You can‚Äôt save the file ‚Äúmpsgraph-69772-2026-02-19_22_12_50-2550156280‚Äù because the volume ‚ÄúMacintosh HD‚Äù is out of space.
2026-02-19 22:12:50.385 python[69772:2371302] Error creating directory 
 The volume ‚ÄúMacintosh HD‚Äù is out of space. You can‚Äôt save the fi

7 beta 1.0 loss 3.434058467817415
{'main': 2.950323326247079, 'term': 0.006298692590332523, 'kld': 0.4624867205108915, 'total': 3.4191087441784993}


                                                                                                   

8 beta 1.0 loss 3.3429094231317915
{'main': 2.867559994970049, 'term': 0.007108658941433532, 'kld': 0.4884468588445868, 'total': 3.363115519285202}


                                                                                                   

9 beta 1.0 loss 3.271131559834729
{'main': 2.7655502344880785, 'term': 0.006333418306374889, 'kld': 0.5046376463557992, 'total': 3.2765212953090668}


                                                                                                   

10 beta 1.0 loss 3.2079108047917857
{'main': 2.6640693971088956, 'term': 0.005196100452199711, 'kld': 0.5289451341543879, 'total': 3.1982106353555406}


                                                                                                   

11 beta 1.0 loss 3.1538990830618236
{'main': 2.5944941341876984, 'term': 0.005349439980688787, 'kld': 0.5545392196093287, 'total': 3.1543827993529185}


                                                                                                   

12 beta 1.0 loss 3.1061171798749303
{'main': 2.54967268875667, 'term': 0.006875664991249713, 'kld': 0.5628752197538104, 'total': 3.119423602308546}


                                                                                                   

13 beta 1.0 loss 3.0586022757618876
{'main': 2.476993965251105, 'term': 0.004461514982592364, 'kld': 0.5770915118711335, 'total': 3.058546985898699}


                                                                                                   

14 beta 1.0 loss 3.0116471146780346
{'main': 2.47940747652735, 'term': 0.005811729513360271, 'kld': 0.5978370223726545, 'total': 3.083056228501456}


                                                                                                   

15 beta 1.0 loss 2.977761597860427
{'main': 2.369800865650177, 'term': 0.005341703438846578, 'kld': 0.6135297481502805, 'total': 2.9886723288467953}


                                                                                                   

16 beta 1.0 loss 2.9413297711586464
{'main': 2.338465017931802, 'term': 0.007194815894763451, 'kld': 0.6165324566619736, 'total': 2.962192267179489}


                                                                                                   

17 beta 1.0 loss 2.9105560752539947
{'main': 2.302781799009868, 'term': 0.00935386632357092, 'kld': 0.6326318361929485, 'total': 2.9447675006730214}


                                                                                                   

18 beta 1.0 loss 2.8751164650430483
{'main': 2.2657175106661662, 'term': 0.005067095650182247, 'kld': 0.6499272531696728, 'total': 2.9207118451595306}


                                                                                                   

19 beta 1.0 loss 2.8479600320327307
{'main': 2.2259805479219983, 'term': 0.005042003684788402, 'kld': 0.6591568227325167, 'total': 2.890179365873337}
Model saved.


In [16]:
id2type = {v:k for k,v in type2id.items()}
id2sz   = {v:k for k,v in sz2id.items()}
id2ez   = {v:k for k,v in ez2id.items()}
id2out  = {v:k for k,v in out2id.items()}
id2dt   = {v:k for k,v in dt2id.items()}
id2term = {v:k for k,v in term2id.items()}
def pretty_print_seq(seq_ids_dict, max_rows=60):
    T = len(seq_ids_dict["type"])
    for t in range(min(T, max_rows)):
        ty = id2type[seq_ids_dict["type"][t]]
        if ty == "PAD":
            break
        print(
            f"{t:02d}  {ty:14s}  "
            f"{id2sz[seq_ids_dict['sz'][t]]:18s} -> {id2ez[seq_ids_dict['ez'][t]]:18s}  "
            f"out={id2out[seq_ids_dict['out'][t]]:10s}  dt={id2dt[seq_ids_dict['dt'][t]]:5s}  "
            f"term={id2term[seq_ids_dict['term'][t]]}"
        )
        if ty == "END":
            break

@torch.no_grad()
def sample_categorical(logits, temperature=1.0):
    # logits: [V]
    if temperature != 1.0:
        logits = logits / temperature
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1).item()

@torch.no_grad()
def generate_one(model, c_vec, T=40, temperature=1.0, device="cpu"):
    model.eval()
    c = torch.tensor(c_vec, dtype=torch.float32, device=device).unsqueeze(0)  # [1,C]
    z = torch.randn(1, model.zdim, device=device)

    # start with all PAD
    x_gen = torch.zeros(1, T, 6, dtype=torch.long, device=device)

    END_ID = type2id["END"]

    for t in range(T):
        # decoder input = x_gen shifted right (teacher-forcing style but with generated history)
        x_in = x_gen.clone()
        x_in[:, 1:t+1] = x_gen[:, :t]  # shift generated tokens
        x_in[:, 0] = 0

        logits = model.decode(x_in, z, c)  # dict of [1,T,V]
        # take logits at current time t
        ty = sample_categorical(logits["type"][0, t], temperature)
        sz = sample_categorical(logits["sz"][0, t], temperature)
        ez = sample_categorical(logits["ez"][0, t], temperature)
        out = sample_categorical(logits["out"][0, t], temperature)
        dt = sample_categorical(logits["dt"][0, t], temperature)

        # term only matters if END
        if ty == END_ID:
            term = sample_categorical(logits["term"][0, t], temperature)
        else:
            term = term2id.get("NA_TERM", 0)

        x_gen[0, t] = torch.tensor([ty, sz, ez, out, dt, term], device=device)

        if ty == END_ID:
            break

    # convert to python lists
    seq = {
        "type": x_gen[0, :, 0].tolist(),
        "sz":   x_gen[0, :, 1].tolist(),
        "ez":   x_gen[0, :, 2].tolist(),
        "out":  x_gen[0, :, 3].tolist(),
        "dt":   x_gen[0, :, 4].tolist(),
        "term": x_gen[0, :, 5].tolist(),
    }
    return seq

device = next(model.parameters()).device
j = 0
gen = generate_one(model, Cz[j], T=40, temperature=0.9, device=device)
pretty_print_seq(gen)


00  Pass            Wing_Left_Zone2    -> Wing_Left_Zone4     out=Complete    dt=7-15   term=NA_TERM
01  Carry           Wing_Left_Zone4    -> Wing_Left_Zone4     out=NA          dt=0-1    term=NA_TERM
02  Pass            Wing_Left_Zone4    -> Wing_Right_Zone4    out=Complete    dt=0-1    term=NA_TERM
03  Pass            Wing_Right_Zone4   -> Wing_Right_Zone4    out=Complete    dt=3-7    term=NA_TERM
04  Carry           Wing_Right_Zone4   -> Wing_Right_Zone4    out=NA          dt=0-1    term=NA_TERM
05  Pass            Wing_Right_Zone4   -> Wing_Right_Zone4    out=Complete    dt=0-1    term=NA_TERM
06  Carry           Wing_Right_Zone4   -> Att_Pocket_Right    out=NA          dt=1-3    term=NA_TERM
07  Pass            Att_Pocket_Right   -> Att_Pocket_Right    out=Incomplete  dt=3-7    term=NA_TERM
08  END             PAD                -> PAD                 out=PAD         dt=PAD    term=turnover
