In [73]:
import math
import logging
import random

import torch
from torch.utils.data import IterableDataset, DataLoader, Dataset
import numpy as np

from utilities.utility_functions import decompress_dict
from settings import MODEL_DATA_PATHS

In [2]:
LOGGER = logging.getLogger(__name__)

In [148]:
class DiplomacyDataset(IterableDataset):
    def __init__(self, file_path, shuffle, shuffle_buffer_size) -> None:
        super(DiplomacyDataset).__init__()
        
        self.file_path = file_path
        self.shuffle = shuffle
        self.shuffle_buffer_size = shuffle_buffer_size
        
        if self.shuffle and self.shuffle_buffer_size <= 1:
            raise(Exception("""Bad shuffle buffer size.
If you want to shuffle the iterable data, you have to define a positive > 1 shuffle buffer size."""))
        
    def sequential_iterator(self):
        # return iter(range(100))
        return open(self.file_path, "r")
        
    def shuffle_iterator(self):
        shuffle_buffer = []
        local_iterator = self.sequential_iterator()
        
        
        try:
            # fill up the initial buffer
            for _ in range(self.shuffle_buffer_size):
                shuffle_buffer.append(next(local_iterator))
        except StopIteration:
            # in the scenario of either too small dataset file or too big buffer size,
            # shrink the shuffle buffer

            LOGGER.info("Either the dataset file is too small, or shuffle buffer size is too big for the file.\
Shrinking the buffer...")
            self.shuffle_buffer_size = len(shuffle_buffer)
            
        # main loop
        while True:
            remove_index = random.randint(0, len(shuffle_buffer) - 1)
            yield shuffle_buffer.pop(remove_index)
            
            try:
                shuffle_buffer.append(next(local_iterator))
            except StopIteration:
                break
                
        while len(shuffle_buffer) > 0:
            remove_index = random.randint(0, len(shuffle_buffer) - 1)
            yield shuffle_buffer.pop(remove_index)
        
    def __iter__(self):
        
        if(self.shuffle):
            self.iterator = self.shuffle_iterator()
        else:
            self.iterator = self.sequential_iterator()
            
        return self
    
    
    def __next__(self):
        item = next(self.iterator)
        item = decompress_dict(item, ["board_state", "prev_orders_state", "board_alignments", "candidates"])
        return np.array(item["decoder_inputs"]), 1
            

In [149]:
batch_size = 3
shuffle = False
shuffle_buffer_size = 1 * batch_size
file_path = "data/model_data/full_dataset_training.txt"

In [150]:
dataset = DiplomacyDataset(file_path = file_path, shuffle = shuffle, shuffle_buffer_size = shuffle_buffer_size)

In [151]:
dataloader = DataLoader(dataset, batch_size = batch_size)

In [152]:
bruh = iter(dataloader)

In [153]:
a = next(bruh)

In [156]:
x, y = a