In [1]:
!pip install mjai

Collecting mjai
  Downloading mjai-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting mahjong~=1.2.0 (from mjai)
  Downloading mahjong-1.2.1-py3-none-any.whl.metadata (1.6 kB)
Collecting loguru (from mjai)
  Downloading loguru-0.7.3-py3-none-any.whl.metadata (22 kB)
Downloading mjai-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading mahjong-1.2.1-py3-none-any.whl (60 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.9/60.9 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading loguru-0.7.3-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mahjong, loguru, mjai
Successfully installed loguru-0.7.3 mahjong-1.2.1 mjai-0.2.1


In [5]:
# Section 1: Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mjai import Bot, Simulator

NUM_PLAYERS = 4
HEIGHT = 34      # tile types
WIDTH = 4        # max tile instances per type
PLANES = 37      # total feature planes (current + past 1 round)

INPUT_SHAPE = (HEIGHT, WIDTH, PLANES)  # (34, 4, 37)



In [6]:
# Tile mapping: Converts tile strings to indices (0-33)
def tile_to_index(tile_str):
    suit_map = {'m': 0, 'p': 9, 's': 18, 'z': 27}
    if tile_str[0] in '123456789':
        return suit_map[tile_str[-1]] + int(tile_str[0]) - 1
    elif tile_str[0] in 'ESWN':
        wind_map = {'E': 27, 'S': 28, 'W': 29, 'N': 30}
        return wind_map[tile_str[0]]
    elif tile_str[0] in 'PFC':
        dragon_map = {'P': 31, 'F': 32, 'C': 33}
        return dragon_map[tile_str[0]]
    else:
        raise ValueError(f"Invalid tile string: {tile_str}")

# Main encoding function
def encode_mahjong_state(game_state, player_id):
    """
    Encodes the current game state into a (34, 4, 37) tensor.

    Parameters:
    - game_state: The current game state object from MJAI.
    - player_id: The ID of the player (0-3) for whom the encoding is done.

    Returns:
    - A PyTorch tensor of shape (37, 34, 4).
    """
    # Initialize the tensor
    state_tensor = torch.zeros((PLANES, HEIGHT, WIDTH), dtype=torch.float32)

    # Plane indices
    plane_idx = {
        'own_hand': 0,
        'aka_five': 1,
        'discards_start': 2,  # Planes 2-5 for discards of players 0-3
        'melds_start': 6,     # Planes 6-9 for melds of players 0-3
        'dora': 10,
        'riichi_flags_start': 11,  # Planes 11-13 for riichi flags of opponents
        'kyoku_info_start': 14,    # Planes 14-21 for kyoku info
        'round_wind': 22,
        'own_wind': 23,
        'past_hand': 24,
        'past_discards_start': 25, # Planes 25-28 for past discards of players 0-3
        'past_melds_start': 29,    # Planes 29-32 for past melds of players 0-3
        'past_dora': 33,
        'past_riichi_flags_start': 34  # Planes 34-36 for past riichi flags of opponents
    }

    # Helper function to place tiles into the tensor
    def place_tiles(tiles, plane):
        for tile in tiles:
            idx = tile_to_index(tile)
            for i in range(WIDTH):
                if state_tensor[plane, idx, i] == 0:
                    state_tensor[plane, idx, i] = 1
                    break

    # Own hand
    own_hand = game_state.hands[player_id]
    place_tiles(own_hand, plane_idx['own_hand'])

    # Aka five marks
    aka_tiles = [tile for tile in own_hand if tile.endswith('0')]
    place_tiles(aka_tiles, plane_idx['aka_five'])

    # Discards and melds for all players
    for pid in range(4):
        # Discards
        discards = game_state.discards[pid]
        place_tiles(discards, plane_idx['discards_start'] + pid)

        # Melds
        melds = game_state.melds[pid]
        meld_tiles = [tile for meld in melds for tile in meld['tiles']]
        place_tiles(meld_tiles, plane_idx['melds_start'] + pid)

    # Dora indicators
    dora_markers = game_state.dora_markers
    place_tiles(dora_markers, plane_idx['dora'])

    # Riichi flags for opponents
    for i, pid in enumerate(range(4)):
        if pid != player_id and game_state.riichi_declarations.get(pid, False):
            state_tensor[plane_idx['riichi_flags_start'] + i, :, :] = 1

    # Kyoku info
    kyoku_info = [
        game_state.kyoku,  # Current round number
        game_state.honba,  # Number of honba sticks
        game_state.kyotaku,  # Number of riichi sticks
        game_state.oya,  # Dealer ID
        *game_state.scores  # Scores of all players
    ]
    for i, info in enumerate(kyoku_info):
        state_tensor[plane_idx['kyoku_info_start'] + i, :, :] = info / 100000  # Normalize if necessary

    # Round wind
    wind_map = {'E': 0, 'S': 1, 'W': 2, 'N': 3}
    wind_idx = wind_map.get(game_state.bakaze, 0)
    state_tensor[plane_idx['round_wind'], :, :] = wind_idx / 3

    # Own wind
    own_wind_idx = (wind_idx + player_id) % 4
    state_tensor[plane_idx['own_wind'], :, :] = own_wind_idx / 3

    # Past 1 round info (if available)
    if hasattr(game_state, 'past_round'):
        past = game_state.past_round

        # Past own hand
        past_hand = past.hands[player_id]
        place_tiles(past_hand, plane_idx['past_hand'])

        # Past discards and melds for all players
        for pid in range(4):
            past_discards = past.discards[pid]
            place_tiles(past_discards, plane_idx['past_discards_start'] + pid)

            past_melds = past.melds[pid]
            past_meld_tiles = [tile for meld in past_melds for tile in meld['tiles']]
            place_tiles(past_meld_tiles, plane_idx['past_melds_start'] + pid)

        # Past dora indicators
        past_dora = past.dora_markers
        place_tiles(past_dora, plane_idx['past_dora'])

        # Past riichi flags for opponents
        for i, pid in enumerate(range(4)):
            if pid != player_id and past.riichi_declarations.get(pid, False):
                state_tensor[plane_idx['past_riichi_flags_start'] + i, :, :] = 1



In [7]:
# Section 3: Shared CNN Backbone
class MahjongCNNBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(PLANES, 100, kernel_size=(5, 2), padding=0),
            nn.BatchNorm2d(100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Conv2d(100, 100, kernel_size=(5, 2), padding=0),
            nn.BatchNorm2d(100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Conv2d(100, 100, kernel_size=(5, 2), padding=0),
            nn.BatchNorm2d(100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Flatten()
        )
        self.output_dim = self._get_flattened_dim()

    def _get_flattened_dim(self):
        x = torch.zeros((1, PLANES, HEIGHT, WIDTH))
        with torch.no_grad():
            x = self.conv(x)
        return x.shape[1]

    def forward(self, x):
        return self.conv(x)

In [8]:
# Section 4: Action-Specific Heads

class DiscardHead(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 300),
            nn.ReLU(),
            nn.Linear(300, 34)  # 34 tile types
        )

    def forward(self, x):
        return self.head(self.base(x))

class PonHead(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 300),
            nn.ReLU(),
            nn.Linear(300, 1)  # binary: pon or not
        )

    def forward(self, x):
        return self.head(self.base(x))

class ChiHead(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 300),
            nn.ReLU(),
            nn.Linear(300, 4)  # 0: pass, 1–3: chi types
        )

    def forward(self, x):
        return self.head(self.base(x))

class RiichiHead(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.base = base
        self.head = nn.Sequential(
            nn.Linear(base.output_dim, 300),
            nn.ReLU(),
            nn.Linear(300, 2)  # binary: declare or not
        )

    def forward(self, x):
        return self.head(self.base(x))


In [9]:
# Section 5: Modular MJAI Bot
class CNNBot(Bot):
    def __init__(self, player_id, models):
        super().__init__(player_id)
        self.models = models  # dict of action models

    def think(self):
        state = encode_mahjong_state(self.game_state).unsqueeze(0)
        legal = self.get_legal_actions()
        action_candidates = []

        if "dahai" in legal:
            logits = self.models["discard"](state)
            action_index = torch.argmax(logits).item()
            action = {"type": "dahai", "pai": f"{action_index // 4 + 1}m", "tsumogiri": True}
            action_candidates.append((action, logits[0, action_index]))

        if "pon" in legal:
            prob = torch.sigmoid(self.models["pon"](state)).item()
            if prob > 0.5:
                action_candidates.append(({"type": "pon"}, prob))

        if "chi" in legal:
            logits = self.models["chi"](state)
            chi_type = torch.argmax(logits).item()
            if chi_type > 0:
                action_candidates.append(({"type": "chi", "option": chi_type}, logits[0, chi_type]))

        if "riichi" in legal:
            logits = self.models["riichi"](state)
            riichi_flag = torch.argmax(logits).item()
            if riichi_flag == 1:
                action_candidates.append(({"type": "riichi"}, logits[0, 1]))

        if action_candidates:
            return max(action_candidates, key=lambda x: x[1])[0]

        return {"type": "none"}



In [12]:
# Section 6: Self-Play Training Loop
import random
from collections import defaultdict

# Hyperparameters
NUM_EPISODES = 1000
BATCH_SIZE = 32
LEARNING_RATE = 1e-4

# Shared CNN base
base = MahjongCNNBase()

# Modular action heads
models = {
    "discard": DiscardHead(base),
    "pon": PonHead(base),
    "chi": ChiHead(base),
    "riichi": RiichiHead(base)
}


# Optimizers for each model
optimizers = {
    name: torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    for name, model in models.items()
}

# Experience buffer per model type
experience = {
    "discard": [],
    "pon": [],
    "chi": [],
    "riichi": []
}

# Training functions
def train_discard(samples):
    states, actions = zip(*samples)
    states = torch.stack(states)
    actions = torch.tensor(actions)
    logits = models["discard"](states)
    loss = F.cross_entropy(logits, actions)
    optimizers["discard"].zero_grad()
    loss.backward()
    optimizers["discard"].step()
    return loss.item()

def train_binary(name, samples):
    states, labels = zip(*samples)
    states = torch.stack(states)
    labels = torch.tensor(labels).float().unsqueeze(1)
    logits = models[name](states)
    loss = F.binary_cross_entropy_with_logits(logits, labels)
    optimizers[name].zero_grad()
    loss.backward()
    optimizers[name].step()
    return loss.item()

def train_chi(samples):
    states, labels = zip(*samples)
    states = torch.stack(states)
    labels = torch.tensor(labels)
    logits = models["chi"](states)
    loss = F.cross_entropy(logits, labels)
    optimizers["chi"].zero_grad()
    loss.backward()
    optimizers["chi"].step()
    return loss.item()

# Custom bot that logs its own decisions
class TrainingCNNBot(CNNBot):
    def __init__(self, player_id, models, experience):
        super().__init__(player_id, models)
        self.experience = experience

    def think(self):
        state = encode_mahjong_state(self.game_state).unsqueeze(0)
        legal = self.get_legal_actions()
        action_candidates = []

        if "dahai" in legal:
            logits = self.models["discard"](state)
            action_index = torch.argmax(logits).item()
            self.experience["discard"].append((state.squeeze(0), action_index))
            action = {"type": "dahai", "pai": f"{action_index // 4 + 1}m", "tsumogiri": True}
            return action

        if "pon" in legal:
            prob = torch.sigmoid(self.models["pon"](state)).item()
            label = int(prob > 0.5)
            self.experience["pon"].append((state.squeeze(0), label))
            if prob > 0.5:
                return {"type": "pon"}

        if "chi" in legal:
            logits = self.models["chi"](state)
            action_index = torch.argmax(logits).item()
            self.experience["chi"].append((state.squeeze(0), action_index))
            if action_index > 0:
                return {"type": "chi", "option": action_index}

        if "riichi" in legal:
            logits = self.models["riichi"](state)
            action_index = torch.argmax(logits).item()
            self.experience["riichi"].append((state.squeeze(0), int(action_index == 1)))
            if action_index == 1:
                return {"type": "riichi"}

        return {"type": "none"}

# Self-play simulation loop
for episode in range(NUM_EPISODES):
    # Reset experience buffers
    for k in experience:
        experience[k] = []

    # Create bots with shared models and experience buffers
    bots = [
        TrainingCNNBot(pid, models, experience)
        for pid in range(NUM_PLAYERS)
    ]

    # Save bots into zip file as needed or simulate in-process
    submissions = [bots[0], bots[1], bots[2], bots[3]]

    simulator = Simulator(submissions, logs_dir=".")
    simulator.run()

    # Train each model on collected experience
    if len(experience["discard"]) >= BATCH_SIZE:
        loss_d = train_discard(random.sample(experience["discard"], BATCH_SIZE))
        loss_p = train_binary("pon", random.sample(experience["pon"], BATCH_SIZE)) if experience["pon"] else 0
        loss_r = train_binary("riichi", random.sample(experience["riichi"], BATCH_SIZE)) if experience["riichi"] else 0
        loss_c = train_chi(random.sample(experience["chi"], BATCH_SIZE)) if experience["chi"] else 0

        print(f"[Episode {episode}] Losses - Discard: {loss_d:.3f}, Pon: {loss_p:.3f}, Chi: {loss_c:.3f}, Riichi: {loss_r:.3f}")


TypeError: expected str, bytes or os.PathLike object, not TrainingCNNBot