In [1]:
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import glob

class MahjongDataset(Dataset):
    def __init__(self, data_paths):
        self.x_dict = {
            "x_hand": None, 
            "x_hand_red": None, 
            "x_river": None, 
            "x_river_red": None, 
            "x_river_riichi": None, 
            "x_meld": None, 
            "x_meld_red5": None, 
            "x_meld_throw": None, 
            "x_meld_throw_red": None, 
            "x_dora": None, 
            "x_score": None, 
            "x_pool": None, 
            "x_winds": None}
        self.y = None
        self.loads = list(map(np.load, data_paths))
        self.loads = [np.load(path) for path in data_paths]
        
        self.load_index = -1
        self.load_next()

    def unload(self):
        self.y = None
        self.x = None
        for key in self.x_dict.keys():
            self.x_dict[key] = None
    
    def load_next(self):
        self.load_index += 1
        self.load_index %= len(self.loads)
        self.y = self.loads[self.load_index]["y"][:2000].astype(np.float32)
        for key in self.x_dict.keys():
            self.x_dict[key] = self.loads[self.load_index][key][:2000].astype(np.float32)
        
        # Data pre-processing
        # Divide by 250 to average out the score values
        self.y /= 250
        self.x_dict["x_score"] /= 250
        self.x_dict["x_pool"] /= 250
        # To predict score difference at the end of the round
        self.y -= self.x_dict["x_score"]  
        # Resize all input to B, X, 24, 9
        self.x_dict["x_hand"] = np.repeat(self.x_dict["x_hand"], 6, -2)
        self.x_dict["x_hand_red"] = np.expand_dims(self.x_dict["x_hand_red"], (-1, -2))
        self.x_dict["x_hand_red"] = np.repeat(self.x_dict["x_hand_red"], 9, -1)
        self.x_dict["x_hand_red"] = np.repeat(self.x_dict["x_hand_red"], 24, -2)
        self.x_dict["x_river"] = np.reshape(self.x_dict["x_river"], (*self.x_dict["x_river"].shape[:1], -1, *self.x_dict["x_river"].shape[3:]))
        self.x_dict["x_river_red"] = np.expand_dims(self.x_dict["x_river_red"], -1)
        self.x_dict["x_river_red"] = np.repeat(self.x_dict["x_river_red"], 9, -1)
        self.x_dict["x_river_riichi"] = np.expand_dims(self.x_dict["x_river_riichi"], -1)
        self.x_dict["x_river_riichi"] = np.repeat(self.x_dict["x_river_riichi"], 9, -1)
        self.x_dict["x_meld"] = np.reshape(self.x_dict["x_meld"], (*self.x_dict["x_meld"].shape[:1], -1, *self.x_dict["x_meld"].shape[3:]))
        self.x_dict["x_meld"] = np.repeat(self.x_dict["x_meld"], 6, -2)
        self.x_dict["x_meld_red5"] = np.reshape(self.x_dict["x_meld_red5"], (self.x_dict["x_meld_red5"].shape[0], -1))
        self.x_dict["x_meld_red5"] = np.expand_dims(self.x_dict["x_meld_red5"], (-1, -2))
        self.x_dict["x_meld_red5"] = np.repeat(self.x_dict["x_meld_red5"], 9, -1)
        self.x_dict["x_meld_red5"] = np.repeat(self.x_dict["x_meld_red5"], 24, -2)
        self.x_dict["x_meld_throw"] = np.reshape(self.x_dict["x_meld_throw"], (self.x_dict["x_meld_throw"].shape[0], -1, *self.x_dict["x_meld_throw"].shape[3:]))
        self.x_dict["x_meld_throw"] = np.repeat(self.x_dict["x_meld_throw"], 6, -2)
        self.x_dict["x_meld_throw_red"] = np.reshape(self.x_dict["x_meld_throw_red"], (self.x_dict["x_meld_throw_red"].shape[0], -1))
        self.x_dict["x_meld_throw_red"] = np.expand_dims(self.x_dict["x_meld_throw_red"], (-1, -2))
        self.x_dict["x_meld_throw_red"] = np.repeat(self.x_dict["x_meld_throw_red"], 9, -1)
        self.x_dict["x_meld_throw_red"] = np.repeat(self.x_dict["x_meld_throw_red"], 24, -2)
        self.x_dict["x_dora"] = np.repeat(self.x_dict["x_dora"], 6, -2)
        self.x_dict["x_score"] = np.expand_dims(self.x_dict["x_score"], (-1, -2))
        self.x_dict["x_score"] = np.repeat(self.x_dict["x_score"], 9, -1)
        self.x_dict["x_score"] = np.repeat(self.x_dict["x_score"], 24, -2)
        self.x_dict["x_pool"] = np.expand_dims(self.x_dict["x_pool"], (-1, -2, -3))
        self.x_dict["x_pool"] = np.repeat(self.x_dict["x_pool"], 9, -1)
        self.x_dict["x_pool"] = np.repeat(self.x_dict["x_pool"], 24, -2)
        self.x_dict["x_winds"] = np.expand_dims(self.x_dict["x_winds"], -2)
        self.x_dict["x_winds"] = np.repeat(self.x_dict["x_winds"], 24, -2)
        
        # for varname in self.x_dict.keys():
        #     print(varname, self.x_dict[varname].shape)

        self.x = np.concatenate(list(self.x_dict.values()), axis=1)

        # Unload dict to save memory
        for key in self.x_dict.keys():
            self.x_dict[key] = None
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.x[idx], self.y[idx]

# train_ds = MahjongDataset(sorted(glob.glob("dataset/riichi_ds_v02/*.npz"))[:30])
test_ds = MahjongDataset(sorted(glob.glob("dataset/riichi_ds_v02/*.npz"))[30:])

In [2]:
test_ds[:8][0].shape, test_ds[:8][1].shape

((8, 98, 24, 9), (8, 4))