In [1]:
import math
import logging
import random

import torch
from torch.utils.data import IterableDataset, DataLoader, Dataset
import numpy as np

from utilities.utility_functions import decompress_dict
from settings import MODEL_DATA_PATHS, DATA_BLUEPRINT

In [2]:
LOGGER = logging.getLogger(__name__)

In [16]:
class DiplomacyDataset(IterableDataset):
    def __init__(self, file_path, shuffle, shuffle_buffer_size) -> None:
        super(DiplomacyDataset).__init__()
        
        self.file_path = file_path
        self.shuffle = shuffle
        self.shuffle_buffer_size = shuffle_buffer_size
        
        if self.shuffle and self.shuffle_buffer_size <= 1:
            raise ValueError("Bad shuffle buffer size."\
                            "If you want to shuffle the iterable data, you have to define a positive > 1 shuffle buffer size."\
                            f"\nGot 'shuffle': {self.shuffle}, shuffle_buffer_size: {self.shuffle_buffer_size}")
        
    def sequential_iterator(self):
        # return iter(range(100))
        return open(self.file_path, "r")
        
    def shuffle_iterator(self):
        shuffle_buffer = []
        local_iterator = self.sequential_iterator()
        
        
        try:
            # fill up the initial buffer
            for _ in range(self.shuffle_buffer_size):
                shuffle_buffer.append(next(local_iterator))
        except StopIteration:
            # in the scenario of either too small dataset file or too big buffer size,
            # shrink the shuffle buffer

            LOGGER.info("Either the dataset file is too small, or shuffle buffer size is too big for the file. "\
                        "Shrinking the buffer...")
            self.shuffle_buffer_size = len(shuffle_buffer)
            
        # main loop
        while True:
            remove_index = random.randint(0, len(shuffle_buffer) - 1)
            yield shuffle_buffer.pop(remove_index)
            
            try:
                shuffle_buffer.append(next(local_iterator))
            except StopIteration:
                break
                
        while len(shuffle_buffer) > 0:
            remove_index = random.randint(0, len(shuffle_buffer) - 1)
            yield shuffle_buffer.pop(remove_index)
        
    def __iter__(self):
        
        if(self.shuffle):
            self.iterator = self.shuffle_iterator()
        else:
            self.iterator = self.sequential_iterator()
            
        return self
    
    
    def __next__(self):
        item = next(self.iterator)
        item = decompress_dict(item)
        # print(item["decoer_inputs"])
        # return {"a": item["request_id"], "b": item["decoder_lengths"]}
        # return (12, "ba", np.array([4,5,6]))
        return item

In [17]:
batch_size = 1
shuffle = True
shuffle_buffer_size = 10
file_path = "data/model_data/full_dataset_training.txt"

In [18]:
dataset = DiplomacyDataset(file_path = file_path, shuffle = shuffle, shuffle_buffer_size = shuffle_buffer_size)

In [19]:
def custom_collate_fn(batch):
    """
    Modified default pytorch collate function.
    It is meant to collate the received dictionary values onto original keys,
    also pad the non-fixed dimension features to max length of their group per batch.
    
    If you modify the project, please pay attention to this function,
    as it single-purpose and should be extented.
    """
    element = batch[0]
    
    # scalars
    if isinstance(element, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(element, int):
        return torch.tensor(batch)
    elif isinstance(element, str):
        return batch
    
    # data structures
    elif isinstance(element, list):
        
        first_list_element = element[0]
        if isinstance(first_list_element, list):
            list_type = custom_collate_fn(first_list_element).dtype
        else:
            list_type = custom_collate_fn(element).dtype
            
        return torch.tensor(batch, dtype = list_type)
        
    elif isinstance(element, dict):
        return dict({key: custom_collate_fn([dictionary[key] for dictionary in batch]) for key in element})

In [20]:
dataloader = DataLoader(dataset, batch_size = batch_size, collate_fn = custom_collate_fn)

In [21]:
DATA_BLUEPRINT

{'request_id': {'shape': [], 'type': 'static'},
 'player_seed': {'shape': [], 'type': 'static'},
 'board_state': {'shape': [81, 35], 'type': 'static'},
 'board_alignments': {'shape': [None], 'type': 'variable'},
 'prev_orders_state': {'shape': [1, 81, 40], 'type': 'static'},
 'decoder_inputs': {'shape': [None], 'type': 'variable'},
 'decoder_lengths': {'shape': [], 'type': 'static'},
 'candidates': {'shape': [None], 'type': 'variable'},
 'noise': {'shape': [], 'type': 'static'},
 'temperature': {'shape': [], 'type': 'static'},
 'dropout_rate': {'shape': [], 'type': 'static'},
 'current_power': {'shape': [], 'type': 'static'},
 'current_season': {'shape': [], 'type': 'static'},
 'draw_target': {'shape': [], 'type': 'static'},
 'value_target': {'shape': [], 'type': 'static'}}

In [23]:
i = 0
ba = None
for batch in dataloader:
    ba = batch
    print(batch["board_alignments"].shape)
    
    if i == 5:
        break
    i+=1

torch.Size([1, 2, 81])
torch.Size([1, 3, 81])
torch.Size([1, 3, 81])
torch.Size([1, 4, 81])
torch.Size([1, 6, 81])
torch.Size([1, 4, 81])


In [15]:
ba["board_alignments"].numpy().shape

(2, 3, 81)

In [24]:
np.array([np.array([1,2]), np.array([3,4])])

array([[1, 2],
       [3, 4]])