In [1]:
import math
import logging
import random

import torch
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader, Dataset
import numpy as np

from utilities.utility_functions import decompress_dict, pad_tensor
from settings import MODEL_DATA_PATHS, DATA_BLUEPRINT, DATA_FEATURES

In [2]:
LOGGER = logging.getLogger(__name__)

In [3]:
class DiplomacyDataset(IterableDataset):
    def __init__(self, file_path: str, shuffle: bool, shuffle_buffer_size: int) -> None:
        super(DiplomacyDataset).__init__()
        
        self.file_path = file_path
        self.shuffle = shuffle
        self.shuffle_buffer_size = shuffle_buffer_size
        
        if self.shuffle and self.shuffle_buffer_size <= 1:
            raise ValueError("Bad shuffle buffer size."\
                            "If you want to shuffle the iterable data, you have to define a positive > 1 shuffle buffer size."\
                            f"\nGot 'shuffle': {self.shuffle}, shuffle_buffer_size: {self.shuffle_buffer_size}")
        
    def sequential_iterator(self):
        # return iter(range(100))
        return open(self.file_path, "r")
        
    def shuffle_iterator(self):
        shuffle_buffer = []
        local_iterator = self.sequential_iterator()
        
        
        try:
            # fill up the initial buffer
            for _ in range(self.shuffle_buffer_size):
                shuffle_buffer.append(next(local_iterator))
        except StopIteration:
            # in the scenario of either too small dataset file or too big buffer size,
            # shrink the shuffle buffer

            LOGGER.info("Either the dataset file is too small, or shuffle buffer size is too big for the file. "\
                        "Shrinking the buffer...")
            self.shuffle_buffer_size = len(shuffle_buffer)
            
        # main loop
        while True:
            remove_index = random.randint(0, len(shuffle_buffer) - 1)
            yield shuffle_buffer.pop(remove_index)
            
            try:
                shuffle_buffer.append(next(local_iterator))
            except StopIteration:
                break
                
        while len(shuffle_buffer) > 0:
            remove_index = random.randint(0, len(shuffle_buffer) - 1)
            yield shuffle_buffer.pop(remove_index)
        
    def __iter__(self):
        
        if(self.shuffle):
            self.iterator = self.shuffle_iterator()
        else:
            self.iterator = self.sequential_iterator()
            
        return self
    
    
    def __next__(self):
        return decompress_dict( next(self.iterator) )

In [20]:
batch_size = 120
shuffle = False
shuffle_buffer_size = 1
file_path = "data/model_data/full_dataset_training.txt"

In [21]:
dataset = DiplomacyDataset(file_path = file_path, shuffle = shuffle, shuffle_buffer_size = shuffle_buffer_size)

In [22]:
VARIABLE_LEN_FEATURES = [key for key, value in DATA_BLUEPRINT.items() if value["shape"] == [None]]

def custom_collate_fn(batch: dict, pad_batch: bool = False):
    """
    Modified default pytorch collate function.
    It is meant to collate the received dictionary values onto original keys,
    also pad the non-fixed dimension features to max length of their group per batch.
    
    If you modify the project, please pay attention to this function,
    as it single-purpose and should be extented.
    """
    element = batch[0]
    
    # scalars
    if isinstance(element, float):
        return torch.tensor(batch, dtype=torch.float32)
    elif isinstance(element, int):
        return torch.tensor(batch, dtype=torch.int32)
    elif isinstance(element, str):
        return batch
    
    # lists NOTE: WORKS WITH ONLY 1-D lists!!!
    elif isinstance(element, list):
        first_list_element = element[0]
        
        # call collate_fn recursevly to get an the type of list elements
        list_type = custom_collate_fn(element).dtype
            
        # pad the lists if needed
        if pad_batch:
            max_size = max([len(element) for element in batch])
            batch = [element + [0]*(max_size - len(element)) for element in batch]
        
        return torch.tensor(batch, dtype = list_type)
    
    # call collate_fn recursevely to collate all values of the keys amongst the batch of dictionaries
    elif isinstance(element, dict):
        return dict(
            {
                key: custom_collate_fn(
                    batch = [dictionary[key] for dictionary in batch],
                    pad_batch = key in VARIABLE_LEN_FEATURES
                ) for key in element
            }
        )
    

In [23]:
VARIABLE_LEN_FEATURES

['board_alignments', 'decoder_inputs', 'candidates']

In [24]:
dataloader = DataLoader(dataset, batch_size = batch_size, collate_fn = custom_collate_fn)

### Model

In [92]:
class Encoder(torch.nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        pass
    
    def forward(self, inputs: dict):
        
        # the inputs are:
        #board_state - torch.float - (b, N_NODES, N_FEATURES)
        #board_alignments - torch.float - (b, N_NODES * len)
        #prev_orders_state - torch.float - (b, N_PREV_ORDERS, N_NODES, N_ORDERS_FEATURES)
        #decoder_inputs - torch.int - (b, <= 1 + N_SUPPLY_CENTERS)
        #decoder_lengths - torch.int - (b,)
        #candidates - torch.int - (b, n_locs * MAX_CANDIDATES)
        #current_power - torch.int - (b,)
        #current_season - torch.int - (b,)
        #dropout_rates - torch.int - (b,)
        
        # recast some features to float dtype
        inputs["board_state"] = inputs["board_state"].to(torch.float32)
        inputs["board_alignments"] = inputs["board_alignments"].to(torch.float32)
        inputs["prev_orders_state"] = inputs["prev_orders_state"].to(torch.float32)
        
        # variables for data processing
        batch_size = inputs["board_state"].shape[0]
        max_decoder_length = int(torch.max(inputs["decoder_lengths"]))
        
         # Reshaping board alignments - NEED TO REMOVE SINCE IT"S MANUAL DE-FLATTENING
        inputs["board_alignments"] = torch.reshape(
            inputs["board_alignments"], (batch_size, -1, DATA_FEATURES["N_NODES"])
        )
              
        inputs["board_alignments"] /= torch.maximum(
            torch.tensor(1.), torch.sum(
                inputs["board_alignments"], dim=-1, keepdims=True
            )
        )
        
          # this below is litteraly useless!!!!
#         # Overriding dropout_rates if pholder('dropout_rate') > 0
#         dropout_rates = tf.cond(tf.greater(pholder('dropout_rate'), 0.),
#                                 true_fn=lambda: tf.zeros_like(dropout_rates) + pholder('dropout_rate'),
#                                 false_fn=lambda: dropout_rates)

        # Padding decoder_inputs and candidates
        # THIS BELOW COULD BE USELESS useless!!!!
        # implement this if problems arise
        inputs["board_alignments"] = pad_tensor(inputs["board_alignments"], axis = 1, min_size = max_decoder_length)
        inputs["decoder_inputs"] = pad_tensor(inputs["decoder_inputs"], axis = -1, min_size = 2)
        inputs["candidates"] = pad_tensor(inputs["candidates"], axis = -1, min_size = DATA_FEATURES["MAX_CANDIDATES"])


        # Making sure all RNN lengths are at least 1
        # No need to trim, because the fields are variable length
        inputs["raw_decoder_lengths"] = inputs["decoder_lengths"]
        inputs["decoder_lengths"] = torch.maximum(torch.tensor(1), inputs["decoder_lengths"])
        
        # Placeholders
        # TODO: figure out what is this decoder_type used for
        # decoder_type = tf.reduce_max(pholder('decoder_type'))
        # this is possibly useless here
        # is_training = pholder('is_training')

        # Reshaping candidates
        inputs["candidates"] = torch.reshape(inputs["candidates"], (batch_size, -1, DATA_FEATURES["MAX_CANDIDATES"]))
        inputs["candidates"] = inputs["candidates"][:, :max_decoder_length, :] # torch.int - (b, n_locs, MAX_CANDIDATES)      
                
        # Computing FiLM Gammas and Betas
                with tf.variable_scope('film_scope'):
                    power_embedding = uniform(name='power_embedding',
                                              shape=[NB_POWERS, hps('power_emb_size')],
                                              scale=1.)
                    current_power_mask = tf.one_hot(current_power, NB_POWERS, dtype=tf.float32)
                    current_power_embedding = tf.reduce_sum(power_embedding[None]
                                                            * current_power_mask[:, :, None], axis=1)  # (b, power_emb)
                    film_embedding_input = current_power_embedding

                    # Also conditioning on current_season
                    season_embedding = uniform(name='season_embedding',
                                               shape=[NB_SEASONS, hps('season_emb_size')],
                                               scale=1.)
                    current_season_mask = tf.one_hot(current_season, NB_SEASONS, dtype=tf.float32)
                    current_season_embedding = tf.reduce_sum(season_embedding[None]                 # (b,season_emb)
                                                             * current_season_mask[:, :, None], axis=1)
                    film_embedding_input = tf.concat([film_embedding_input, current_season_embedding], axis=1)

                    film_output_dims = [hps('gcn_size')] * (hps('nb_graph_conv') - 1) + [hps('attn_size') // 2]

                    # For board_state
                    board_film_weights = tf.layers.Dense(units=2 * sum(film_output_dims),               # (b, 1, 750)
                                                         use_bias=True,
                                                         activation=None)(film_embedding_input)[:, None, :]
                    board_film_gammas, board_film_betas = tf.split(board_film_weights, 2, axis=2)       # (b, 1, 750)
                    board_film_gammas = tf.split(board_film_gammas, film_output_dims, axis=2)
                    board_film_betas = tf.split(board_film_betas, film_output_dims, axis=2)

                    # For prev_orders
                    prev_ord_film_weights = tf.layers.Dense(units=2 * sum(film_output_dims),            # (b, 1, 750)
                                                            use_bias=True,
                                                            activation=None)(film_embedding_input)[:, None, :]
                    prev_ord_film_weights = tf.tile(prev_ord_film_weights, [NB_PREV_ORDERS, 1, 1])      # (n_pr, 1, 750)
                    prev_ord_film_gammas, prev_ord_film_betas = tf.split(prev_ord_film_weights, 2, axis=2)
                    prev_ord_film_gammas = tf.split(prev_ord_film_gammas, film_output_dims, axis=2)
                    prev_ord_film_betas = tf.split(prev_ord_film_betas, film_output_dims, axis=2)

                    # Storing as temporary output
                    self.add_output('_board_state_conv_film_gammas', board_film_gammas)
                    self.add_output('_board_state_conv_film_betas', board_film_betas)
                    self.add_output('_prev_orders_conv_film_gammas', prev_ord_film_gammas)
                    self.add_output('_prev_orders_conv_film_betas', prev_ord_film_betas)


In [93]:
encoder = Encoder()

In [94]:
real_batch = None

i = 0
for batch in dataloader:
    # print(batch)
    # print(batch["request_id"])
    # print(batch["player_seed"], batch["player_seed"].shape)
    # print(batch["board_state"], batch["board_state"].shape)
    # print(batch["board_alignments"], batch["board_alignments"].shape) ## NEED TO DISABLE FLATTENING
    # print(batch["prev_orders_state"], batch["prev_orders_state"].shape)
    # print(batch["decoder_inputs"], batch["decoder_inputs"].shape)
    # print(batch["decoder_lengths"], batch["decoder_lengths"].shape)
    # print(batch["candidates"], batch["candidates"].shape)
    # print(batch["noise"], batch["noise"].shape)
    # print(batch["temperature"], batch["temperature"].shape, batch["temperature"].dtype)
    # print(batch["dropout_rate"], batch["dropout_rate"].shape, batch["dropout_rate"].dtype)
    # print(batch["current_power"], batch["current_power"].shape)
    # print(batch["current_season"], batch["current_season"].shape)
    # print(batch["draw_target"], batch["draw_target"].shape)
    # print(batch["value_target"], batch["value_target"].shape)
    real_batch = batch
    encoder(real_batch)
    if i == 1:
        break
    i+=1

torch.Size([120, 4080])
torch.Size([120, 17, 240])
torch.Size([120, 17, 240])
17
torch.Size([120, 3120])
torch.Size([120, 13, 240])
torch.Size([120, 13, 240])
13
