# PokerNet Training Starter Notebook

Small starter notebook for training your `PokerNet` model.

What this notebook includes:
- device setup
- model instantiation
- dummy batch shape sanity check
- a tiny synthetic dataset + dataloader
- one training loop skeleton (policy + value loss)
- hand/game state handling (with hand-state reset example)

> Replace the synthetic dataset with your real poker data pipeline when ready.


In [27]:
import torch

print(torch.__version__)
print(
    "cuda available:", torch.cuda.is_available()
)  # ROCm also uses torch.cuda namespace
if torch.cuda.is_available():
    print("device count:", torch.cuda.device_count())
    print("device 0:", torch.cuda.get_device_name(0))

2.10.0+cu128
cuda available: False


In [19]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

print("Torch version:", torch.__version__)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Torch version: 2.10.0+cu128
Device: cpu


## 1) Import your model

Option A (recommended): put your model code in a file, e.g. `poker_model.py`, then import it.

```python
from poker_model import PokerNet
```

Option B: paste your `PokerNet` class (and helper classes) into a cell below.


In [20]:
from nn.nn import PokerNet

In [21]:
# --- Hyperparameters / dimensions (adjust to your setup) ---
BETS_IN_DIM = 128
HAND_STATE_DIM = 32
GAME_STATE_DIM = 32
NUM_ACTIONS = 3  # fold / call-check / raise-intent

BATCH_SIZE = 32
EPOCHS = 100
LR = 1e-3

# Model init example (matches your PokerNet signature)
model = PokerNet(
    bets_in_dim=BETS_IN_DIM,
    hand_state_dim=HAND_STATE_DIM,
    game_state_dim=GAME_STATE_DIM,
    state_mode="simple",  # or 'branched'
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
policy_criterion = nn.CrossEntropyLoss()  # expects raw logits + class indices
value_criterion = nn.MSELoss()  # example regression loss

print(model.__class__.__name__)

PokerNet


## 2) Quick shape sanity check (dummy batch)


In [22]:
with torch.no_grad():
    B = 8
    cards = torch.randn(B, 4, 4, 13, device=DEVICE)
    bets = torch.randn(B, BETS_IN_DIM, device=DEVICE)
    hand_state = torch.zeros(B, HAND_STATE_DIM, device=DEVICE)
    game_state = torch.zeros(B, GAME_STATE_DIM, device=DEVICE)

    action_logits, value, next_hand_state, next_game_state = model(
        cards, bets, hand_state, game_state
    )

print("action_logits:", tuple(action_logits.shape))  # [B, 3]
print("value:", tuple(value.shape))  # [B, 1]
print("next_hand_state:", tuple(next_hand_state.shape))
print("next_game_state:", tuple(next_game_state.shape))

probs = torch.softmax(action_logits, dim=1)
print("probs row sums (should be ~1):", probs.sum(dim=1)[:3])

action_logits: (8, 3)
value: (8, 1)
next_hand_state: (8, 32)
next_game_state: (8, 32)
probs row sums (should be ~1): tensor([1.0000, 1.0000, 1.0000])


## 3) Synthetic dataset (replace later)

This dataset returns:
- `cards`: `[4,4,13]`
- `bets`: `[bets_in_dim]`
- `target_action`: class index in `{0,1,2}`
- `target_value`: scalar
- `new_hand`: boolean flag (if `True`, reset hand state before using sample)

It also returns `sample_id`, useful if you later want persistent per-sample state storage.


In [23]:
class SyntheticPokerDataset(Dataset):
    def __init__(self, n_samples=512, bets_in_dim=128):
        self.n_samples = n_samples
        self.cards = torch.randn(n_samples, 4, 4, 13)
        self.bets = torch.randn(n_samples, bets_in_dim)
        self.target_action = torch.randint(0, 3, (n_samples,), dtype=torch.long)
        self.target_value = torch.randn(n_samples, 1)
        # Randomly mark ~10% as new-hand boundaries (demo only)
        self.new_hand = torch.rand(n_samples) < 0.10

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return {
            "sample_id": idx,
            "cards": self.cards[idx],
            "bets": self.bets[idx],
            "target_action": self.target_action[idx],
            "target_value": self.target_value[idx],
            "new_hand": self.new_hand[idx],
        }


train_ds = SyntheticPokerDataset(n_samples=4096, bets_in_dim=BETS_IN_DIM)
train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=False
)  # keep ordered for state demo
len(train_ds), len(train_loader)

(4096, 128)

## 4) Minimal training loop skeleton

Notes:
- `action_logits` are used directly in `CrossEntropyLoss` (no softmax before loss)
- `hand_state` is reset when `new_hand == True`
- `next_*_state` are detached before feeding back to avoid backprop across batches/episodes

This is a *starter pattern*; for real poker training, youâ€™ll likely manage state per table/seat/episode rather than globally.


In [24]:
def train_one_epoch(
    model, loader, optimizer, policy_criterion, value_criterion, device
):
    model.train()
    total_loss = 0.0
    total_policy = 0.0
    total_value = 0.0
    total_correct = 0
    total_count = 0

    # Demo global states (batch-aligned); for real environments you'll manage these per sequence/table/player.
    hand_state = None
    game_state = None

    for batch in loader:
        cards = batch["cards"].to(device).float()
        bets = batch["bets"].to(device).float()
        target_action = batch["target_action"].to(device).long()
        target_value = batch["target_value"].to(device).float()
        new_hand = batch["new_hand"].to(device)

        B = cards.shape[0]

        if hand_state is None or hand_state.shape[0] != B:
            hand_state = torch.zeros(B, HAND_STATE_DIM, device=device)
            game_state = torch.zeros(B, GAME_STATE_DIM, device=device)

        # Reset hand state where a new hand starts
        # Keep game_state (longer-term memory)
        if new_hand.any():
            hand_state = hand_state.clone()
            hand_state[new_hand] = 0.0

        action_logits, value, next_hand_state, next_game_state = model(
            cards, bets, hand_state, game_state
        )

        policy_loss = policy_criterion(action_logits, target_action)
        value_loss = value_criterion(value, target_value)
        loss = policy_loss + 0.5 * value_loss  # tune weight as needed

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            preds = action_logits.argmax(dim=1)
            total_correct += (preds == target_action).sum().item()
            total_count += B
            total_loss += loss.item() * B
            total_policy += policy_loss.item() * B
            total_value += value_loss.item() * B

            # Feed state forward to next batch (demo pattern)
            hand_state = next_hand_state.detach()
            game_state = next_game_state.detach()

    metrics = {
        "loss": total_loss / max(total_count, 1),
        "policy_loss": total_policy / max(total_count, 1),
        "value_loss": total_value / max(total_count, 1),
        "acc": total_correct / max(total_count, 1),
    }
    return metrics

In [25]:
for epoch in range(1, EPOCHS + 1):
    metrics = train_one_epoch(
        model, train_loader, optimizer, policy_criterion, value_criterion, DEVICE
    )
    print(
        f"Epoch {epoch:02d} | "
        f"loss={metrics['loss']:.4f} | "
        f"policy={metrics['policy_loss']:.4f} | "
        f"value={metrics['value_loss']:.4f} | "
        f"acc={metrics['acc']:.3f}"
    )

Epoch 01 | loss=1.5999 | policy=1.0995 | value=1.0008 | acc=0.346
Epoch 02 | loss=1.5962 | policy=1.0978 | value=0.9967 | acc=0.348
Epoch 03 | loss=1.5758 | policy=1.0921 | value=0.9674 | acc=0.382
Epoch 04 | loss=1.5159 | policy=1.0710 | value=0.8897 | acc=0.431
Epoch 05 | loss=1.4117 | policy=1.0268 | value=0.7698 | acc=0.476
Epoch 06 | loss=1.3098 | policy=0.9687 | value=0.6823 | acc=0.531
Epoch 07 | loss=1.2098 | policy=0.8938 | value=0.6320 | acc=0.582
Epoch 08 | loss=1.1242 | policy=0.8229 | value=0.6025 | acc=0.629
Epoch 09 | loss=1.0318 | policy=0.7519 | value=0.5600 | acc=0.668
Epoch 10 | loss=0.9319 | policy=0.6698 | value=0.5242 | acc=0.720
Epoch 11 | loss=0.8440 | policy=0.5923 | value=0.5034 | acc=0.758
Epoch 12 | loss=0.7878 | policy=0.5426 | value=0.4905 | acc=0.777
Epoch 13 | loss=0.7204 | policy=0.4830 | value=0.4747 | acc=0.803
Epoch 14 | loss=0.6640 | policy=0.4326 | value=0.4629 | acc=0.827
Epoch 15 | loss=0.6424 | policy=0.4147 | value=0.4553 | acc=0.828
Epoch 16 |

## 5) Inference example

Use softmax *outside* the model to get probabilities.


In [26]:
model.eval()
with torch.no_grad():
    cards = torch.randn(1, 4, 4, 13, device=DEVICE)
    bets = torch.randn(1, BETS_IN_DIM, device=DEVICE)
    hand_state = torch.zeros(1, HAND_STATE_DIM, device=DEVICE)
    game_state = torch.zeros(1, GAME_STATE_DIM, device=DEVICE)

    action_logits, value, next_hand_state, next_game_state = model(
        cards, bets, hand_state, game_state
    )
    probs = torch.softmax(action_logits, dim=1)
    action = probs.argmax(dim=1)

print("probs:", probs.cpu())
print("chosen action index:", action.item())
print("value:", value.item())

probs: tensor([[5.1270e-04, 9.9949e-01, 3.3736e-07]])
chosen action index: 1
value: -1.2747899293899536


## Next steps for real training

Replace the synthetic dataset with your real poker pipeline and decide how to manage state:
- **hand_state reset** on new hand
- **game_state carryover** across hands (same opponent/session), or reset between sessions
- if using shuffled supervised samples, store state by episode/table/player rather than passing batch-to-batch
