In [1]:
import numpy as np

DATASET_PATH = "dataset/riichi_ds_v02/0003000.npz"

loads = np.load(DATASET_PATH)
# x, y = torch.tensor(loads["x"].astype(np.float32)), torch.tensor(loads["y"].astype(np.float32))

In [2]:
print(loads.keys())

KeysView(NpzFile 'dataset/riichi_ds_v02/0003000.npz' with keys: x_hand, x_hand_red, x_river, x_river_red, x_river_riichi...)


In [3]:
varnames = ("x_hand", "x_hand_red", "x_river", 
"x_river_red", "x_river_riichi", "x_meld", 
"x_meld_red5", "x_meld_throw", "x_meld_throw_red", 
"x_dora", "x_score", "x_pool", "x_winds", "y")

for varname in varnames:
    print(varname)
    print(loads[varname].shape)

x_hand
(799885, 4, 4, 9)
x_hand_red
(799885, 3)
x_river
(799885, 4, 4, 24, 9)
x_river_red
(799885, 4, 24)
x_river_riichi
(799885, 4, 24)
x_meld
(799885, 4, 4, 4, 9)
x_meld_red5
(799885, 4, 3)
x_meld_throw
(799885, 4, 4, 4, 9)
x_meld_throw_red
(799885, 4, 3)
x_dora
(799885, 4, 4, 9)
x_score
(799885, 4)
x_pool
(799885,)
x_winds
(799885, 2, 9)
y
(799885, 4)


In [4]:
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import glob

class MahjongDataset(Dataset):
    def __init__(self, data_paths):
        self.x_dict = {
            "x_hand": None, 
            "x_hand_red": None, 
            "x_river": None, 
            "x_river_red": None, 
            "x_river_riichi": None, 
            "x_meld": None, 
            "x_meld_red5": None, 
            "x_meld_throw": None, 
            "x_meld_throw_red": None, 
            "x_dora": None, 
            "x_score": None, 
            "x_pool": None, 
            "x_winds": None}
        self.y = None
        self.loads = list(map(np.load, data_paths))
        self.loads = [np.load(path) for path in data_paths]
        
        self.load_index = -1
        self.load_next()

    def unload(self):
        self.y = None
        for key in self.x_dict.keys():
            self.x_dict[key] = None
    
    def load_next(self):
        self.load_index += 1
        self.load_index %= len(self.loads)
        self.y = self.loads[self.load_index]["y"].astype(np.float32)
        for key in self.x_dict.keys():
            self.x_dict[key] = self.loads[self.load_index][key].astype(np.float32)
        
        # Data pre-processing
        # Divide by 250 to average out the score values
        self.y /= 250
        self.x_dict["x_score"] /= 250
        self.x_dict["x_pool"] /= 250
        # Resize all input to B, X, 24, 9
        self.x_dict["x_hand"] = np.repeat(self.x_dict["x_hand"], 6, -2)
        self.x_dict["x_hand_red"] = np.expand_dims(self.x_dict["x_hand_red"], (-2, -3))
        self.x_dict["x_hand_red"] = np.repeat(self.x_dict["x_hand_red"], 3, -1)
        self.x_dict["x_hand_red"] = np.repeat(self.x_dict["x_hand_red"], 24, -2)
        print(self.x_dict["x_hand"].shape)
        print(self.x_dict["x_hand_red"].shape)
        # To predict score difference at the end of the round
        self.y -= self.x_dict["x_score"]  
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        x = {key: self.x_dict[key][idx] for key in self.x_dict.keys()}
        return x, self.y

# train_ds = MahjongDataset(sorted(glob.glob("dataset/riichi_ds_v02/*.npz"))[:30])
test_ds = MahjongDataset(sorted(glob.glob("dataset/riichi_ds_v02/*.npz"))[30:])

(262763, 4, 24, 9)
(262763, 1, 24, 9)


In [24]:
test_ds[:8][0]["x_hand_red"].shape

(8, 1, 1, 3)

In [20]:
test_ds[:4][0]["x_hand"]

array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 1., 1.],
         [0., 0., 0., ..., 0., 1., 1.],
         [0., 0., 0., ..., 0., 1., 1.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 1., ..., 1., 1., 0.],
         [0., 0., 1., ..., 1., 1., 0.],
         [0., 0., 1., ..., 1., 1., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 0., ..., 0., 0., 0.],
         [1., 0., 

In [12]:
len(test_ds)

262763

In [13]:
test_ds.load_next()

(270608, 4, 4, 9)


In [14]:
len(test_ds)

270608