In [1]:
# !pip install pymahjong

In [1]:
# Section 1: Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import os
from pymahjong import MahjongEnv

torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

OBS_CHANNELS = 93
HEIGHT = 34
WIDTH = 1
NUM_ACTIONS = 136


In [2]:
def encode_game_state(obs):
    """
    Converts pymahjong (93, 34) obs into CNN input format: (1, 93, 34, 1)
    """
    obs = np.array(obs, dtype=np.float32)  # (93, 34)
    return torch.tensor(obs).unsqueeze(0).unsqueeze(-1).to(device)  # (1, 93, 34, 1)


In [3]:
class MahjongCNNBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(OBS_CHANNELS, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Flatten()
        )
        self.output_dim = self._get_output_dim()

    def _get_output_dim(self):
        dummy = torch.zeros((1, OBS_CHANNELS, HEIGHT, WIDTH))
        return self.conv(dummy).shape[1]

    def forward(self, x):
        return self.conv(x)

class DiscreteHead(nn.Module):
    def __init__(self, base, output_size):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_size)
        )

    def forward(self, x):
        return self.head(self.base(x))


In [4]:
class TrainingBot:
    def __init__(self, pid, models, logs):
        self.pid = pid
        self.models = models
        self.logs = logs

    def act(self, obs, valid_actions, reward=0.0):
        state = encode_game_state(obs)  # shape: (1, 93, 34, 1)
        logits = self.models["action"](state)  # shape: (1, N) where N = model output (e.g. 136)
        logits = logits.squeeze(0)

        # Only use logits for valid actions
        valid_logits = logits[valid_actions]
        probs = F.softmax(valid_logits, dim=0).detach().cpu().numpy()
        action_index = np.random.choice(len(valid_actions), p=probs)
        action = valid_actions[action_index]

        self.logs.append((state.squeeze(0), action, reward))
        return action


In [5]:
def train_pg(model, optimizer, log_data, gamma=1.0):
    """
    Train a policy using REINFORCE loss.
    log_data: list of (state, action, reward)
    """
    if not log_data:
        return 0.0

    states, actions, rewards = zip(*log_data)
    states = torch.stack(states).to(device)
    actions = torch.tensor(actions, dtype=torch.long).to(device)
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)

    # Use baseline to reduce variance

    advantages = rewards

    # Forward
    logits = model(states)  # shape: (batch, num_actions)
    log_probs = F.log_softmax(logits, dim=1)
    selected_log_probs = log_probs[torch.arange(len(actions)), actions]

    # REINFORCE loss: maximize reward-weighted log-prob
    loss = -torch.mean(advantages * selected_log_probs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


In [7]:




# Init model + optimizer
base = MahjongCNNBase().to(device)
models = {
    "action": DiscreteHead(base, NUM_ACTIONS).to(device)
}
optimizers = {
    name: torch.optim.Adam(models[name].parameters(), lr=1e-4)
    for name in models
}

NUM_EPISODES = 750000
SAVE_EVERY = 25000

for episode in range(1, NUM_EPISODES + 1):
    env = MahjongEnv()
    env.reset()
    logs = []

    bots = [
        TrainingBot(pid, models, logs if pid == 0 else [])
        for pid in range(4)
    ]

    while not env.is_over():
        pid = env.get_curr_player_id()
        obs = env.get_obs(pid)
        valid = env.get_valid_actions()
        action = bots[pid].act(obs, valid)
        env.step(pid, action)

    # Get final reward for REINFORCE
    payoffs = env.get_payoffs()  # [p0, p1, p2, p3]
    for i in range(len(logs)):
        state, action, _ = logs[i]
        logs[i] = (state, action, payoffs[0])

    # Train
    loss = train_pg(models["action"], optimizers["action"], logs)
    # print(f"[Episode {episode}] Reward: {payoffs[0]:.1f} | Loss: {loss:.4f}")

    if episode % SAVE_EVERY == 0:
        os.makedirs(f"checkpoints/ep_{episode}", exist_ok=True)
        torch.save(models["action"].state_dict(), f"checkpoints/ep_{episode}/action.pt")
        print(f"✓ Model saved at ep {episode}")


✓ Model saved at ep 25000
✓ Model saved at ep 50000
✓ Model saved at ep 75000
✓ Model saved at ep 100000
✓ Model saved at ep 125000
✓ Model saved at ep 150000
✓ Model saved at ep 175000
✓ Model saved at ep 200000
✓ Model saved at ep 225000
✓ Model saved at ep 250000
✓ Model saved at ep 275000
✓ Model saved at ep 300000
✓ Model saved at ep 325000
✓ Model saved at ep 350000
✓ Model saved at ep 375000
✓ Model saved at ep 400000
✓ Model saved at ep 425000
✓ Model saved at ep 450000
✓ Model saved at ep 475000
✓ Model saved at ep 500000
✓ Model saved at ep 525000
✓ Model saved at ep 550000
✓ Model saved at ep 575000
✓ Model saved at ep 600000
✓ Model saved at ep 625000
✓ Model saved at ep 650000
✓ Model saved at ep 675000
✓ Model saved at ep 700000
✓ Model saved at ep 725000
✓ Model saved at ep 750000


In [103]:
import torch
import torch.nn.functional as F
import numpy as np
import random
from pymahjong import MahjongEnv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load trained model ---
class MahjongCNNBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(93, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Flatten()
        )
        self.output_dim = self._get_output_dim()

    def _get_output_dim(self):
        dummy = torch.zeros((1, 93, 34, 1))
        return self.conv(dummy).shape[1]

    def forward(self, x):
        return self.conv(x)

class DiscreteHead(nn.Module):
    def __init__(self, base, output_size):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_size)
        )

    def forward(self, x):
        return self.head(self.base(x))

def encode_game_state(obs):
    obs = np.array(obs, dtype=np.float32)
    return torch.tensor(obs).unsqueeze(0).unsqueeze(-1).to(device)

# --- Define agents ---
class RandomBot:
    def __init__(self, pid):
        self.pid = pid

    def act(self, obs, valid_actions):
        return random.choice(valid_actions)

TERMINAL_HONOR_INDICES = set([
    0, 8, 9, 17, 18, 26, 27, 28, 29, 30, 31, 32, 33
])

WIN_INDICES = set([42,43])

def terminal_honor_win_action(valid_actions):
    # Filter for discard actions only (0–33)

    win_actions = [a for a in valid_actions if a in WIN_INDICES]
    if win_actions:
        return np.random.choice(win_actions)
    terminal_honor_discards = [a for a in valid_actions if a in TERMINAL_HONOR_INDICES]

    if terminal_honor_discards:
        return np.random.choice(terminal_honor_discards)
    else:
        return np.random.choice(valid_actions)

class TerminalBot:
    def __init__(self, pid):
        self.pid = pid

    def act(self, obs, valid_actions):
        return terminal_honor_action(valid_actions)

class TrainedBot:
    def __init__(self, pid, model):
        self.pid = pid
        self.model = model

    def act(self, obs, valid_actions):
        state = encode_game_state(obs)  # shape: (1, 93, 34, 1)
        logits = self.model(state).squeeze(0)
        valid_logits = logits[valid_actions]
        probs = F.softmax(valid_logits, dim=0).detach().cpu().numpy()
        index = np.random.choice(len(valid_actions), p=probs)
        return valid_actions[index]

# --- Load model ---
base = MahjongCNNBase().to(device)
model = DiscreteHead(base, 136).to(device)
model.load_state_dict(torch.load("checkpoints/ep_10000/action.pt"))
model.eval()

# --- Evaluation loop ---
NUM_GAMES = 1000
total_payoff = 0
env = MahjongEnv()

for game in range(NUM_GAMES):
    env.reset()
    bots = [
        TrainedBot(0, model),
        TerminalBot(1),
        TerminalBot(2),
        TerminalBot(3)
    ]

    while not env.is_over():
        pid = env.get_curr_player_id()
        obs = env.get_obs(pid)
        valid = env.get_valid_actions()
        action = bots[pid].act(obs, valid)
        env.step(pid, action)

    payoffs = env.get_payoffs()
    total_payoff += payoffs[0]
    # print(f"Game {game + 1}: payoff = {payoffs[0]}")

avg = total_payoff / NUM_GAMES
print(f"\nAverage payoff over {NUM_GAMES} games: {avg:.2f}")
print(f"\n total is {total_payoff}")



Average payoff over 1000 games: -34.60

 total is -34600.0


In [99]:
import torch
import torch.nn.functional as F
import numpy as np
import random
from pymahjong import MahjongEnv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load trained model ---
class MahjongCNNBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(93, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Flatten()
        )
        self.output_dim = self._get_output_dim()

    def _get_output_dim(self):
        dummy = torch.zeros((1, 93, 34, 1))
        return self.conv(dummy).shape[1]

    def forward(self, x):
        return self.conv(x)

class DiscreteHead(nn.Module):
    def __init__(self, base, output_size):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_size)
        )

    def forward(self, x):
        return self.head(self.base(x))

def encode_game_state(obs):
    obs = np.array(obs, dtype=np.float32)
    return torch.tensor(obs).unsqueeze(0).unsqueeze(-1).to(device)

# --- Define agents ---
class RandomBot:
    def __init__(self, pid):
        self.pid = pid

    def act(self, obs, valid_actions):
        return random.choice(valid_actions)

class TrainedBot:
    def __init__(self, pid, model):
        self.pid = pid
        self.model = model

    def act(self, obs, valid_actions):
        state = encode_game_state(obs)  # shape: (1, 93, 34, 1)
        logits = self.model(state).squeeze(0)
        valid_logits = logits[valid_actions]
        probs = F.softmax(valid_logits, dim=0).detach().cpu().numpy()
        index = np.random.choice(len(valid_actions), p=probs)
        return valid_actions[index]

# --- Load model ---
base = MahjongCNNBase().to(device)
model = DiscreteHead(base, 136).to(device)
model.load_state_dict(torch.load("checkpoints/ep_10000/action.pt"))
model.eval()

# --- Evaluation loop ---
NUM_GAMES = 1
total_payoff = 0
env = MahjongEnv()

for game in range(NUM_GAMES):
    env.reset()
    bots = [
        TrainedBot(0, model),
        RandomBot(1),
        RandomBot(2),
        RandomBot(3)
    ]

    while not env.is_over():
        pid = env.get_curr_player_id()
        obs = env.get_obs(pid)
        valid = env.get_valid_actions()
        action = bots[pid].act(obs, valid)



        env.render()  # ← this prints the table state
        env.step(pid, action)

    payoffs = env.get_payoffs()
    total_payoff += payoffs[0]
    print(f"Game {game + 1}: payoff = {payoffs[0]}")

avg = total_payoff / NUM_GAMES
print(f"\nAverage payoff over {NUM_GAMES} games: {avg:.2f}")


----------- current player: 0 -----------------
[Player 0]
Pt: 25000
Wind: East
Hand: 2m 2m 4m 2p 0p 3s 4s 0s 8s 1z 4z 5z 6z 3m 
River: 
Riichi: No
Menzen: Yes
[Player 1]
Pt: 25000
Wind: South
Hand: 3m 4m 4m 9m 2p 2p 4p 5s 6s 9s 2z 2z 7z 
River: 
Riichi: No
Menzen: Yes
[Player 2]
Pt: 25000
Wind: West
Hand: 5m 9m 3p 3p 7p 7p 9p 3s 4s 8s 1z 1z 6z 
River: 
Riichi: No
Menzen: Yes
[Player 3]
Pt: 25000
Wind: North
Hand: 3m 5m 7m 8m 8m 3p 4p 6p 7s 9s 4z 5z 6z 
River: 
Riichi: No
Menzen: Yes
----------- current player: 1 -----------------
[Player 0]
Pt: 25000
Wind: East
Hand: 2m 2m 3m 4m 2p 3s 4s 0s 8s 1z 4z 5z 6z 
River: 0p1h 
Riichi: No
Menzen: Yes
[Player 1]
Pt: 25000
Wind: South
Hand: 3m 4m 4m 9m 2p 2p 4p 5s 6s 9s 2z 2z 7z 1s 
River: 
Riichi: No
Menzen: Yes
[Player 2]
Pt: 25000
Wind: West
Hand: 5m 9m 3p 3p 7p 7p 9p 3s 4s 8s 1z 1z 6z 
River: 
Riichi: No
Menzen: Yes
[Player 3]
Pt: 25000
Wind: North
Hand: 3m 5m 7m 8m 8m 3p 4p 6p 7s 9s 4z 5z 6z 
River: 
Riichi: No
Menzen: Yes
----------- curre