<a href="https://colab.research.google.com/github/Neelavo/RL_Basics/blob/main/dqn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

In [None]:
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 30)
        self.fc2 = nn.Linear(30, 30)
        self.fc3 = nn.Linear(30, action_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
class DQNAgent:
    def __init__(self):
        self.model = DQN()
        self.target_model = DQN()
        self.memory = []
        self.gamma = 0.95  # Discount factor for Q-learning
        self.epsilon = 1.0  # Exploration rate
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.batch_size = 32
        self.update_target_freq = 5
        self.target_update_counter = 0
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def choose_action(self, state):
        if np.random.rand() <= self.epsilon:
            return np.random.choice(action_size)
        state_tensor = torch.FloatTensor(state)
        q_values = self.model(state_tensor)
        return torch.argmax(q_values).item()

    def replay(self):
      if len(self.memory) < self.batch_size:
          return

      indices = np.random.randint(len(self.memory), size=self.batch_size)
      batch = [self.memory[i] for i in indices]
      states, actions, rewards, next_states, dones = zip(*batch)

      states = torch.FloatTensor(states)
      next_states = torch.FloatTensor(next_states)

      q_values = self.model(states)
      target_q_values = self.target_model(next_states).detach()

      for i in range(self.batch_size):
          if dones[i]:
              q_values[i][actions[i]] = rewards[i]
          else:
              q_values[i][actions[i]] = rewards[i] + self.gamma * torch.max(target_q_values[i])

      self.optimizer.zero_grad()
      loss = F.mse_loss(q_values, q_values)
      loss.backward()
      self.optimizer.step()

      # Update exploration rate
      if self.epsilon > self.epsilon_min:
          self.epsilon *= self.epsilon_decay

      # Update target model weights
      self.target_update_counter += 1
      if self.target_update_counter % self.update_target_freq == 0:
          self.update_target_model()

In [None]:
def train_dqn_agent(agent, num_episodes=1000):
    for episode in range(num_episodes):
        state = env.reset()
        total_reward = 0
        done = False

        while not done:
            action = agent.choose_action(state)
            next_state, reward, done, _ = env.step(action)
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward

        print(f"Episode: {episode + 1}, Reward: {total_reward}, Epsilon: {agent.epsilon}")
        agent.replay()

    print("Training completed.")

In [None]:
dqn_agent = DQNAgent()
train_dqn_agent(dqn_agent)

Episode: 1, Reward: 13.0, Epsilon: 1.0
Episode: 2, Reward: 29.0, Epsilon: 1.0
Episode: 3, Reward: 18.0, Epsilon: 0.995
Episode: 4, Reward: 30.0, Epsilon: 0.990025
Episode: 5, Reward: 30.0, Epsilon: 0.985074875
Episode: 6, Reward: 10.0, Epsilon: 0.9801495006250001
Episode: 7, Reward: 13.0, Epsilon: 0.9752487531218751
Episode: 8, Reward: 14.0, Epsilon: 0.9703725093562657
Episode: 9, Reward: 30.0, Epsilon: 0.9655206468094844
Episode: 10, Reward: 13.0, Epsilon: 0.960693043575437
Episode: 11, Reward: 21.0, Epsilon: 0.9558895783575597
Episode: 12, Reward: 12.0, Epsilon: 0.9511101304657719
Episode: 13, Reward: 13.0, Epsilon: 0.946354579813443
Episode: 14, Reward: 27.0, Epsilon: 0.9416228069143757
Episode: 15, Reward: 21.0, Epsilon: 0.9369146928798039
Episode: 16, Reward: 23.0, Epsilon: 0.9322301194154049
Episode: 17, Reward: 26.0, Epsilon: 0.9275689688183278
Episode: 18, Reward: 32.0, Epsilon: 0.9229311239742362
Episode: 19, Reward: 12.0, Epsilon: 0.918316468354365
Episode: 20, Reward: 36.0, 

In [None]:
state = env.reset()
total_reward = 0
done = False

while not done:
    action = dqn_agent.choose_action(state)
    next_state, reward, done, _ = env.step(action)
    state = next_state
    total_reward += reward

print(f"Total Reward: {total_reward}")

env.close()

Total Reward: 10.0
