In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import gym

# Define the Deep Q-Network (DQN) architecture
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Define replay memory
class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, transition):
        self.memory.append(transition)
        if len(self.memory) > self.capacity:
            del self.memory[0]

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

# Define epsilon greedy exploration strategy
def epsilon_greedy_action(policy_net, state, epsilon, action_dim):
    if random.random() < epsilon:
        return random.randint(0, action_dim - 1)
    else:
        with torch.no_grad():
            return policy_net(state).argmax().item()

# Define function to optimize the Q-network
def optimize_model(policy_net, target_net, memory, batch_size, gamma, optimizer, loss_fn):
    if len(memory) < batch_size:
        return
    transitions = memory.sample(batch_size)
    batch = zip(*transitions)
    state_batch = torch.tensor(batch[0], dtype=torch.float32)
    action_batch = torch.tensor(batch[1], dtype=torch.long)
    reward_batch = torch.tensor(batch[2], dtype=torch.float32)
    next_state_batch = torch.tensor(batch[3], dtype=torch.float32)
    done_batch = torch.tensor(batch[4], dtype=torch.bool)

    Q_values = policy_net(state_batch).gather(1, action_batch.unsqueeze(1)).squeeze(1)
    next_Q_values = torch.zeros(batch_size)
    next_Q_values[~done_batch] = target_net(next_state_batch[~done_batch]).max(1)[0].detach()
    target_Q_values = reward_batch + (gamma * next_Q_values)

    loss = loss_fn(Q_values, target_Q_values)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Define training function
def train_dqn(env_name='CartPole-v1', num_episodes=1000, batch_size=64, gamma=0.99, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995, target_update=10, memory_capacity=10000):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    policy_net = DQN(input_dim, output_dim)
    target_net = DQN(input_dim, output_dim)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=0.001)
    loss_fn = nn.MSELoss()

    memory = ReplayMemory(memory_capacity)

    epsilon = epsilon_start
    for episode in range(num_episodes):
        state = env.reset()
        done = False
        total_reward = 0

        while not done:
            action = epsilon_greedy_action(policy_net, torch.tensor(state, dtype=torch.float32), epsilon, output_dim)
            next_state, reward, done, _ = env.step(action)
            memory.push((state, action, reward, next_state, done))
            state = next_state
            total_reward += reward

            optimize_model(policy_net, target_net, memory, batch_size, gamma, optimizer, loss_fn)

        if episode % target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())

        epsilon = max(epsilon_end, epsilon * epsilon_decay)

        print(f"Episode {episode + 1}, Reward: {total_reward}")

    env.close()

# Train the DQN
train_dqn()


Tensor:
tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32)

variable:
<tf.Variable 'Variable:0' shape=(2, 2) dtype=float32, numpy=
array([[0., 0.],
       [0., 0.]], dtype=float32)>

Updated variable:
<tf.Variable 'Variable:0' shape=(2, 2) dtype=float32, numpy=
array([[1., 1.],
       [1., 1.]], dtype=float32)>


Step Result: (array([-0.0405223 ,  0.20130984,  0.01130362, -0.24377534], dtype=float32), 1.0, False, False, {})


ValueError: too many values to unpack (expected 4)