In [1]:
import math
import logging
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import uniform_, xavier_uniform_, zeros_
from torch.utils.data import IterableDataset, DataLoader, Dataset
import numpy as np

from utilities.utility_functions import decompress_dict
from settings import MODEL_DATA_PATHS, DATA_BLUEPRINT, DATA_FEATURES, H_PARAMETERS
from supervised_model import Encoder

In [2]:
# setting seeds
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

In [3]:
LOGGER = logging.getLogger(__name__)

In [4]:
class DiplomacyDataset(IterableDataset):
    def __init__(self, file_path: str, shuffle: bool, shuffle_buffer_size: int) -> None:
        super(DiplomacyDataset).__init__()
        
        self.file_path = file_path
        self.shuffle = shuffle
        self.shuffle_buffer_size = shuffle_buffer_size
        
        if self.shuffle and self.shuffle_buffer_size <= 1:
            raise ValueError("Bad shuffle buffer size."\
                            "If you want to shuffle the iterable data, you have to define a positive > 1 shuffle buffer size."\
                            f"\nGot 'shuffle': {self.shuffle}, shuffle_buffer_size: {self.shuffle_buffer_size}")
        
    def sequential_iterator(self):
        # return iter(range(100))
        return open(self.file_path, "r")
        
    def shuffle_iterator(self):
        shuffle_buffer = []
        local_iterator = self.sequential_iterator()
        
        
        try:
            # fill up the initial buffer
            for _ in range(self.shuffle_buffer_size):
                shuffle_buffer.append(next(local_iterator))
        except StopIteration:
            # in the scenario of either too small dataset file or too big buffer size,
            # shrink the shuffle buffer

            LOGGER.info("Either the dataset file is too small, or shuffle buffer size is too big for the file. "\
                        "Shrinking the buffer...")
            self.shuffle_buffer_size = len(shuffle_buffer)
            
        # main loop
        while True:
            remove_index = random.randint(0, len(shuffle_buffer) - 1)
            yield shuffle_buffer.pop(remove_index)
            
            try:
                shuffle_buffer.append(next(local_iterator))
            except StopIteration:
                break
                
        while len(shuffle_buffer) > 0:
            remove_index = random.randint(0, len(shuffle_buffer) - 1)
            yield shuffle_buffer.pop(remove_index)
        
    def __iter__(self):
        
        if(self.shuffle):
            self.iterator = self.shuffle_iterator()
        else:
            self.iterator = self.sequential_iterator()
            
        return self
    
    
    def __next__(self):
        return decompress_dict( next(self.iterator) )

In [5]:
batch_size = H_PARAMETERS["batch_size"]
shuffle = False
shuffle_buffer_size = 1
file_path = "data/model_data/full_dataset_training.txt"

In [6]:
dataset = DiplomacyDataset(file_path = file_path, shuffle = shuffle, shuffle_buffer_size = shuffle_buffer_size)

In [7]:
VARIABLE_LEN_FEATURES = [key for key, value in DATA_BLUEPRINT.items() if value["shape"] == [None]]

def custom_collate_fn(batch: dict, pad_batch: bool = False):
    """
    Modified default pytorch collate function.
    It is meant to collate the received dictionary values onto original keys,
    also pad the non-fixed dimension features to max length of their group per batch.
    
    If you modify the project, please pay attention to this function,
    as it single-purpose and should be extented.
    """
    element = batch[0]
    
    # scalars
    if isinstance(element, float):
        return torch.tensor(batch, dtype=torch.float32)
    elif isinstance(element, int):
        return torch.tensor(batch, dtype=torch.int32)
    elif isinstance(element, str):
        return batch
    
    # lists NOTE: WORKS WITH ONLY 1-D lists!!!
    elif isinstance(element, list):
        first_list_element = element[0]
        
        # call collate_fn recursevly to get an the type of list elements
        list_type = custom_collate_fn(element).dtype
            
        # pad the lists if needed
        if pad_batch:
            max_size = max([len(element) for element in batch])
            batch = [element + [0]*(max_size - len(element)) for element in batch]
        
        return torch.tensor(batch, dtype = list_type)
    
    # call collate_fn recursevely to collate all values of the keys amongst the batch of dictionaries
    elif isinstance(element, dict):
        return dict(
            {
                key: custom_collate_fn(
                    batch = [dictionary[key] for dictionary in batch],
                    pad_batch = key in VARIABLE_LEN_FEATURES
                ) for key in element
            }
        )
    

In [8]:
dataloader = DataLoader(dataset, batch_size = batch_size, collate_fn = custom_collate_fn)

### Model

In [9]:
encoder = Encoder(data_features = DATA_FEATURES, h_params = H_PARAMETERS)

In [12]:
first_attempt = None
i = 0
for batch in dataloader:
    # print(batch)
    # print(batch["request_id"])
    # print(batch["player_seed"], batch["player_seed"].shape)
    # print(batch["board_state"], batch["board_state"].shape)
    # print(batch["board_alignments"], batch["board_alignments"].shape) ## NEED TO DISABLE FLATTENING
    # print(batch["prev_orders_state"], batch["prev_orders_state"].shape)
    # print(batch["decoder_inputs"], batch["decoder_inputs"].shape)
    # print(batch["decoder_lengths"], batch["decoder_lengths"].shape)
    # print(batch["candidates"], batch["candidates"].shape)
    # print(batch["noise"], batch["noise"].shape)
    # print(batch["temperature"], batch["temperature"].shape, batch["temperature"].dtype)
    # print(batch["dropout_rate"], batch["dropout_rate"].shape, batch["dropout_rate"].dtype)
    # print(batch["current_power"], batch["current_power"].shape)
    # print(batch["current_season"], batch["current_season"].shape)
    # print(batch["draw_target"], batch["draw_target"].shape)
    # print(batch["value_target"], batch["value_target"].shape)
    first_attempt = encoder(batch)
    # real_batch = batch
    # first_batch = batch
    
    if i == 0:
        break
    i+=1