In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BlitzTransformer(nn.Module):
    """
    Input:  x [B, K, F*]  (K defenders, F* can vary; we fix to feature_len internally)
    Output: logits [B]
    """
    def __init__(self,
                 feature_len=9,
                 model_dim=128,
                 num_heads=4,
                 num_layers=2,
                 dim_feedforward=512,
                 dropout=0.1):
        super().__init__()
        self.feature_len = feature_len

        self.bn = nn.BatchNorm1d(feature_len)

        self.embed = nn.Sequential(
            nn.Linear(feature_len, model_dim),
            nn.ReLU(),
            nn.LayerNorm(model_dim),
            nn.Dropout(dropout),
        )

        enc_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.enc = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        self.reduce = nn.Linear(model_dim * 2, model_dim)
        self.head = nn.Sequential(
            nn.Linear(model_dim, model_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(model_dim, model_dim // 4),
            nn.ReLU(),
            nn.LayerNorm(model_dim // 4),
            nn.Linear(model_dim // 4, 1),
        )

    @staticmethod
    def _fixF(x: torch.Tensor, F_target: int) -> torch.Tensor:
        """Pad with zeros or truncate to match F_target on the last dimension."""
        F_cur = x.shape[-1]
        if F_cur == F_target:
            return x
        if F_cur < F_target:
            pad = x.new_zeros(*x.shape[:-1], F_target - F_cur)
            return torch.cat([x, pad], dim=-1)
        return x[..., :F_target]

    def forward(self, x, mask=None):
        x = self._fixF(x, self.feature_len)               # [B,K,F*] → [B,K,F]

        # ⬇️ correct permutation round-trip
        x = x.permute(0, 2, 1)                            # [B,F,K]
        x = self.bn(x)                                    # BN over feature dim (F)
        x = x.permute(0, 2, 1)                            # ✅ back to [B,K,F]

        x = self.embed(x)                                 # [B,K,D]
        h = self.enc(x)                                   # [B,K,D]

        if mask is None:
            mask = (x.abs().sum(dim=-1) > 0).float()      # [B,K]
        m = mask.unsqueeze(-1)                            # [B,K,1]
        safe = m.sum(dim=1).clamp_min(1e-6)
        mean = (h * m).sum(dim=1) / safe                  # [B,D]
        mx   = (h + (1.0 - m) * (-1e9)).amax(dim=1)       # [B,D]
        pooled = torch.cat([mean, mx], dim=-1)            # [B,2D]
        pooled = self.reduce(pooled)                      # [B,D]
        logit  = self.head(pooled).squeeze(-1)            # [B]
        return logit


In [None]:

from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
from torch.optim import AdamW
pd.options.mode.chained_assignment = None

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from torch.utils.data import TensorDataset, DataLoader

def ensure_feature_dim(x, F_expected):
    """Pad with zeros or truncate to match F_expected."""
    F = x.shape[-1]
    if F == F_expected:
        return x
    if F < F_expected:
        pad = torch.zeros(*x.shape[:-1], F_expected - F, device=x.device, dtype=x.dtype)
        return torch.cat([x, pad], dim=-1)
    # F > F_expected: truncate (not ideal, better to regenerate consistently)
    return x[..., :F_expected]

In [None]:
import os, glob, torch
import torch.nn as nn
from torch.utils.data import IterableDataset, DataLoader, TensorDataset

# ---------- streaming dataset over shards ----------
class ShardStream(IterableDataset):
    """Streams samples from shard files (one sample at a time)."""
    def __init__(self, feat_paths, tgt_paths, device=None):
        assert len(feat_paths) == len(tgt_paths), "Mismatched shard counts"
        self.feat_paths = feat_paths
        self.tgt_paths  = tgt_paths
        self.device     = device
    def __iter__(self):
        for fp, tp in zip(self.feat_paths, self.tgt_paths):
            X = torch.load(fp, map_location="cpu").to(dtype=torch.float32)   # [N,K,F*]
            y = torch.load(tp, map_location="cpu").to(dtype=torch.float32)   # [N]
            for i in range(X.shape[0]):
                xi, yi = X[i], y[i]
                if self.device:
                    xi = xi.to(self.device, non_blocking=True)
                    yi = yi.to(self.device, non_blocking=True)
                yield xi, yi

# ------------ training loop ------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 64
learning_rate = 2e-4
num_epochs = 10
early_stopping_patience = 5

ART = "/content/drive/MyDrive/bdb25-blitz/artifacts"
weeks_train = [1,2,3,4,5,6,7,8,9]

for week_eval in weeks_train:
    print(f"\n######## WEEK {week_eval:02d} ########")

    # collect train shards
    feat_paths = sorted(glob.glob(os.path.join(ART, f"features_train_week{week_eval:02d}_shard_w*.pt")))
    tgt_paths  = sorted(glob.glob(os.path.join(ART, f"targets_train_week{week_eval:02d}_shard_w*.pt")))
    if not feat_paths:
        raise FileNotFoundError(f"No train shards for eval week {week_eval:02d} in {ART}")

    # load val
    vaX = torch.load(os.path.join(ART, f"features_val_week{week_eval:02d}.pt"), map_location="cpu").to(torch.float32)
    vaY = torch.load(os.path.join(ART, f"targets_val_week{week_eval:02d}.pt"),  map_location="cpu").to(torch.float32)

    # decide a single feature_len (BN & first linear need a fixed size).
    # pick the max feature width across val + all shards
    F_candidates = [vaX.shape[-1]]
    for fp in feat_paths:
        with torch.no_grad():
            F_candidates.append(torch.load(fp, map_location="cpu").shape[-1])
    F_model = max(F_candidates)
    print(f"[week {week_eval:02d}] inferred feature_len (F_model) = {F_model}")

    # class imbalance across all shards
    pos = neg = 0
    for tp in tgt_paths:
        y = torch.load(tp, map_location="cpu").to(torch.float32)
        pos += int((y == 1).sum().item())
        neg += int((y == 0).sum().item())
    pos_weight = torch.tensor([max(1.0, neg / max(1, pos))], device=device)
    print(f"[week {week_eval:02d}] pos={pos} neg={neg} pos_weight={pos_weight.item():.2f}")

    # model / opt / loss
    model = BlitzTransformer(
        feature_len=F_model,   # <-- key: BN & first Linear expect this width
        model_dim=128, num_heads=4, num_layers=2,
        dim_feedforward=512, dropout=0.1
    ).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-2)

    # loaders
    train_stream = ShardStream(feat_paths, tgt_paths, device=device)
    train_loader = DataLoader(train_stream, batch_size=batch_size, shuffle=False, num_workers=0)
    val_loader   = DataLoader(TensorDataset(vaX.to(device), vaY.to(device)),
                              batch_size=batch_size, shuffle=False, num_workers=0)

    # train/eval
    best_val, no_improve = float("inf"), 0
    for epoch in range(num_epochs):
        # at top of the week loop
        ckpt_best = os.path.join(ART, f"best_model_week{week_eval:02d}.pth")
        ckpt_last = os.path.join(ART, f"last_model_week{week_eval:02d}.pth")
        best_val, no_improve = float("inf"), 0
        saved_any = False
        # train
        model.train(); run = 0.0; n = 0
        first_batch_logged = False
        for xb, yb in train_loader:
            if not first_batch_logged:
                print(f"[train] xb shape {tuple(xb.shape)}  (B,K,F*), model.feature_len={model.feature_len}")
                first_batch_logged = True
            optimizer.zero_grad()
            logits = model(xb)                  # model pads/truncates internally to feature_len
            loss = criterion(logits, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            run += loss.item()*xb.size(0); n += xb.size(0)
        train_loss = run / max(1, n)

        # val
        model.eval(); vrun = 0.0; vn = 0; correct = 0
        first_val_logged = False
        with torch.no_grad():
            for xb, yb in val_loader:
                if not first_val_logged:
                    print(f"[val]   xb shape {tuple(xb.shape)}  (B,K,F*), model.feature_len={model.feature_len}")
                    first_val_logged = True
                logits = model(xb)
                vloss = criterion(logits, yb)
                vrun += vloss.item()*xb.size(0); vn += xb.size(0)
                preds = (torch.sigmoid(logits) >= 0.5).long()
                correct += (preds == yb.long()).sum().item()
        val_loss = vrun / max(1, vn)
        val_acc  = correct / max(1, vn)
        print(f"Epoch {epoch+1:02d}  train {train_loss:.4f}  val {val_loss:.4f}  acc {val_acc:.3f}")

        if val_loss < best_val:
          best_val, no_improve = val_loss, 0
          torch.save(model.state_dict(), ckpt_best)
          saved_any = True
        else:
            no_improve += 1
            if no_improve >= early_stopping_patience:
                print("Early stopping.")
                break

        # after the epoch loop finishes (always save a fallback)
        torch.save(model.state_dict(), ckpt_last)
        print(f"Saved fallback checkpoint: {ckpt_last}")
        if not saved_any:
            # also mirror as "best" to simplify downstream code
            torch.save(model.state_dict(), ckpt_best)
            print(f"No improvement checkpoint found; mirrored last -> {ckpt_best}")



######## WEEK 01 ########
[week 01] inferred feature_len (F_model) = 9
[week 01] pos=182 neg=59264 pos_weight=325.63
[train] xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
[val]   xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
Epoch 01  train nan  val nan  acc 0.996
Saved fallback checkpoint: /content/drive/MyDrive/bdb25-blitz/artifacts/last_model_week01.pth
No improvement checkpoint found; mirrored last -> /content/drive/MyDrive/bdb25-blitz/artifacts/best_model_week01.pth
[train] xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
[val]   xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
Epoch 02  train nan  val nan  acc 0.996
Saved fallback checkpoint: /content/drive/MyDrive/bdb25-blitz/artifacts/last_model_week01.pth
No improvement checkpoint found; mirrored last -> /content/drive/MyDrive/bdb25-blitz/artifacts/best_model_week01.pth
[train] xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
[val]   xb shape (64, 8, 9)  (B,K,F*), model.feature_len=9
Epoch 03  train nan  val n