A complete pipeline for training consumption prediction models using rational addiction theory. The system combines neural factorization machines with reinforcement learning to capture user behavior patterns.
Here's the typical workflow from raw data to trained model:
cd wide_deep/
python filt_data.py --input raw_data.csv --output_dir data/
python kfold_predict.py --input data/train.csv --output data/train_with_scores.csv
python personalize.py --input data/train_with_scores.csv --output user_params.csvcd ../verl/addiction_theory_training/sft/
python prepare_sft_data.py --input ../../wide_deep/data/train_with_scores.csv --output sft_data/
bash run_kuairec_sft.shcd ../GRPO_personalize/
python prepare_grpo_data.py --input_dir ../../wide_deep/data/ --output_dir GRPO_data/
bash train_grpo.shcd ../eval/
python eval_by_session.py --model_path ../outputs/checkpoint/ --data ../data/test.parquetAddictSim/
├── wide_deep/ # Part 1: NFM & Personalization
│ ├── model.py # Neural Factorization Machine
│ ├── train.py # Training script
│ ├── kfold_predict.py # K-fold cross-validation
│ ├── personalize.py # Fit user parameters
│ ├── data_preprocessor.py # Feature preprocessing
│ ├── convert_kuairec.py # Dataset converter
│ ├── filt_data.py # Data filtering
│ ├── user_sampler.py # User sampling utility
│ └── val.py # Validation utilities
│
└── verl/addiction_theory_training/ # Part 2: RL Training
├── data_process/ # Data preprocessing
├── sft/ # Supervised fine-tuning
├── GRPO/ # Group Relative Policy Optimization
├── GRPO_personalize/ # Personalized GRPO
├── PPO/ # Proximal Policy Optimization
└── eval/ # Evaluation scripts
This part trains a Neural Factorization Machine on user-item interactions and fits economic parameters for each user.
The NFM uses bi-interaction pooling to capture feature combinations efficiently. It's basically a neural network that learns how different features interact with each other.
Main training script that handles:
- Negative sampling (generates non-interaction examples)
- Feature encoding (converts categorical features to numbers)
- Model training with dropout and regularization
Usage:
python train.pyTrains multiple models using cross-validation and generates NFM scores for each interaction. This helps avoid overfitting and gives more reliable scores.
Usage:
python kfold_predict.py --input data/input.csv --output data/output_with_scores.csv --k 5The key part - fits individual utility function parameters (a, b, c) for each user based on their history. These parameters capture how sensitive each user is to:
- Base consumption (a)
- Fatigue from overconsumption (b)
- Addiction/habit effects (c)
Usage:
python personalize.py --input train.csv --output user_params.csv --delta 0.9Unified preprocessing for all features. Handles both categorical encoding and numerical binning.
Converts KuaiRec dataset format to our format. Maps fields like play_duration → watch_time.
Usage:
python convert_kuairec.py --input grpo_kuairec.csv --output converted.csvFilters users by interaction count and splits into train/test sets. Removes users with too few interactions (can't fit parameters) or too many (likely bots).
Usage:
python filt_data.py --input raw_data.csv --output_dir filtered_data/Randomly samples N users for testing or creating smaller datasets.
Usage:
python user_sampler.py --input data.csv --output sampled.csv --n_users 100Base validator for rational addiction theory. Handles session partitioning and addiction stock calculations.
This part uses the user parameters from Part I to train language models with reinforcement learning. The models learn to predict consumption values based on session context.
The training has 5 stages:
- Data Processing - Structure the data for training
- Supervised Fine-tuning (SFT) - Train with labeled data
- GRPO Training - Optimize with group-based RL
- Personalized GRPO - Use user-specific parameters
- Evaluation - Measure performance
data_processor.py - Base processor for interaction data
- Partitions sessions
- Extracts features
- Computes addiction stock dynamics
kuairec_processor.py - Specialized for KuaiRec dataset
Usage:
python data_processor.py --input raw_data.csv --output processed/prepare_sft_data.py - Creates instruction-tuning format
- Builds prompt-response pairs
- Output format:
{"consumption_reward": 1.1139}\n<END>
run_kuairec_sft.sh - Launch SFT training with LoRA
Usage:
python prepare_sft_data.py --input ../data/train.csv --output sft_data/
bash run_kuairec_sft.shGroup Relative Policy Optimization - generates multiple candidates and learns from relative preferences.
Key files:
reward_function.py- Computes intertemporal utility rewardssession_batch_sampler.py- Batches data while preserving session orderprepare_grpo_data.py- Converts SFT data to GRPO formattrain_grpo.sh- Main training script
Usage:
python prepare_grpo_data.py --input_dir ../data/ --output_dir GRPO_data/
bash train_grpo.shImportant parameters:
DISCOUNT_FACTOR(ρ) - How much to discount future rewards (default: 0.9)ROLLOUT_N- Number of candidates per prompt (default: 4)DECAY_RATE(δ) - How fast habits decay (default: 0.7)REINFORCEMENT_COEFF(α) - How consumption builds habits (default: 1.0)
Same as GRPO but uses individual user parameters from Part I.
Key differences:
personalized_params_manager.py- Loads user-specific (a, b, c) valuesreward_function.py- Computes rewards using personalized utilities
The model automatically loads parameters from data/delta0.9.csv (generated by Part I).
Usage:
bash train_grpo.shProximal Policy Optimization - simpler alternative to GRPO. Reuses most GRPO components but uses single-candidate mode and GAE advantage estimation.
Usage:
bash train_ppo.sheval.py - Standard evaluation
- Uses vLLM for fast inference
- Reports MAE, RMSE, correlation
eval_by_session.py - Session-level evaluation
- Aggregates predictions per session
- Better for understanding real-world performance
Usage:
python eval_by_session.py --model_path ../outputs/checkpoint/ --data ../data/test.parquetThe framework implements the Becker-Murphy rational addiction model. Here's the core idea:
Utility at time t:
U(Ct, St) = a·Ct + c·St·Ct - 0.5·b·Ct²
This says utility comes from consumption (a·Ct), boosted by habit effects (c·St·Ct), but with diminishing returns (-0.5·b·Ct²).
Habit dynamics:
S_{t+1} = (1-δ)·St + α·Ct
Habits decay over time (δ) but are reinforced by consumption (α).
Total value of a session:
Total = Σ ρ^t × U(Ct, St)
Sum up utilities across time steps, with future rewards discounted by ρ.
Variables:
Ct- Consumption at time t (we use log10(1 + watch_time))St- Addiction stock (habit level) at time ta, b, c- User-specific coefficients (fitted in Part I)δ- Decay rate (how fast habits fade)α- Reinforcement coefficient (how much consumption builds habits)ρ- Discount factor (patience/farsightedness)
Personalization Each user gets their own (a, b, c) parameters fitted from their historical data. This captures individual differences in how people respond to content.
Session-Aware The system respects natural viewing sessions. Videos in the same session are never split across training batches, and rewards propagate backward through the session.
Efficient Training
- Uses LoRA for parameter-efficient fine-tuning
- Supports FSDP for distributed training
- Fast inference with vLLM
Flexible Both GRPO (multi-candidate) and PPO (single-candidate) modes are supported. All parameters can be configured via shell scripts.
You can adjust behavior by editing the shell scripts or passing environment variables:
Reward function:
base_weight- Base consumption weight (a)addiction_weight- Habit interaction weight (c)fatigue_weight- Diminishing returns weight (b)discount_factor- Time discount (ρ)decay_rate- Habit decay (δ)reinforcement_coeff- Habit growth (α)
Training:
LORA_RANK- LoRA rank (default: 16)BATCH_SIZE- Sessions per batchROLLOUT_N- Candidates per prompt (4 for GRPO, 1 for PPO)SESSION_LEN- Max session length (default: 20)
# 1. Filter and prepare data
cd wide_deep/
python filt_data.py --input raw_data.csv --output_dir data/
# 2. Train NFM and get scores
python kfold_predict.py --input data/train.csv --output data/train_with_scores.csv
# 3. Fit user parameters
python personalize.py --input data/train_with_scores.csv --output ../verl/addiction_theory_training/data/delta0.9.csv --delta 0.9
# 4. Prepare and run SFT
cd ../verl/addiction_theory_training/sft/
python prepare_sft_data.py --input ../data/train.csv --output sft_data/
bash run_kuairec_sft.sh
# 5. Train with personalized GRPO
cd ../GRPO_personalize/
python prepare_grpo_data.py --input_dir ../data/ --output_dir GRPO_data/
bash train_grpo.sh
# 6. Evaluate
cd ../eval/
python eval_by_session.py --model_path ../outputs/checkpoint/ --data ../data/test.parquet