In [1]:
!pip install pymahjong

Collecting pymahjong
  Downloading pymahjong-1.0.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl.metadata (2.5 kB)
Downloading pymahjong-1.0.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl (413 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m413.5/413.5 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pymahjong
Successfully installed pymahjong-1.0.4


In [2]:
# Section 1: Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import os
from pymahjong import MahjongEnv

torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

OBS_CHANNELS = 93
HEIGHT = 34
WIDTH = 1
NUM_ACTIONS = 136


In [3]:
def encode_game_state(obs):
    """
    Converts pymahjong (93, 34) obs into CNN input format: (1, 93, 34, 1)
    """
    obs = np.array(obs, dtype=np.float32)  # (93, 34)
    return torch.tensor(obs).unsqueeze(0).unsqueeze(-1).to(device)  # (1, 93, 34, 1)


In [40]:
class MahjongCNNBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(OBS_CHANNELS, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Flatten()
        )
        self.output_dim = self._get_output_dim()

    def _get_output_dim(self):
        dummy = torch.zeros((1, OBS_CHANNELS, HEIGHT, WIDTH))
        return self.conv(dummy).shape[1]

    def forward(self, x):
        return self.conv(x)

class DiscreteHead(nn.Module):
    def __init__(self, base, output_size):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_size)
        )

    def forward(self, x):
        return self.head(self.base(x))


In [41]:
class TrainingBot:
    def __init__(self, pid, models, logs):
        self.pid = pid
        self.models = models
        self.logs = logs

    def act(self, obs, valid_actions, reward=0.0):
        state = encode_game_state(obs)  # shape: (1, 93, 34, 1)
        logits = self.models["action"](state)  # shape: (1, N) where N = model output (e.g. 136)
        logits = logits.squeeze(0)

        # Only use logits for valid actions
        valid_logits = logits[valid_actions]
        probs = F.softmax(valid_logits, dim=0).detach().cpu().numpy()
        action_index = np.random.choice(len(valid_actions), p=probs)
        action = valid_actions[action_index]

        self.logs.append((state.squeeze(0), action, reward))
        return action


In [69]:
def train_pg(model, optimizer, log_data, normalize_advantage=True, reward_scale=1.0):
    """
    Train a policy using REINFORCE with optional reward scaling and advantage normalization.

    Parameters:
        model: policy model outputting logits over actions
        optimizer: optimizer for the model
        log_data: list of (state, action, reward)
        normalize_advantage: if True, normalize the reward signal to mean=0, std=1
        reward_scale: scale the reward signal to prevent exploding loss
    """
    if not log_data:
        return 0.0

    states, actions, rewards = zip(*log_data)
    states = torch.stack(states).to(device)
    actions = torch.tensor(actions, dtype=torch.long).to(device)
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)

    # 🧠 Scale reward down to avoid huge loss values
    rewards *= reward_scale

    # 🎯 Use advantage = reward (baseline = 0 for now)
    advantages = rewards.clone()

    if normalize_advantage:
        mean = advantages.mean()
        std = advantages.std()
        if std > 1e-5:
            advantages = (advantages - mean) / (std + 1e-8)
        else:
            advantages = advantages - mean

    # 🔁 Forward pass
    logits = model(states)
    log_probs = F.log_softmax(logits, dim=1)
    selected_log_probs = log_probs[torch.arange(len(actions)), actions]

    # 🎯 Policy gradient loss
    loss = -(advantages * selected_log_probs).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 🔎 Debug: print stats
    print(f"Loss: {loss.item():.4f}, Reward mean: {rewards.mean().item():.2f}, Std: {rewards.std().item():.2f}")

    return loss.item()


In [74]:
def simulate_hanchan(bots, gamma=1.0, placement_bonus={1: 90, 2: 45, 3: 0, 4: -150}):
    from pymahjong import MahjongEnv
    import numpy as np

    env = MahjongEnv()
    logs = []
    scores = [25000, 25000, 25000, 25000]
    oya = 0
    kyoutaku = 0
    honba = 0
    wind_order = ["east", "south"]

    def determine_next_oya(current_oya, winner_id, tenpai, is_draw):
        if not is_draw and winner_id == current_oya:
            return current_oya, True
        if is_draw and tenpai[current_oya]:
            return current_oya, True
        return (current_oya + 1) % 4, False

    def get_player_placement(scores, oya):
        """
        Returns 1-based rank of player 0 using score and Tenhou-style clockwise tiebreaking.
        """
        def oya_tiebreak_key(pid): return (pid - oya) % 4
        player_scores = [(score, pid) for pid, score in enumerate(scores)]
        sorted_players = sorted(player_scores, key=lambda x: (-x[0], oya_tiebreak_key(x[1])))
        for rank, (_, pid) in enumerate(sorted_players, start=1):
            if pid == 0:
                return rank

    busted = False

    for wind in wind_order:
        for round_index in range(4):
            consecutive_oya = 0
            while consecutive_oya < 4:
                env.reset(
                    oya=oya,
                    game_wind=wind,
                    scores=scores,
                    kyoutaku=kyoutaku,
                    honba=honba
                )

                hand_logs = []
                step_index = 0
                while not env.is_over():
                    pid = env.get_curr_player_id()
                    obs = env.get_obs(pid)
                    valid = env.get_valid_actions()
                    reward = 0

                    action = bots[pid].act(obs, valid, reward)
                    env.step(pid, action)

                    if pid == 0:
                        hand_logs.append((encode_game_state(obs).squeeze(0), action, reward))
                    step_index += 1

                payoffs = env.get_payoffs()
                scores = list(np.array(scores) + np.array(payoffs))
                scores = [int(s) for s in scores]

                # 💥 Bust-out check
                if any(score < 0 for score in scores):
                    busted = True
                    break

                # Discounted reward shaping
                for i in range(len(hand_logs)):
                    s, a, r = hand_logs[i]
                    r += payoffs[0] * (gamma ** i)
                    hand_logs[i] = (s, a, r)
                logs += hand_logs

                # Next hand setup
                try:
                    winner = env.t.winner_id
                except:
                    winner = -1
                try:
                    tenpai = env.t.tenpais
                except:
                    tenpai = [False] * 4

                oya_next, oya_continues = determine_next_oya(oya, winner, tenpai, winner == -1)
                if oya_next == oya:
                    honba += 1
                else:
                    honba = 0
                    oya = oya_next
                    consecutive_oya += 1

            if busted:
                break
        if busted:
            break

    # 🎯 Final placement bonus (Tenhou style + oya tiebreaker)
    placement = get_player_placement(scores, oya)
    bonus = placement_bonus[placement]
    logs = [(s, a, r + bonus) for (s, a, r) in logs]

    return logs


In [75]:




# Init model + optimizer
base = MahjongCNNBase().to(device)
models = {
    "action": DiscreteHead(base, NUM_ACTIONS).to(device)
}
optimizers = {
    name: torch.optim.Adam(models[name].parameters(), lr=1e-4)
    for name in models
}

NUM_EPISODES = 10000
SAVE_EVERY = 100

for episode in range(1, NUM_EPISODES + 1):
    env = MahjongEnv()
    env.reset()
    logs = []

    bots = [
        TrainingBot(pid, models, logs if pid == 0 else [])
        for pid in range(4)
    ]
    logs = simulate_hanchan(bots, gamma=1.0)
    loss = train_pg(models["action"], optimizers["action"], logs)
    # while not env.is_over():
    #     pid = env.get_curr_player_id()
    #     obs = env.get_obs(pid)
    #     valid = env.get_valid_actions()
    #     action = bots[pid].act(obs, valid)
    #     env.step(pid, action)

    # # Get final reward for REINFORCE
    # payoffs = env.get_payoffs()  # [p0, p1, p2, p3]
    # for i in range(len(logs)):
    #     state, action, _ = logs[i]
    #     logs[i] = (state, action, payoffs[0])

    # Train
    # loss = train_pg(models["action"], optimizers["action"], logs)
    print(f"[Episode {episode}] | Loss: {loss:.4f}")

    if episode % SAVE_EVERY == 0:
        os.makedirs(f"checkpoints/ep_{episode}", exist_ok=True)
        torch.save(models["action"].state_dict(), f"checkpoints/ep_{episode}/action.pt")
        print(f"✓ Model saved at ep {episode}")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[Episode 1217] | Loss: -0.0028
Loss: -0.0039, Reward mean: -2444.87, Std: 2688.61
[Episode 1218] | Loss: -0.0039
Loss: -0.0038, Reward mean: 573.61, Std: 1193.62
[Episode 1219] | Loss: -0.0038
Loss: -0.0138, Reward mean: 4986.86, Std: 7386.57
[Episode 1220] | Loss: -0.0138
Loss: 0.0082, Reward mean: -1790.18, Std: 2778.15
[Episode 1221] | Loss: 0.0082
Loss: -0.0057, Reward mean: -1549.48, Std: 3010.40
[Episode 1222] | Loss: -0.0057
Loss: 0.0114, Reward mean: 7673.85, Std: 9192.09
[Episode 1223] | Loss: 0.0114
Loss: 0.0043, Reward mean: -159.57, Std: 366.86
[Episode 1224] | Loss: 0.0043
Loss: -0.0057, Reward mean: -3007.63, Std: 3129.26
[Episode 1225] | Loss: -0.0057
Loss: 0.0038, Reward mean: 2090.00, Std: 1261.51
[Episode 1226] | Loss: 0.0038
Loss: -0.0026, Reward mean: -2864.52, Std: 3351.83
[Episode 1227] | Loss: -0.0026
Loss: 0.0164, Reward mean: -3449.15, Std: 2649.87
[Episode 1228] | Loss: 0.0164
Loss: -0.0174, Rewa

KeyboardInterrupt: 

In [68]:
import torch
import torch.nn.functional as F
import numpy as np
import random
from pymahjong import MahjongEnv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load trained model ---
class MahjongCNNBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(93, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 1), padding=0),
            nn.ReLU(),
            nn.Flatten()
        )
        self.output_dim = self._get_output_dim()

    def _get_output_dim(self):
        dummy = torch.zeros((1, 93, 34, 1))
        return self.conv(dummy).shape[1]

    def forward(self, x):
        return self.conv(x)

class DiscreteHead(nn.Module):
    def __init__(self, base, output_size):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_size)
        )

    def forward(self, x):
        return self.head(self.base(x))

def encode_game_state(obs):
    obs = np.array(obs, dtype=np.float32)
    return torch.tensor(obs).unsqueeze(0).unsqueeze(-1).to(device)

# --- Define agents ---
class RandomBot:
    def __init__(self, pid):
        self.pid = pid

    def act(self, obs, valid_actions):
        return random.choice(valid_actions)

TERMINAL_HONOR_INDICES = set([
    0, 8, 9, 17, 18, 26, 27, 28, 29, 30, 31, 32, 33
])

WIN_INDICES = set([42,43])

def terminal_honor_win_action(valid_actions):
    # Filter for discard actions only (0–33)

    win_actions = [a for a in valid_actions if a in WIN_INDICES]
    if win_actions:
        return np.random.choice(win_actions)
    terminal_honor_discards = [a for a in valid_actions if a in TERMINAL_HONOR_INDICES]

    if terminal_honor_discards:
        return np.random.choice(terminal_honor_discards)
    else:
        return np.random.choice(valid_actions)

class TerminalBot:
    def __init__(self, pid):
        self.pid = pid

    def act(self, obs, valid_actions):
        return terminal_honor_win_action(valid_actions)

class TrainedBot:
    def __init__(self, pid, model):
        self.pid = pid
        self.model = model

    def act(self, obs, valid_actions):
        state = encode_game_state(obs)  # shape: (1, 93, 34, 1)
        logits = self.model(state).squeeze(0)
        valid_logits = logits[valid_actions]
        probs = F.softmax(valid_logits, dim=0).detach().cpu().numpy()
        index = np.random.choice(len(valid_actions), p=probs)
        return valid_actions[index]

# --- Load model ---
base = MahjongCNNBase().to(device)
model = DiscreteHead(base, 136).to(device)
model.load_state_dict(torch.load("./checkpoints/ep_1200/action.pt", map_location=torch.device('cpu')))
model.eval()

# --- Evaluation loop ---
NUM_GAMES = 1
total_payoff = 0
env = MahjongEnv()

for game in range(NUM_GAMES):
    env.reset()
    bots = [
        TrainingBot(0, model),
        TerminalBot(1),
        TerminalBot(2),
        TerminalBot(3)
    ]
    simulate_hanchan(bots=bots)
    while not env.is_over():
        pid = env.get_curr_player_id()
        obs = env.get_obs(pid)
        valid = env.get_valid_actions()
        action = bots[pid].act(obs, valid)
        env.step(pid, action)

    payoffs = env.get_payoffs()
    total_payoff += payoffs[0]
    print(f"Game {game + 1}: payoff = {payoffs[0]}")

avg = total_payoff / NUM_GAMES
print(f"\nAverage payoff over {NUM_GAMES} games: {avg:.2f}")
print(f"\n total is {total_payoff}")


TypeError: TrainingBot.__init__() missing 1 required positional argument: 'logs'

In [82]:
import torch
import torch.nn.functional as F
import numpy as np
import random
from pymahjong import MahjongEnv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # --- Load trained model ---
# class MahjongCNNBase(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.conv = nn.Sequential(
#             nn.Conv2d(93, 64, kernel_size=(3, 1), padding=0),
#             nn.ReLU(),
#             nn.Conv2d(64, 64, kernel_size=(3, 1), padding=0),
#             nn.ReLU(),
#             nn.Flatten()
#         )
#         self.output_dim = self._get_output_dim()

#     def _get_output_dim(self):
#         dummy = torch.zeros((1, 93, 34, 1))
#         return self.conv(dummy).shape[1]

#     def forward(self, x):
#         return self.conv(x)

# class DiscreteHead(nn.Module):
#     def __init__(self, base, output_size):
#         super().__init__()
#         self.base = base
#         self.head = nn.Sequential(
#             nn.Linear(base.output_dim, 256),
#             nn.ReLU(),
#             nn.Linear(256, output_size)
#         )

#     def forward(self, x):
#         return self.head(self.base(x))

# def encode_game_state(obs):
#     obs = np.array(obs, dtype=np.float32)
#     return torch.tensor(obs).unsqueeze(0).unsqueeze(-1).to(device)

# # --- Define agents ---
# class RandomBot:
#     def __init__(self, pid):
#         self.pid = pid

#     def act(self, obs, valid_actions):
#         return random.choice(valid_actions)

class TrainedBot:
    def __init__(self, pid, model):
        self.pid = pid
        self.model = model

    def act(self, obs, valid_actions):
        state = encode_game_state(obs)  # shape: (1, 93, 34, 1)
        logits = self.model(state).squeeze(0)
        valid_logits = logits[valid_actions]
        probs = F.softmax(valid_logits, dim=0).detach().cpu().numpy()
        index = np.random.choice(len(valid_actions), p=probs)
        return valid_actions[index]

# --- Load model ---
base = MahjongCNNBase().to(device)
model = DiscreteHead(base, 136).to(device)
model.load_state_dict(torch.load("./checkpoints/ep_3700/action.pt", map_location=torch.device("cpu")))
model.eval()

# --- Evaluation loop ---
NUM_GAMES = 1000
total_payoff = 0
env = MahjongEnv()

for game in range(NUM_GAMES):
    env.reset()
    bots = [
        TrainedBot(0, model),
        TerminalBot(1),
        TerminalBot(2),
        TerminalBot(3)
    ]
    #play a hanchan

    while not env.is_over():
        pid = env.get_curr_player_id()
        obs = env.get_obs(pid)
        valid = env.get_valid_actions()
        action = bots[pid].act(obs, valid)



        # env.render()  # ← this prints the table state
        env.step(pid, action)

    payoffs = env.get_payoffs()
    total_payoff += payoffs[0]
    print(f"Game {game + 1}: payoff = {payoffs[0]}")

avg = total_payoff / NUM_GAMES
print(f"\nAverage payoff over {NUM_GAMES} games: {avg:.2f}")


Game 1: payoff = 0.0
Game 2: payoff = 0.0
Game 3: payoff = 0.0
Game 4: payoff = -1500.0
Game 5: payoff = 0.0
Game 6: payoff = 0.0
Game 7: payoff = 0.0
Game 8: payoff = -1000.0
Game 9: payoff = 0.0
Game 10: payoff = 0.0
Game 11: payoff = 0.0
Game 12: payoff = -3900.0
Game 13: payoff = 0.0
Game 14: payoff = -1000.0
Game 15: payoff = 0.0
Game 16: payoff = 0.0
Game 17: payoff = 0.0
Game 18: payoff = -1000.0
Game 19: payoff = -1000.0
Game 20: payoff = 0.0
Game 21: payoff = -1500.0
Game 22: payoff = -1000.0
Game 23: payoff = 0.0
Game 24: payoff = 0.0
Game 25: payoff = 0.0
Game 26: payoff = 0.0
Game 27: payoff = 0.0
Game 28: payoff = -1000.0
Game 29: payoff = 0.0
Game 30: payoff = 0.0
Game 31: payoff = -1000.0
Game 32: payoff = -1500.0
Game 33: payoff = -1500.0
Game 34: payoff = -1500.0
Game 35: payoff = -1500.0
Game 36: payoff = -1000.0
Game 37: payoff = 0.0
Game 38: payoff = -1000.0
Game 39: payoff = 0.0
Game 40: payoff = -1000.0
Game 41: payoff = -1500.0
Game 42: payoff = -1000.0
Game 43: 