In [1]:
from collections import deque
import random
import torch
import torch.nn as nn

In [None]:
class ReplayBuffer:
    """A simple replay buffer to store and sample experiences."""
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        """Store an experience in the buffer."""
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        """Sample a batch of experiences from the buffer."""
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.stack(states),
            torch.tensor(actions, dtype=torch.long),
            torch.tensor(rewards, dtype=torch.float32),
            torch.stack(next_states),
            torch.tensor(dones, dtype=torch.float32)
        )

    def __len__(self):
        return len(self.buffer)


In [None]:
class DQNetwork(nn.Module):
    """DQN model to predict Q-values."""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(DQNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        """Forward pass through the network."""
        x = torch.relu(self.fc1(x))
        q_values = self.fc2(x)
        return q_values


In [None]:
class DeepQAgent:
    """DQN agent with epsilon-greedy strategy and experience replay."""
    def __init__(self, input_dim, hidden_dim, num_actions, lr=0.001, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.1, batch_size=32, memory_capacity=10000):
        self.model = DQNetwork(input_dim, hidden_dim, num_actions)
        self.target_model = DQNetwork(input_dim, hidden_dim, num_actions)  # Target network for stability
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.loss_fn = nn.MSELoss()
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.epsilon_decay = epsilon_decay  # Decay rate for epsilon
        self.epsilon_min = epsilon_min  # Minimum epsilon value
        self.batch_size = batch_size
        self.replay_buffer = ReplayBuffer(memory_capacity)

        # Initialize the target network with the same weights as the main model
        self.update_target_network()

    def select_action(self, state):
        """Select an action using epsilon-greedy policy."""
        if random.random() < self.epsilon:
            return random.randint(0, self.model.fc2.out_features - 1)  # Explore
        else:
            with torch.no_grad():
                q_values = self.model(state.unsqueeze(0))  # Add batch dimension
                return torch.argmax(q_values).item()  # Exploit

    def update_target_network(self):
        """Update the target network's weights."""
        self.target_model.load_state_dict(self.model.state_dict())

    def train(self):
        """Train the DQN using experience replay."""
        if len(self.replay_buffer) < self.batch_size:
            return  # Wait until enough experiences are in the buffer

        # Sample a batch from the replay buffer
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)

        # Calculate current Q-values
        q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)

        # Calculate target Q-values using the target network
        with torch.no_grad():
            max_next_q_values = self.target_model(next_states).max(1)[0]
            target_q_values = rewards + self.gamma * max_next_q_values * (1 - dones)

        # Compute the loss
        loss = self.loss_fn(q_values, target_q_values)

        # Perform backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Decay epsilon
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
