Skip to content

Lucadmin/poker-attention

Repository files navigation

🎰 Poker Attention: Fast-Adapting Opponent-Aware Poker AI

poker-banner

📋 Overview

This project implements a transformer-based poker agent that:

  • Understands poker hand history using tokenized sequences and self-attention
  • Models opponents through separate encoding branches and attention pooling
  • Adapts across hands using persistent opponent memory (fingerprints)
  • Trains efficiently with mixed-precision (AMP) on GPU or CPU
  • Villain Archetypes to train the model on

villain_archetypes

Core Architecture

POKER HAND           OPPONENT ACTIONS
    ↓                     ↓
[Hero Encoder]       [Opponent Encoder]
    ↓                     ↓
  160-dim           [Attention Pool]
    ↓                     ↓
    └─────────[Concatenate]─────┘
                      ↓
              [Policy Head (16 actions)]
                      ↓
              ACTION LOGITS

Key numbers:

  • Model size: 4M parameters (~16MB weights)
  • Hero encoder: 4 layers, 8 attention heads, 160-dim
  • Opponent encoder: 2 layers, 4 heads, 64-dim
  • Inference: ~50ms per batch (GPU)
  • Typical accuracy: 55-70% on unseen opponent mixtures

🚀 Quick Start

This project uses uv for dependency management (install here).

Installation

# Clone and navigate
git clone https://github.com/Lucadmin/poker-attention.git
cd poker-attention

# Install dependencies (creates virtual environment automatically)
uv sync

Verify Installation

# Test the Python environment
uv run python -c "import torch; print(f'PyTorch {torch.__version__}')"

# Test the frontend
cd frontend && npm install && cd ..

🏃 Running the Project

Backend / CLI

1. Supervised Training

Train from generated data or a pre-generated dataset.

Option A: Train from pre-generated dataset (faster)

# Step 1: Generate a benchmark-aligned dataset (one-time)
uv run python -m poker_attention.training.data_generator \
  --out training_data_16.npz \
  --num-actions 16 \
  --num-players 8 \
  --opponent-archetypes all \
  --num-examples 50000 \
  --seed 0

# Step 2: Train from that dataset on GPU with AMP
uv run python -m poker_attention.training.train_supervised \
  --data training_data_16.npz \
  --num-actions 16 \
  --device cuda \
  --amp \
  --epochs 3 \
  --batch-size 32 \
  --save model_supervised.pt

Option B: On-the-fly generation + training (slower)

uv run python -m poker_attention.training.train_supervised \
  --generate \
  --num-actions 16 \
  --num-players 8 \
  --opponent-archetypes all \
  --num-examples 20000 \
  --epochs 3 \
  --batch-size 128 \
  --device cuda \
  --amp \
  --seed 0 \
  --save model_supervised.pt

With tilt opponents:

uv run python -m poker_attention.training.train_supervised \
  --generate \
  --num-actions 16 \
  --num-players 8 \
  --opponent-archetypes all \
  --tilt-fluctuation 0.08 \
  --tilt-randomness 0.03 \
  --tilt-retain 0.90 \
  --tilt-baseline 0.50 \
  --num-examples 20000 \
  --epochs 3 \
  --batch-size 128 \
  --device cuda \
  --amp \
  --save model_supervised.pt

2. Reinforcement Learning Training

Train using PPO with self-play and population of opponents.

# CPU smoke test
uv run python -m poker_attention.training.train_rl \
  --total-updates 2 \
  --rollout-hands 10 \
  --num-players 6

# Full CUDA training
uv run python -m poker_attention.training.train_rl \
  --device cuda \
  --amp \
  --total-updates 200 \
  --rollout-hands 200 \
  --opponent-archetypes tight,loose \
  --save model_rl.pt

3. Evaluation

Evaluate model performance with optional memory ablation.

# Standard evaluation (with opponent memory)
uv run python -m poker_attention.cli.eval \
  --checkpoint model_supervised.pt \
  --num-hands 200 \
  --num-players 8 \
  --opponent-archetypes all

# Memory ablation (how much does opponent memory help?)
uv run python -m poker_attention.cli.eval \
  --checkpoint model_supervised.pt \
  --num-hands 200 \
  --num-players 8 \
  --opponent-archetypes all \
  --ablate-memory

# Custom opponent mix
uv run python -m poker_attention.cli.eval \
  --checkpoint model_supervised.pt \
  --num-hands 500 \
  --opponent-archetypes tight,loose,aggressive \
  --num-players 6

4. Model Inference & Testing

# Play interactively
uv run python -m poker_attention.cli.play --checkpoint model_supervised.pt

# Inspect model tokens/embeddings
uv run python -m poker_attention.cli.tokens --checkpoint model_supervised.pt

# Demo (batch inference)
uv run python -m poker_attention.cli.demo --checkpoint model_supervised.pt

# Infer on custom scenarios
uv run python -m poker_attention.cli.infer \
  --checkpoint model_supervised.pt \
  --hand "AK" \
  --position "UTG" \
  --num-opponents 3

5. Visualization & Monitoring

# Monitor training live with TensorBoard
tensorboard --logdir runs

# Generate PNG reports from saved runs
uv run python -m poker_attention.cli.plot_runs --run-dir runs/<RUN_DIR>

# Generalization analysis
uv run python -m poker_attention.cli.generalize --checkpoint model_supervised.pt

Note: All training/eval CLIs write artifacts to runs/<timestamp>/ and print a run_dir: line that survives terminal loss or crashes.


Backend API Server

Run the FastAPI web server for remote inference and interactive sessions.

# Development server (with auto-reload)
uv run python -m poker_attention.web.server --reload --port 8000

# Production server
uv run python -m poker_attention.web.server --host 0.0.0.0 --port 8000

Server features:

  • REST API for inference on poker hands
  • WebSocket support for multi-hand sessions
  • Opponent memory state management
  • Real-time action recommendations

API endpoints:

  • POST /api/predict — Get action logits for a poker state
  • WS /ws/session — Interactive poker session
  • GET /health — Health check

Frontend

website

The React + TypeScript web interface for visualization and interactive play.

Development

cd frontend

# Install dependencies
npm install

# Start dev server (hot reload)
npm run dev
# Runs on http://localhost:5173

# Lint
npm run lint

# Build for production
npm run build

Production Build

cd frontend

# Build static assets
npm run build

# Preview production build locally
npm run preview

Frontend features:

  • Real-time poker table visualization
  • Interactive action selection
  • Opponent adaptation visualization
  • Attention weight heatmaps
  • Game history and statistics

📊 Project Structure

poker-attention/
├── README.md                          # This file
├── pyproject.toml                     # Python dependencies & config
│
├── src/poker_attention/
│   ├── models/
│   │   ├── transformer_opponent_model.py  # Main architecture
│   │   ├── opponent_encoder.py            # Opponent branch
│   │   ├── opponent_pooling.py            # Attention pooling
│   │   ├── token_embedding.py             # Token embeddings
│   │   └── actor_critic.py                # Actor-critic for RL
│   │
│   ├── training/
│   │   ├── train_supervised.py        # Supervised learning (main script)
│   │   ├── train_rl.py                # PPO self-play training
│   │   ├── train_opponent.py          # Opponent model training
│   │   ├── data_generator.py          # Generate training examples
│   │   └── league.py                  # Population/league management
│   │
│   ├── cli/
│   │   ├── eval.py                    # Evaluation script (BB/100, ablations)
│   │   ├── infer.py                   # Single inference
│   │   ├── play.py                    # Interactive play
│   │   ├── demo.py                    # Batch demo inference
│   │   ├── tokens.py                  # Token inspection
│   │   ├── generalize.py              # Generalization analysis
│   │   └── plot_runs.py               # Generate PNG reports
│   │
│   ├── agents/
│   │   ├── model_policy.py            # Model-based policy wrapper
│   │   ├── opponent_policies.py       # Fixed opponent archetypes
│   │   ├── opponent_memory_store.py   # Cross-hand opponent memory
│   │   └── model_league.py            # Multi-model tournament
│   │
│   ├── envs/
│   │   ├── tokenization.py            # Token encoding/decoding
│   │   ├── legal_actions.py           # Action masking
│   │   └── __init__.py                # PokerKit wrappers
│   │
│   ├── evaluation/
│   │   └── run_logger.py              # TensorBoard + artifact logging
│   │
│   ├── visualization/
│   │   └── (attention, embedding viz)
│   │
│   └── web/
│       ├── app.py                     # FastAPI app
│       ├── server.py                  # Server entry point
│       ├── session.py                 # Session management
│       └── events.py                  # WebSocket events
│
├── frontend/
│   ├── package.json
│   ├── src/
│   │   ├── App.tsx                    # Main app component
│   │   ├── components/                # React components
│   │   └── hooks/                     # Custom React hooks
│   ├── tsconfig.json
│   ├── vite.config.ts
│   └── tailwind.config.js
│
├── tests/
│   └── (unit & integration tests)
│
├── old_data/
│   ├── docs/                          # Architecture documentation
│   ├── rl_checkpoints/                # RL model snapshots
│   └── (legacy training data)
│
└── .github/
    └── copilot-instructions.md

⚙️ Configuration & Parameters

Model Architecture

Edit ModelConfig in src/poker_attention/models/transformer_opponent_model.py:

ModelConfig(
    d_model=160,              # Hero embedding dimension
    n_heads=8,                # Hero attention heads
    n_layers=4,               # Hero transformer depth
    num_actions=16,           # Action space (must match data!)
    max_seq_len=512,          # Max poker hand length
    opponent_max_seq_len=128, # Max opponent action history
)

Opponent encoder is hardcoded but can be customized in opponent_encoder.py.

Training Hyperparameters

Common settings for train_supervised.py:

--lr 0.0003           # Learning rate (conservative!)
--weight-decay 0.01   # L2 regularization
--grad-clip 1.0       # Gradient clipping threshold
--epochs 3            # Number of epochs
--batch-size 32       # Batch size (increase if more VRAM)
--device cuda         # cuda or cpu
--amp                 # Mixed precision training
--seed 0              # Random seed

RL hyperparameters for train_rl.py:

--ppo-lr 0.0003               # Actor learning rate
--ppo-value-lr 0.001          # Critic learning rate
--ppo-clip-ratio 0.2          # PPO clip range
--ppo-epochs 3                # PPO optimization passes
--ppo-batch-size 32           # PPO batch size
--gamma 0.99                  # Discount factor
--gae-lambda 0.95             # GAE lambda
--rollout-hands 200           # Hands per rollout
--total-updates 200           # Total PPO updates

Data Generation

--num-examples 50000              # Training examples
--num-players 8                   # Players per hand
--num-actions 16                  # Discrete action bins
--opponent-archetypes all         # Opponent types: tight, loose, aggressive, tilt, all
--use-memory                      # Enable cross-hand opponent memory
--tilt-fluctuation 0.08           # Tilt volatility
--tilt-randomness 0.03            # Tilt stochasticity
--tilt-retain 0.90                # Tilt persistence
--tilt-baseline 0.50              # Tilt baseline aggression

📈 Training & Evaluation Workflow

Typical Pipeline

# 1. Generate data once (reusable)
uv run python -m poker_attention.training.data_generator \
  --out training_data.npz \
  --num-examples 50000 \
  --num-actions 16 \
  --opponent-archetypes all

# 2. Train model
uv run python -m poker_attention.training.train_supervised \
  --data training_data.npz \
  --num-actions 16 \
  --device cuda \
  --amp \
  --epochs 3 \
  --save model_v1.pt

# 3. Evaluate with memory
uv run python -m poker_attention.cli.eval \
  --checkpoint model_v1.pt \
  --num-hands 500

# 4. Ablate memory (measure opponent learning)
uv run python -m poker_attention.cli.eval \
  --checkpoint model_v1.pt \
  --num-hands 500 \
  --ablate-memory

# 5. Monitor results
tensorboard --logdir runs

Output Artifacts

All runs save to runs/<timestamp>/:

runs/2025-01-24_12-34-56_supervised_16actions/
├── events.out.tfevents.123...    # TensorBoard events
├── config.json                   # Hyperparameters
├── metrics.jsonl                 # Per-epoch metrics
├── model_checkpoint.pt           # Final model weights
└── log.txt                       # Training log

🧠 Key Concepts

Token System

Poker events are tokenized (encoded as integers) rather than using word embeddings:

  • Card tokens: Encode player, card rank, suit
  • Action tokens: Encode player, street, action type, bet size
  • Marker tokens: Encode table state transitions

Benefits: Smaller model (~4M params), captures poker structure, efficient computation.

token design

Opponent Fingerprinting

The opponent encoder learns a 64-dimensional "fingerprint" that compresses opponent behavior. This fingerprint:

  • Updates across hands using persistent memory
  • Feeds into attention pooling to weight opponent importance
  • Enables rapid adaptation to new opponents
  • Is independent of the hero's hand

Training Signal

Both hero and opponent branches receive gradients from the same loss signal (CrossEntropy):

Loss ↓
  ├─→ Hero path: hand understanding improves
  └─→ Opponent path: opponent modeling improves

📦 Dependencies

Python

Managed by uv (pyproject.toml):

  • PyTorch (with CUDA support)
  • PokerKit (poker environment)
  • NumPy (numerical computing)
  • TensorBoard (training visualization)
  • FastAPI + Uvicorn (web server)
  • Rich (CLI formatting)
  • Scikit-learn, UMAP, Matplotlib (analysis & viz)

Frontend

Managed by npm (frontend/package.json):

  • React 19 + TypeScript
  • Vite (build tool)
  • Tailwind CSS (styling)
  • Vite React Plugin (HMR)

Happy training! 🎰

uv run python -m poker_attention.cli.eval \
	--checkpoint tmp_model_16_benchmarktrain.pt \
	--num-hands 500 \
	--num-players 8 \
	--opponent-archetypes all \
	--ablate-memory \
	--num-runs 10

PPO RL training (time-limited)

uv run python -m poker_attention.training.train_rl \
	--init-checkpoint rl_checkpoints_bigtest/rl_actor_000010.pt \
	--num-actions 16 \
	--device cuda --opponent-device cuda --amp \
	--save-dir rl_checkpoints --league-dir rl_league \
	--snapshot-every 10 \
	--time-limit-minutes 30

Diagnose whether memory changes decisions

This checks whether opponent context changes logits and/or the chosen action (masked argmax).

uv run python -m poker_attention.cli.eval \
	--checkpoint tmp_model_16_benchmarktrain.pt \
	--num-hands 200 \
	--num-players 8 \
	--opponent-archetypes all \
	--diagnose-memory-sensitivity \
	--num-runs 10

## Web demo (FastAPI backend + React/TypeScript/Tailwind frontend)

This repo includes an interactive demo UI that streams:
- the live action log (seat-by-seat)
- the hero policy probabilities at decision time
- the live token stream in a side panel

### 1) Install backend dependencies

Using `uv` (recommended):

```bash
uv sync --extra web

2) Start the backend

uv run python -m poker_attention.web.server --reload --port 8000

3) Start the frontend

cd frontend
npm install
npm run dev

Open the UI at http://localhost:5173.

4) Connect

Enter a checkpoint path relative to repo root (e.g. rl_league/league_final.pt) and click Connect.

Notes:

  • The backend runs multiple hands per session (hands, default 25).
  • With "Pause on hero" enabled, the backend stops at each hero decision until you press Continue.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors