In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!nvidia-smi

In [3]:
import os
import copy
import ast
import json
import random
import glob
import numpy as np
from functools import partial
from collections import Counter
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary

from sklearn.metrics import f1_score

## Data Extraction (Do not run if data files exist)

In [None]:
!unzip -q '/content/drive/MyDrive/mjdata/mjdata.zip'

In [None]:
game_folders = ['MO/', 'PLAY/', 'LIU/']
folder_name = 'output2017/'
all_game_files = []

for game_folder in game_folders:
    for file_name in os.listdir(folder_name + game_folder):
        all_game_files.append(folder_name + game_folder + file_name)

### Chow-able data

In [None]:
def detect_chowable(hand_tiles, discarded_tile):
    discard_type, discard_num = discarded_tile[0], int(discarded_tile[1])
    # search for "W" tiles if discarded is "W", and tiles with number within [discard_num-2, discard_num+2]
    candidates = set([int(t[1]) for t in hand_tiles if t.startswith(discard_type) and abs(int(t[1])-discard_num) <= 2])
    if not candidates:  # No candidate exists
        return False

    # chow at end
    if discard_num-2 in candidates and discard_num-1 in candidates:
        return True
    
    # chow at middle
    if discard_num-1 in candidates and discard_num+1 in candidates:
        return True

    # chow at start
    if discard_num+1 in candidates and discard_num+2 in candidates:
        return True

    return False

def extract_target_data(file, history_len=4):
    '''
    Args:
    - history_len (int): max length of history of a single player to use

    Returns:
    - target_data (list): all chow-able data of all players
    '''

    def custom_eval(x):
        try:
            return eval(x)
        except (NameError, SyntaxError):
            return x

    with open(file, 'r') as f:
        lines = f.readlines()

    target_data = []

    # Record of all discarded tiles in sequential order
    players_discard_tiles = {}

    # The histories for respective players
    players_history = {}

    # The tiles they respectively have seen
    players_seen = {}

    # Latest hand tiles for respective players
    players_latest_hands = {}

    for line in lines[2:6]:
        player_num, hands, _ = list(map(custom_eval, line.split('\t')))
        players_latest_hands[player_num] = hands
        players_seen[player_num] = hands
        players_history[player_num] = []
        players_discard_tiles[player_num] = []


    # Play records
    plays = lines[6:]
    prev_turn_info = [-999, -999, -999, -999]  # dummy for first turn
    for turn_i, line in enumerate(plays):
        turn_info = list(map(custom_eval, line.split('\t')))   # e.g. [3, '打牌', ['F2'], '\n']
        turn_player = turn_info[0]
        action = turn_info[1]
        turn_data = {'turn_id': turn_i, 'turn_player': turn_player}

        if action == '补花':
            continue

        if action == '和牌':
            continue

        if action == '补花后摸牌' or action == '杠后摸牌': # ['3', '杠后摸牌', ['W1'], '\n']
            players_latest_hands[turn_player].append(turn_info[2][0])

        if action == '打牌':
            discard = turn_info[2][0]
            players_discard_tiles[turn_player].append(discard)
            players_latest_hands[turn_player].remove(discard)

        if action == '摸牌' or action == '碰' or action == '明杠' or action == '暗杠':
            turn_data['steal'] = turn_info[2][0]
            turn_data['hand'] = players_latest_hands[turn_player].copy()  # all children are string, shallow copy is ok
            turn_data['discard'] = copy.deepcopy(players_discard_tiles)  # children are lists (mutable), deep copy is needed
            if prev_turn_info[0] != turn_player and prev_turn_info[1] == '打牌':  # preceeding player discarded tiles
                prev_player_discard = prev_turn_info[2][0]
                if detect_chowable(players_latest_hands[turn_player], prev_player_discard):
                    turn_data['label'] = 0  # Not chow
                    target_data.append([copy.deepcopy(turn_data)] + copy.deepcopy(players_history[turn_player][::-1]))  # data are dynamic (keep changing in loop), deepcopy for safe

            # Update player's hand tiles
            if action == '摸牌':  # ['3', '摸牌', ['W1'], '\n']
                players_latest_hands[turn_player].append(turn_info[2][0])
            elif action == '碰':  # ['3', '碰', ['W1','W1','W1'], 'W2', '2\n']
                for tile in turn_info[2][:-1]:
                    players_latest_hands[turn_player].remove(tile)
            elif action == '明杠' or action == '暗杠':  # ['3', '明杠', ['F2','F2','F2','F2'], 'F2', '2\n']
                for tile in turn_info[2][:-1]:
                    players_latest_hands[turn_player].remove(tile)

        if action == '吃':  # ['3', '吃', ['W1','W2','W3'], 'W2', '2\n']
            turn_data['steal'] = turn_info[3]
            turn_data['hand'] = players_latest_hands[turn_player].copy()  # all children are string, shallow copy is ok
            turn_data['discard'] = copy.deepcopy(players_discard_tiles)  # children are lists (mutable), deep copy is needed
            chow_permutation = sorted(turn_info[2])

            # chow at start
            if turn_info[3] == chow_permutation[0]:
                turn_data['label'] = 1

            # chow in middle
            if turn_info[3] == chow_permutation[1]:
                turn_data['label'] = 2

            # chow at end
            if turn_info[3] == chow_permutation[2]:
                turn_data['label'] = 3

            for tile in turn_info[2]:
                if tile != turn_info[3]:  # exclude steal tile
                    players_latest_hands[turn_player].remove(tile)

            target_data.append([copy.deepcopy(turn_data)] + copy.deepcopy(players_history[turn_player][::-1]))  # data are dynamic (keep changing in loop), deepcopy for safe

        # history only includes states in '摸牌'/'碰'/'杠'/'吃' situations, states = [own hand(1 dim) + discard(4 dim) + steal(1 dim)]
        if 'steal' in turn_data:
            players_history[turn_player].append(turn_data)
        if len(players_history[turn_player]) > history_len:  # keep the 4 most recent histories only
            players_history[turn_player].pop(0)  # remove the oldest
        prev_turn_info = turn_info
    
    return target_data

In [None]:
for file in tqdm(all_game_files):
    chowable_data = extract_target_data(file, history_len=4)
    with open('/content/drive/MyDrive/mjdata/chowable_data.txt', 'a') as f:
        for line in chowable_data:
            f.write(json.dumps(line)+'\n')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """Entry point for launching an IPython kernel.


HBox(children=(FloatProgress(value=0.0, max=530458.0), HTML(value='')))




In [None]:
# Split data (unbalanced random)
import random

val_ratio = 0.25
test_ratio = 0.25
train_ratio = 0.5

with open('/content/drive/MyDrive/mjdata/chowable_data.txt', 'r') as f:
    all_data_str = f.readlines()

random.shuffle(all_data_str)
train_size = int(train_ratio*len(all_data_str))
val_size = int(val_ratio*len(all_data_str))
train_data_str = all_data_str[:train_size]
val_data_str = all_data_str[train_size:train_size+val_size]
test_data_str = all_data_str[train_size+val_size:]

with open('/content/drive/MyDrive/mjdata/chowable_train.txt', 'w') as f:
    for line in train_data_str:
        f.write(line)

with open('/content/drive/MyDrive/mjdata/chowable_val.txt', 'w') as f:
    for line in val_data_str:
        f.write(line)

with open('/content/drive/MyDrive/mjdata/chowable_test.txt', 'w') as f:
    for line in test_data_str:
        f.write(line)

### Pong-able data

In [None]:
def detect_pongable(hand_tiles, discarded_tile):
    discard_type, discard_num = discarded_tile[0], int(discarded_tile[1])
    # search for "W" tiles if discarded is "W", and tiles with number within [discard_num-2, discard_num+2]
    candidates = [t for t in hand_tiles if t == discarded_tile]

    if len(candidates) >= 2:
        return True
    else:
        return False

def extract_target_data(file, history_len=4):
    '''
    Args:
    - history_len (int): max length of history of a single player to use

    Returns:
    - target_data (list): all chow-able data of all players
    '''

    def custom_eval(x):
        try:
            return eval(x)
        except (NameError, SyntaxError):
            return x

    with open(file, 'r') as f:
        lines = f.readlines()

    winds = ['F1', 'F2', 'F3', 'F4']  #风 东南西北

    target_data = []

    # Record of all discarded tiles in sequential order
    players_discard_tiles = {}

    # The histories for respective players
    players_history = {}

    # Latest hand tiles for respective players
    players_latest_hands = {}

    # Latest hand tiles for respective players
    players_open_melds = {}

    wind2id = {'东': 'F1', '南': 'F2', '西': 'F3', '北': 'F4'}
    round_wind = wind2id[lines[1][0]]
    discard = None

    # Distribute wind
    for line in lines[2:6]:
        player_num, hands, _ = list(map(custom_eval, line.split('\t')))
        players_latest_hands[player_num] = [tile for tile in hands if not tile.startswith('H')]
        players_history[player_num] = []
        players_discard_tiles[player_num] = []
        players_open_melds[player_num] = []
        if len(hands) == 14:
            dealer = player_num

    # Assign winds
    player_winds = winds[dealer:] + winds[:dealer]
    player_winds = {player: wind for player, wind in enumerate(player_winds)}

    # Play records
    plays = lines[6:]
    prev_turn_info = [-999, -999, -999, -999]  # dummy for first turn
    for turn_i, line in enumerate(plays):
        turn_info = list(map(custom_eval, line.split('\t')))   # e.g. [3, '打牌', ['F2'], '\n']
        turn_player = turn_info[0]
        action = turn_info[1]
        turn_data = {'turn_id': turn_i, 'turn_player': turn_player, 'round_wind': round_wind, 'player_wind': player_winds[turn_player]}

        if action == '补花':
            continue

        if action == '和牌':
            continue

        if action == '补花后摸牌' or action == '杠后摸牌': # ['3', '杠后摸牌', ['W1'], '\n']
            players_latest_hands[turn_player] = [tile for tile in players_latest_hands[turn_player] if not tile.startswith('H')]  # remove flowers for safe
            if not turn_info[2][0].startswith('H'):
                players_latest_hands[turn_player].append(turn_info[2][0])

        if action == '打牌':
            discard = turn_info[2][0]
            players_discard_tiles[turn_player].append(discard)
            players_latest_hands[turn_player].remove(discard)

        if action == '摸牌' or action == '明杠' or action == '暗杠' or action == '补杠':
            turn_data['steal'] = discard
            turn_data['hand'] = players_latest_hands[turn_player].copy()  # all children are string, shallow copy is ok
            turn_data['discard'] = copy.deepcopy(players_discard_tiles)  # children are lists (mutable), deep copy is needed
            turn_data['open_melds'] = copy.deepcopy(players_open_melds)  # children are lists (mutable), deep copy is needed
            if prev_turn_info[0] != turn_player and prev_turn_info[1] == '打牌':  # preceeding player discarded tiles
                prev_player_discard = prev_turn_info[2][0]
                if detect_pongable(players_latest_hands[turn_player], prev_player_discard):
                    turn_data['label'] = 0  # Not pong
                    target_data.append([copy.deepcopy(turn_data)] + copy.deepcopy(players_history[turn_player][::-1]))  # data are dynamic (keep changing in loop), deepcopy for safe

            # Update player's hand tiles
            if action == '摸牌' and not turn_info[2][0].startswith('H'):  # ['3', '摸牌', ['W1'], '\n']
                players_latest_hands[turn_player].append(turn_info[2][0])
                
            elif action == '明杠' or action == '暗杠' or action == '补杠':  # ['3', '明杠', ['F2','F2','F2','F2'], 'F2', '2\n']
                for tile in turn_info[2][:-1]:
                  if tile in players_latest_hands[turn_player]:
                    players_latest_hands[turn_player].remove(tile)

                if action == '明杠' or action == '暗杠':
                  players_open_melds[turn_player] += turn_info[2]
                elif action == '补杠':
                  players_open_melds[turn_player].append(turn_info[3])

        if action == '碰': # ['3', '碰', ['W1','W1','W1'], 'W2', '2\n']
            turn_data['steal'] = discard
            turn_data['hand'] = players_latest_hands[turn_player].copy()  # all children are string, shallow copy is ok
            turn_data['discard'] = copy.deepcopy(players_discard_tiles)  # children are lists (mutable), deep copy is needed
            turn_data['open_melds'] = copy.deepcopy(players_open_melds)  # children are lists (mutable), deep copy is needed
            turn_data['label'] = 1
            target_data.append([copy.deepcopy(turn_data)] + copy.deepcopy(players_history[turn_player][::-1]))  # data are dynamic (keep changing in loop), deepcopy for safe
            for tile in turn_info[2][:-1]:
                players_latest_hands[turn_player].remove(tile)
            players_open_melds[turn_player] += turn_info[2]

        if action == '吃':  # ['3', '吃', ['W1','W2','W3'], 'W2', '2\n']
            for tile in turn_info[2]:
                if tile != turn_info[3]:  # exclude steal tile
                    players_latest_hands[turn_player].remove(tile)
            players_open_melds[turn_player] += turn_info[2]

        # history only includes states in '摸牌'/'碰'/'杠'/'吃' situations, states = [own hand(1 dim) + discard(4 dim) + steal(1 dim)]
        if 'steal' in turn_data:
            players_history[turn_player].append(turn_data)
        if len(players_history[turn_player]) > history_len:  # keep the 4 most recent histories only
            players_history[turn_player].pop(0)  # remove the oldest
        prev_turn_info = turn_info
    
    return target_data

In [None]:
all_pong_data = []
save_dir = '/content/drive/MyDrive/mjdata/pong'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
for file in tqdm(all_game_files, desc="Extracting all pong-able data: "):
    pongable_data = extract_target_data(file)
    for data in pongable_data:
        data = json.dumps(data)
        all_pong_data.append(data)

val_ratio = 0.1
test_ratio = 0.05
random.shuffle(all_pong_data)
num_val = int(len(all_pong_data)*val_ratio)
num_test = int(len(all_pong_data)*test_ratio)
val_data = all_pong_data[:num_val]
test_data = all_pong_data[num_val:num_val+num_test]
train_data = all_pong_data[num_val+num_test:]

with open(os.path.join(save_dir, 'pongable_train.txt'), 'a') as f:
    for data in tqdm(train_data, desc="Saving training data: "):
        f.writelines(data+'\n')

with open(os.path.join(save_dir, 'pongable_val.txt'), 'a') as f:
    for data in tqdm(val_data, desc="Saving validation data: "):
        f.writelines(data+'\n')

with open(os.path.join(save_dir, 'pongable_test.txt'), 'a') as f:
    for data in tqdm(test_data, desc="Saving test data: "):
        f.writelines(data+'\n')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


HBox(children=(FloatProgress(value=0.0, description='Extracting all pong-able data: ', max=530458.0, style=Pro…




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Saving training data: ', max=1261120.0, style=ProgressSty…




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Saving validation data: ', max=148367.0, style=ProgressSt…




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Saving test data: ', max=74183.0, style=ProgressStyle(des…




### Kong data

In [None]:
def detect_kongable(hand_tiles, discarded_tile):
    discard_type, discard_num = discarded_tile[0], int(discarded_tile[1])
    # search for "W" tiles if discarded is "W", and tiles with number within [discard_num-2, discard_num+2]
    candidates = [t for t in hand_tiles if t == discarded_tile]

    if len(candidates) == 3:
        return True
    else:
        return False

def extract_target_data(file, history_len=4):
    '''
    Args:
    - history_len (int): max length of history of a single player to use

    Returns:
    - target_data (list): all chow-able data of all players
    '''

    def custom_eval(x):
        try:
            return eval(x)
        except (NameError, SyntaxError):
            return x

    with open(file, 'r') as f:
        lines = f.readlines()

    winds = ['F1', 'F2', 'F3', 'F4']  #风 东南西北

    target_data = []

    # Record of all discarded tiles in sequential order
    players_discard_tiles = {}

    # The histories for respective players
    players_history = {}

    # Latest hand tiles for respective players
    players_latest_hands = {}

    # Latest hand tiles for respective players
    players_open_melds = {}

    wind2id = {'东': 'F1', '南': 'F2', '西': 'F3', '北': 'F4'}
    round_wind = wind2id[lines[1][0]]
    discard = None

    # Distribute wind
    for line in lines[2:6]:
        player_num, hands, _ = list(map(custom_eval, line.split('\t')))
        players_latest_hands[player_num] = [tile for tile in hands if not tile.startswith('H')]
        players_history[player_num] = []
        players_discard_tiles[player_num] = []
        players_open_melds[player_num] = []
        if len(hands) == 14:
            dealer = player_num

    # Assign winds
    player_winds = winds[dealer:] + winds[:dealer]
    player_winds = {player: wind for player, wind in enumerate(player_winds)}

    # Play records
    plays = lines[6:]
    prev_turn_info = [-999, -999, -999, -999]  # dummy for first turn
    for turn_i, line in enumerate(plays):
        # ##############################################
        # for player_h in players_latest_hands.values():
        #     if len(player_h) > 14:
        #         print(file)
        # ##############################################
        turn_info = list(map(custom_eval, line.split('\t')))   # e.g. [3, '打牌', ['F2'], '\n']
        turn_player = turn_info[0]
        action = turn_info[1]
        turn_data = {'turn_id': turn_i, 'turn_player': turn_player, 'round_wind': round_wind, 'player_wind': player_winds[turn_player]}

        if action == '补花':
            continue

        if action == '和牌':
            continue

        if action == '补花后摸牌' or action == '杠后摸牌': # ['3', '杠后摸牌', ['W1'], '\n']
            players_latest_hands[turn_player] = [tile for tile in players_latest_hands[turn_player] if not tile.startswith('H')]  # remove flowers for safe, may have flowers since start of the game
            if not turn_info[2][0].startswith('H'):
                players_latest_hands[turn_player].append(turn_info[2][0])

        if action == '打牌':
            discard = turn_info[2][0]
            players_discard_tiles[turn_player].append(discard)
            players_latest_hands[turn_player].remove(discard)

        if action == '摸牌' or action == '碰':
            turn_data['steal'] = discard
            turn_data['hand'] = copy.deepcopy(players_latest_hands[turn_player]) # all children are string, shallow copy is ok
            turn_data['discard'] = copy.deepcopy(players_discard_tiles)  # children are lists (mutable), deep copy is needed
            turn_data['open_melds'] = copy.deepcopy(players_open_melds)  # children are lists (mutable), deep copy is needed
            if prev_turn_info[0] != turn_player and prev_turn_info[1] == '打牌':  # preceeding player discarded tiles
                prev_player_discard = prev_turn_info[2][0]
                if detect_kongable(players_latest_hands[turn_player], prev_player_discard):
                    turn_data['label'] = 0  # Not kong
                    target_data.append([copy.deepcopy(turn_data)] + copy.deepcopy(players_history[turn_player][::-1]))  # data are dynamic (keep changing in loop), deepcopy for safe

            # Update player's hand tiles
            if action == '摸牌' and not turn_info[2][0].startswith('H'):  # ['3', '摸牌', ['W1'], '\n']
                players_latest_hands[turn_player].append(turn_info[2][0])

            elif action == '碰': # ['3', '碰', ['W1','W1','W1'], 'W2', '2\n']
                for tile in turn_info[2][:-1]:
                    players_latest_hands[turn_player].remove(tile)
                players_open_melds[turn_player] += turn_info[2]

        if action == '明杠' or action == '暗杠' or action == '补杠': # ['3', '明杠', ['F2','F2','F2','F2'], 'F2', '2\n']
            turn_data['steal'] = discard
            turn_data['hand'] = players_latest_hands[turn_player].copy()  # all children are string, shallow copy is ok
            turn_data['discard'] = copy.deepcopy(players_discard_tiles)  # children are lists (mutable), deep copy is needed
            turn_data['open_melds'] = copy.deepcopy(players_open_melds)  # children are lists (mutable), deep copy is needed
            turn_data['label'] = 1
            target_data.append([copy.deepcopy(turn_data)] + copy.deepcopy(players_history[turn_player][::-1]))  # data are dynamic (keep changing in loop), deepcopy for safe
            for tile in turn_info[2][:-1]:
              if tile in players_latest_hands[turn_player]:
                players_latest_hands[turn_player].remove(tile)
            if action == '明杠' or action == '暗杠':
              players_open_melds[turn_player] += turn_info[2]
            elif action == '补杠':
              players_open_melds[turn_player].append(turn_info[3])

        if action == '吃':  # ['3', '吃', ['W1','W2','W3'], 'W2', '2\n']
            for tile in turn_info[2]:
                if tile != turn_info[3]:  # exclude steal tile
                    players_latest_hands[turn_player].remove(tile)
            players_open_melds[turn_player] += turn_info[2]

        # history only includes states in '摸牌'/'碰'/'杠'/'吃' situations, states = [own hand(1 dim) + discard(4 dim) + steal(1 dim)]
        if 'steal' in turn_data:
            players_history[turn_player].append(turn_data)
        if len(players_history[turn_player]) > history_len:  # keep the 4 most recent histories only
            players_history[turn_player].pop(0)  # remove the oldest
        prev_turn_info = turn_info
    
    return target_data

In [None]:
all_kong_data = []
save_dir = '/content/drive/MyDrive/mjdata/kong'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
for file in tqdm(all_game_files, desc="Extracting all kong-able data: "):
    kongable_data = extract_target_data(file)
    for data in kongable_data:
        data = json.dumps(data)
        all_kong_data.append(data)

val_ratio = 0.1
test_ratio = 0.1
random.shuffle(all_kong_data)
num_val = int(len(all_kong_data)*val_ratio)
num_test = int(len(all_kong_data)*test_ratio)
val_data = all_kong_data[:num_val]
test_data = all_kong_data[num_val:num_val+num_test]
train_data = all_kong_data[num_val+num_test:]

with open(os.path.join(save_dir, 'kongable_train.txt'), 'a') as f:
    for data in tqdm(train_data, desc="Saving training data: "):
        f.writelines(data+'\n')

with open(os.path.join(save_dir, 'kongable_val.txt'), 'a') as f:
    for data in tqdm(val_data, desc="Saving validation data: "):
        f.writelines(data+'\n')

with open(os.path.join(save_dir, 'kongable_test.txt'), 'a') as f:
    for data in tqdm(test_data, desc="Saving test data: "):
        f.writelines(data+'\n')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


HBox(children=(FloatProgress(value=0.0, description='Extracting all kong-able data: ', max=530458.0, style=Pro…




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Saving training data: ', max=146097.0, style=ProgressStyl…




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Saving validation data: ', max=18261.0, style=ProgressSty…




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Saving test data: ', max=18261.0, style=ProgressStyle(des…




### Discard data

In [None]:
def extract_target_data(file, history_len=4):
    '''
    Args:
    - history_len (int): max length of history of a single player to use

    Returns:
    - target_data (list): all chow-able data of all players
    '''

    def custom_eval(x):
        try:
            return eval(x)
        except (NameError, SyntaxError):
            return x

    with open(file, 'r') as f:
        lines = f.readlines()

    winds = ['F1', 'F2', 'F3', 'F4']  #风 东南西北

    target_data = []

    # Record of all discarded tiles in sequential order
    players_discard_tiles = {}

    # The histories for respective players
    players_history = {}

    # Latest hand tiles for respective players
    players_latest_hands = {}

    # Latest hand tiles for respective players
    players_open_melds = {}

    wind2id = {'东': 'F1', '南': 'F2', '西': 'F3', '北': 'F4'}
    round_wind = wind2id[lines[1][0]]
    steal = None

    # Distribute wind
    for line in lines[2:6]:
        player_num, hands, _ = list(map(custom_eval, line.split('\t')))
        players_latest_hands[player_num] = [tile for tile in hands if not tile.startswith('H')]
        players_history[player_num] = []
        players_discard_tiles[player_num] = []
        players_open_melds[player_num] = []
        if len(hands) == 14:
            dealer = player_num

    # Assign winds
    player_winds = winds[dealer:] + winds[:dealer]
    player_winds = {player: wind for player, wind in enumerate(player_winds)}

    # Play records
    plays = lines[6:]
    prev_turn_info = [-999, -999, -999, -999]  # dummy for first turn
    for turn_i, line in enumerate(plays):
        turn_info = list(map(custom_eval, line.split('\t')))   # e.g. [3, '打牌', ['F2'], '\n']
        turn_player = turn_info[0]
        action = turn_info[1]
        turn_data = {'turn_id': turn_i, 'turn_player': turn_player, 'round_wind': round_wind, 'player_wind': player_winds[turn_player]}

        if action == '补花':
            continue

        if action == '和牌':
            continue

        if action == '补花后摸牌' or action == '杠后摸牌': # ['3', '杠后摸牌', ['W1'], '\n']
            players_latest_hands[turn_player] = [tile for tile in players_latest_hands[turn_player] if not tile.startswith('H')]  # remove flowers for safe
            if not turn_info[2][0].startswith('H'):
                steal = turn_info[2][0]
                players_latest_hands[turn_player].append(turn_info[2][0])

        if action == '打牌':
            discard = turn_info[2][0]
            turn_data['steal'] = steal
            turn_data['hand'] = players_latest_hands[turn_player].copy()  # all children are string, shallow copy is ok
            turn_data['discard'] = copy.deepcopy(players_discard_tiles)  # children are lists (mutable), deep copy is needed
            turn_data['open_melds'] = copy.deepcopy(players_open_melds)  # children are lists (mutable), deep copy is needed
            turn_data['label'] = discard
            target_data.append([copy.deepcopy(turn_data)] + copy.deepcopy(players_history[turn_player][::-1]))  # data are dynamic (keep changing in loop), deepcopy for safe
            # Updates
            players_discard_tiles[turn_player].append(discard)
            players_latest_hands[turn_player].remove(discard)

        if action == '摸牌':  # ['3', '摸牌', ['W1'], '\n']
            # Update player's hand tiles
            if not turn_info[2][0].startswith('H'):
                steal = turn_info[2][0]
                players_latest_hands[turn_player].append(turn_info[2][0])

        if action == '碰' or action == '明杠' or action == '暗杠' or action == '补杠': # ['3', '碰', ['W1','W1','W1'], 'W2', '2\n']
            steal = turn_info[3]
            # Update player's hand tiles
            for tile in turn_info[2][:-1]:
              if tile in players_latest_hands[turn_player]:
                players_latest_hands[turn_player].remove(tile)
            if action == '明杠' or action == '暗杠':
              players_open_melds[turn_player] += turn_info[2]
            elif action == '补杠':
              players_open_melds[turn_player].append(turn_info[3])

        if action == '吃':  # ['3', '吃', ['W1','W2','W3'], 'W2', '2\n']
            steal = turn_info[3]
            for tile in turn_info[2]:
                if tile != turn_info[3]:  # exclude steal tile
                    players_latest_hands[turn_player].remove(tile)
            players_open_melds[turn_player] += turn_info[2]

        # history only includes states in '摸牌'/'碰'/'杠'/'吃' situations, states = [own hand(1 dim) + discard(4 dim) + steal(1 dim)]
        if 'steal' in turn_data:
            players_history[turn_player].append(turn_data)
        if len(players_history[turn_player]) > history_len:  # keep the 4 most recent histories only
            players_history[turn_player].pop(0)  # remove the oldest
        prev_turn_info = turn_info
    
    return target_data

In [None]:
random.shuffle(all_game_files)
all_game_files = all_game_files[:200000]

all_discard_data = []
save_dir = '/content/drive/MyDrive/mjdata/discard'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
for file in tqdm(all_game_files, desc="Extracting all discard data: "):
    discard_data = extract_target_data(file)
    for i, data in enumerate(discard_data):
        data = json.dumps(data)
        if i not in [len(discard_data)-1, len(discard_data)-2]:
          with open(os.path.join(save_dir, 'discard_train.txt'), 'a') as f:
                f.writelines(data+'\n')
        elif i == len(discard_data)-1:
          with open(os.path.join(save_dir, 'discard_val.txt'), 'a') as f:
                f.writelines(data+'\n')

        elif i == len(discard_data)-2:
          with open(os.path.join(save_dir, 'discard_test.txt'), 'a') as f:
                f.writelines(data+'\n')

# val_ratio = 0.1
# test_ratio = 0.1
# random.shuffle(all_discard_data)
# num_val = int(len(all_discard_data)*val_ratio)
# num_test = int(len(all_discard_data)*test_ratio)
# val_data = all_discard_data[:num_val]
# test_data = all_discard_data[num_val:num_val+num_test]
# train_data = all_discard_data[num_val+num_test:]

# with open(os.path.join(save_dir, 'discard_train.txt'), 'a') as f:
#     for data in tqdm(train_data, desc="Saving training data: "):
#         f.writelines(data+'\n')

# with open(os.path.join(save_dir, 'discard_val.txt'), 'a') as f:
#     for data in tqdm(val_data, desc="Saving validation data: "):
#         f.writelines(data+'\n')

# with open(os.path.join(save_dir, 'discard_test.txt'), 'a') as f:
#     for data in tqdm(test_data, desc="Saving test data: "):
#         f.writelines(data+'\n')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, description='Extracting all discard data: ', max=200000.0, style=Progr…

## Data Checking

In [None]:
with open('/content/drive/MyDrive/mjdata/pong/pongable_train.txt', 'r') as f:
    data = f.readlines()

wrong_count = 0
pos_count = 0
neg_count = 0
cls_ratio = {}
for i, line in enumerate(tqdm(data)):
    turn_data = json.loads(line)
    if 'label' not in turn_data[0]:
        wrong_count += 1
    elif turn_data[0]['label'] == 0:
        neg_count += 1
        cls_ratio[0] = cls_ratio.get(0, 0) + 1
    elif turn_data[0]['label'] == 1:
        pos_count += 1
        cls_ratio[1] = cls_ratio.get(1, 0) + 1
    elif turn_data[0]['label'] == 2:
        pos_count += 1
        cls_ratio[2] = cls_ratio.get(2, 0) + 1
    elif turn_data[0]['label'] == 3:
        pos_count += 1
        cls_ratio[3] = cls_ratio.get(3, 0) + 1

print('(All) Number of wrongs: ', wrong_count)
print(f'Number of positives: {pos_count} ({(pos_count/len(data)):.2f}%)')
print(f'Number of negatives: {neg_count} ({(neg_count/len(data)):.2f}%)')
print(cls_ratios)

In [None]:
with open('/content/drive/MyDrive/mjdata/kong/kongable_train.txt', 'r') as f:
    data = f.readlines()

hand_wrongs = 0
discard_wrongs = 0
open_wrongs = 0
player_wind_wrongs = 0
round_wind_wrongs = 0
steal_wrongs = 0

for line in data:
    line = json.loads(line)
    for turn_data in line:
        discard = [i for d in turn_data['discard'].values() for i in d]
        if len([t for t in discard if t.startswith("H")]) != 0:
            # print("Flowers in discard")
            discard_wrongs += 1
            break
            
        if len([t for t in turn_data['hand'] if t.startswith("H")]) != 0:
            # print(turn_data['hand'])
            # print("Flowers in hand")
            hand_wrongs += 1
            break

        open_melds = [i for d in turn_data['open_melds'].values() for i in d]
        if len([t for t in open_melds if t.startswith("H")]) != 0:
            # print("Flowers in open melds")
            open_wrongs += 1
            break

        if turn_data['player_wind'].startswith("H"):
            # print("Flowers in player wind")
            player_wind_wrongs += 1
            break

        if turn_data['round_wind'].startswith("H"):
            # print("Flowers in round wind")
            round_wind_wrongs += 1
            break
        if turn_data['steal'] is not None:
            if turn_data['steal'].startswith("H"):
                # print("Flowers in steal")
                steal_wrongs += 1
                break

print(f'hand_wrongs: {hand_wrongs}')
print(f'discard_wrongs: {discard_wrongs}')
print(f'open_wrongs: {open_wrongs}')
print(f'player_wind_wrongs: {player_wind_wrongs}')
print(f'round_wind_wrongs: {round_wind_wrongs}')
print(f'steal_wrongs: {steal_wrongs}')

hand_wrongs: 0
discard_wrongs: 0
open_wrongs: 0
player_wind_wrongs: 0
round_wind_wrongs: 0
steal_wrongs: 0


In [None]:
with open('/content/drive/MyDrive/mjdata/kong/kongable_train.txt', 'r') as f:
    data = f.readlines()

hand_wrongs = 0

for line in data:
    line = json.loads(line)
    for turn_data in line:   
        counter = Counter(turn_data['hand'])
        for tile in counter:
            if counter[tile] > 4:
                print(turn_data['hand'])
                hand_wrongs += 1

print(f'hand_wrongs: {hand_wrongs}')

['B6', 'W2', 'B6', 'B9', 'T4', 'T1', 'W7', 'T2', 'B3', 'B6', 'B6', 'W2', 'F3', 'B6', 'B9', 'T4', 'T9', 'T1', 'W7', 'T2', 'B3', 'B6', 'W5', 'T2', 'J3', 'B4']
['B6', 'W2', 'B6', 'B9', 'T4', 'T9', 'T1', 'W7', 'T2', 'B3', 'B6', 'B6', 'W2', 'F3', 'B6', 'B9', 'T4', 'T9', 'T1', 'W7', 'T2', 'B3', 'B6', 'W5', 'T2', 'J3']
['B6', 'W2', 'B6', 'B9', 'T4', 'T9', 'T1', 'W7', 'T2', 'B3', 'B6', 'B6', 'W2', 'F3', 'B6', 'B9', 'T4', 'T9', 'T1', 'W7', 'J3', 'T2', 'B3', 'B6', 'W5', 'T2']
['B6', 'W2', 'F3', 'B6', 'B9', 'T4', 'T9', 'T1', 'W7', 'T2', 'B3', 'B6', 'B6', 'W2', 'F3', 'B6', 'B9', 'T4', 'T9', 'T1', 'W7', 'J3', 'T2', 'B3', 'B6', 'W5']
['B6', 'W2', 'F3', 'B6', 'B9', 'T4', 'T9', 'T1', 'W7', 'J3', 'T2', 'B3', 'B6', 'B6', 'W2', 'F3', 'B6', 'B9', 'T4', 'T9', 'T1', 'W7', 'J3', 'T2', 'B3', 'B6']
hand_wrongs: 5


In [None]:
import matplotlib.pyplot as plt


with open('/content/drive/MyDrive/mjdata/pong/pongable_train.txt', 'r') as f:
    data = f.readlines()

hand_wrongs = []

for line in data:
    line = json.loads(line)
    for turn_data in line:
        # if len(turn_data['hand']) > 14:
        hand_wrongs.append(len(turn_data['hand']))
        break

# plt.hist(hand_wrongs)
Counter(hand_wrongs)

Counter({4: 12534,
         5: 1463,
         6: 84,
         7: 108817,
         8: 7703,
         9: 237,
         10: 331665,
         11: 8251,
         12: 114,
         13: 790250,
         14: 1,
         26: 1})

In [None]:
'''
own wind: (1, 34, 1)
round wind: (1, 34, 1)
own hand: (4, 34, 1)  # [w1, w1, w2, w2, ...]
steal tile: (1, 34, 1)  # w1
own all discards: (4, 34, 1)
own+1 all discards: (4, 34, 1)
own+2 all discards: (4, 34, 1)
own+3 all discards: (4, 34, 1)
own open melds: (4, 34, 1)
own+1 open melds: (4, 34, 1)
own+2 open melds: (4, 34, 1)
own+3 open melds: (4, 34, 1)
'''

''
# own last discard: (1, 34, 1)
# own+1 last discard: (1, 34, 1)
# own+2 last discard: (1, 34, 1)
# own+3 last discard: (1, 34, 1)'
''
line = json.loads(data[7])
line

[{'discard': {'0': ['T1', 'T7', 'F1', 'W5', 'W6', 'T9', 'W7', 'T2', 'W6'],
   '1': ['F3', 'J2', 'F2', 'W7', 'W4', 'T8', 'W9', 'F4'],
   '2': ['F1', 'W8', 'F3', 'F2', 'B9', 'B8', 'T7', 'J1', 'B1'],
   '3': ['F4', 'W3', 'W4', 'W1', 'W9', 'W8', 'W6', 'T5', 'T4']},
  'hand': ['B1', 'J3', 'B1', 'B6', 'B8', 'B6', 'J3', 'B4', 'H7', 'B7', 'B2'],
  'label': 1,
  'open_melds': {'0': ['F4', 'F4', 'F4'],
   '1': [],
   '2': ['W2', 'W3', 'W4', 'T9', 'T9', 'T9'],
   '3': ['T8', 'T8', 'T8']},
  'player_wind': 'F2',
  'round_wind': 'F1',
  'steal': 'B1',
  'turn_id': 73,
  'turn_player': 3},
 {'discard': {'0': ['T1', 'T7', 'F1', 'W5', 'W6', 'T9', 'W7', 'T2'],
   '1': ['F3', 'J2', 'F2', 'W7', 'W4', 'T8', 'W9'],
   '2': ['F1', 'W8', 'F3', 'F2', 'B9', 'B8', 'T7', 'J1'],
   '3': ['F4', 'W3', 'W4', 'W1', 'W9', 'W8', 'W6', 'T5']},
  'hand': ['B1', 'J3', 'B1', 'B6', 'B8', 'B6', 'J3', 'T4', 'B4', 'H7', 'B7'],
  'open_melds': {'0': ['F4', 'F4', 'F4'],
   '1': [],
   '2': ['W2', 'W3', 'W4', 'T9', 'T9', 'T9'],
 

### Data cleaning & overwriting

In [None]:
def clean_write_data(filepath):
    with open(filepath, 'r') as f:
        data = f.readlines()

    valid_data = []

    for line_str in tqdm(data):
        line = json.loads(line_str)
        wrongs = 0
        for turn_data in line:
            if len(turn_data['hand']) > 14:
                wrongs += 1
            for cnt in Counter(turn_data['hand']).values():
                if cnt > 4:
                    wrongs += 1
                    break
        if wrongs == 0:
            valid_data.append(line_str)

    hand_stats = []

    for line in tqdm(valid_data):
        line = json.loads(line)
        for turn_data in line:
            hand_stats.append(len(turn_data['hand']))

    hand_stats = Counter(hand_stats)

    print(hand_stats)

    with open(filepath, 'w') as f:
        for new_data in tqdm(valid_data, desc="Overwriting data: "):
            f.writelines(new_data)

In [None]:
# clean_write_data('/content/drive/MyDrive/mjdata/pong/pongable_train.txt')
# clean_write_data('/content/drive/MyDrive/mjdata/pong/pongable_val.txt')
# clean_write_data('/content/drive/MyDrive/mjdata/pong/pongable_test.txt')

# clean_write_data('/content/drive/MyDrive/mjdata/kong/kongable_train.txt')
# clean_write_data('/content/drive/MyDrive/mjdata/kong/kongable_val.txt')
# clean_write_data('/content/drive/MyDrive/mjdata/kong/kongable_test.txt')

clean_write_data('/content/drive/MyDrive/mjdata/discard/discard_train.txt')
clean_write_data('/content/drive/MyDrive/mjdata/discard/discard_val.txt')
clean_write_data('/content/drive/MyDrive/mjdata/discard/discard_test.txt')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


HBox(children=(FloatProgress(value=0.0, max=915159.0), HTML(value='')))




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=915159.0), HTML(value='')))


Counter({14: 2332060, 11: 999248, 8: 378128, 5: 90054, 2: 8192, 12: 7454, 9: 6341, 6: 2580, 3: 515, 13: 305, 10: 156, 7: 134, 4: 1})


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Overwriting data: ', max=915159.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, max=18788.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=18788.0), HTML(value='')))


Counter({14: 36815, 11: 31793, 8: 17968, 5: 5539, 2: 721, 9: 338, 12: 287, 6: 195, 3: 76, 7: 8, 10: 7, 13: 3, 4: 2})


HBox(children=(FloatProgress(value=0.0, description='Overwriting data: ', max=18788.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, max=18788.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=18788.0), HTML(value='')))


Counter({14: 36274, 11: 31306, 8: 18640, 5: 5916, 2: 640, 9: 370, 12: 303, 6: 211, 3: 35, 7: 7, 13: 2, 10: 2})


HBox(children=(FloatProgress(value=0.0, description='Overwriting data: ', max=18788.0, style=ProgressStyle(des…




## Data Preprocessing

In [4]:
# Old codes (for chow)
# class ICJAIDataset(Dataset):
#     def __init__(self, data_path, history_len, data_ratio=1.):
#         with open(data_path, 'r') as f:
#             self.all_data_str = f.readlines()
#         if data_ratio != 1.:
#             random.shuffle(self.all_data_str)
#             data_size = int(len(self.all_data_str)*data_ratio)
#             self.all_data_str = self.all_data_str[:data_size]
#         self.hist_len = history_len
#         self._init_mj2id()
#         self._init_cls_ratios()

#     def __getitem__(self, idx):
#         data = json.loads(self.all_data_str[idx])
#         x, y = self.preprocess(data)
#         return x, y

#     def __len__(self):
#         return len(self.all_data_str)

#     def _init_mj2id(self):
#         w_dict = {'W' + str(i+1): i for i in range(9)} #万
#         b_dict = {'B' + str(i+1): i+9 for i in range(9)} #饼
#         t_dict = {'T' + str(i+1): i+18 for i in range(9)} #条
#         f_dict = {'F' + str(i+1): i+27 for i in range(4)} #风 东南西北
#         j_dict = {'J' + str(i+1): i+31 for i in range(3)} #（剑牌）中发白
#         h_dict = {'H' + str(i+1): i+34 for i in range(8)} #梅兰竹菊
#         self.mj2id = {**w_dict, **b_dict,**t_dict,**f_dict,**j_dict,**h_dict}

#     def _init_cls_ratios(self):
#         cls_ratios = {}
#         for i, line in enumerate(tqdm(self.all_data_str, desc='Calculating class ratios: ')):
#             turn_data = json.loads(line)
#             label = turn_data[0]['label']
#             cls_ratios[label] = cls_ratios.get(label, 0) + 1
#         # cls_ratios = {k: v/len(self.all_data_str) for k, v in cls_ratios.items()}
#         self.cls_ratios = cls_ratios

#     def tiles2mat(self, mj_list):
#         '''
#         Args:
#         - mj_list (list): list of mahjongs (e.g. ['B1', 'B3', 'B9', 'T1', 'T4'])

#         Returns:
#         - repr (torch.tensor, float32): shape (4, 42) 
#         '''
#         repr = torch.zeros(4, 34, dtype=torch.float32)
#         count = Counter(mj_list)
#         for i in count:
#             index = self.mj2id[i]
#             nums = count[i]
#             for j in range(nums):
#                 repr[j, index] = 1
#         return repr

#     def preprocess(self, data):
#         '''
#         Args:
#         - data (list): include 5 data from latest to oldest (current state, 4 history states)

#         Returns:
#         - x (torch.tensor, float32): shape [(self.hist_len+1)*4*4, 42, 1]  (C, H, W)
#         - y (torch.tensor, int64): shape [1]
#         '''
#         player_order = [0,1,2,3]
#         x = torch.zeros((self.hist_len+1)*4*4, 34, dtype=torch.float32)
#         y = torch.tensor(data[0]['label'], dtype=torch.int64)
#         for hist_i, hist_data in enumerate(data):
#             player = str(hist_data['turn_player'])
#             cur_order = player_order[player:] + player_order[:player]
#             own_hand = self.tiles2mat(hist_data['hand'])
#             steal = self.tiles2mat([hist_data['steal']])
#             own_discard = self.tiles2mat(hist_data['discard'][player])
#             own_1_discard = self.tiles2mat(hist_data['discard'][cur_order[1]])
#             own_2_discard = self.tiles2mat(hist_data['discard'][cur_order[2]])
#             own_3_discard = self.tiles2mat(hist_data['discard'][cur_order[3]])
#             # for player_id, discard in hist_data['discard'].items():
#             #     if player_id != player:
#             #         others_discard += discard
#             # others_discard = self.tiles2mat(others_discard)
#             # # x[hist_i*4, :, :] = own_hand
#             # # x[(hist_i*4)+1, :, :] = steal
#             # # x[(hist_i*4)+2, :, :] = own_discard
#             # # x[(hist_i*4)+3, :, :] = others_discard
#             hist_x = torch.cat([own_hand, steal, own_discard, own_1_discard, own_2_discard, own_3_discard], dim=0)
#             x[hist_i*hist_x.shape[0]:(hist_i+1)*hist_x.shape[0], :] = hist_x
#         return x.unsqueeze(-1), y

In [5]:
class ICJAIDataset(Dataset):
    def __init__(self, data_path, history_len, data_ratio=1.):
        with open(data_path, 'r') as f:
            self.all_data_str = f.readlines()
        if data_ratio != 1.:
            random.shuffle(self.all_data_str)
            data_size = int(len(self.all_data_str)*data_ratio)
            self.all_data_str = self.all_data_str[:data_size]
        self.hist_len = history_len
        self._init_mj2id()
        self._init_cls_ratios()

    def __getitem__(self, idx):
        data = json.loads(self.all_data_str[idx])
        x, y = self.preprocess(data)
        return x, y

    def __len__(self):
        return len(self.all_data_str)

    def _init_mj2id(self):
        w_dict = {'W' + str(i+1): i for i in range(9)} #万
        b_dict = {'B' + str(i+1): i+9 for i in range(9)} #饼
        t_dict = {'T' + str(i+1): i+18 for i in range(9)} #条
        f_dict = {'F' + str(i+1): i+27 for i in range(4)} #风 东南西北
        j_dict = {'J' + str(i+1): i+31 for i in range(3)} #（剑牌）中发白
        h_dict = {'H' + str(i+1): i+34 for i in range(8)} #梅兰竹菊
        self.mj2id = {**w_dict, **b_dict,**t_dict,**f_dict,**j_dict,**h_dict}

    def _init_cls_ratios(self):
        cls_ratios = {}
        for i, line in enumerate(tqdm(self.all_data_str, desc='Calculating class ratios: ')):
            turn_data = json.loads(line)
            label = turn_data[0]['label']
            cls_ratios[label] = cls_ratios.get(label, 0) + 1
        # cls_ratios = {k: v/len(self.all_data_str) for k, v in cls_ratios.items()}
        self.cls_ratios = cls_ratios

    def tiles2mat(self, mj_list):
        '''
        Args:
        - mj_list (list): list of mahjongs (e.g. ['B1', 'B3', 'B9', 'T1', 'T4'])

        Returns:
        - repr (torch.tensor, float32): shape (4, 34) 
        '''
        repr = torch.zeros(4, 34, dtype=torch.float32)
        count = Counter(mj_list)
        for i in count:
            index = self.mj2id[i]
            nums = count[i]
            for j in range(nums):
                repr[j, index] = 1
        return repr

    def tiles2vec(self, mj_list):
        '''
        Args:
        - mj_list (list): list of mahjongs (e.g. ['B1', 'B3', 'B9', 'T1', 'T4'])

        Returns:
        - repr (torch.tensor, float32): shape (1, 34) 
        '''
        repr = torch.zeros(1, 34, dtype=torch.float32)
        count = Counter(mj_list)
        for i in count:
            index = self.mj2id[i]
            nums = count[i]
            for j in range(nums):
                repr[j, index] = 1
        return repr

    def preprocess(self, data):
        '''
        Args:
        - data (list): include 5 data from latest to oldest (current state, 4 history states)

        Returns:
        - x (torch.tensor, float32): shape [(self.hist_len+1)*4*4, 42, 1]  (C, H, W)
        - y (torch.tensor, int64): shape [1]
        '''
        player_order = ['0','1','2','3']
        x = torch.zeros((self.hist_len+1)*39, 34, dtype=torch.float32)
        label = self.mj2id[data[0]['label']]
        y = torch.tensor(label, dtype=torch.int64)
        for hist_i, hist_data in enumerate(data):
            player = str(hist_data['turn_player'])
            cur_order = player_order[int(player):] + player_order[:int(player)]

            # Accessing features
            own_wind = self.tiles2vec([hist_data['player_wind']])   # [1, 34, 1]
            round_wind = self.tiles2vec([hist_data['round_wind']])   # [1, 34, 1]
            own_hand = self.tiles2mat(hist_data['hand'])   # [4, 34, 1]
            steal = self.tiles2vec([hist_data['steal']]) if hist_data['steal'] is not None else torch.zeros(1, 34, dtype=torch.float32) # [1, 34, 1]
            own_discard = self.tiles2mat(hist_data['discard'][player])   # [4, 34, 1]
            own_1_discard = self.tiles2mat(hist_data['discard'][cur_order[1]])   # [4, 34, 1]
            own_2_discard = self.tiles2mat(hist_data['discard'][cur_order[2]])   # [4, 34, 1]
            own_3_discard = self.tiles2mat(hist_data['discard'][cur_order[3]])   # [4, 34, 1]
            own_open_melds = self.tiles2mat(hist_data['open_melds'][player])   # [4, 34, 1]
            own_1_open_melds = self.tiles2mat(hist_data['open_melds'][cur_order[1]])   # [4, 34, 1]
            own_2_open_melds = self.tiles2mat(hist_data['open_melds'][cur_order[2]])   # [4, 34, 1]
            own_3_open_melds = self.tiles2mat(hist_data['open_melds'][cur_order[3]])   # [4, 34, 1]

            hist_x = torch.cat([
                    own_wind, 
                    round_wind, 
                    own_hand, 
                    steal, 
                    own_discard, 
                    own_1_discard, 
                    own_2_discard, 
                    own_3_discard, 
                    own_open_melds, 
                    own_1_open_melds, 
                    own_2_open_melds, 
                    own_3_open_melds], 
            dim=0)
            
            x[hist_i*hist_x.shape[0]:(hist_i+1)*hist_x.shape[0], :] = hist_x
        return x.unsqueeze(-1), y

## Model

In [6]:
class SamePadConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)

conv3x1 = partial(SamePadConv2d, kernel_size=(3,1))

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.layer1 = self.make_layer(in_channels)
        self.layer2 = self.make_layer(in_channels)

    def make_layer(self, in_channels, dropout_prob=0.5):
        layer = nn.Sequential(
            conv3x1(in_channels, in_channels),
            nn.BatchNorm2d(256),
            nn.Dropout2d(dropout_prob),
            nn.LeakyReLU()
        )
        return layer

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out += x
        return out

class MJModel(nn.Module):
    def __init__(self, history_len, n_cls=4, n_residuals=50):
        super().__init__()
        self.net = self.create_model((history_len+1)*39, n_residuals, n_cls)

    def forward(self, x):
        return self.net(x)

    def create_model(self, in_channels, n_residuals, n_cls):
        # First layer
        module_list = nn.ModuleList([
            conv3x1(in_channels, 256),
            nn.BatchNorm2d(256),
            nn.Dropout2d(0.5),
            nn.LeakyReLU()
        ])
        # Adding residual blocks
        for layer_i in range(n_residuals):
            module_list.append(ResidualBlock(256))

        # Flatten & then fc layers
        module_list.append(nn.Flatten())
        out_feat = 1024
        module_list += nn.ModuleList([
            *self.linear_block(256*34, 1024, dropout_prob=0.2),
            *self.linear_block(1024, 256, dropout_prob=0.2),
            nn.Linear(256, n_cls)
        ])

        return nn.Sequential(*module_list)

    def linear_block(self, n_feat, out_feat, dropout_prob=0.5):
        block = nn.ModuleList([
            nn.Linear(n_feat, out_feat),
            nn.BatchNorm1d(out_feat),
            nn.Dropout(dropout_prob),
            nn.LeakyReLU()
        ])
        return block

In [7]:
# model = ChowModel(history_len=4).cuda()
# summary(model, (80, 42, 1))

## Hyperparameters

In [18]:
lr = 0.001
n_epoch = 100
batch_size = 256
n_cls = 34
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Training

### Chow Model

In [None]:
train_set = ICJAIDataset('/content/drive/MyDrive/mjdata/chowable_train.txt', history_len=4, data_ratio=0.1)
val_set = ICJAIDataset('/content/drive/MyDrive/mjdata/chowable_val.txt', history_len=4, data_ratio=0.05)

train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=8)

val_loader = DataLoader(dataset=val_set,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=8)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Calculating class ratios: ', max=126112.0, style=Progress…




HBox(children=(FloatProgress(value=0.0, description='Calculating class ratios: ', max=7418.0, style=ProgressSt…




In [None]:
# Chow training
def compute_acc(pred, target):
    '''
    Args:
    - pred (torch.tensor, float32): unnormalized logits (before softmax) shape [bs, 4]
    - target (torch.tensor, int64): shape [bs]

    Returns:
    - acc (float): exact classification accuracy
    - bin_acc (float): binary classification accuracy
    '''
    pred = torch.argmax(F.softmax(pred, dim=-1), dim=-1)

    # Binary accuracy
    bin_pred = (pred != 0).int()
    bin_target = (target != 0).int()
    bin_acc = (bin_pred == bin_target).sum()/(bin_target.shape[0])

    # Exact accuracy
    acc = (pred == target).sum()/(target.shape[0])
    return acc.item(), bin_acc.item()

def validate(model, val_loader, epoch):
    model.eval()
    val_acc, val_bin_acc = 0, 0
    pbar = tqdm(val_loader, desc=f"Epoch {epoch} Validation")
    with torch.no_grad():
        for bi, (X, Y) in enumerate(pbar):
            # Forward
            X, Y = list(map(lambda x: x.to(device), [X, Y]))
            preds = model(X)

            acc, bin_acc = compute_acc(preds.cpu(), Y.cpu())
            val_acc += acc
            val_bin_acc += bin_acc
    val_acc /= (bi+1)
    val_bin_acc /= (bi+1)
    pbar.set_postfix_str(f'Val Acc: {val_acc:.4f} | Val Bin Acc: {val_bin_acc:.4f}')
    return val_acc, val_bin_acc

def save_checkpoints(epoch, model, optimizer, train_loss, val_acc):
    save_dir = '/content/drive/MyDrive/mjdata/ckpts/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, f'ep{epoch}-val_acc{val_acc:.4f}.tar')
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, save_path)
    print('Checkpoint saved')

In [None]:
# Model
model = MJModel(history_len=4)

# Loss function
cls_weights = [1/train_set.cls_ratios[i] for i in range(n_cls)]
cls_weights = torch.tensor([w/sum(cls_weights) for w in cls_weights])
criterion = nn.CrossEntropyLoss(cls_weights)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=n_epoch, div_factor=20)

In [None]:
model.to(device)
criterion.to(device)

best_val_acc = 0
for epoch in range(n_epoch):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{n_epoch}")
    train_loss, train_acc, train_bin_acc = 0, 0, 0
    for bi, (X, Y) in enumerate(pbar):
        optimizer.zero_grad()

        # Forward
        X, Y = list(map(lambda x: x.to(device), [X, Y]))
        preds = model(X)

        # Calculate loss & update
        loss = criterion(preds, Y)
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss += loss.detach().item()
        acc, bin_acc = compute_acc(preds.cpu(), Y.cpu())
        train_acc += acc
        train_bin_acc += bin_acc

        pbar.set_postfix_str(f'Train loss: {loss.detach().item():.4f} | Train Acc: {(train_acc/(bi+1)):.4f} | Train Bin Acc: {(train_bin_acc/(bi+1)):.4f}')

    # End of epoch
    train_loss /= (bi+1)
    train_acc /= (bi+1)
    train_bin_acc /= (bi+1)
    val_acc, val_bin_acc = validate(model, val_loader, epoch)

    if val_acc >= best_val_acc:
        best_val_acc = val_acc
        save_checkpoints(epoch, model, optimizer, train_loss, val_acc)

    pbar.set_postfix_str(f'Train loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train Bin Acc: {train_bin_acc:.4f}')

### Pong & Kong Model

In [17]:
training_type = 'discard'

In [18]:
train_set = ICJAIDataset(f'/content/drive/MyDrive/mjdata/{training_type}/{training_type}_train.txt', history_len=4, data_ratio=.1)
val_set = ICJAIDataset(f'/content/drive/MyDrive/mjdata/{training_type}/{training_type}_val.txt', history_len=4, data_ratio=1.)

train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=8)

val_loader = DataLoader(dataset=val_set,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=8)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Calculating class ratios: ', max=91515.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, description='Calculating class ratios: ', max=18788.0, style=ProgressS…




In [19]:
def compute_acc_f1(pred, target):
    '''
    Args:
    - pred (torch.tensor, float32): unnormalized logits (before softmax) shape [bs, 4]
    - target (torch.tensor, int64): shape [bs]

    Returns:
    - acc (float): exact classification accuracy
    - f1 (float)
    '''
    pred = torch.argmax(F.softmax(pred, dim=-1), dim=-1)

    acc = (pred == target).sum()/(target.shape[0])
    f1= f1_score(target, pred, average='weighted')
    return acc.item(), f1

def validate(model, val_loader, epoch):
    model.eval()
    val_acc, val_f1 = 0, 0
    pbar = tqdm(val_loader, desc=f"Epoch {epoch} Validation")
    with torch.no_grad():
        for bi, (X, Y) in enumerate(pbar):
            # Forward
            X, Y = list(map(lambda x: x.to(device), [X, Y]))
            preds = model(X)

            acc, f1 = compute_acc_f1(preds.cpu(), Y.cpu())
            val_acc += acc
            val_f1 += f1

    val_acc /= (bi+1)
    val_f1 /= (bi+1)
    pbar.set_postfix_str(f'Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}')
    return val_acc, val_f1

def save_checkpoints(epoch, model, optimizer, train_loss, val_acc, val_f1):
    save_dir = f'/content/drive/MyDrive/mjdata/{training_type}_ckpts/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, f'ep{epoch}-val_acc_{val_acc:.4f}-val_f1_{val_f1:.4f}.tar')
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, save_path)
    print('Checkpoint saved')

In [20]:
# Model
model = MJModel(history_len=4, n_cls=34)

# Loss function
# cls_weights = [1/train_set.cls_ratios[i] for i in range(n_cls)]
# cls_weights = torch.tensor([w/sum(cls_weights) for w in cls_weights])
# criterion = nn.CrossEntropyLoss(cls_weights)
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=n_epoch, div_factor=20)

In [None]:
model.to(device)
criterion.to(device)

best_val_f1 = 0
for epoch in range(n_epoch):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{n_epoch}")
    train_loss, train_acc, train_f1 = 0, 0, 0
    for bi, (X, Y) in enumerate(pbar):
        optimizer.zero_grad()

        # Forward
        X, Y = list(map(lambda x: x.to(device), [X, Y]))
        preds = model(X)

        # Calculate loss & update
        loss = criterion(preds, Y)
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss += loss.detach().item()
        acc, f1 = compute_acc_f1(preds.cpu(), Y.cpu())
        train_acc += acc
        train_f1 += f1

        pbar.set_postfix_str(f'Train loss: {loss.detach().item():.4f} | Train Acc: {(train_acc/(bi+1)):.4f} | Train F1: {(train_f1/(bi+1)):.4f}')

    # End of epoch
    train_loss /= (bi+1)
    train_acc /= (bi+1)
    train_f1 /= (bi+1)
    val_acc, val_f1 = validate(model, val_loader, epoch)

    if val_f1 >= best_val_f1:
        best_val_f1 = val_f1
        save_checkpoints(epoch, model, optimizer, train_loss, val_acc, val_f1)

    pbar.set_postfix_str(f'Train loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


HBox(children=(FloatProgress(value=0.0, description='Epoch 0/100', max=358.0, style=ProgressStyle(description_…




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Epoch 0 Validation', max=74.0, style=ProgressStyle(descri…


Checkpoint saved


HBox(children=(FloatProgress(value=0.0, description='Epoch 1/100', max=358.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 1 Validation', max=74.0, style=ProgressStyle(descri…


Checkpoint saved


HBox(children=(FloatProgress(value=0.0, description='Epoch 2/100', max=358.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 2 Validation', max=74.0, style=ProgressStyle(descri…


Checkpoint saved


HBox(children=(FloatProgress(value=0.0, description='Epoch 3/100', max=358.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 3 Validation', max=74.0, style=ProgressStyle(descri…


Checkpoint saved


HBox(children=(FloatProgress(value=0.0, description='Epoch 4/100', max=358.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 4 Validation', max=74.0, style=ProgressStyle(descri…


Checkpoint saved


HBox(children=(FloatProgress(value=0.0, description='Epoch 5/100', max=358.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 5 Validation', max=74.0, style=ProgressStyle(descri…


Checkpoint saved


HBox(children=(FloatProgress(value=0.0, description='Epoch 6/100', max=358.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 6 Validation', max=74.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 7/100', max=358.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 7 Validation', max=74.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 8/100', max=358.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 8 Validation', max=74.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 9/100', max=358.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 9 Validation', max=74.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Epoch 10/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 10 Validation', max=74.0, style=ProgressStyle(descr…


Checkpoint saved


HBox(children=(FloatProgress(value=0.0, description='Epoch 11/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 11 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 12/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 12 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 13/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 13 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 14/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 14 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 15/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 15 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 16/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 16 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 17/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 17 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 18/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 18 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 19/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 19 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 20/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 20 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 21/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 21 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 22/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 22 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 23/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 23 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 24/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 24 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 25/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 25 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 26/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 26 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 27/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 27 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 28/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 28 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 29/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 29 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 30/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 30 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 31/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 31 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 32/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 32 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 33/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 33 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 34/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 34 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 35/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 35 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 36/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 36 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 37/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 37 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 38/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 38 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 39/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 39 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 40/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 40 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 41/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 41 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 42/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 42 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 43/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 43 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 44/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 44 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 45/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 45 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 46/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 46 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 47/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 47 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 48/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 48 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 49/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 49 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 50/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 50 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 51/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 51 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 52/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 52 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 53/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 53 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 54/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 54 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 55/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 55 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 56/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 56 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 57/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 57 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 58/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 58 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 59/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 59 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 60/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 60 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 61/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 61 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 62/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 62 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 63/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 63 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 64/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 64 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 65/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 65 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 66/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 66 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 67/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 67 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 68/100', max=358.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 68 Validation', max=74.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Epoch 69/100', max=358.0, style=ProgressStyle(description…

## Test

In [None]:
test_set = ICJAIDataset('/content/drive/MyDrive/mjdata/chowable_test.txt', history_len=4)

test_loader = DataLoader(dataset=test_set,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=8)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Calculating class ratios: ', max=1237252.0, style=Progres…




In [None]:
# Model
model = ChowModel(history_len=4)
best_ckpt = max(glob.glob('/content/drive/MyDrive/mjdata/ckpts/*.tar'), key=lambda x: float(x.split('/')[-1][-10:-5]))
checkpoint = torch.load(best_ckpt)
model.load_state_dict(checkpoint['model_state_dict'])

model.eval()
model.to(device)

test_acc, test_bin_acc = 0, 0
with torch.no_grad():
    pbar = tqdm(test_loader, desc="Testing")
    for bi, (test_X, test_Y) in enumerate(pbar):
        # Forward
        test_X, test_Y = list(map(lambda x: x.to(device), [test_X, test_Y]))
        preds = model(test_X)

        # Calculate accuracy
        acc, bin_acc = compute_acc(preds, test_Y)

        test_acc += acc
        test_bin_acc += bin_acc

        pbar.set_postfix_str(f'Test Acc: {(test_acc/(bi+1)):.4f} | Test Bin Acc: {(test_bin_acc/(bi+1)):.4f}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if sys.path[0] == '':


HBox(children=(FloatProgress(value=0.0, description='Testing', max=4834.0, style=ProgressStyle(description_wid…

## Trash

In [9]:
# Model
model = MJModel(history_len=4, n_cls=34)
best_ckpt = '/content/drive/MyDrive/mjdata/discard_ckpts/ep10-val_acc_0.6793-val_f1_0.6787.tar'
checkpoint = torch.load(best_ckpt)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [11]:
test_set = ICJAIDataset('/content/drive/MyDrive/mjdata/discard/discard_test.txt', history_len=4)

# test_loader = DataLoader(dataset=test_set,
#                           batch_size=batch_size,
#                           shuffle=True,
#                           num_workers=8)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, description='Calculating class ratios: ', max=18788.0, style=ProgressS…




In [27]:
x, y = test_set[100]

model.eval()
with torch.no_grad():
  y_pred = model(x.unsqueeze(0))

In [28]:
y

tensor(25)

In [30]:
nn.Softmax(-1)(y_pred), y_pred.argmax(-1)

(tensor([[6.0017e-05, 1.3811e-05, 9.2958e-03, 1.6336e-05, 1.4314e-04, 1.5888e-04,
          2.1670e-06, 5.3210e-03, 6.7318e-05, 5.7189e-05, 1.2867e-04, 2.9143e-05,
          1.6681e-05, 4.7916e-03, 1.1038e-04, 1.8255e-04, 8.6267e-03, 1.5235e-05,
          1.3702e-04, 2.4336e-06, 1.0051e-04, 3.4924e-05, 3.0378e-04, 8.4500e-05,
          2.9281e-04, 9.4456e-01, 2.0951e-02, 3.3996e-04, 2.1272e-05, 3.5763e-03,
          4.0531e-04, 3.2354e-06, 1.3762e-04, 9.2876e-06]]), tensor([25]))

In [9]:
torch.save(model.state_dict(),  '/content/drive/MyDrive/mjdata/discard_ckpts/discard-ep10-val_acc_0.6793-val_f1_0.6787.pth')

In [None]:
b = {'a': deque([]), 'b': deque([])}

b['a'].append(1)
b['a'].append(2)
b['a'][0]

1

In [None]:
# 打牌 摸牌 明杠 暗杠 杠后摸牌 补花 补花后摸牌 碰 吃 和牌

In [None]:
lines[6].split('\t')

['3', '打牌', "['F2']", '\n']

In [None]:
def generate_training_set(file_name, player_num, history_num = 4):
    other_player = [i for i in range(4)]
    other_player.pop(player_num)
    
    master_X = np.empty((0,30, 4, 42), int)
    master_Y = np.empty((0,42), int)
    
    f = open(file_name, "r")
    title = f.readline()
    game_info = f.readline()
    wind = game_info.split('\t')[0]
    score = game_info.split('\t')[1]
    discard_card = [[],[],[],[]]
    past_game_history = deque([np.zeros((6, 4, 42),int) for x in range(history_num)])

    # Get player starting tile information
    for i in range(4):
        player_hand = f.readline()
        start = player_hand.find('[')
        end = player_hand.find(']') + 1
        starting_tile = ast.literal_eval(player_hand[start:end])
        if i == player_num:
            player_tile = starting_tile


    round_info = f.readline()
    
    while round_info != None and len(round_info) > 0:
        try:
            # Extracting information of that round
            round_info = round_info.split('\t')
            round_player_num = int(round_info[0])
            action = round_info[1]
            round_tile = ast.literal_eval(round_info[2])[0]
            eat_tile = round_info[3]

            # Not yet implemneted for target
            # Everyone discard tile can be seen for everyone
            # Now we have feature (X)
            if action == '打牌':
                if player_num == round_player_num:
                    player_tile.remove(round_tile)
                    Y = serialize_y(round_tile)
                discard_card[round_player_num].append(round_tile)

            elif action == '摸牌' or action == '补花后摸牌':
                if player_num == round_player_num:
                    player_tile.append(round_tile)
                else:
                    pass

            elif action == '补花':
                if player_num == round_player_num:
                  # Add in the flower list
                    pass
                else:
                    pass

            elif action == '吃':
                if player_num == round_player_num:
                    player_tile.append(eat_tile)
                else:
                    pass

            round_info = f.readline()

            # If round player number = your player number
            if round_player_num == player_num:
                # Feature encoding
                # Encoding own hand feature
                X = serialize(player_tile) 

                for discard_list in discard_card:
                    # Encoding all the discard card
                    X = np.concatenate((X, serialize(discard_list)), axis = 0) 

                # Encoding the current card
                X = np.concatenate((X, serialize([round_tile])), axis = 0) 

                X_with_history = np.copy(X)

                for past_game_list in past_game_history:
                    X_with_history = np.concatenate((X_with_history, past_game_list), axis = 0) # Encode the past history situation into feature

                past_game_history.append(X) # Append the current situation into the history queue
                past_game_history.pop() # Remove the last information history

#                 print(master_X.shape)
                # print(X.shape)
                # print(X_with_history.shape)
                if action == '打牌':
                    master_X = np.append(master_X, np.expand_dims(X_with_history, axis = 0), axis = 0)
                    master_Y = np.append(master_Y, np.expand_dims(Y, axis = 0), axis = 0)

        except:
            break
            
    f.close()
    return master_X, master_Y

In [None]:
master_X = np.empty((0, 30, 4, 42), int)
master_Y = np.empty((0, 42), int)
file_count = 0


for file_name in all_game_files:
    #for player_num in range(0,4):
    X, Y = generate_training_set(file_name, player_num=0, history_num = 4)
    master_X = np.append(master_X, X, axis = 0)
    master_Y = np.append(master_Y, Y, axis = 0)
    if master_X.shape[0] > 10000:
        break
    # print('Shape', master_X.shape[0])

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30, 4, 42)
(6, 4, 42)
(30,

KeyboardInterrupt: ignored

In [None]:
np.save('./input_X.npx', master_X)
np.save('./input_Y.npx', master_Y)

In [None]:
master_X[0]

array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]])