In [1]:
from dataset import MahjongSLDataset
from model_slim import MahJongNetBatchedRevised
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
import os
from torch import nn
from tqdm import tqdm

# from model import DataEntryModule
from collections import defaultdict

# from model_batched_MNN import MahJongNetBatchedRevised, TileProbModuleNew
from torch.optim.lr_scheduler import (
    SequentialLR,
    LambdaLR,
    CosineAnnealingWarmRestarts,
    ConstantLR,
    ExponentialLR,
)
import copy
import itertools
import matplotlib.pyplot as plt
import json
import feature


def default_zero():
    return 0


plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["figure.figsize"] = (15, 7)
plt.rcParams["font.size"] = 16


# fmt:off
tile_list_raw = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9',  # 饼
                 'W1', 'W2', 'W3', 'W4', 'W5', 'W6', 'W7', 'W8', 'W9',  # 万
                 'T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7', 'T8', 'T9',  # 条
                 'F1', 'F2', 'F3', 'F4', 'J1', 'J2', 'J3'  # 风、箭
                 ]
fan_list_raw = ["大四喜", "大三元", "绿一色", "九莲宝灯", "四杠", "连七对", "十三幺", "清幺九", "小四喜", "小三元", "字一色", "四暗刻", "一色双龙会",
                "一色四同顺", "一色四节高", "一色四步高", "三杠", "混幺九", "七对", "七星不靠", "全双刻", "清一色", "一色三同顺", "一色三节高", "全大",
                "全中", "全小", "清龙", "三色双龙会", "一色三步高", "全带五", "三同刻", "三暗刻", "全不靠", "组合龙", "大于五", "小于五", "三风刻", "花龙",
                "推不倒", "三色三同顺", "三色三节高", "无番和", "妙手回春", "海底捞月", "杠上开花", "抢杠和",
                "碰碰和", "混一色", "三色三步高", "五门齐", "全求人", "双暗杠", "双箭刻", "全带幺",
                "不求人", "双明杠", "和绝张", "箭刻", "圈风刻", "门风刻", "门前清", "平和", "四归一",
                "双同刻", "双暗刻", "暗杠", "断幺", "一般高", "喜相逢", "连六", "老少副", "幺九刻",
                "明杠", "缺一门", "无字", "边张", "嵌张", "单钓将", "自摸"]
# fmt:on

In [2]:
def reversed_tile_conversion(list_rep, key_order_list):
    ret_dict = defaultdict(default_zero)
    for i in range(len(key_order_list)):
        ret_dict[key_order_list[i]] += list_rep[i]
    return ret_dict

In [2]:
def count_param(model):
    param_count = 0
    for param in model.parameters():
        param_count += param.view(-1).size()[0]
    return param_count


device = "cpu"
net = MahJongNetBatchedRevised(device).to(device)
count_param(net)

114

In [3]:
from model_TFx_OB_v6_ls import MahJongNetBatchedRevised
def count_param(model):
    param_count = 0
    for param in model.parameters():
        param_count += param.view(-1).size()[0]
    return param_count


device = "cpu"
net = MahJongNetBatchedRevised(device).to(device)
count_param(net)

200562

In [4]:
# change params here
log_name = "07_16_02_53_57-model_slim_rev6_normed2"

# typically leave as is
logdir = "../log/"
log_suffix = "checkpoint"
extract_dir = "extracted_model_slim_params"
weight_list = {
    "top1": "best_acc.pkl",
    "top2": "best_acc_top2.pkl",
    "top3": "best_acc_top3.pkl",
}
# create path for model parameter extraction
extract_path = os.path.join(extract_dir, "top2")
if not os.path.isdir(extract_path):
    os.makedirs(extract_path)

# load model parameter from log
net.load_state_dict(
    torch.load(
        os.path.join(logdir, log_name, log_suffix, weight_list["top2"]),
        map_location=torch.device(device),
    )
)

<All keys matched successfully>

In [14]:
fan_output = reversed_tile_conversion(net.fan_coeff.detach().numpy(), fan_list_raw)

In [15]:
fan_output = sorted(fan_output.items(), key=lambda x: -x[1])

In [18]:
fan_output

[('十三幺', 1.0015867948532104),
 ('全不靠', 0.6198009848594666),
 ('清幺九', 0.3847324252128601),
 ('组合龙', 0.3242482841014862),
 ('大四喜', 0.30645227432250977),
 ('九莲宝灯', 0.30010512471199036),
 ('七对', 0.2564011514186859),
 ('一色四同顺', 0.21450115740299225),
 ('四暗刻', 0.17405714094638824),
 ('绿一色', 0.16177420318126678),
 ('小四喜', 0.15118086338043213),
 ('一色双龙会', 0.13844110071659088),
 ('全小', 0.12439137697219849),
 ('海底捞月', 0.123080775141716),
 ('无番和', 0.12171602249145508),
 ('三色双龙会', 0.11615011096000671),
 ('四杠', 0.1100180596113205),
 ('杠上开花', 0.10616981238126755),
 ('推不倒', 0.10579821467399597),
 ('一色三节高', 0.10329914093017578),
 ('大于五', 0.09802407026290894),
 ('三同刻', 0.09150362759828568),
 ('妙手回春', 0.09031521528959274),
 ('清一色', 0.09017141908407211),
 ('抢杠和', 0.08968039602041245),
 ('七星不靠', 0.08460734784603119),
 ('花龙', 0.08459737151861191),
 ('三色三节高', 0.08418107032775879),
 ('小于五', 0.07953164726495743),
 ('门前清', 0.07840710878372192),
 ('全带幺', 0.07794518768787384),
 ('全大', 0.0767446905374527),
 ('混一色'

In [11]:
tile_output = reversed_tile_conversion(net.tile_coeff.detach().numpy(), tile_list_raw)

In [12]:
tile_output = sorted(tile_output.items(), key=lambda x: -x[1])

In [13]:
tile_output

[('J3', 1.067261815071106),
 ('J2', 1.0413070917129517),
 ('J1', 1.026802897453308),
 ('F4', 0.989987850189209),
 ('F3', 0.9676172137260437),
 ('F1', 0.9501000642776489),
 ('F2', 0.9354256391525269),
 ('W8', 0.8719266653060913),
 ('W5', 0.8695725202560425),
 ('W9', 0.8673602938652039),
 ('W2', 0.8654977083206177),
 ('W1', 0.8652534484863281),
 ('T3', 0.8609082102775574),
 ('B9', 0.86046302318573),
 ('W3', 0.8592910170555115),
 ('W4', 0.8553491234779358),
 ('W7', 0.8547901511192322),
 ('T1', 0.8493829369544983),
 ('T2', 0.8435449600219727),
 ('W6', 0.8431706428527832),
 ('B3', 0.8396374583244324),
 ('T4', 0.8283160328865051),
 ('T5', 0.8252464532852173),
 ('T7', 0.8227099180221558),
 ('T6', 0.8167823553085327),
 ('B8', 0.8153514862060547),
 ('B7', 0.8137266039848328),
 ('B5', 0.809043288230896),
 ('T9', 0.808642566204071),
 ('B4', 0.8079805374145508),
 ('T8', 0.8060420155525208),
 ('B6', 0.8037850260734558),
 ('B1', 0.8004127740859985),
 ('B2', 0.7946748733520508)]

In [5]:
path_to_ds = "../tmp"
ds = MahjongSLDataset(path_to_ds, 0.0, 1)
loader = DataLoader(dataset=ds, batch_size=1, shuffle=False, num_workers=4)

In [6]:
d = next(iter(loader))

In [7]:
d[1] == ds[0][1]

tensor([[True, True, True, True, True, True]])

In [8]:
path_to_data = "../data"
data_name = "0.npy"
(
    botzone_log,  # 1
    tileWall_log,  # 2
    pack_log,  # 3
    handWall_log,  # 4
    obsWall_log,  # 5
    remaining_tile_log,  # 6
    botzone_id,  # 7
    winner_id,  # 8
    prevailing_wind,  # 9
    fan_sum,  # 10
    score,
    fan_list,  # 11
) = feature.load_log(path_to_data, data_name)

In [9]:
(
    meta,
    meta_feature_new,
    tile_wall_feature,
    k,
    q,
    m,
    n,
    o,
    # fan_summary,
    label_action,
    label_tile,
    v_info,
) = d

(
    meta,
    meta_feature_new,
    tile_wall_feature,
    k,
    q,
    m,
    n,
    o,
    # fan_summary,
    label_action,
    label_tile,
) = (
    meta.to(device),
    meta_feature_new.to(device),
    tile_wall_feature.to(device),
    k.to(device),
    q.to(device),
    m.to(device),
    n.to(device),
    o.to(device),
    # data_description.to(device).float(),
    # fan_summary.to(device),
    label_action.to(device),
    label_tile.to(device),
)
data1 = (k, q, m, n, o)

index_list_plus_one = torch.tensor([i + 1 for i in range(len(meta))]).to(device)
x = (
    meta_feature_new,
    tile_wall_feature,
    data1,
)

In [10]:
(meta_feature, tile_wall_feature, search_matrix) = x

(
    tile_prep_data,
    fan_prep,
    missing_tile_prep_data,
    count_prep,
    chi_peng_count_remain_data,
) = search_matrix

In [11]:
action, tile = net(x)

In [12]:
network_output_dict = reversed_tile_conversion(tile[0], tile_list_raw)
for key in tile_list_raw:
    if network_output_dict[key] == 0.0:
        del network_output_dict[key]

In [13]:
network_output_dict

defaultdict(<function __main__.default_zero()>,
            {'B2': tensor(1.1318, grad_fn=<AddBackward0>),
             'B7': tensor(3.6527, grad_fn=<AddBackward0>),
             'W3': tensor(1.2274, grad_fn=<AddBackward0>),
             'W6': tensor(1.2103, grad_fn=<AddBackward0>),
             'W9': tensor(5.2198, grad_fn=<AddBackward0>),
             'T1': tensor(0.9428, grad_fn=<AddBackward0>),
             'T6': tensor(5.1672, grad_fn=<AddBackward0>),
             'T8': tensor(4.5847, grad_fn=<AddBackward0>),
             'T9': tensor(6.1355, grad_fn=<AddBackward0>),
             'F4': tensor(1.7701, grad_fn=<AddBackward0>),
             'J1': tensor(2.0272, grad_fn=<AddBackward0>),
             'J3': tensor(2.2050, grad_fn=<AddBackward0>)})