In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import random
import numpy as np

# Define Q-Network
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize the environment
env = gym.make('CartPole-v1')
input_dim = env.observation_space.shape[0]
print(f"input dim: {input_dim}")
output_dim = env.action_space.n

# Initialize the Q-network
q_network = QNetwork(input_dim, output_dim)

# Initialize the target Q-network
target_q_network = QNetwork(input_dim, output_dim)
target_q_network.load_state_dict(q_network.state_dict())
target_q_network.eval()

# Initialize the optimizer
optimizer = optim.Adam(q_network.parameters(), lr=0.001)

# Initialize replay buffer
replay_buffer = []
replay_buffer_size = 10000
batch_size = 32

# Epsilon-greedy exploration parameters
epsilon = 0.1
epsilon_decay = 0.995
min_epsilon = 0.01

# Discount factor
gamma = 0.99

# Training loop
for episode in range(3):
    state = env.reset()
    episode_reward = 0

    while True:
        # env.render()  # Render the environment

        # Epsilon-greedy action selection
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                q_values = q_network(torch.tensor(state, dtype=torch.float32))
                print(f"qvalues: {q_values}")
                action = q_values.argmax().item()
        print(f"action: {action}")
        next_state, reward, done, _ = env.step(action)

        # Store transition in replay buffer
        replay_buffer.append((state, action, reward, next_state, done))
        if len(replay_buffer) > replay_buffer_size:
            replay_buffer.pop(0)

        # Sample random minibatch from replay buffer
        if len(replay_buffer) >= batch_size:
            minibatch = random.sample(replay_buffer, batch_size)
            states, actions, rewards, next_states, dones = zip(*minibatch)
            states = torch.tensor(states, dtype=torch.float32)
            actions = torch.tensor(actions)
            rewards = torch.tensor(rewards, dtype=torch.float32)
            next_states = torch.tensor(next_states, dtype=torch.float32)
            dones = torch.tensor(dones)

            # Compute Q-value targets
            with torch.no_grad():
                target_q_values = target_q_network(next_states)
                max_target_q_values = torch.max(target_q_values, dim=1).values
                q_value_targets = rewards + gamma * (1.0 - dones.float()) * max_target_q_values


            # Compute Q-values
            q_values = q_network(states)
            q_values_actions = q_values.gather(1, actions.unsqueeze(1)).squeeze()
            
            # Compute loss
            loss = nn.MSELoss()(q_values_actions, q_value_targets)

            # Update Q-network
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update target Q-network
            target_q_network.load_state_dict(q_network.state_dict())

        episode_reward += reward
        state = next_state

        if done:
            print("Episode: {}, Reward: {}, Epsilon: {:.2f}".format(episode, episode_reward, epsilon))
            break

    # Decay epsilon
    epsilon = max(min_epsilon, epsilon * epsilon_decay)

env.close()


input dim: 4
qvalues: tensor([ 0.0776, -0.0558])
action: 0
qvalues: tensor([ 0.0960, -0.0708])
action: 0
qvalues: tensor([ 0.1025, -0.0678])
action: 0
action: 1
qvalues: tensor([ 0.1033, -0.0656])
action: 0
qvalues: tensor([ 0.1020, -0.0575])
action: 0
qvalues: tensor([ 0.0974, -0.0531])
action: 0
qvalues: tensor([ 0.0989, -0.0542])
action: 0
qvalues: tensor([ 0.1015, -0.0550])
action: 0
action: 1
qvalues: tensor([ 0.1052, -0.0566])
action: 0
Episode: 0, Reward: 11.0, Epsilon: 0.10
qvalues: tensor([ 0.0773, -0.0554])
action: 0
qvalues: tensor([ 0.0951, -0.0687])
action: 0
qvalues: tensor([ 0.1009, -0.0669])
action: 0
qvalues: tensor([ 0.1014, -0.0583])
action: 0
qvalues: tensor([ 0.0953, -0.0519])
action: 0
qvalues: tensor([ 0.0978, -0.0511])
action: 0
qvalues: tensor([ 0.1004, -0.0532])
action: 0
qvalues: tensor([ 0.1071, -0.0564])
action: 0
action: 0
qvalues: tensor([ 0.1263, -0.0706])
action: 0
Episode: 1, Reward: 10.0, Epsilon: 0.10
qvalues: tensor([ 0.0719, -0.0498])
action: 0
act