In [1]:
from dataset import MahjongSLDataset
from model_slim import MahJongNetBatchedRevised
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
import os
from torch import nn
from tqdm import tqdm
import pandas as pd

# from model import DataEntryModule
from collections import defaultdict

# from model_batched_MNN import MahJongNetBatchedRevised, TileProbModuleNew
from torch.optim.lr_scheduler import (
    SequentialLR,
    LambdaLR,
    CosineAnnealingWarmRestarts,
    ConstantLR,
    ExponentialLR,
)
import copy
import itertools
import matplotlib.pyplot as plt
import json
import feature


def default_zero():
    return 0


plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["figure.figsize"] = (15, 7)
plt.rcParams["font.size"] = 16


# fmt:off
tile_list_raw = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9',  # 饼
                 'W1', 'W2', 'W3', 'W4', 'W5', 'W6', 'W7', 'W8', 'W9',  # 万
                 'T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7', 'T8', 'T9',  # 条
                 'F1', 'F2', 'F3', 'F4', 'J1', 'J2', 'J3'  # 风、箭
                 ]
fan_list_raw = ["大四喜", "大三元", "绿一色", "九莲宝灯", "四杠", "连七对", "十三幺", "清幺九", "小四喜", "小三元", "字一色", "四暗刻", "一色双龙会",
                "一色四同顺", "一色四节高", "一色四步高", "三杠", "混幺九", "七对", "七星不靠", "全双刻", "清一色", "一色三同顺", "一色三节高", "全大",
                "全中", "全小", "清龙", "三色双龙会", "一色三步高", "全带五", "三同刻", "三暗刻", "全不靠", "组合龙", "大于五", "小于五", "三风刻", "花龙",
                "推不倒", "三色三同顺", "三色三节高", "无番和", "妙手回春", "海底捞月", "杠上开花", "抢杠和",
                "碰碰和", "混一色", "三色三步高", "五门齐", "全求人", "双暗杠", "双箭刻", "全带幺",
                "不求人", "双明杠", "和绝张", "箭刻", "圈风刻", "门风刻", "门前清", "平和", "四归一",
                "双同刻", "双暗刻", "暗杠", "断幺", "一般高", "喜相逢", "连六", "老少副", "幺九刻",
                "明杠", "缺一门", "无字", "边张", "嵌张", "单钓将", "自摸"]

tile_list_conversion = {'B1':'D1', 'B2':'D2', 'B3':'D3', 'B4':'D4', 'B5':'D5', 'B6':'D6', 'B7':'D7', 'B8':'D8', 'B9':'D9',  # 饼
                 'W1':'C1', 'W2':'C2', 'W3':'C3', 'W4':'C4', 'W5':'C5', 'W6':'C6', 'W7':'C7', 'W8':'C8', 'W9':'C9',  # 万
                 'T1':'B1', 'T2':'B2', 'T3':'B3', 'T4':'B4', 'T5':'B5', 'T6':'B6', 'T7':'B7', 'T8':'B8', 'T9':'B9',  # 条
                 'F1':'EW', 'F2':'SW', 'F3':'WW', 'F4':'NW', 'J1':'RD', 'J2':'GD', 'J3':'WD'  # 风、箭
                 }
# fmt:on

In [2]:
def reversed_tile_conversion(list_rep, key_order_list):
    ret_dict = defaultdict(default_zero)
    for i in range(len(key_order_list)):
        ret_dict[key_order_list[i]] += list_rep[i]
    return ret_dict

In [3]:
def count_param(model):
    param_count = 0
    for param in model.parameters():
        param_count += param.view(-1).size()[0]
    return param_count


device = "cpu"
net = MahJongNetBatchedRevised(device).to(device)
count_param(net)

126

In [4]:
# change params here
log_name = "07_22_03_21_12-model_slim_7p_normed"

# typically leave as is
logdir = "../log/"
log_suffix = "checkpoint"
extract_dir = "extracted_model_slim_params"
weight_list = {
    "top1": "best_acc.pkl",
    "top2": "best_acc_top2.pkl",
    "top3": "best_acc_top3.pkl",
}

# load model parameter from log
net.load_state_dict(
    torch.load(
        os.path.join(logdir, log_name, log_suffix, weight_list["top1"]),
        map_location=torch.device(device),
    )
)

<All keys matched successfully>

In [5]:
net.prob_module_throw.state_dict()

OrderedDict([('final_layer.0.weight',
              tensor([[13.2868, -2.4386, -6.8902,  3.7414,  4.2955,  4.2075,  0.0544,  0.3670,
                       -1.7708,  0.5356,  0.4127]])),
             ('final_layer.0.bias', tensor([3.9621]))])

In [6]:
fan_output = reversed_tile_conversion(net.fan_coeff.detach().numpy(), fan_list_raw)

In [7]:
fan_output = sorted(fan_output.items(), key=lambda x: -x[1])

In [8]:
fan_output

[('七对', 8.898839950561523),
 ('一色双龙会', 1.2458250522613525),
 ('一色四节高', 0.6864998936653137),
 ('组合龙', 0.19646704196929932),
 ('十三幺', 0.1776566207408905),
 ('全双刻', 0.12358742207288742),
 ('海底捞月', 0.123080775141716),
 ('无番和', 0.12171602249145508),
 ('大四喜', 0.1164558157324791),
 ('四暗刻', 0.11355960369110107),
 ('全大', 0.11242605000734329),
 ('四杠', 0.1100180596113205),
 ('杠上开花', 0.10616981238126755),
 ('暗杠', 0.09763997048139572),
 ('一色四同顺', 0.09622961282730103),
 ('三杠', 0.09620863199234009),
 ('妙手回春', 0.09031521528959274),
 ('抢杠和', 0.08968039602041245),
 ('双明杠', 0.08683235943317413),
 ('双暗杠', 0.08668944984674454),
 ('清幺九', 0.08600226044654846),
 ('小四喜', 0.08466795086860657),
 ('七星不靠', 0.08460734784603119),
 ('绿一色', 0.08410496264696121),
 ('一色三同顺', 0.0808059424161911),
 ('字一色', 0.07909941673278809),
 ('门前清', 0.07840710878372192),
 ('三色双龙会', 0.07557263970375061),
 ('明杠', 0.0725608691573143),
 ('一色三节高', 0.06913502514362335),
 ('推不倒', 0.06112995371222496),
 ('一色四步高', 0.058565206825733185),
 ('连七对

In [9]:
tile_output = reversed_tile_conversion(net.tile_coeff.detach().numpy(), tile_list_raw)

In [10]:
_tile_pref = {
    "B1": 0.99,
    "B2": 0.991,
    "B3": 0.992,
    "B4": 0.993,
    "B5": 0.994,
    "B6": 0.995,
    "B7": 0.996,
    "B8": 0.996,
    "B9": 0.997,  # 饼
    "W1": 0.998,
    "W2": 0.999,
    "W3": 1.0,
    "W4": 1.001,
    "W5": 1.002,
    "W6": 1.003,
    "W7": 1.004,
    "W8": 1.005,
    "W9": 1.006,  # 万
    "T1": 0.989,
    "T2": 0.988,
    "T3": 0.987,
    "T4": 0.987,
    "T5": 0.986,
    "T6": 0.985,
    "T7": 0.984,
    "T8": 0.983,
    "T9": 0.982,  # 条
    "F1": 1.007,
    "F2": 1.008,
    "F3": 1.009,
    "F4": 1.01,
    "J1": 1.011,
    "J2": 1.012,
    "J3": 1.013,  # 风、箭
}

In [11]:
known_tile_pref = sorted(_tile_pref.items(), key=lambda x: -x[1])

In [12]:
l1 = []
for p in known_tile_pref:
    l1.append(p[0])

In [13]:
tile_output = sorted(tile_output.items(), key=lambda x: -x[1])

In [18]:
l2 = []
l2d = {}
for p in tile_output:
    l2.append(p[0])
    l2d[p[0]]=p[1]

In [15]:
l1_converted = [tile_list_conversion[x] for x in l1]
l2_converted = [tile_list_conversion[x] for x in l2]

In [16]:
df = pd.DataFrame()

In [19]:
l2d

{'J3': 1.413110375404358,
 'J2': 1.1383740901947021,
 'J1': 1.0742130279541016,
 'F4': 0.9988686442375183,
 'F3': 0.8750644326210022,
 'F2': 0.8038995265960693,
 'W9': 0.7539696097373962,
 'F1': 0.749397337436676,
 'W8': 0.7057797312736511,
 'W7': 0.6879855394363403,
 'W6': 0.6517385244369507,
 'W5': 0.6148751974105835,
 'W4': 0.5762643218040466,
 'W3': 0.5414806604385376,
 'W2': 0.4989180862903595,
 'W1': 0.4630112946033478,
 'B9': 0.44540563225746155,
 'B8': 0.4367682635784149,
 'B7': 0.43627840280532837,
 'B6': 0.4314165711402893,
 'B5': 0.4152246415615082,
 'B4': 0.4089546501636505,
 'B3': 0.39770740270614624,
 'T1': 0.39184990525245667,
 'B2': 0.39055466651916504,
 'B1': 0.3822522461414337,
 'T2': 0.3620748221874237,
 'T3': 0.32989853620529175,
 'T4': 0.3259626030921936,
 'T5': 0.2986399531364441,
 'T6': 0.2768508791923523,
 'T7': 0.2553465664386749,
 'T8': 0.2364477515220642,
 'T9': 0.2165195345878601}

In [30]:
df["Known Tile Order"] = l1_converted
df["Learned Tile Order"] = l2_converted
df["known order weight"] = [l2d[x] for x in l1]

In [31]:
print(
    df.to_latex(
        index=False,
        formatters={"name": str.upper},
        float_format="{:.3f}".format,
    )
)

\begin{tabular}{llr}
\toprule
Known Tile Order & Learned Tile Order &  known order weight \\
\midrule
              WD &                 WD &               1.413 \\
              GD &                 GD &               1.138 \\
              RD &                 RD &               1.074 \\
              NW &                 NW &               0.999 \\
              WW &                 WW &               0.875 \\
              SW &                 SW &               0.804 \\
              EW &                 C9 &               0.749 \\
              C9 &                 EW &               0.754 \\
              C8 &                 C8 &               0.706 \\
              C7 &                 C7 &               0.688 \\
              C6 &                 C6 &               0.652 \\
              C5 &                 C5 &               0.615 \\
              C4 &                 C4 &               0.576 \\
              C3 &                 C3 &               0.541 \\
              C2

  df.to_latex(


In [33]:
with open("a.txt", "w+") as f:
    f.writelines(
        df.to_latex(
            index=True,
            formatters={"name": str.upper},
            float_format="{:.4f}".format,
        )
    )

  df.to_latex(


In [10]:
path_to_ds = "../tmp"
ds = MahjongSLDataset(path_to_ds, 0.0, 1)
loader = DataLoader(dataset=ds, batch_size=1, shuffle=False, num_workers=4)

In [11]:
d = next(iter(loader))

In [12]:
d[1] == ds[0][1]

tensor([[True, True, True, True, True, True]])

In [13]:
path_to_data = "../data"
data_name = "0.npy"
(
    botzone_log,  # 1
    tileWall_log,  # 2
    pack_log,  # 3
    handWall_log,  # 4
    obsWall_log,  # 5
    remaining_tile_log,  # 6
    botzone_id,  # 7
    winner_id,  # 8
    prevailing_wind,  # 9
    fan_sum,  # 10
    score,
    fan_list,  # 11
) = feature.load_log(path_to_data, data_name)

In [14]:
(
    meta,
    meta_feature_new,
    tile_wall_feature,
    k,
    q,
    m,
    n,
    o,
    # fan_summary,
    label_action,
    label_tile,
    v_info,
) = d

(
    meta,
    meta_feature_new,
    tile_wall_feature,
    k,
    q,
    m,
    n,
    o,
    # fan_summary,
    label_action,
    label_tile,
) = (
    meta.to(device),
    meta_feature_new.to(device),
    tile_wall_feature.to(device),
    k.to(device),
    q.to(device),
    m.to(device),
    n.to(device),
    o.to(device),
    # data_description.to(device).float(),
    # fan_summary.to(device),
    label_action.to(device),
    label_tile.to(device),
)
data1 = (k, q, m, n, o)

index_list_plus_one = torch.tensor([i + 1 for i in range(len(meta))]).to(device)
x = (
    meta_feature_new,
    tile_wall_feature,
    data1,
)

In [15]:
(meta_feature, tile_wall_feature, search_matrix) = x

(
    tile_prep_data,
    fan_prep,
    missing_tile_prep_data,
    count_prep,
    chi_peng_count_remain_data,
) = search_matrix

In [16]:
action, tile = net(x)

In [17]:
network_output_dict = reversed_tile_conversion(tile[0], tile_list_raw)
for key in tile_list_raw:
    if network_output_dict[key] == 0.0:
        del network_output_dict[key]

In [19]:
network_output_dict

defaultdict(<function __main__.default_zero()>,
            {'B2': tensor(1.1383, grad_fn=<AddBackward0>),
             'B7': tensor(4.5847, grad_fn=<AddBackward0>),
             'W3': tensor(1.2013, grad_fn=<AddBackward0>),
             'W6': tensor(1.2322, grad_fn=<AddBackward0>),
             'W9': tensor(4.4820, grad_fn=<AddBackward0>),
             'T1': tensor(1.1608, grad_fn=<AddBackward0>),
             'T6': tensor(5.2495, grad_fn=<AddBackward0>),
             'T8': tensor(4.2498, grad_fn=<AddBackward0>),
             'T9': tensor(5.2461, grad_fn=<AddBackward0>),
             'F4': tensor(1.1229, grad_fn=<AddBackward0>),
             'J1': tensor(1.4294, grad_fn=<AddBackward0>),
             'J3': tensor(1.4016, grad_fn=<AddBackward0>)})