# Colab: PLE Data Prep + Training

This notebook is meant to be copy-pasted into Google Colab to:

1) Clone this repo and install dependencies
2) Prepare the Census-Income (KDD) dataset
3) Train the 2-level PLE model end-to-end

In [None]:
# Clone repo, install dependencies, and make src importable (Colab-friendly)
try:
    import google.colab  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False

repo_url = 'https://github.com/allyoushawn/recsys_playground.git'
repo_dir = 'recsys_playground'

import os, sys
if IN_COLAB:
    if not os.path.exists(repo_dir):
        !git clone $repo_url
    %cd $repo_dir
    !pip -q install -r requirements.txt
    if os.path.exists('ple_experiment/requirements.txt'):
        !pip -q install -r ple_experiment/requirements.txt
    src_path = os.path.abspath('ple_experiment')
    if src_path not in sys.path:
        sys.path.insert(0, src_path)
else:
    repo_root = os.getcwd()
    src_path = os.path.join(repo_root, 'ple_experiment')
    if src_path not in sys.path:
        sys.path.insert(0, src_path)


## Prepare Census-Income (KDD)

In [None]:
# Generates X_*.npy, y_*.npy, and feature_meta.json under ./data/census_kdd
output_dir = './data/census_kdd'
!python ple_experiment/prepare_census_income.py \
        --output_dir $output_dir \
        --test_size 0.15 \
        --val_size 0.10 \
        --batch_size 4096 \
        --num_workers 2 \
        --onehot_min_freq 10 \
        --seed 42


## Train PLE (2 levels)

In [None]:
# Trains PLE and writes checkpoints + metrics to ./runs/ple_census
out_dir = './runs/ple_census'
!python ple_experiment/train_ple.py \
        --data_dir $output_dir \
        --out_dir $out_dir \
        --epochs 5 \
        --batch_size 4096 \
        --num_workers 2 \
        --lr 2e-3 \
        --weight_decay 1e-4 \
        --d_model 128 \
        --expert_hidden 256 \
        --num_levels 2 \
        --num_shared_experts 2 \
        --num_task_experts 2 \
        --dropout 0.1 \
        --w_income 1.0 \
        --w_never_married 1.0 \
        --use_pos_weight true \
        --grad_clip 1.0 \
        --mixed_precision true \
        --seed 42


In [None]:
# Inspect outputs: show test report and tail metrics
import json, os, itertools
print('Artifacts under', out_dir)
print(sorted(os.listdir(out_dir)))
with open(f'{out_dir}/test_report.json','r') as f:
    rep = json.load(f)
rep
