In [13]:
!pip install pymahjong

Collecting pymahjong
  Downloading pymahjong-1.0.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl.metadata (2.5 kB)
Downloading pymahjong-1.0.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl (413 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m413.5/413.5 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pymahjong
Successfully installed pymahjong-1.0.4


In [74]:
# Section 1: Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import os
from pymahjong import MahjongEnv

torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

OBS_CHANNELS = 93
HEIGHT = 34
WIDTH = 1
NUM_ACTIONS = 136


In [70]:
def encode_game_state(obs):
    """
    Converts pymahjong (93, 34) obs into CNN input format: (1, 93, 34, 1)
    """
    obs = np.array(obs, dtype=np.float32)  # (93, 34)
    return torch.tensor(obs).unsqueeze(0).unsqueeze(-1).to(device)  # (1, 93, 34, 1)


In [66]:
class MahjongCNNBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(OBS_CHANNELS, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Flatten()
        )
        self.output_dim = self._get_output_dim()

    def _get_output_dim(self):
        dummy = torch.zeros((1, OBS_CHANNELS, HEIGHT, WIDTH))
        return self.conv(dummy).shape[1]

    def forward(self, x):
        return self.conv(x)

class DiscreteHead(nn.Module):
    def __init__(self, base, output_size):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_size)
        )

    def forward(self, x):
        return self.head(self.base(x))


In [75]:
class TrainingBot:
    def __init__(self, pid, models, logs):
        self.pid = pid
        self.models = models
        self.logs = logs

    def act(self, obs, valid_actions, reward=0.0):
        state = encode_game_state(obs)  # shape: (1, 93, 34, 1)
        logits = self.models["action"](state)  # shape: (1, N) where N = model output (e.g. 136)
        logits = logits.squeeze(0)

        # Only use logits for valid actions
        valid_logits = logits[valid_actions]
        probs = F.softmax(valid_logits, dim=0).detach().cpu().numpy()
        action_index = np.random.choice(len(valid_actions), p=probs)
        action = valid_actions[action_index]

        self.logs.append((state.squeeze(0), action, reward))
        return action


In [87]:
def train_pg(model, optimizer, log_data, gamma=1.0):
    """
    Train a policy using REINFORCE loss.
    log_data: list of (state, action, reward)
    """
    if not log_data:
        return 0.0

    states, actions, rewards = zip(*log_data)
    states = torch.stack(states).to(device)
    actions = torch.tensor(actions, dtype=torch.long).to(device)
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)

    # Use baseline to reduce variance
    baseline = rewards.mean()
    advantages = rewards - baseline
    print("\n🔍 DEBUG TRAINING BATCH:")
    print(f"Rewards: {rewards}")
    print(f"Baseline: {baseline}")
    print(f"Advantages: {advantages}")

    for i, (s, a, r) in enumerate(log_data[:5]):
        print(f"Sample {i}: action={a}, reward={r}")

    # Forward
    logits = model(states)  # shape: (batch, num_actions)
    log_probs = F.log_softmax(logits, dim=1)
    selected_log_probs = log_probs[torch.arange(len(actions)), actions]

    # REINFORCE loss: maximize reward-weighted log-prob
    loss = -torch.mean(advantages * selected_log_probs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


In [88]:
# Init model + optimizer
base = MahjongCNNBase().to(device)
models = {
    "action": DiscreteHead(base, NUM_ACTIONS).to(device)
}
optimizers = {
    name: torch.optim.Adam(models[name].parameters(), lr=1e-4)
    for name in models
}

NUM_EPISODES = 1000
SAVE_EVERY = 100

for episode in range(1, NUM_EPISODES + 1):
    env = MahjongEnv()
    env.reset()
    logs = []

    bots = [
        TrainingBot(pid, models, logs if pid == 0 else [])
        for pid in range(4)
    ]

    while not env.is_over():
        pid = env.get_curr_player_id()
        obs = env.get_obs(pid)
        valid = env.get_valid_actions()
        action = bots[pid].act(obs, valid)
        env.step(pid, action)

    # Get final reward for REINFORCE
    payoffs = env.get_payoffs()  # [p0, p1, p2, p3]
    for i in range(len(logs)):
        state, action, _ = logs[i]
        logs[i] = (state, action, payoffs[0])

    # Train
    loss = train_pg(models["action"], optimizers["action"], logs)
    print(f"[Episode {episode}] Reward: {payoffs[0]:.1f} | Loss: {loss:.4f}")

    if episode % SAVE_EVERY == 0:
        os.makedirs(f"checkpoints/ep_{episode}", exist_ok=True)
        torch.save(models["action"].state_dict(), f"checkpoints/ep_{episode}/action.pt")
        print(f"✓ Model saved at ep {episode}")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        -1000., -1000., -1000., -1000.])
Baseline: -1000.0
Advantages: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Sample 0: action=30, reward=-1000.0
Sample 1: action=40, reward=-1000.0
Sample 2: action=9, reward=-1000.0
Sample 3: action=41, reward=-1000.0
Sample 4: action=29, reward=-1000.0
[Episode 566] Reward: -1000.0 | Loss: -0.0000

🔍 DEBUG TRAINING BATCH:
Rewards: tensor([-1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.,
        -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000., -1000.,
        -1000., -1000., -1000., -1000., -1000.])
Baseline: -1000.0
Advantages: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Sample 0: action=4, reward=-1000.0
Sample 1: action=29, reward=-1000.0
Sample 2: action=8, reward=-1000.0
Sample 3: action=27, reward=-1000.0
Sample 4: action=44, rewar