In [None]:
import torch
import gym
import torch.nn as nn
import torch.optim as optim
import numpy as np
import copy


class TransformerExperienceReplay:
    def __init__(self, capacity, input_dim):
        self.capacity = capacity
        self.buffer = [None] * capacity
        self.position = 0
        self.input_dim = input_dim
        self.transformer = nn.Transformer(d_model=input_dim, nhead=4, 
                                          num_encoder_layers=2, num_decoder_layers=2)


    def push(self, experience):
        state, action, reward, next_state, done = experience
        state_tensor = state.clone().detach()
        next_state_tensor = next_state.clone().detach()

        action_tensor = torch.tensor([action], dtype=torch.float32)
        reward_tensor = torch.tensor([reward], dtype=torch.float32)
        done_tensor = torch.tensor([done], dtype=torch.float32)
        self.buffer[self.position] = (state_tensor, action_tensor, reward_tensor, next_state_tensor, done_tensor)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size, current_state):
        if len(self.buffer) < batch_size:
            return []

        query = self._prepare_query(current_state)
        keys_values = self._prepare_keys_values()

        attention_output = self.transformer(query, keys_values, keys_values)
        attention_weights = torch.softmax(attention_output.squeeze(), dim=0)

        sampled_indices = torch.multinomial(attention_weights, batch_size, replacement=True)
        return [self.buffer[i] for i in sampled_indices]

    def _prepare_query(self, current_state, batch_size=1, seq_length=1):
        # Flatten current_state and adjust its length to match d_model
        current_state_flat = current_state.view(-1)
        if current_state_flat.nelement() < self.input_dim:
            current_state_padded = torch.nn.functional.pad(current_state_flat, (0, self.input_dim - current_state_flat.nelement()), "constant", 0)
        else:
            current_state_padded = current_state_flat[:self.input_dim]

        # Reshape to [batch_size, seq_length, feature_dim]
        query = current_state_padded.view(batch_size, seq_length, self.input_dim)
        return query


    def _prepare_keys_values(self, batch_size=1, seq_length=1):
        processed_experiences = []
        for exp in self.buffer:
            if exp is not None:
                state, action, reward, next_state, done = exp
                exp_tensor = torch.cat([s.view(-1) for s in [state, next_state, action, reward, done]])
                if exp_tensor.nelement() < self.input_dim:
                    exp_tensor_padded = torch.nn.functional.pad(exp_tensor, (0, self.input_dim - exp_tensor.nelement()), "constant", 0)
                else:
                    exp_tensor_padded = exp_tensor[:self.input_dim]

                exp_tensor_padded = exp_tensor_padded.view(batch_size, seq_length, self.input_dim)
                processed_experiences.append(exp_tensor_padded)

        if not processed_experiences:
            return torch.zeros((batch_size, seq_length, self.input_dim))

        keys_values = torch.cat(processed_experiences, dim=0)
        return keys_values



    def __len__(self):
        return len(self.buffer)

    def count_parameters(self):
        return sum(p.numel() for p in self.transformer.parameters() if p.requires_grad)

