<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Reinforcement_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import torch.nn.functional as F

class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class QLearningAgent:
    def __init__(self, input_dim, output_dim, gamma=0.99, epsilon=0.1, lr=1e-3):
        self.q_network = QNetwork(input_dim, output_dim)
        self.target_network = QNetwork(input_dim, output_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon
        self.replay_buffer = []
        self.replay_buffer_size = 10000
        self.batch_size = 64
        self.update_counter = 0

    def act(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.q_network.fc3.out_features - 1)  # Exploration
        with torch.no_grad():
            return torch.argmax(self.q_network(state.unsqueeze(0))).item()  # Exploitation

    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        # Sample a batch from the replay buffer
        batch = random.sample(self.replay_buffer, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        # Convert batch elements to tensors
        states = torch.stack(states)
        next_states = torch.stack(next_states)
        actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1)  # Shape: (batch_size, 1)
        rewards = torch.tensor(rewards, dtype=torch.float32)  # Shape: (batch_size)
        dones = torch.tensor(dones, dtype=torch.float32)  # Shape: (batch_size)

        # Compute Q-values and target Q-values
        q_values = self.q_network(states).gather(1, actions).squeeze()  # Shape: (batch_size)
        next_q_values = torch.max(self.target_network(next_states), dim=1).values
        target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # Compute loss
        loss = F.mse_loss(q_values, target_q_values)

        # Optimize the network
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update target network periodically
        self.update_counter += 1
        if self.update_counter % 10 == 0:  # Update every 10 updates
            self.target_network.load_state_dict(self.q_network.state_dict())

    def remember(self, state, action, reward, next_state, done):
        if len(self.replay_buffer) >= self.replay_buffer_size:
            self.replay_buffer.pop(0)  # Remove the oldest memory
        self.replay_buffer.append((state, action, reward, next_state, done))

# Example usage
input_dim = 4  # Example for CartPole observation space
output_dim = 2  # Example for CartPole action space

agent = QLearningAgent(input_dim=input_dim, output_dim=output_dim)

# Simulate an example interaction
state = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=torch.float32)  # Example state
action = agent.act(state)
agent.remember(state, action, 1, state, False)
agent.update()