In [None]:
!pip3 install pymahjong

In [None]:
import pymahjong
import numpy as np

In [None]:
TERMINAL_HONOR_INDICES = set([
    0, 8, 9, 17, 18, 26, 27, 28, 29, 30, 31, 32, 33
])

WIN_INDICES = set([42,43])

class THW():

  def select_action(self, obs_np, valid_actions):
    win_actions = [a for a in valid_actions if a in WIN_INDICES]
    if win_actions:
        return np.random.choice(win_actions)
    terminal_honor_discards = [a for a in valid_actions if a in TERMINAL_HONOR_INDICES]

    if terminal_honor_discards:
        return np.random.choice(terminal_honor_discards)
    else:
        return np.random.choice(valid_actions)

class TH():

  def select_action(self, obs_np, valid_actions):
    terminal_honor_discards = [a for a in valid_actions if a in TERMINAL_HONOR_INDICES]
    if terminal_honor_discards:
        return np.random.choice(terminal_honor_discards)
    else:
        return np.random.choice(valid_actions)

class RAND():
  def select_action(self, obs_np, valid_actions):
    return np.random.choice(valid_actions)


In [None]:
def run_game(agents, num_rounds=1000):
  reward_sum = 0
  for _ in range(num_rounds):
    env = pymahjong.MahjongEnv()
    obs = env.reset()
    while True:
        curr_pid = env.get_curr_player_id()
        valid_actions = env.get_valid_actions()  # e.g., [0, 3, 4, 20, 21]
        executor_obs = env.get_obs(curr_pid)
        a = agents[curr_pid].select_action(executor_obs, valid_actions)

        env.step(curr_pid, a)


        # print(executor_obs)
        # oracle_obs = env.get_oracle_obs(curr_pid)
        # full_obs = concat((executor_obs, oracle_obs), axis=0)

        if env.is_over():
            payoffs = env.get_payoffs() # payoffs = [p0, p1, p2, p3]
            reward_sum += payoffs[0]
            break
  print("total payoff = {} after {} rounds".format(reward_sum, num_rounds))

In [None]:
print("terminal / honor discard agent")
run_game(agents=[TH()] * 4, num_rounds=1000)
print("terminal / honor / win discard agent")
run_game(agents=[THW()] * 4, num_rounds=1000)
print("random agent")
run_game(agents=[RAND()] * 4, num_rounds=1000)

In [None]:
# training a cnn

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MahjongCNNPolicy(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 93 * 34, 512)
        self.fc2 = nn.Linear(512, 47)

    def forward(self, x):  # x: [batch_size, 1, 93, 34]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)  # logits

In [None]:
class CNNAgent:
    def __init__(self, model, device="cpu"):
        self.model = model.to(device)
        self.device = device

    def select_action(self, obs_np, valid_actions, training=True):
        # Convert 93×34 obs to torch tensor [1, 1, 93, 34]
        obs_tensor = torch.tensor(obs_np, dtype=torch.float32, device=self.device).unsqueeze(0).unsqueeze(0)

        if not training:
          with torch.no_grad():
            logits = self.model(obs_tensor)[0]
        else:
            logits = self.model(obs_tensor)[0]  # no torch.no_grad() here!

        valid_actions = [a for a in valid_actions if 0 <= a < 47]


        logits = logits[0]  # [47]

        # Mask invalid actions
        mask = torch.full((47,), float('-inf'), device=self.device)
        mask[valid_actions] = 0  # allow valid actions only

        masked_logits = logits + mask
        probs = F.softmax(masked_logits, dim=0)

        dist = torch.distributions.Categorical(probs)
        action = dist.sample()

        return action.item(), dist.log_prob(action)

In [None]:
def train_cnn_agent_w_randoms(num_games=1000, learning_rate=1e-4):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    model = MahjongCNNPolicy()
    agent = CNNAgent(model, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    all_log_probs = []
    all_rewards = []

    for episode in range(num_games):
        env = pymahjong.MahjongEnv()
        obs = env.reset()

        log_probs = []
        rewards = []

        while True:
            pid = env.get_curr_player_id()
            valid_actions = env.get_valid_actions()
            executor_obs = env.get_obs(pid)

            if pid == 0: #cnn agent
                a, log_prob = agent.select_action(executor_obs, valid_actions, training=True)
                print(f"cnn turn action: {a}")
                log_probs.append(log_prob)
            else: # 3 other random agents
                a = np.random.choice(valid_actions)

            env.step(pid, a)

            if env.is_over():
              payoffs = env.get_payoffs() # payoffs = [p0, p1, p2, p3]
              break

        reward = payoffs[0]


        # Save logs for training
        all_log_probs.append(log_probs)
        all_rewards.append(reward)

        # Training step (REINFORCE)
        optimizer.zero_grad()
        loss = 0
        for log_prob in log_probs:
            loss += -log_prob * reward
        print(f"raw reward {reward}. raw log prob[0]: {log_probs[0].item()}. loss: {loss}")
        loss.backward()
        optimizer.step()


        print(f"Episode [{episode}] reward: {reward:.2f}, loss: {loss.item():.2f}")

        # Save the model
    torch.save(model.state_dict(), "trained_cnn_rand_model.pth")
    print("Training complete and model saved.")

    return model

In [None]:
model = train_cnn_agent_w_randoms()
# run_game(agents=[model, RAND(), RAND(), RAND()], num_rounds=1000)