Course project demonstrating three distinct AI learning paradigms on the CartPole-v1 environment.
A cart sits on a frictionless track. A pole is hinged to the top of the cart. On every timestep you push the cart left or right. The goal is to keep the pole upright for as long as possible. An episode ends when:
- The pole angle exceeds 12 from vertical, or
- The cart moves more than 2.4 units from centre, or
- 500 timesteps are reached (perfect score = 500)
The 4-dimensional state vector fed to every model:
| Index | Feature | Range |
|---|---|---|
| 0 | Cart position | 2.4 |
| 1 | Cart velocity | |
| 2 | Pole angle (rad) | 0.21 |
| 3 | Pole angular velocity |
Pure reinforcement learning. No human input. The agent uses trial-and-error and a neural network Q-function to learn which action maximises future reward.
Key DQN concepts implemented:
- Experience replay buffer — stores past (s, a, r, s, done) tuples, samples random mini-batches to break temporal correlations
- Target network — a frozen copy of the Q-network updated every N steps; stabilises training
- Epsilon-greedy exploration — starts nearly random (ε1), decays to ε=0.05 so the agent exploits what it has learned
- Custom Q-network architecture — configurable MLP hidden layers via --net-arch
- EvalCallback — evaluates the deterministic policy on a separate env every 5 000 steps, completely isolated from exploration noise
- TensorBoard logging — logs loss, Q-values, episode stats to ./tb_logs/
- Global seeds — andom, umpy, orch, and SB3 all seeded for reproducible runs
Result achieved: 500/500 reward (maximum possible) from ~95 000 steps. Training time 2.5 minutes on CPU.
An interactive Pygame game where a human plays CartPole with keyboard controls. Every (state action) pair from every step is recorded and written to human_data.pkl and human_data.csv.
What the game does:
- Opens two windows: a stats panel (600500 Pygame) and the CartPole physics window (Gymnasium)
- Stats panel shows: high score, rolling average score, episode counter, total steps recorded
- Live pole danger bar — a colour gauge (green red) showing the current pole angle as a % of the 12 failure limit, so you can see how critical the situation is in real time
- 3-2-1 countdown before each episode (2.9 s total — fast enough to not be annoying)
- Episode result screen (AMAZING / EXCELLENT / GOOD / NOT BAD / FELL) — waits for SPACE before continuing
- Data is saved after every episode — quitting early with Q loses nothing
- After all episodes, offers to immediately train an imitation model on your data and watch it play
Controls:
| Key | Action |
|---|---|
| A / LEFT | Push cart left |
| D / RIGHT | Push cart right |
| SPACE | Start next episode |
| Q / ESC | Quit |
Behavioural cloning — pure supervised learning on the recorded human gameplay. No reward signal. No environment interaction during training.
How it works:
human gameplay (state, action) pairs train classifier policy: state action
At inference the CartPole env gives 4 numbers classifier predicts left or right action taken. It copies your decision-making, so quality scales directly with how good and how much data you provide.
Three classifiers:
| Model | Best for |
|---|
| eural_network | 300 samples; architecture auto-scales to data size | | logistic_regression | Small datasets (<200 samples); more stable generalisation | | andom_forest | Good all-rounder; built-in feature importance |
Training pipeline:
- Loads human_data.pkl, extracts states + actions
- Data sufficiency check — warns if <200 samples, refuses training if <50; prints the majority-class baseline (always predict most common action) so you know the trivial benchmark
- StandardScaler — normalises all 4 features to zero mean/unit variance; fitted only on the training split to prevent data leakage into the test set
- Auto-scaled MLP architecture — prevents overfitting: (32,16) for <300 samples, (64,32) for <1 000, (128,64) for larger sets
- early_stopping=True — MLP training stops when validation loss plateaus rather than hammering a fixed iteration limit
- 5-fold cross-validation inside a sklearn Pipeline — each fold scales independently; gives reliable std accuracy estimates
- Prints baseline vs model accuracy with a pass/fail indicator
- Analysis plot: feature importance (or weights), human action distribution, MLP training loss curve
�ash pip install -r requirements.txt
Step 1: Train the DQN agent (~2.5 min)
�ash python cartpole_dqn.py --train --timesteps 100000
Step 2: Play the game and collect data (~5–10 min)
�ash python cartpole_human.py --episodes 20
Play at least 20 episodes for the imitation model to be meaningful.
Step 3: Watch the trained DQN play
�ash python cartpole_dqn.py --test
Step 4: Train imitation models and compare all three
�ash python cartpole_imitation.py --model compare --evaluate
`�ash
python cartpole_dqn.py --train
--timesteps 100000
--lr 1e-3
--buffer-size 50000
--batch-size 32
--gamma 0.99
--exploration-fraction 0.1
--final-eps 0.05
--target-update 500
--net-arch 256 256
--seed 42
python cartpole_dqn.py --test --no-video
tensorboard --logdir tb_logs `
�ash python cartpole_human.py --episodes 20 # play 20 episodes python cartpole_human.py --auto-demo # skip prompt, auto-train AI python cartpole_human.py --no-demo # skip AI demo entirely python cartpole_human.py --output my_data.pkl # custom save path
�ash python cartpole_imitation.py --model neural_network --evaluate python cartpole_imitation.py --model logistic_regression --evaluate python cartpole_imitation.py --model compare # train all 3, pick best python cartpole_imitation.py --model neural_network --save my_model.pkl python cartpole_imitation.py --load my_model.pkl --evaluate --no-render
The plot shows:
- Faded line — raw per-episode reward (noisy due to epsilon exploration)
- Solid line — 50-episode rolling mean
- Shaded band — 1 standard deviation
- Lower panel — epsilon decay from ~1.0 0.05
Deterministic policy (EvalCallback) reached 500/500 by ~95 000 steps.
| File / Folder | Created by | Description |
|---|---|---|
| dqn_cartpole.zip | cartpole_dqn.py --train | Trained DQN model weights |
| �est_model/ | cartpole_dqn.py --train | Best checkpoint (EvalCallback) |
| eward_curve.png | cartpole_dqn.py --train | Reward + epsilon decay plot | | b_logs/ | cartpole_dqn.py --train | TensorBoard event files | | eval_logs/ | cartpole_dqn.py --train | Per-checkpoint evaluation results | | human_data.pkl | cartpole_human.py | Recorded gameplay (binary) | | human_data.csv | cartpole_human.py | Recorded gameplay (readable CSV) | | imitation_analysis_*.png | cartpole_imitation.py | Feature importance + loss curve | | �ideos/ | cartpole_dqn.py --test | MP4 recordings of test episodes |
NNRL Project/ cartpole_dqn.py # DQN agent: train, test, full hyperparameter CLI cartpole_human.py # Interactive game: human play + data collection cartpole_imitation.py # Behavioural cloning from recorded human data convert_to_gif.py # Convert videos/mp4 to gif for sharing test_human_control.py # Standalone test for Pygame/gym setup requirements.txt # All Python dependencies with minimum versions README.md # This file dqn_cartpole.zip # Pre-trained DQN model (500/500 reward) reward_curve.png # Training reward + epsilon decay plot imitation_analysis_neural_network.png human_data.pkl # Example recorded gameplay (binary) human_data.csv # Same data as readable CSV videos/ # Recorded test episode MP4s best_model/ # Best DQN checkpoint (auto-saved during training) tb_logs/ # TensorBoard training logs .venv/ # Python virtual environment
| Human Play | Imitation Learning | DQN | |
|---|---|---|---|
| Data source | Your keyboard | Your recordings | Environment reward signal |
| Training time | Instant (you play) | Seconds–minutes | ~2.5 min on CPU |
| Reward achieved | ~10–50 steps (beginner) | Mirrors your skill | 500/500 (optimal) |
| Learns from mistakes | You do, manually | No | Yes |
| Needs human demos | Yes (you are the demo) | Yes | No |
| Key ML concept | Manual policy | Supervised learning / classification | Off-policy RL + neural Q-function |
Python 3.8+ — install everything with:
�ash pip install -r requirements.txt
Core packages: gymnasium[classic-control] stable-baselines3 orch scikit-learn pygame matplotlib ensorboard
ich
