# PaST â€” PPO (Colab/T4): CNN+DeepSets or CWE

> If you want to run the **CWE** backbone (`cwe_sparse`), you must clone a repo/branch that includes the CWE changes.

This notebook trains one PPO variant, then evaluates it with Greedy / SGBS / EAS or SGBS+EAS.

**Typical flow**
1) Install deps (Colab)
2) Train: `python -m PaST.train_ppo ...`
3) Auto-pick latest checkpoint
4) Evaluate: `python -m PaST.run_eval_eas_* ...`

In [1]:
import os, sys, subprocess, textwrap
from pathlib import Path

ROOT = Path.cwd()
print("CWD:", ROOT)
print("Python:", sys.version)

# Make sure imports like `import PaST` work
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

# Basic sanity: torch + GPU
import torch

print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

CWD: /content
Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
torch: 2.9.0+cu126
cuda available: True
gpu: Tesla T4


In [2]:
# Colab installs (safe to re-run).
# If you're not on Colab and already have deps, you can skip this cell.
#
# IMPORTANT: The default upstream repo may not include the CWE variant yet.
# If you want to run CWE, point these to the branch where you added it.
REPO_URL = "https://github.com/Abdellahbado/PaST"  # TODO: change to your fork if needed
REPO_BRANCH = "main"  # TODO: change to your CWE branch if needed

!rm -rf PaST
!git clone -b "$REPO_BRANCH" "$REPO_URL"

!pip -q install -r PaST/requirements.txt
!pip -q install pandas matplotlib

Cloning into 'PaST'...
remote: Enumerating objects: 356, done.[K
remote: Counting objects: 100% (27/27), done.[K
remote: Compressing objects: 100% (22/22), done.[K
remote: Total 356 (delta 10), reused 16 (delta 5), pack-reused 329 (from 1)[K
Receiving objects: 100% (356/356), 125.83 MiB | 39.72 MiB/s, done.
Resolving deltas: 100% (178/178), done.


## Train (long-run-friendly hyperparams)

Defaults here are chosen to **keep learning for a long time**:
- Cosine LR decay with a non-trivial end LR (`lr_end_factor=0.1`)
- Cosine entropy decay (explore early, still some exploration late)
- Target KL to prevent updates collapsing
- Curriculum enabled (helps early stability): starts on *small* horizons then gradually introduces larger ones, while also annealing the epsilon-constraint slack range

Adjust `TOTAL_ENV_STEPS` depending on runtime budget.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory in Drive
import os
drive_checkpoint_dir = '/content/drive/MyDrive/PaST_checkpoints'
os.makedirs(drive_checkpoint_dir, exist_ok=True)
print(f"Checkpoint directory: {drive_checkpoint_dir}")

In [None]:
import sys
import torch

# Pick ONE variant_id:
# - q_sequence_cwe_ctx13 (CWE + CTX13, T4 optimized)
#
VARIANT_ID = "q_sequence_cwe_ctx13"
SEED = 0
OUT_DIR = "/content/drive/MyDrive/PaST_checkpoints/runs_q_sequence"

# Main training knobs (T4-friendly Q-Learning)
NUM_ROUNDS = 50
EPISODES_PER_ROUND = 2048   # High throughput collection on T4 (16 batches of 128)
COLLECTION_BATCH_SIZE = 128 # Maximize GPU vectorization (Parallel envs)
TRAIN_BATCH_SIZE = 128      # Train batch size
EPOCHS_PER_ROUND = 4        # Supervised epochs per DAgger round

LR = 1e-4
WD = 1e-5

cmd = [
    sys.executable,
    "-m",
    "PaST.train_q_sequence",
    "--variant_id",
    VARIANT_ID,
    "--seed",
    str(SEED),
    "--device",
    "cuda" if torch.cuda.is_available() else "cpu",
    "--output_dir",
    OUT_DIR,
    # DAgger Loop
    "--num_rounds",
    str(NUM_ROUNDS),
    "--episodes_per_round",
    str(EPISODES_PER_ROUND),
    "--warmup_rounds",
    "5",
    
    # Batching & Parallelization
    "--collection_batch_size",
    str(COLLECTION_BATCH_SIZE),
    "--batch_size",
    str(TRAIN_BATCH_SIZE),
    "--num_dataloader_workers",
    "2",  # 2 workers pre-fetch data for the GPU trainer
    "--num_collection_workers",
    "1",  # 1 is optimal for GPUBatchSequenceEnv (Vectorized on GPU > Multiprocessing)
    
    # Optimization
    "--num_epochs_per_round",
    str(EPOCHS_PER_ROUND),
    "--learning_rate",
    str(LR),
    "--weight_decay",
    str(WD),
    
    # Policy Strategy
    "--completion_policy",
    "mix",
    "--completion_prob_start",
    "1.0",
]

print(" ".join(cmd))

/usr/bin/python3 -m PaST.train_q_sequence --variant_id q_sequence_cwe_ctx13 --seed 0 --device cuda --output_dir runs_colab --num_rounds 50 --episodes_per_round 2048 --warmup_rounds 5 --collection_batch_size 128 --batch_size 128 --num_dataloader_workers 2 --num_collection_workers 1 --num_epochs_per_round 4 --learning_rate 0.0001 --weight_decay 1e-05 --completion_policy mix --completion_prob_start 1.0


In [None]:
# Start training
!{' '.j

Using device: cuda
CPU setup: os.cpu_count=2 | torch_threads=2 | collection_workers=1 | dataloader_workers=2
Model parameters: 466,562

Q-Sequence Training: q_sequence_cwe_ctx13
Output: runs_colab/q_sequence_cwe_ctx13_s0_20260122_153027
Warmup rounds (SPT completion): 5
Completion policy: mix | mix_prob_start=1.0 | mix_prob_end=1.0

Round 1/50 [SPT]: Collecting... 100071 transitions in 307.3s
  Training (100000 in buffer)... loss=549.9055, mae=550.40, listwise=0.0000 in 99.7s
Round 2/50 [SPT]: Collecting... 100847 transitions in 321.5s
  Training (100000 in buffer)... loss=359.1475, mae=359.65, listwise=0.0000 in 100.8s
Round 3/50 [SPT]: Collecting... 100072 transitions in 324.6s
  Training (100000 in buffer)... loss=353.6341, mae=354.13, listwise=0.0000 in 100.9s
Round 4/50 [SPT]: Collecting... 100179 transitions in 326.5s
  Training (100000 in buffer)... loss=268.9048, mae=269.40, listwise=0.0000 in 101.1s
Round 5/50 [SPT]: Collecting... 102595 transitions in 325.5s
  Training (10000

: 

## Find the latest checkpoint
This grabs the most recent `latest.pt` under the output directory.

In [None]:
import glob

ckpts = glob.glob(f"{OUT_DIR}/**/checkpoints/latest.pt", recursive=True)
if not ckpts:
    raise FileNotFoundError(f"No latest.pt found under {OUT_DIR}/")

ckpts_sorted = sorted(ckpts, key=lambda p: Path(p).stat().st_mtime)
CKPT = ckpts_sorted[-1]
print("Using checkpoint:", CKPT)

FileNotFoundError: No latest.pt found under runs_colab/

## Evaluate (Greedy / SGBS / EAS / SGBS+EAS)
Choose the script based on the variant family:
- `run_eval_eas_family_q4_beststart` for `ppo_family_*_beststart_*`
- `run_eval_eas_duration_aware` for `ppo_duration_aware_*`

In [None]:
EVAL_SEED = 42
NUM_INSTANCES = 16
SCALE = "small"  # small|medium|large
EPS_STEPS = 5

METHOD = "sgbs_eas"  # eas|sgbs_eas
BETA = "4"
GAMMA = "4"
MAX_ITERS = 50
EAS_LR = 0.003
EAS_IL = 0.01
SAMPLES_PER_ITER = 32

if "duration_aware" in VARIANT_ID:
    eval_mod = "PaST.run_eval_eas_duration_aware"
else:
    eval_mod = "PaST.run_eval_eas_family_q4_beststart"

eval_cmd = [
    sys.executable,
    "-m",
    eval_mod,
    "--checkpoint",
    CKPT,
    "--variant_id",
    VARIANT_ID,
    "--eval_seed",
    str(EVAL_SEED),
    "--num_instances",
    str(NUM_INSTANCES),
    "--scale",
    SCALE,
    "--epsilon_steps",
    str(EPS_STEPS),
    "--method",
    METHOD,
    "--beta",
    BETA,
    "--gamma",
    GAMMA,
    "--max_iterations",
    str(MAX_ITERS),
    "--eas_lr",
    str(EAS_LR),
    "--eas_il_weight",
    str(EAS_IL),
    "--samples_per_iter",
    str(SAMPLES_PER_ITER),
]
print(" ".join(eval_cmd))

In [None]:
!{' '.join(eval_cmd)}