In [13]:
!pip install pymahjong

Collecting pymahjong
  Downloading pymahjong-1.0.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl.metadata (2.5 kB)
Downloading pymahjong-1.0.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl (413 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m413.5/413.5 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pymahjong
Successfully installed pymahjong-1.0.4


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import os
from pymahjong import MahjongEnv
from copy import deepcopy

# Set device and seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)

# Constants
OBS_CHANNELS = 93
HEIGHT = 34
WIDTH = 1
NUM_ACTIONS = 34


In [18]:
def encode_game_state(obs):
    """Convert pymahjong obs (93×34) into CNN input shape (C, H, W)."""
    obs = np.array(obs, dtype=np.float32)
    return torch.tensor(obs).unsqueeze(0).permute(0, 2, 1).unsqueeze(-1)  # (1, 34, 93, 1)


In [19]:
class MahjongCNNBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(OBS_CHANNELS, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Flatten()
        )
        self.output_dim = self._get_output_dim()

    def _get_output_dim(self):
        dummy = torch.zeros((1, OBS_CHANNELS, HEIGHT, WIDTH))
        return self.conv(dummy).shape[1]

    def forward(self, x):
        return self.conv(x)

class DiscardHead(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 256),
            nn.ReLU(),
            nn.Linear(256, NUM_ACTIONS)
        )

    def forward(self, x):
        return self.head(self.base(x))

class BinaryHead(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.head(self.base(x)).squeeze(1)


In [20]:
class TrainingBot:
    def __init__(self, player_id, models, logs):
        self.pid = player_id
        self.models = models
        self.logs = logs

    def act(self, obs, legal_actions, last_reward):
        state = encode_game_state(obs).to(device)
        outputs = {k: self.models[k](state) for k in self.models}

        # Select discard (policy gradient + mask)
        discard_logits = outputs["discard"]
        discard_probs = F.softmax(discard_logits, dim=1).detach().cpu().numpy().flatten()

        mask = np.zeros(NUM_ACTIONS)
        mask[legal_actions] = 1
        discard_probs *= mask
        discard_probs /= discard_probs.sum() if discard_probs.sum() > 0 else 1

        action = np.random.choice(NUM_ACTIONS, p=discard_probs)
        self.logs["discard"].append((state.squeeze(0), action, last_reward))
        return action


In [21]:
def train_pg(model, optimizer, log_data, mode="softmax"):
    if not log_data:
        return 0.0
    states, actions, rewards = zip(*log_data)
    states = torch.stack(states).to(device)
    actions = torch.tensor(actions).to(device)
    rewards = torch.tensor(rewards).to(device)
    rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-6)

    logits = model(states)
    if mode == "softmax":
        log_probs = F.log_softmax(logits, dim=1)
        selected = log_probs[range(len(actions)), actions]
    else:
        probs = torch.sigmoid(logits)
        selected = actions * torch.log(probs + 1e-6) + (1 - actions) * torch.log(1 - probs + 1e-6)

    loss = -torch.mean(rewards * selected)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()


In [22]:
def load_frozen_model(checkpoint_dir):
    base = MahjongCNNBase().to(device)
    frozen = {
        "discard": DiscardHead(base).to(device),
        "riichi": BinaryHead(base).to(device),
        "pon": BinaryHead(base).to(device),
        "chi": BinaryHead(base).to(device),
    }
    for name in frozen:
        frozen[name].load_state_dict(torch.load(f"{checkpoint_dir}/{name}.pt"))
        frozen[name].eval()
    return frozen


In [24]:
# Hyperparameters
NUM_EPISODES = 1000
BATCH_SIZE = 32
SAVE_EVERY = 100
USE_CURRICULUM = False
CURRICULUM_PATH = "checkpoints/ep_500"

# Initialize models
base = MahjongCNNBase().to(device)
models = {
    "discard": DiscardHead(base).to(device),
    "riichi": BinaryHead(base).to(device),
    "pon": BinaryHead(base).to(device),
    "chi": BinaryHead(base).to(device),
}
optimizers = {
    name: torch.optim.Adam(models[name].parameters(), lr=1e-4)
    for name in models
}

# Load frozen opponent if curriculum is on
opponent_models = (
    load_frozen_model(CURRICULUM_PATH) if USE_CURRICULUM else models
)

for episode in range(1, NUM_EPISODES + 1):
    env = MahjongEnv()
    obs = env.reset()
    logs = {name: [] for name in models}

    # Curriculum: player 0 uses live model, others use frozen copy
    bots = []
    for pid in range(4):
        model_set = models if pid == 0 else opponent_models
        bots.append(TrainingBot(pid, model_set, logs if pid == 0 else {k: [] for k in models}))

    rewards = [0.0] * 4
    while not env.is_over():
        pid = env.get_curr_player_id()
        obs = env.get_obs(pid)
        legal = env.get_valid_actions(pid)
        action = bots[pid].act(obs, legal, rewards[pid])
        _, reward, done, _ = env.step(action)
        rewards[pid] += reward

    # Train all heads from accumulated logs
    print(f"\n[Episode {episode}]")
    for name in models:
        loss = train_pg(
            model=models[name],
            optimizer=optimizers[name],
            log_data=logs[name],
            mode="softmax" if name == "discard" else "binary"
        )
        print(f"  {name} loss: {loss:.4f}")

    # Save model checkpoint
    if episode % SAVE_EVERY == 0:
        save_dir = f"checkpoints/ep_{episode}"
        os.makedirs(save_dir, exist_ok=True)
        for name in models:
            torch.save(models[name].state_dict(), f"{save_dir}/{name}.pt")
        print(f"  ✓ Checkpoint saved to {save_dir}")


RuntimeError: Given groups=1, weight of size [64, 93, 3, 1], expected input[1, 34, 93, 1] to have 93 channels, but got 34 channels instead