In [1]:
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import glob

class MahjongDataset(Dataset):
    def __init__(self, data_paths, load=True):
        self.x_dict = {
            "x_hand": None, 
            "x_hand_red": None, 
            "x_river": None, 
            "x_river_red": None, 
            "x_river_riichi": None, 
            "x_meld": None, 
            "x_meld_red5": None, 
            "x_meld_throw": None, 
            "x_meld_throw_red": None, 
            "x_dora": None, 
            "x_score": None, 
            "x_pool": None, 
            "x_winds": None}
        self.y = None
        self.loads = list(map(np.load, data_paths))
        self.loads = [np.load(path) for path in data_paths]

        if load:
            self.load_index = -1
            self.load_next()

    def unload(self):
        self.y = None
        self.x = None
        for key in self.x_dict.keys():
            self.x_dict[key] = None
    
    def load_next(self):
        self.load_index += 1
        self.load_index %= len(self.loads)
        self.y = self.loads[self.load_index]["y"].astype(np.float32)
        for key in self.x_dict.keys():
            self.x_dict[key] = self.loads[self.load_index][key].astype(np.float32)
        self.post_load()
    
    def post_load(self):
        # Data pre-processing
        # Divide by 250 to average out the score values
        self.y /= 250
        self.x_dict["x_score"] /= 250
        self.x_dict["x_pool"] /= 250
        # To predict score difference at the end of the round
        self.y -= self.x_dict["x_score"]  
        # Resize all input to B, X, 24, 9
        self.x_dict["x_hand"] = np.repeat(self.x_dict["x_hand"], 6, -2)
        self.x_dict["x_hand_red"] = np.expand_dims(self.x_dict["x_hand_red"], (-1, -2))
        self.x_dict["x_hand_red"] = np.repeat(self.x_dict["x_hand_red"], 9, -1)
        self.x_dict["x_hand_red"] = np.repeat(self.x_dict["x_hand_red"], 24, -2)
        self.x_dict["x_river"] = np.reshape(self.x_dict["x_river"], (*self.x_dict["x_river"].shape[:1], -1, *self.x_dict["x_river"].shape[3:]))
        self.x_dict["x_river_red"] = np.expand_dims(self.x_dict["x_river_red"], -1)
        self.x_dict["x_river_red"] = np.repeat(self.x_dict["x_river_red"], 9, -1)
        self.x_dict["x_river_riichi"] = np.expand_dims(self.x_dict["x_river_riichi"], -1)
        self.x_dict["x_river_riichi"] = np.repeat(self.x_dict["x_river_riichi"], 9, -1)
        self.x_dict["x_meld"] = np.reshape(self.x_dict["x_meld"], (*self.x_dict["x_meld"].shape[:1], -1, *self.x_dict["x_meld"].shape[3:]))
        self.x_dict["x_meld"] = np.repeat(self.x_dict["x_meld"], 6, -2)
        self.x_dict["x_meld_red5"] = np.reshape(self.x_dict["x_meld_red5"], (self.x_dict["x_meld_red5"].shape[0], -1))
        self.x_dict["x_meld_red5"] = np.expand_dims(self.x_dict["x_meld_red5"], (-1, -2))
        self.x_dict["x_meld_red5"] = np.repeat(self.x_dict["x_meld_red5"], 9, -1)
        self.x_dict["x_meld_red5"] = np.repeat(self.x_dict["x_meld_red5"], 24, -2)
        self.x_dict["x_meld_throw"] = np.reshape(self.x_dict["x_meld_throw"], (self.x_dict["x_meld_throw"].shape[0], -1, *self.x_dict["x_meld_throw"].shape[3:]))
        self.x_dict["x_meld_throw"] = np.repeat(self.x_dict["x_meld_throw"], 6, -2)
        self.x_dict["x_meld_throw_red"] = np.reshape(self.x_dict["x_meld_throw_red"], (self.x_dict["x_meld_throw_red"].shape[0], -1))
        self.x_dict["x_meld_throw_red"] = np.expand_dims(self.x_dict["x_meld_throw_red"], (-1, -2))
        self.x_dict["x_meld_throw_red"] = np.repeat(self.x_dict["x_meld_throw_red"], 9, -1)
        self.x_dict["x_meld_throw_red"] = np.repeat(self.x_dict["x_meld_throw_red"], 24, -2)
        self.x_dict["x_dora"] = np.repeat(self.x_dict["x_dora"], 6, -2)
        self.x_dict["x_score"] = np.expand_dims(self.x_dict["x_score"], (-1, -2))
        self.x_dict["x_score"] = np.repeat(self.x_dict["x_score"], 9, -1)
        self.x_dict["x_score"] = np.repeat(self.x_dict["x_score"], 24, -2)
        self.x_dict["x_pool"] = np.expand_dims(self.x_dict["x_pool"], (-1, -2, -3))
        self.x_dict["x_pool"] = np.repeat(self.x_dict["x_pool"], 9, -1)
        self.x_dict["x_pool"] = np.repeat(self.x_dict["x_pool"], 24, -2)
        self.x_dict["x_winds"] = np.expand_dims(self.x_dict["x_winds"], -2)
        self.x_dict["x_winds"] = np.repeat(self.x_dict["x_winds"], 24, -2)
        
        # for varname in self.x_dict.keys():
        #     print(varname, self.x_dict[varname].shape)

        self.x = np.concatenate(list(self.x_dict.values()), axis=1)

        # Unload dict to save memory
        for key in self.x_dict.keys():
            self.x_dict[key] = None
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.x[idx], self.y[idx]

# train_ds = MahjongDataset(sorted(glob.glob("dataset/riichi_ds_v02/*.npz"))[:150])
# test_ds = MahjongDataset(sorted(glob.glob("dataset/riichi_ds_v02/*.npz"))[150:])

full_ds = MahjongDataset(sorted(glob.glob("dataset/riichi_ds_v02/*.npz")))

In [2]:
import torch

# test_dl = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=True)
# train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

full_dl = torch.utils.data.DataLoader(full_ds, batch_size=64, shuffle=True, drop_last=True)

### Models

In [3]:
# Increasing layers

from torchsummary import summary


class ModelV1(torch.nn.Module):
    def __init__(self):
        super(ModelV1, self).__init__()
        self.relu = torch.nn.ReLU()
        self.flatten = torch.nn.Flatten()
        self.max_pool = torch.nn.MaxPool2d(3)
        
        self.conv2d_1 = torch.nn.Conv2d(98, 256, (6, 7), padding=(0, 3))
        self.bn_1 = torch.nn.BatchNorm2d(256)
        
        self.conv2d_2 = torch.nn.Conv2d(256, 256, (6, 7), padding=(0, 3))
        self.bn_2 = torch.nn.BatchNorm2d(256)
        
        self.conv2d_3 = torch.nn.Conv2d(256, 256, (6, 7), padding=(0, 3))
        self.bn_3 = torch.nn.BatchNorm2d(256)
        
        self.conv2d_4 = torch.nn.Conv2d(256, 512, 3, padding=1)
        self.bn_4 = torch.nn.BatchNorm2d(512)
        
        self.conv2d_5 = torch.nn.Conv2d(512, 1024, 3, padding=1)
        self.bn_5 = torch.nn.BatchNorm2d(1024)
        
        self.fc_1 = torch.nn.Linear(9216, 256)
        self.bn_6 = torch.nn.BatchNorm1d(256)
        
        self.fc_2 = torch.nn.Linear(256, 4)
    
    def forward(self, x):
        x = self.bn_1(self.relu(self.conv2d_1(x)))
        x = self.bn_2(self.relu(self.conv2d_2(x)))
        x = self.bn_3(self.relu(self.conv2d_3(x)))
        x = self.bn_4(self.relu(self.conv2d_4(x)))
        x = self.max_pool(x)
        x = self.bn_5(self.relu(self.conv2d_5(x)))
        x = self.flatten(x)
        x = self.bn_6(self.relu(self.fc_1(x)))
        x = self.fc_2(x)
        return x

summary(ModelV1(), (98, 24, 9), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 256, 19, 9]       1,053,952
              ReLU-2           [-1, 256, 19, 9]               0
       BatchNorm2d-3           [-1, 256, 19, 9]             512
            Conv2d-4           [-1, 256, 14, 9]       2,752,768
              ReLU-5           [-1, 256, 14, 9]               0
       BatchNorm2d-6           [-1, 256, 14, 9]             512
            Conv2d-7            [-1, 256, 9, 9]       2,752,768
              ReLU-8            [-1, 256, 9, 9]               0
       BatchNorm2d-9            [-1, 256, 9, 9]             512
           Conv2d-10            [-1, 512, 9, 9]       1,180,160
             ReLU-11            [-1, 512, 9, 9]               0
      BatchNorm2d-12            [-1, 512, 9, 9]           1,024
        MaxPool2d-13            [-1, 512, 3, 3]               0
           Conv2d-14           [-1, 102

### Training

In [6]:
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import datetime
import tqdm

EPOCHS = 150
RUN_PATH = "runs/riichi-dsv02-mv01"

timestamp = datetime.datetime.today().strftime("%Y%m%d_%H%M%S")
full_writer = SummaryWriter(f"{RUN_PATH}/{timestamp}")
self_writer = SummaryWriter(f"{RUN_PATH}/{timestamp}/self")
early_writer = SummaryWriter(f"{RUN_PATH}/{timestamp}/early")
mid_writer = SummaryWriter(f"{RUN_PATH}/{timestamp}/mid")
late_writer = SummaryWriter(f"{RUN_PATH}/{timestamp}/late")

model = ModelV1()
model = torch.nn.DataParallel(model)
model.to("cuda")

get_mae = torch.nn.L1Loss()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_deque = deque(maxlen=len(full_dl))

write_counter = 0
for epoch in range(EPOCHS):
    pbar = tqdm.tqdm(
        range(len(full_dl)), 
        desc=f"Epoch {epoch+1}/{EPOCHS}", 
        position=0, leave=True, ncols=90)
    
    model.train(True)
    for i, (x, y) in enumerate(full_dl):
        x, y = x.to("cuda"), y.to("cuda")
        optimizer.zero_grad()
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
        
        full_mae = get_mae(pred, y) * 25000
        self_mae = get_mae(pred[:, 0], y[:, 0]) * 25000
        tile_nums = torch.sum(x[:, 7:23], dim=(1, 2, 3))
        early_game, mid_game, late_game = tile_nums <= 24, (24 < tile_nums) & (tile_nums <= 48), 48 <= tile_nums
        early_mae = get_mae(pred[early_game], y[early_game]) * 25000
        mid_mae = get_mae(pred[mid_game], y[mid_game]) * 25000
        late_mae = get_mae(pred[late_game], y[late_game]) * 25000

        loss_deque.append(loss.item())
        full_writer.add_scalar("loss", loss.item(), write_counter)
        full_writer.add_scalar("score_mae", full_mae, write_counter)
        self_writer.add_scalar("score_mae", self_mae, write_counter)
        early_writer.add_scalar("score_mae", early_mae, write_counter)
        mid_writer.add_scalar("score_mae", mid_mae, write_counter)
        late_writer.add_scalar("score_mae", late_mae, write_counter)
        write_counter += 1
        pbar.set_postfix({"loss": f"{sum(loss_deque) / len(loss_deque):.3f}"})
        pbar.update()
    full_ds.load_next()
pbar.close()

Epoch 6/150:   0%|                                                | 0/872 [00:05<?, ?it/s]
Epoch 1/150: 100%|██████████████████████████| 872/872 [01:16<00:00, 11.39it/s, loss=0.041]
Epoch 2/150: 100%|██████████████████████████| 832/832 [01:11<00:00, 11.58it/s, loss=0.033]
Epoch 3/150: 100%|██████████████████████████| 826/826 [01:12<00:00, 11.35it/s, loss=0.028]
Epoch 4/150: 100%|██████████████████████████| 862/862 [01:14<00:00, 11.52it/s, loss=0.029]
Epoch 5/150: 100%|██████████████████████████| 833/833 [01:09<00:00, 12.02it/s, loss=0.030]
Epoch 6/150: 100%|██████████████████████████| 826/826 [01:10<00:00, 11.77it/s, loss=0.031]
Epoch 7/150: 100%|██████████████████████████| 843/843 [01:10<00:00, 11.99it/s, loss=0.030]
Epoch 8/150: 100%|██████████████████████████| 852/852 [01:10<00:00, 12.02it/s, loss=0.030]
Epoch 9/150: 100%|██████████████████████████| 830/830 [01:09<00:00, 11.99it/s, loss=0.031]
Epoch 10/150: 100%|█████████████████████████| 840/840 [01:10<00:00, 11.99it/s, loss=0.029]

### Evaluation

In [7]:
torch.save(model.state_dict(), "models/model.pth")

In [8]:
torch.set_printoptions(sci_mode=False)
x, y = next(iter(full_dl))
x, y = x.to("cuda"), y.to("cuda")

model(x) * 25000

tensor([[ -1118.0691,  -2460.5576,   2965.7051,    849.8802],
        [    16.0355,   3333.3032,   1550.1235,  -4873.4150],
        [ -3164.3574,   6006.6348,   -364.6023,  -2448.3506],
        [ -5868.3174,   3699.3816,    590.9872,   1628.3444],
        [ -2604.8049,  -1888.7352,   1894.8672,   2862.0574],
        [   333.4638,    873.0424,  -1424.9915,    399.4430],
        [  3947.2683,   -760.9034,   -234.7463,  -2356.8179],
        [ -3137.8054,    892.3446,  -1313.4896,   3474.6860],
        [  -739.7477,    338.8961,    307.7733,      0.3733],
        [  4938.5493,    -55.2997,  -2705.8667,  -2094.5667],
        [  3846.3994,   5107.6597,  -2973.0164,  -5757.3398],
        [ -3107.0449,   4115.6987,  -1431.2545,    387.8928],
        [  -193.6262,  -3247.8831,   -585.0303,   4402.7236],
        [ -1435.0062,  -1394.9271,  -2990.1636,   5841.6812],
        [    84.0741,  -1261.3650,    513.6969,    843.0516],
        [   607.7133,   -226.2752,   -448.8474,    332.2622],
        

In [9]:
y * 25000

tensor([[     0.0000,  -1000.0005,      0.0000,    999.9990],
        [     0.0000,  16999.9980,      0.0000, -17000.0000],
        [   999.9990,      0.0000,   -999.9990,      0.0000],
        [ -1000.0005,   -999.9990,   3000.0000,  -1000.0005],
        [  -400.0008,   -599.9998,   1399.9999,   -400.0008],
        [  -999.9990,   6000.0000,   -999.9990,  -3000.0000],
        [     0.0000,      0.0000,  -1800.0006,   2799.9998],
        [     0.0000,      0.0000,  -1000.0020,   1000.0005],
        [     0.0000,  -8700.0000,   8700.0000,      0.0000],
        [   999.9990,  -3000.0000,      0.0000,      0.0000],
        [ 13299.9990,   7699.9990,      0.0000, -20000.0000],
        [     0.0000,      0.0000,  -8300.0000,   8300.0000],
        [     0.0000,   7999.9985,  -8000.0015,      0.0000],
        [ -1999.9996,  -1999.9996,  -1999.9996,   7000.0010],
        [     0.0000,  16000.0010,      0.0000, -16000.0000],
        [     0.0000,      0.0000,  -7700.0005,   7700.0005],
        