In [None]:
import os, random, pandas as pd, torch
from torch.utils.data import TensorDataset

# ---- config ----
PROJ = "/content/drive/MyDrive/bdb25-blitz"
ART  = f"{PROJ}/artifacts"   # where week_XX_clean_blitz.parquet lives
os.makedirs(ART, exist_ok=True)

# your new feature set + target
features = ["x_clean","y_clean","v_x","v_y","depth_to_los","o_to_los_cos",
            "creep_depth_mean","creep_lat_mean","pre_speed_mean"]
target_column = "blitz"

# K defenders kept per frame inside prepare_frame_data_blitz (must match your helper default)
K_DEF = 8

# minimal columns we’ll require from parquet (don’t drop anything else while writing parquet)
NEEDED = ["frameUniqueId","frameId","frameType","defensiveTeam","defense", target_column] + features

def load_week_df(week: int) -> pd.DataFrame:
    path = f"{ART}/week_{week:02d}_clean_blitz.parquet"
    df = pd.read_parquet(path)
    missing = [c for c in NEEDED if c not in df.columns]
    if missing:
        raise KeyError(f"Week {week:02d} missing columns: {missing}. "
                       "Ensure process_week_data added LOS/disguise/labels first.")
    # keep only rows that have a target (some plays might not have labels if filtered)
    df = df.dropna(subset=[target_column]).copy()
    return df

# NOTE: this calls the new defense-only packer you wrote earlier
# def prepare_frame_data_blitz(df, features, target_column, K=8) -> (torch.Tensor[N,K,F], torch.Tensor[N])

for week_eval in range(1, 10):
    print(f"\n=== Eval week {week_eval:02d} ===")

    # ------- Validation (single week) -------
    val_df = load_week_df(week_eval)
    val_features, val_targets = prepare_frame_data_blitz(val_df, features, target_column, K=K_DEF)
    if val_features is None:
        print(f"[val] week {week_eval:02d}: no frames after packing; skipping.")
        continue

    # quick random sample check
    ridx = random.randrange(len(val_features))
    print(f"[val] shape={val_features.shape}  sample[{ridx}][0]={val_features[ridx][0]}")

    torch.save(val_features, f"{ART}/features_val_week{week_eval:02d}.pt")
    torch.save(val_targets,  f"{ART}/targets_val_week{week_eval:02d}.pt")

    # ------- Training (all other weeks) saved as shards -------
    shard_feats, shard_tgts = [], []
    for wk in range(1, 10):
        if wk == week_eval:
            continue
        tr_df = load_week_df(wk)
        trX, trY = prepare_frame_data_blitz(tr_df, features, target_column, K=K_DEF)
        if trX is None:
            print(f"[train shard] week {wk:02d}: empty after packing; skipping.")
            continue

        fpath = f"{ART}/features_train_week{week_eval:02d}_shard_w{wk:02d}.pt"
        tpath = f"{ART}/targets_train_week{week_eval:02d}_shard_w{wk:02d}.pt"
        torch.save(trX, fpath)
        torch.save(trY, tpath)
        shard_feats.append(fpath); shard_tgts.append(tpath)

        rtr = random.randrange(len(trX))
        print(f"[train shard w{wk:02d}] shape={trX.shape}  sample[{rtr}][0]={trX[rtr][0]}")

    print(f"Saved {len(shard_feats)} train shards for eval week {week_eval:02d} in {ART}")

    # OPTIONAL: if you truly need one big train tensor (can be large), concat on-disk cautiously:
    # catX, catY = [], []
    # for f,t in zip(shard_feats, shard_tgts):
    #     catX.append(torch.load(f, map_location='cpu'))
    #     catY.append(torch.load(t, map_location='cpu'))
    # train_features = torch.cat(catX, dim=0)
    # train_targets  = torch.cat(catY, dim=0)
    # torch.save(train_features, f"{ART}/features_training_week{week_eval:02d}.pt")
    # torch.save(train_targets,  f"{ART}/targets_training_week{week_eval:02d}.pt")
    # print(f"[train concat] {train_features.shape}")



=== Eval week 01 ===
[val] shape=torch.Size([8501, 8, 9])  sample[3381][0]=tensor([ 3.6000e+01,  2.2100e+01, -1.1547e+00, -4.7876e-01,  4.2000e-01,
        -9.2375e-01, -2.5000e-03, -2.5000e-03,  0.0000e+00])
[train shard w02] shape=torch.Size([7759, 8, 9])  sample[6899][0]=tensor([107.6100,  31.8100,   1.5057,  -2.0083,   1.6500,   0.5999,   0.2375,
         -0.3100,   2.5275])
[train shard w03] shape=torch.Size([8541, 8, 9])  sample[2730][0]=tensor([ 6.0400e+01,  3.7890e+01, -1.5985e-02,  1.2020e-02,  5.4000e-01,
        -7.9927e-01,  0.0000e+00,  2.5000e-03,  2.5000e-02])
[train shard w04] shape=torch.Size([7348, 8, 9])  sample[2332][0]=tensor([ 4.4950e+01,  2.4640e+01, -1.0345e+00, -1.0690e-01,  3.1000e-01,
        -9.9470e-01, -1.0000e-02, -2.5000e-03,  7.2500e-02])
[train shard w05] shape=torch.Size([7982, 8, 9])  sample[5018][0]=tensor([ 6.9840e+01,  2.8200e+01, -5.4999e-01,  2.3998e-03,  8.8000e-01,
        -9.9999e-01, -2.0000e-03,  0.0000e+00,  1.2000e-02])
[train shard w06]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BlitzTransformer(nn.Module):
    """
    Input:  x [B, K, F*]  (K defenders, F* can vary; we fix to feature_len internally)
    Output: logits [B]
    """
    def __init__(self,
                 feature_len=9,
                 model_dim=128,
                 num_heads=4,
                 num_layers=2,
                 dim_feedforward=512,
                 dropout=0.1):
        super().__init__()
        self.feature_len = feature_len

        self.bn = nn.BatchNorm1d(feature_len)

        self.embed = nn.Sequential(
            nn.Linear(feature_len, model_dim),
            nn.ReLU(),
            nn.LayerNorm(model_dim),
            nn.Dropout(dropout),
        )

        enc_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.enc = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        self.reduce = nn.Linear(model_dim * 2, model_dim)
        self.head = nn.Sequential(
            nn.Linear(model_dim, model_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(model_dim, model_dim // 4),
            nn.ReLU(),
            nn.LayerNorm(model_dim // 4),
            nn.Linear(model_dim // 4, 1),
        )

    @staticmethod
    def _fixF(x: torch.Tensor, F_target: int) -> torch.Tensor:
        """Pad with zeros or truncate to match F_target on the last dimension."""
        F_cur = x.shape[-1]
        if F_cur == F_target:
            return x
        if F_cur < F_target:
            pad = x.new_zeros(*x.shape[:-1], F_target - F_cur)
            return torch.cat([x, pad], dim=-1)
        return x[..., :F_target]

    def forward(self, x, mask=None):
        x = self._fixF(x, self.feature_len)               # [B,K,F*] → [B,K,F]

        # ⬇️ correct permutation round-trip
        x = x.permute(0, 2, 1)                            # [B,F,K]
        x = self.bn(x)                                    # BN over feature dim (F)
        x = x.permute(0, 2, 1)                            # ✅ back to [B,K,F]

        x = self.embed(x)                                 # [B,K,D]
        h = self.enc(x)                                   # [B,K,D]

        if mask is None:
            mask = (x.abs().sum(dim=-1) > 0).float()      # [B,K]
        m = mask.unsqueeze(-1)                            # [B,K,1]
        safe = m.sum(dim=1).clamp_min(1e-6)
        mean = (h * m).sum(dim=1) / safe                  # [B,D]
        mx   = (h + (1.0 - m) * (-1e9)).amax(dim=1)       # [B,D]
        pooled = torch.cat([mean, mx], dim=-1)            # [B,2D]
        pooled = self.reduce(pooled)                      # [B,D]
        logit  = self.head(pooled).squeeze(-1)            # [B]
        return logit


In [None]:

from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
from torch.optim import AdamW
pd.options.mode.chained_assignment = None

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from torch.utils.data import TensorDataset, DataLoader

def ensure_feature_dim(x, F_expected):
    """Pad with zeros or truncate to match F_expected."""
    F = x.shape[-1]
    if F == F_expected:
        return x
    if F < F_expected:
        pad = torch.zeros(*x.shape[:-1], F_expected - F, device=x.device, dtype=x.dtype)
        return torch.cat([x, pad], dim=-1)
    # F > F_expected: truncate (not ideal, better to regenerate consistently)
    return x[..., :F_expected]

In [None]:
import os, glob, torch
import torch.nn as nn
from torch.utils.data import IterableDataset, DataLoader, TensorDataset

# ---------- streaming dataset over shards ----------
class ShardStream(IterableDataset):
    """Streams samples from shard files (one sample at a time)."""
    def __init__(self, feat_paths, tgt_paths, device=None):
        assert len(feat_paths) == len(tgt_paths), "Mismatched shard counts"
        self.feat_paths = feat_paths
        self.tgt_paths  = tgt_paths
        self.device     = device
    def __iter__(self):
        for fp, tp in zip(self.feat_paths, self.tgt_paths):
            X = torch.load(fp, map_location="cpu").to(dtype=torch.float32)   # [N,K,F*]
            y = torch.load(tp, map_location="cpu").to(dtype=torch.float32)   # [N]
            for i in range(X.shape[0]):
                xi, yi = X[i], y[i]
                if self.device:
                    xi = xi.to(self.device, non_blocking=True)
                    yi = yi.to(self.device, non_blocking=True)
                yield xi, yi

# ------------ training loop ------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 64
learning_rate = 2e-4
num_epochs = 10
early_stopping_patience = 5

ART = "/content/drive/MyDrive/bdb25-blitz/artifacts"
weeks_train = [1,2,3,4,5,6,7,8,9]

for week_eval in weeks_train:
    print(f"\n######## WEEK {week_eval:02d} ########")

    # collect train shards
    feat_paths = sorted(glob.glob(os.path.join(ART, f"features_train_week{week_eval:02d}_shard_w*.pt")))
    tgt_paths  = sorted(glob.glob(os.path.join(ART, f"targets_train_week{week_eval:02d}_shard_w*.pt")))
    if not feat_paths:
        raise FileNotFoundError(f"No train shards for eval week {week_eval:02d} in {ART}")

    # load val
    vaX = torch.load(os.path.join(ART, f"features_val_week{week_eval:02d}.pt"), map_location="cpu").to(torch.float32)
    vaY = torch.load(os.path.join(ART, f"targets_val_week{week_eval:02d}.pt"),  map_location="cpu").to(torch.float32)

    # decide a single feature_len (BN & first linear need a fixed size).
    # pick the max feature width across val + all shards
    F_candidates = [vaX.shape[-1]]
    for fp in feat_paths:
        with torch.no_grad():
            F_candidates.append(torch.load(fp, map_location="cpu").shape[-1])
    F_model = max(F_candidates)
    print(f"[week {week_eval:02d}] inferred feature_len (F_model) = {F_model}")

    # class imbalance across all shards
    pos = neg = 0
    for tp in tgt_paths:
        y = torch.load(tp, map_location="cpu").to(torch.float32)
        pos += int((y == 1).sum().item())
        neg += int((y == 0).sum().item())
    pos_weight = torch.tensor([max(1.0, neg / max(1, pos))], device=device)
    print(f"[week {week_eval:02d}] pos={pos} neg={neg} pos_weight={pos_weight.item():.2f}")

    # model / opt / loss
    model = BlitzTransformer(
        feature_len=F_model,   # <-- key: BN & first Linear expect this width
        model_dim=128, num_heads=4, num_layers=2,
        dim_feedforward=512, dropout=0.1
    ).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-2)

    # loaders
    train_stream = ShardStream(feat_paths, tgt_paths, device=device)
    train_loader = DataLoader(train_stream, batch_size=batch_size, shuffle=False, num_workers=0)
    val_loader   = DataLoader(TensorDataset(vaX.to(device), vaY.to(device)),
                              batch_size=batch_size, shuffle=False, num_workers=0)

    # train/eval
    best_val, no_improve = float("inf"), 0
    for epoch in range(num_epochs):
        # at top of the week loop
        ckpt_best = os.path.join(ART, f"best_model_week{week_eval:02d}.pth")
        ckpt_last = os.path.join(ART, f"last_model_week{week_eval:02d}.pth")
        best_val, no_improve = float("inf"), 0
        saved_any = False
        # train
        model.train(); run = 0.0; n = 0
        first_batch_logged = False
        for xb, yb in train_loader:
            if not first_batch_logged:
                print(f"[train] xb shape {tuple(xb.shape)}  (B,K,F*), model.feature_len={model.feature_len}")
                first_batch_logged = True
            optimizer.zero_grad()
            logits = model(xb)                  # model pads/truncates internally to feature_len
            loss = criterion(logits, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            run += loss.item()*xb.size(0); n += xb.size(0)
        train_loss = run / max(1, n)

        # val
        model.eval(); vrun = 0.0; vn = 0; correct = 0
        first_val_logged = False
        with torch.no_grad():
            for xb, yb in val_loader:
                if not first_val_logged:
                    print(f"[val]   xb shape {tuple(xb.shape)}  (B,K,F*), model.feature_len={model.feature_len}")
                    first_val_logged = True
                logits = model(xb)
                vloss = criterion(logits, yb)
                vrun += vloss.item()*xb.size(0); vn += xb.size(0)
                preds = (torch.sigmoid(logits) >= 0.5).long()
                correct += (preds == yb.long()).sum().item()
        val_loss = vrun / max(1, vn)
        val_acc  = correct / max(1, vn)
        print(f"Epoch {epoch+1:02d}  train {train_loss:.4f}  val {val_loss:.4f}  acc {val_acc:.3f}")

        if val_loss < best_val:
          best_val, no_improve = val_loss, 0
          torch.save(model.state_dict(), ckpt_best)
          saved_any = True
        else:
            no_improve += 1
            if no_improve >= early_stopping_patience:
                print("Early stopping.")
                break

        # after the epoch loop finishes (always save a fallback)
        torch.save(model.state_dict(), ckpt_last)
        print(f"Saved fallback checkpoint: {ckpt_last}")
        if not saved_any:
            # also mirror as "best" to simplify downstream code
            torch.save(model.state_dict(), ckpt_best)
            print(f"No improvement checkpoint found; mirrored last -> {ckpt_best}")



######## WEEK 01 ########
[week 01] inferred feature_len (F_model) = 9
[week 01] pos=182 neg=59264 pos_weight=325.63
[train] xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
[val]   xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
Epoch 01  train nan  val nan  acc 0.996
Saved fallback checkpoint: /content/drive/MyDrive/bdb25-blitz/artifacts/last_model_week01.pth
No improvement checkpoint found; mirrored last -> /content/drive/MyDrive/bdb25-blitz/artifacts/best_model_week01.pth
[train] xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
[val]   xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
Epoch 02  train nan  val nan  acc 0.996
Saved fallback checkpoint: /content/drive/MyDrive/bdb25-blitz/artifacts/last_model_week01.pth
No improvement checkpoint found; mirrored last -> /content/drive/MyDrive/bdb25-blitz/artifacts/best_model_week01.pth
[train] xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
[val]   xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
Epoch 03  train nan  val n