Skip to content

aprv10/Snake_MLP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🐍 Snake-MLP

A 36K parameter model that plays Snake — trained on a BFS expert bot in under 5 minutes.

Python 3.10+ PyTorch License: MIT

Why This Works

Large language models are terrible at Snake. They reason in words, not reflexes. They take 600ms to decide to turn left. By the time they respond, the snake is already dead.

This model takes ~1ms per decision on CPU. It doesn't think — it pattern-matches, exactly like a human playing on instinct after enough practice.

The approach is called behavioral cloning: run a BFS expert that always knows the shortest path to food, record every (game state → action) pair it produces, train a classifier on those pairs. The result is a model that has distilled the expert's decision-making into 36,355 numbers.


Benchmark

Agent Avg Score Max Score Latency
Snake-MLP (ours) ~26 53+ ~1ms
Random agent ~1 4 <1ms
GPT-4o (text-only) ~800ms (+ dead on arrival)

Tested on a 20×20 grid, 50 evaluation episodes.


AI (left) vs BFS Expert (right)

AI (left) vs BFS Expert (right)

Architecture

The model is a 3-layer MLP. That's it.

Input (11 features)
    │
Linear(11 → 256) + ReLU + Dropout(0.15)
    │
Linear(256 → 128) + ReLU
    │
Linear(128 → 3)
    │
Output: [straight, turn right, turn left]

Total parameters: 36,355
Model size on disk: ~145KB

The 11 Features

Rather than feeding raw pixels, the state is compressed into 11 binary values:

Feature Description
danger_straight Collision if we continue forward
danger_right Collision if we turn right
danger_left Collision if we turn left
dir_left / right / up / down Current heading (one-hot)
food_left / right / up / down Relative food direction

No convolutions. No attention. No tokenization. Just 11 bits and a few matrix multiplies.

Pipeline

Same 3-step structure as the DOOM paper:

1. collect   →   BFS expert plays 1000 games, logs (state, action) pairs
2. train     →   MLP trains on those pairs with cross-entropy loss (~50 epochs)
3. play      →   Trained model runs in real-time in a Pygame window

Quick Start

# 1. Clone and install
git clone https://github.com/yourusername/snake-mlp
cd snake-mlp
pip install -r requirements.txt

# 2. Train (collects data + trains in one command)
python train.py --episodes 1000 --epochs 50

# 3. Watch it play
python play.py

# 4. Side-by-side: AI vs Expert
python play.py --compare --fps 20

# 5. Adjust speed
python play.py --fps 30

GPU Training

If you have a CUDA GPU, the training loop automatically uses it:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Tip: bump batch_size to 4096 or higher when using GPU — the model is tiny and you need large batches to saturate the GPU and actually see a speedup.


Project Structure

snake-mlp/
├── snakenv.py      # Headless Snake game environment, returns 11-feature state
├── data.py         # BFS expert bot — the data generator
├── train.py          # Collect → train → evaluate pipeline
├── play.py           # Pygame visualizer (--compare mode included)
├── requirements.txt
└── assets/           # Put your demo GIF and screenshots here

Training Details

Data collection: The BFS expert runs 1000 episodes on a 20×20 grid. Each episode generates ~200–1000 (state, action) transitions. A full 1000-episode run produces ~1M labeled pairs.

Class imbalance: The expert mostly goes straight. To prevent the model from learning "always go straight", the training loss is weighted inversely by class frequency:

counts  = np.bincount(y)
weights = 1.0 / counts
criterion = nn.CrossEntropyLoss(weight=torch.tensor(weights / weights.sum()))

Training time: ~2–5 minutes on CPU, under a minute on GPU.


Extending This

The same pipeline works for any game where:

  • The state can be captured as a fixed-length feature vector
  • Actions are discrete (a few choices per step)
  • A scripted expert or human demo is feasible

Good next targets: Flappy Bird (Gymnasium), Tetris (PyBoy), LunarLander (Gymnasium). Replace snakenv.py with your environment, rewrite data.py with a domain-appropriate heuristic, and train.py stays almost entirely the same.


License

MIT

About

36K parameter MLP that plays Snake using behavioral cloning. Trains in 5 minutes on CPU. No RL, no rewards — just a BFS expert and a classifier.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages