# Football Diffusion: Colab Runner

This notebook assumes you uploaded the entire repo to Colab (e.g., via a zip) and have raw NFL Big Data Bowl CSVs accessible (either in Google Drive or alongside the repo). It mounts Drive, installs deps, sets `PYTHONPATH`, and lets you preprocess, train, and evaluate.

In [None]:
#@title Mount Google Drive (optional)
USE_DRIVE = True  #@param {type:"boolean"}
if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    BASE_DIR = '/content/drive/MyDrive/dl_project'  # change if you put the repo elsewhere in Drive
else:
    BASE_DIR = '/content/dl_project'  # update if you unzip elsewhere

%cd $BASE_DIR

In [None]:
#@title Install dependencies
!pip install --quiet torch pytorch-lightning pandas numpy pyarrow tqdm scikit-learn tabulate pyyaml matplotlib

In [None]:
#@title Set paths and PYTHONPATH
import os, sys, pathlib

repo_root = pathlib.Path(BASE_DIR)
diffusion_dir = repo_root / 'diffusion'
data_raw = repo_root / 'data' / 'nfl-big-data-bowl-2023'  # adjust if stored elsewhere
data_cache = repo_root / 'data' / 'cache'
artifacts_dir = repo_root / 'artifacts' / 'diffusion'
config_train = diffusion_dir / 'src' / 'football_diffusion' / 'config' / 'train.yaml'
config_eval = diffusion_dir / 'src' / 'football_diffusion' / 'config' / 'eval.yaml'

os.environ['PYTHONPATH'] = str(diffusion_dir / 'src') + ':' + os.environ.get('PYTHONPATH', '')
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # harmless on Colab

print('Repo root:', repo_root)
print('Raw data:', data_raw)
print('Cache dir:', data_cache)
print('Artifacts:', artifacts_dir)
print('PYTHONPATH:', os.environ['PYTHONPATH'])

## 1) Preprocess data
- Requires CSVs: `plays.csv`, `games.csv`, `players.csv`, `week*.csv` in `data/nfl-big-data-bowl-2023/` (or adjust `data_raw`).
- Generates `processed_plays.pkl` and `metadata.json` in `data/cache/`.

In [None]:
#@title Run preprocessing (skip if cache already built)
from pathlib import Path

data_cache.mkdir(parents=True, exist_ok=True)

!bash $diffusion_dir/scripts/preprocess.sh \
  $data_raw \
  $data_cache \
  $diffusion_dir/src/football_diffusion/config/default.yaml

## 2) Train diffusion model (optional)
- Uses cached data; writes checkpoints to `artifacts/diffusion/`.
- Set `MAX_EPOCHS` lower if you just want a smoke test.

In [None]:
#@title Train (set MAX_EPOCHS as needed)
MAX_EPOCHS = 5  #@param {type:"integer"}

!python $diffusion_dir/train_main.py \
  --config $config_train \
  --cache_dir $data_cache \
  --output_dir $artifacts_dir \
  --devices 1 \
  --max_epochs $MAX_EPOCHS

## 3) Evaluate
- Point `CHECKPOINT` to your saved `.ckpt` (e.g., `last.ckpt`).
- Uses sample steps 20/50/100 by default.

In [None]:
#@title Evaluate checkpoint
import glob, os, json, sys, torch
from pathlib import Path

# Resolve diffusion_dir robustly
if "diffusion_dir" not in globals():
    candidates = [Path.cwd(), Path.cwd() / "diffusion", Path("/content/dl_project/diffusion"), Path("/content/drive/MyDrive/dl_project/diffusion")]
    diffusion_dir = None
    for c in candidates:
        if (c / "src" / "football_diffusion").exists():
            diffusion_dir = c.resolve()
            break
    if diffusion_dir is None:
        raise RuntimeError("Could not locate diffusion directory; set diffusion_dir manually.")

os.environ['PYTHONPATH'] = str(diffusion_dir / 'src') + ':' + os.environ.get('PYTHONPATH', '')
sys.path.insert(0, str(diffusion_dir / 'src'))

DEFAULT_CKPT = sorted(glob.glob(str((diffusion_dir.parent / 'artifacts' / 'diffusion' / '*.ckpt')))) if 'artifacts_dir' not in globals() else sorted(glob.glob(str(artifacts_dir / '*.ckpt')))
CHECKPOINT = DEFAULT_CKPT[-1] if DEFAULT_CKPT else ''  #@param {type:"string"}

if not CHECKPOINT:
    raise ValueError("No checkpoint found; train first or set CHECKPOINT manually")

from football_diffusion.eval.eval_diffusion import run_evaluation

device_pref = 'cuda' if torch.cuda.is_available() else ('mps' if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available() else 'cpu')

cache_dir = diffusion_dir.parent / 'data' / 'cache' if 'data_cache' not in globals() else data_cache
config_eval_path = diffusion_dir / 'src' / 'football_diffusion' / 'config' / 'eval.yaml' if 'config_eval' not in globals() else config_eval

results = run_evaluation(
    checkpoint_path=str(CHECKPOINT),
    cache_dir=str(cache_dir),
    config_path=str(config_eval_path),
    split='test',
    batch_size=8,
    num_samples=8,
    sample_steps=[50, 100],
    ddim=True,
    device=device_pref
)
print('Using device:', device_pref)
print(json.dumps(results, indent=2))


## 4) Visualize trajectories
- Animate a ground-truth play and a generated sample.
- Uses cached data and the chosen checkpoint; denormalizes before plotting.


In [None]:
#@title Animate a sample play (ground truth vs generated)
import os, sys, yaml, torch
import numpy as np
from pathlib import Path
from IPython.display import HTML

# Resolve diffusion_dir robustly
if "diffusion_dir" not in globals():
    candidates = [Path.cwd(), Path.cwd() / "diffusion", Path("/content/dl_project/diffusion"), Path("/content/drive/MyDrive/dl_project/diffusion")]
    diffusion_dir = None
    for c in candidates:
        if (c / "src" / "football_diffusion").exists():
            diffusion_dir = c.resolve()
            break
    if diffusion_dir is None:
        raise RuntimeError("Could not locate diffusion directory; set diffusion_dir manually.")

cache_dir = diffusion_dir.parent / "data" / "cache" if "data_cache" not in globals() else data_cache
config_train_path = diffusion_dir / "src" / "football_diffusion" / "config" / "train.yaml" if "config_train" not in globals() else config_train

os.environ['PYTHONPATH'] = str(diffusion_dir / 'src') + ':' + os.environ.get('PYTHONPATH', '')
sys.path.insert(0, str(diffusion_dir / "src"))
from football_diffusion.data.dataset import FootballPlayDataset
from football_diffusion.training.train_diffusion import DiffusionLightningModule
from football_diffusion.viz.animate import animate_comparison

if "CHECKPOINT" not in globals() or not CHECKPOINT:
    raise ValueError("Set CHECKPOINT above (run the evaluate cell)")

cfg = yaml.safe_load(open(config_train_path))
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
)
print("Using device:", device)

ds = FootballPlayDataset(cache_dir / "processed_plays.pkl", cache_dir / "metadata.json", split="val")
if len(ds) == 0:
    raise ValueError("Dataset is empty; ensure preprocessing ran correctly.")
SAMPLE_IDX = 0  # change to visualize a different play
sample = ds[SAMPLE_IDX]
context_cat = [sample["context_categorical"]]
context_cont = sample["context_continuous"].unsqueeze(0).to(device)

module = DiffusionLightningModule.load_from_checkpoint(str(CHECKPOINT), config=cfg)
model = module.model.to(device)
model.eval()

with torch.no_grad():
    gen = model.sample(
        shape=sample["X"].shape[0:3],
        context_categorical=context_cat,
        context_continuous=context_cont,
        num_steps=100,
        ddim=True,
        smooth=True,
    )

gt_np = ds.denormalize_tensor(sample["X"].numpy())  # [T, P, F]
gen_np = ds.denormalize_tensor(gen.cpu().numpy()[0])  # [T, P, F]

gt_xy = gt_np[:, :, :2]
gen_xy = gen_np[:, :, :2]

anim, fig = animate_comparison([gt_xy, gen_xy], labels=["Ground Truth", "Generated"], interval=120)
HTML(anim.to_jshtml())


## 5) Load and inspect outputs
- `eval_results.json` is written next to the checkpoint.

In [None]:
#@title Show saved eval results (if present)
import json, pathlib

results_file = pathlib.Path(CHECKPOINT).parent / 'eval_results.json'
if results_file.exists():
    print(json.dumps(json.load(open(results_file)), indent=2))
else:
    print('No eval_results.json found; run evaluation first.')