In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim

# Define the Q-network
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        q_values = self.fc3(x)
        return q_values

# Create the environment
env = gym.make('CartPole-v0')

# Define the Q-network and optimizer
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
q_network = QNetwork(state_dim, action_dim)
optimizer = optim.Adam(q_network.parameters(), lr=0.001)

# Training loop
num_episodes = 1000
for episode in range(num_episodes):
    state = env.reset()
    done = False

    while not done:
        # Convert the state to a PyTorch tensor
        state_tensor = torch.tensor(state, dtype=torch.float32)

        # Forward pass through the Q-network to get Q-values
        q_values = q_network(state_tensor)

        # Choose the action with the highest Q-value (exploitation)
        action = torch.argmax(q_values).item()

        # Take the chosen action and observe the next state and reward
        next_state, reward, done, _ = env.step(action)

        # Convert the next state to a PyTorch tensor
        next_state_tensor = torch.tensor(next_state, dtype=torch.float32)

        # Calculate the target Q-value using the Bellman equation
        with torch.no_grad():
            next_q_values = q_network(next_state_tensor)
            target_q_value = reward + 0.99 * torch.max(next_q_values)

        # Calculate the loss between predicted and target Q-values
        loss = nn.MSELoss()(q_values[action], target_q_value)

        # Update the Q-network parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        state = next_state

    # Print the total reward of the episode
    print(f"Episode {episode+1}: Total Reward = {reward}")
