This project implements a transformer-based poker agent that:
- Understands poker hand history using tokenized sequences and self-attention
- Models opponents through separate encoding branches and attention pooling
- Adapts across hands using persistent opponent memory (fingerprints)
- Trains efficiently with mixed-precision (AMP) on GPU or CPU
- Villain Archetypes to train the model on
POKER HAND OPPONENT ACTIONS
↓ ↓
[Hero Encoder] [Opponent Encoder]
↓ ↓
160-dim [Attention Pool]
↓ ↓
└─────────[Concatenate]─────┘
↓
[Policy Head (16 actions)]
↓
ACTION LOGITS
Key numbers:
- Model size: 4M parameters (~16MB weights)
- Hero encoder: 4 layers, 8 attention heads, 160-dim
- Opponent encoder: 2 layers, 4 heads, 64-dim
- Inference: ~50ms per batch (GPU)
- Typical accuracy: 55-70% on unseen opponent mixtures
This project uses uv for dependency management (install here).
# Clone and navigate
git clone https://github.com/Lucadmin/poker-attention.git
cd poker-attention
# Install dependencies (creates virtual environment automatically)
uv sync# Test the Python environment
uv run python -c "import torch; print(f'PyTorch {torch.__version__}')"
# Test the frontend
cd frontend && npm install && cd ..Train from generated data or a pre-generated dataset.
Option A: Train from pre-generated dataset (faster)
# Step 1: Generate a benchmark-aligned dataset (one-time)
uv run python -m poker_attention.training.data_generator \
--out training_data_16.npz \
--num-actions 16 \
--num-players 8 \
--opponent-archetypes all \
--num-examples 50000 \
--seed 0
# Step 2: Train from that dataset on GPU with AMP
uv run python -m poker_attention.training.train_supervised \
--data training_data_16.npz \
--num-actions 16 \
--device cuda \
--amp \
--epochs 3 \
--batch-size 32 \
--save model_supervised.ptOption B: On-the-fly generation + training (slower)
uv run python -m poker_attention.training.train_supervised \
--generate \
--num-actions 16 \
--num-players 8 \
--opponent-archetypes all \
--num-examples 20000 \
--epochs 3 \
--batch-size 128 \
--device cuda \
--amp \
--seed 0 \
--save model_supervised.ptWith tilt opponents:
uv run python -m poker_attention.training.train_supervised \
--generate \
--num-actions 16 \
--num-players 8 \
--opponent-archetypes all \
--tilt-fluctuation 0.08 \
--tilt-randomness 0.03 \
--tilt-retain 0.90 \
--tilt-baseline 0.50 \
--num-examples 20000 \
--epochs 3 \
--batch-size 128 \
--device cuda \
--amp \
--save model_supervised.ptTrain using PPO with self-play and population of opponents.
# CPU smoke test
uv run python -m poker_attention.training.train_rl \
--total-updates 2 \
--rollout-hands 10 \
--num-players 6
# Full CUDA training
uv run python -m poker_attention.training.train_rl \
--device cuda \
--amp \
--total-updates 200 \
--rollout-hands 200 \
--opponent-archetypes tight,loose \
--save model_rl.ptEvaluate model performance with optional memory ablation.
# Standard evaluation (with opponent memory)
uv run python -m poker_attention.cli.eval \
--checkpoint model_supervised.pt \
--num-hands 200 \
--num-players 8 \
--opponent-archetypes all
# Memory ablation (how much does opponent memory help?)
uv run python -m poker_attention.cli.eval \
--checkpoint model_supervised.pt \
--num-hands 200 \
--num-players 8 \
--opponent-archetypes all \
--ablate-memory
# Custom opponent mix
uv run python -m poker_attention.cli.eval \
--checkpoint model_supervised.pt \
--num-hands 500 \
--opponent-archetypes tight,loose,aggressive \
--num-players 6# Play interactively
uv run python -m poker_attention.cli.play --checkpoint model_supervised.pt
# Inspect model tokens/embeddings
uv run python -m poker_attention.cli.tokens --checkpoint model_supervised.pt
# Demo (batch inference)
uv run python -m poker_attention.cli.demo --checkpoint model_supervised.pt
# Infer on custom scenarios
uv run python -m poker_attention.cli.infer \
--checkpoint model_supervised.pt \
--hand "AK" \
--position "UTG" \
--num-opponents 3# Monitor training live with TensorBoard
tensorboard --logdir runs
# Generate PNG reports from saved runs
uv run python -m poker_attention.cli.plot_runs --run-dir runs/<RUN_DIR>
# Generalization analysis
uv run python -m poker_attention.cli.generalize --checkpoint model_supervised.ptNote: All training/eval CLIs write artifacts to runs/<timestamp>/ and print a run_dir: line that survives terminal loss or crashes.
Run the FastAPI web server for remote inference and interactive sessions.
# Development server (with auto-reload)
uv run python -m poker_attention.web.server --reload --port 8000
# Production server
uv run python -m poker_attention.web.server --host 0.0.0.0 --port 8000Server features:
- REST API for inference on poker hands
- WebSocket support for multi-hand sessions
- Opponent memory state management
- Real-time action recommendations
API endpoints:
POST /api/predict— Get action logits for a poker stateWS /ws/session— Interactive poker sessionGET /health— Health check
The React + TypeScript web interface for visualization and interactive play.
cd frontend
# Install dependencies
npm install
# Start dev server (hot reload)
npm run dev
# Runs on http://localhost:5173
# Lint
npm run lint
# Build for production
npm run buildcd frontend
# Build static assets
npm run build
# Preview production build locally
npm run previewFrontend features:
- Real-time poker table visualization
- Interactive action selection
- Opponent adaptation visualization
- Attention weight heatmaps
- Game history and statistics
poker-attention/
├── README.md # This file
├── pyproject.toml # Python dependencies & config
│
├── src/poker_attention/
│ ├── models/
│ │ ├── transformer_opponent_model.py # Main architecture
│ │ ├── opponent_encoder.py # Opponent branch
│ │ ├── opponent_pooling.py # Attention pooling
│ │ ├── token_embedding.py # Token embeddings
│ │ └── actor_critic.py # Actor-critic for RL
│ │
│ ├── training/
│ │ ├── train_supervised.py # Supervised learning (main script)
│ │ ├── train_rl.py # PPO self-play training
│ │ ├── train_opponent.py # Opponent model training
│ │ ├── data_generator.py # Generate training examples
│ │ └── league.py # Population/league management
│ │
│ ├── cli/
│ │ ├── eval.py # Evaluation script (BB/100, ablations)
│ │ ├── infer.py # Single inference
│ │ ├── play.py # Interactive play
│ │ ├── demo.py # Batch demo inference
│ │ ├── tokens.py # Token inspection
│ │ ├── generalize.py # Generalization analysis
│ │ └── plot_runs.py # Generate PNG reports
│ │
│ ├── agents/
│ │ ├── model_policy.py # Model-based policy wrapper
│ │ ├── opponent_policies.py # Fixed opponent archetypes
│ │ ├── opponent_memory_store.py # Cross-hand opponent memory
│ │ └── model_league.py # Multi-model tournament
│ │
│ ├── envs/
│ │ ├── tokenization.py # Token encoding/decoding
│ │ ├── legal_actions.py # Action masking
│ │ └── __init__.py # PokerKit wrappers
│ │
│ ├── evaluation/
│ │ └── run_logger.py # TensorBoard + artifact logging
│ │
│ ├── visualization/
│ │ └── (attention, embedding viz)
│ │
│ └── web/
│ ├── app.py # FastAPI app
│ ├── server.py # Server entry point
│ ├── session.py # Session management
│ └── events.py # WebSocket events
│
├── frontend/
│ ├── package.json
│ ├── src/
│ │ ├── App.tsx # Main app component
│ │ ├── components/ # React components
│ │ └── hooks/ # Custom React hooks
│ ├── tsconfig.json
│ ├── vite.config.ts
│ └── tailwind.config.js
│
├── tests/
│ └── (unit & integration tests)
│
├── old_data/
│ ├── docs/ # Architecture documentation
│ ├── rl_checkpoints/ # RL model snapshots
│ └── (legacy training data)
│
└── .github/
└── copilot-instructions.md
Edit ModelConfig in src/poker_attention/models/transformer_opponent_model.py:
ModelConfig(
d_model=160, # Hero embedding dimension
n_heads=8, # Hero attention heads
n_layers=4, # Hero transformer depth
num_actions=16, # Action space (must match data!)
max_seq_len=512, # Max poker hand length
opponent_max_seq_len=128, # Max opponent action history
)Opponent encoder is hardcoded but can be customized in opponent_encoder.py.
Common settings for train_supervised.py:
--lr 0.0003 # Learning rate (conservative!)
--weight-decay 0.01 # L2 regularization
--grad-clip 1.0 # Gradient clipping threshold
--epochs 3 # Number of epochs
--batch-size 32 # Batch size (increase if more VRAM)
--device cuda # cuda or cpu
--amp # Mixed precision training
--seed 0 # Random seedRL hyperparameters for train_rl.py:
--ppo-lr 0.0003 # Actor learning rate
--ppo-value-lr 0.001 # Critic learning rate
--ppo-clip-ratio 0.2 # PPO clip range
--ppo-epochs 3 # PPO optimization passes
--ppo-batch-size 32 # PPO batch size
--gamma 0.99 # Discount factor
--gae-lambda 0.95 # GAE lambda
--rollout-hands 200 # Hands per rollout
--total-updates 200 # Total PPO updates--num-examples 50000 # Training examples
--num-players 8 # Players per hand
--num-actions 16 # Discrete action bins
--opponent-archetypes all # Opponent types: tight, loose, aggressive, tilt, all
--use-memory # Enable cross-hand opponent memory
--tilt-fluctuation 0.08 # Tilt volatility
--tilt-randomness 0.03 # Tilt stochasticity
--tilt-retain 0.90 # Tilt persistence
--tilt-baseline 0.50 # Tilt baseline aggression# 1. Generate data once (reusable)
uv run python -m poker_attention.training.data_generator \
--out training_data.npz \
--num-examples 50000 \
--num-actions 16 \
--opponent-archetypes all
# 2. Train model
uv run python -m poker_attention.training.train_supervised \
--data training_data.npz \
--num-actions 16 \
--device cuda \
--amp \
--epochs 3 \
--save model_v1.pt
# 3. Evaluate with memory
uv run python -m poker_attention.cli.eval \
--checkpoint model_v1.pt \
--num-hands 500
# 4. Ablate memory (measure opponent learning)
uv run python -m poker_attention.cli.eval \
--checkpoint model_v1.pt \
--num-hands 500 \
--ablate-memory
# 5. Monitor results
tensorboard --logdir runsAll runs save to runs/<timestamp>/:
runs/2025-01-24_12-34-56_supervised_16actions/
├── events.out.tfevents.123... # TensorBoard events
├── config.json # Hyperparameters
├── metrics.jsonl # Per-epoch metrics
├── model_checkpoint.pt # Final model weights
└── log.txt # Training log
Poker events are tokenized (encoded as integers) rather than using word embeddings:
- Card tokens: Encode player, card rank, suit
- Action tokens: Encode player, street, action type, bet size
- Marker tokens: Encode table state transitions
Benefits: Smaller model (~4M params), captures poker structure, efficient computation.
The opponent encoder learns a 64-dimensional "fingerprint" that compresses opponent behavior. This fingerprint:
- Updates across hands using persistent memory
- Feeds into attention pooling to weight opponent importance
- Enables rapid adaptation to new opponents
- Is independent of the hero's hand
Both hero and opponent branches receive gradients from the same loss signal (CrossEntropy):
Loss ↓
├─→ Hero path: hand understanding improves
└─→ Opponent path: opponent modeling improves
Managed by uv (pyproject.toml):
- PyTorch (with CUDA support)
- PokerKit (poker environment)
- NumPy (numerical computing)
- TensorBoard (training visualization)
- FastAPI + Uvicorn (web server)
- Rich (CLI formatting)
- Scikit-learn, UMAP, Matplotlib (analysis & viz)
Managed by npm (frontend/package.json):
- React 19 + TypeScript
- Vite (build tool)
- Tailwind CSS (styling)
- Vite React Plugin (HMR)
Happy training! 🎰
uv run python -m poker_attention.cli.eval \
--checkpoint tmp_model_16_benchmarktrain.pt \
--num-hands 500 \
--num-players 8 \
--opponent-archetypes all \
--ablate-memory \
--num-runs 10uv run python -m poker_attention.training.train_rl \
--init-checkpoint rl_checkpoints_bigtest/rl_actor_000010.pt \
--num-actions 16 \
--device cuda --opponent-device cuda --amp \
--save-dir rl_checkpoints --league-dir rl_league \
--snapshot-every 10 \
--time-limit-minutes 30This checks whether opponent context changes logits and/or the chosen action (masked argmax).
uv run python -m poker_attention.cli.eval \
--checkpoint tmp_model_16_benchmarktrain.pt \
--num-hands 200 \
--num-players 8 \
--opponent-archetypes all \
--diagnose-memory-sensitivity \
--num-runs 10
## Web demo (FastAPI backend + React/TypeScript/Tailwind frontend)
This repo includes an interactive demo UI that streams:
- the live action log (seat-by-seat)
- the hero policy probabilities at decision time
- the live token stream in a side panel
### 1) Install backend dependencies
Using `uv` (recommended):
```bash
uv sync --extra webuv run python -m poker_attention.web.server --reload --port 8000cd frontend
npm install
npm run devOpen the UI at http://localhost:5173.
Enter a checkpoint path relative to repo root (e.g. rl_league/league_final.pt) and click Connect.
Notes:
- The backend runs multiple hands per session (
hands, default 25). - With "Pause on hero" enabled, the backend stops at each hero decision until you press Continue.



