In [None]:
import numpy as np
from stable_baselines3.common.buffers import ReplayBuffer

class CustomReplayBuffer(ReplayBuffer):
    def __init__(self, buffer_size, observation_space, action_space, device, n_envs=1, optimize_memory_usage=False):
        super(CustomReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs, optimize_memory_usage)
        
        # Example: Initialize priorities for prioritized experience replay
        self.priorities = np.zeros((buffer_size,), dtype=np.float32)

    def add(self, obs, next_obs, action, reward, done, infos):
        # Call the parent class method to add the transition
        idxs = super(CustomReplayBuffer, self).add(obs, next_obs, action, reward, done, infos)
        
        # Example: Set the maximum priority for the new experience
        self.priorities[idxs] = self.priorities.max() if self.priorities.max() > 0 else 1.0
    
    def sample(self, batch_size, env=None):
        # Example: Sample based on priorities
        if np.sum(self.priorities) == 0:
            probabilities = np.ones_like(self.priorities) / len(self.priorities)
        else:
            probabilities = self.priorities / np.sum(self.priorities)
        
        idxs = np.random.choice(len(self), batch_size, p=probabilities)
        return super(CustomReplayBuffer, self)._get_samples(idxs, env)

    def update_priorities(self, idxs, priorities):
        self.priorities[idxs] = priorities

In [None]:
import gym
from stable_baselines3 import DQN

# Define the environment
env = gym.make('CartPole-v1')

# Initialize the custom replay buffer
buffer_size = 10000
device = 'cuda' if torch.cuda.is_available() else 'cpu'
custom_replay_buffer = CustomReplayBuffer(buffer_size, env.observation_space, env.action_space, device)

# Define the DQN model with the custom replay buffer
model = DQN('MlpPolicy', env, verbose=1, replay_buffer=custom_replay_buffer, learning_rate=1e-3, batch_size=64, target_update_interval=1000)

# Train the agent
model.learn(total_timesteps=100000)

# Save the model
model.save("dqn_cartpole_custom_replay_buffer")

# Load the model
model = DQN.load("dqn_cartpole_custom_replay_buffer")

# Evaluate the agent
obs = env.reset()
for _ in range(1000):
    action, _ = model.predict(obs, deterministic=True)
    obs, rewards, dones, info = env.step(action)
    env.render()

env.close()