In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import gym

# Define the DRQN model
class DRQN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(DRQN, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden_state):
        lstm_out, hidden_state = self.lstm(x, hidden_state)
        q_values = self.fc(lstm_out)
        return q_values, hidden_state

# Define a replay memory buffer
class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)

    def push(self, transition):
        self.memory.append(transition)

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# Define the DRQN agent
class DRQNAgent:
    def __init__(self, input_size, hidden_size, output_size, lr, gamma, batch_size, memory_capacity):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = DRQN(input_size, hidden_size, output_size).to(self.device)
        self.target_model = DRQN(input_size, hidden_size, output_size).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.memory = ReplayMemory(memory_capacity)
        self.gamma = gamma
        self.batch_size = batch_size
        self.hidden_state = None

    def select_action(self, state):
        q_values, self.hidden_state = self.model(state.unsqueeze(0).unsqueeze(0).to(self.device), self.hidden_state)
        action = q_values.squeeze(0).argmax().item()
        return action

    def store_transition(self, state, action, next_state, reward, done):
        self.memory.push((state, action, next_state, reward, done))

    def train(self):
        if len(self.memory) < self.batch_size:
            return

        transitions = self.memory.sample(self.batch_size)
        batch = zip(*transitions)
        state_batch = torch.stack(batch[0]).to(self.device)
        action_batch = torch.tensor(batch[1]).to(self.device)
        next_state_batch = torch.stack(batch[2]).to(self.device)
        reward_batch = torch.tensor(batch[3]).to(self.device)
        done_batch = torch.tensor(batch[4]).to(self.device)

        q_values, _ = self.model(state_batch, None)
        q_values = q_values.gather(1, action_batch.unsqueeze(1))

        next_q_values, _ = self.target_model(next_state_batch, None)
        max_next_q_values = next_q_values.max(1)[0].detach()
        expected_q_values = reward_batch + self.gamma * max_next_q_values * (1 - done_batch)

        loss = nn.MSELoss()(q_values.squeeze(), expected_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_model.load_state_dict(self.model.state_dict())

# Define training parameters
input_size = 4  # Input size for CartPole environment
hidden_size = 128  # Hidden size for LSTM
output_size = 2  # Output size for CartPole environment
lr = 0.001  # Learning rate
gamma = 0.99  # Discount factor
batch_size = 32  # Batch size
memory_capacity = 10000  # Replay memory capacity
epsilon_start = 1.0  # Starting value of epsilon for epsilon-greedy policy
epsilon_end = 0.01  # Minimum value of epsilon
epsilon_decay = 0.995  # Decay rate of epsilon

# Initialize environment and agent
env = gym.make('CartPole-v1')
agent = DRQNAgent(input_size, hidden_size, output_size, lr, gamma, batch_size, memory_capacity)

# Training loop
num_episodes = 1000
for episode in range(num_episodes):
    state = env.reset()
    state = torch.tensor(state, dtype=torch.float32)
    agent.hidden_state = None
    episode_reward = 0
    done = False

    while not done:
        epsilon = max(epsilon_end, epsilon_start * epsilon_decay ** episode)
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            action = agent.select_action(state)

        next_state, reward, done, _ = env.step(action)
        next_state = torch.tensor(next_state, dtype=torch.float32)
        reward = torch.tensor(reward, dtype=torch.float32)

        agent.store_transition(state, action, next_state, reward, done)
        state = next_state
        episode_reward += reward.item()

        agent.train()

    if episode % 10 == 0:
        agent.update_target_network()

    print(f"Episode {episode + 1}, Reward: {episode_reward}")

env.close()
