In [1]:
import sys
sys.path.append('..')


import torch 
import numpy as np 
from rlcard.agents.dqn_agent.memory import Memory
from rlcard.agents.dqn_agent.typing import Transition

#### I have removed padding from SequenceMemory, since it is more appropriate down-the-line. For example, with current SinusoidalEmbedding there is no need for padding at all and hence no need for masking.

In [2]:
def post_pad_transition(transitions: list[Transition], length: int) -> list[Transition]:
    """Post-pad the sequence of transitions with the last transition to make it a fixed length"""
    return transitions + [Transition.padding_transition(transitions[-1]) for _ in range(length - len(transitions))]


class SequenceMemory(Memory):
    ''' Memory for saving sequences of transitions
    '''
    def __init__(self, memory_size, batch_size, max_sequence_length: int):
        ''' Initialize
        Args:
            memory_size (int): the size of the memroy buffer
            max_sequence_length (int): the maximum length of the sequence
        '''
        super().__init__(memory_size, batch_size)
        self.max_sequence_length = max_sequence_length
        self._memory = []

    @property
    def memory(self) -> list[Transition]:
        return self._memory

    @property 
    def memory_size(self):
        return len(self.memory)

    def sample(self):
        ''' Sample a minibatch from the replay memory

        Returns:
            state_batch (list): a batch of sequences of states
            action_batch (list): a batch of sequences of actions
            reward_batch (list): a batch of sequences of rewards
            next_state_batch (list): a batch of sequences of states
            done_batch (list): a batch of sequences of dones
        '''
        padded_memory = post_pad_transition(self.memory, self.max_sequence_length)

        # Sample a batch of starting indices of sequences and get sequences up to max_sequence_length
        start_idx = torch.randint(0, self.memory_size, (self.batch_size, ))
        sequences = [padded_memory[i:i+self.max_sequence_length] for i in start_idx] # [batch_size, sequence_length, 5]

        # The processing below is a bit convoluted, but it's just mirroring what the SimpleMemory does.
        def unpack_and_cat_transitions(sequence: list[Transition]):
            return (
                np.array([transition.state for transition in sequence]),
                sequence[-1].action, # Should be the last action in the sequence
                sequence[-1].reward, # TODO (Kacper) figure out if this should be the last reward or some combination of previous rewards
                np.array([transition.next_state for transition in sequence]), # This has to be a sequence 
                sequence[-1].done, # Sequence done if the last state done
                sequence[-1].legal_actions, # Take the last legal action
            )        
        sequences = map(unpack_and_cat_transitions, sequences)
        samples = list(zip(*sequences))
        return tuple(map(np.array, samples[:-1])) + (samples[-1],)

In [3]:
import math 
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from rlcard.agents.dqn_agent.estimator import EstimatorNetwork
from rlcard.agents.dqn_agent.typing import Transition


def padding_mask(s: Tensor) -> Tensor:
    return (s != Transition.padding_value()).all(dim=-1) # TODO we could add a test that no partial padding is present


class SinusoidalPositionalEmbedding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_sequence_length: int = 512):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # From "Attention is All You Need"
        # Alternate sine and cosine of different frequencies and decreasing amplitudes
        position = torch.arange(max_sequence_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        self.positional_embedding = torch.zeros(max_sequence_length, 1, d_model)
        self.positional_embedding[..., 0, 0::2] = torch.sin(position * div_term)
        self.positional_embedding[..., 0, 1::2] = torch.cos(position * div_term)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        return self.positional_embedding[:x.size(0)]
    

class AverageSequencePooling(nn.Module):
    def __init__(self, dim: int = 1):
        super().__init__()
        self.dim = dim

    def forward(self, x: Tensor) -> Tensor:
        return x.mean(dim=self.dim)


class TransformerEstimatorNetwork(EstimatorNetwork):

    def __init__(
        self, 
        num_actions=2, 
        state_shape=None, 
        num_layers: int = 2, 
        d_model: int = 128,
        nhead: int = 8, 
        dim_feedforward: int = 32,
        dropout: float = 0.1,
        max_sequence_length: int = 512, 
    ):
        super().__init__(num_actions=num_actions, state_shape=state_shape)

        # TODO (Kacper) maybe we should add batchnorm before embedding as in the original MLP?
        # TODO (Kacper) also find out whether this embedding method with a linear layer is common
        self.embedding = nn.Linear(self.input_dims, d_model, bias=True)

        # With sinusoidal embedding there is technically no limit on the sequence length. 
        # However, the performance does deteriorate with longer sequences. 
        # Thus, a limit is helpful both for performance, speed, and memory.
        self.max_sequence_length = max_sequence_length
        self.positional_embedding = SinusoidalPositionalEmbedding(d_model=d_model, dropout=dropout, max_sequence_length=self.max_sequence_length)
        self.embedding_dropout = nn.Dropout(p=dropout)

        encoder_layer = TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward,
            dropout=dropout, 
            activation='relu', 
            layer_norm_eps=1e-5, 
            batch_first=True, # [batch, seq, feature]
            norm_first=False, # TODO (Kacper) check if modern version used layer norm prior to attention and feedforward or after
            bias=True, 
        )
        self.encoder = TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers,
            norm=None, # TODO (Kacper) check if modern architectures use layer norm (I don't think so)
            enable_nested_tensor=True,
        )
        self.pooling = AverageSequencePooling(dim=1) # 1 is the sequence dimension
        self.output_linear = nn.Linear(d_model, self.num_actions, bias=True)
        
    
    def forward(self, s: Tensor, pad: bool = True) -> Tensor:
        """
        :param s: Batch of input sequences [seq, batch, feature]
        """
        s = self.embedding_dropout(self.embedding(s) + self.positional_embedding(s))
        mask = padding_mask(s) if pad else None

        s = self.encoder(s, src_key_padding_mask=mask)
        s = self.pooling(s)
        return self.output_linear(s)

In [7]:
state_shape = (12, )
model = TransformerEstimatorNetwork(state_shape=state_shape)

batch_size = 7
sequence_length = 10
input = torch.randn(batch_size, sequence_length, model.input_dims)

In [9]:
input.shape

torch.Size([7, 10, 12])

In [8]:
with torch.no_grad():
    model(input)