<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Full_Working_Q_Learning_Implementation_(Function_Based).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import torch
import torch.nn as nn
import numpy as np
from torch.optim import Adam

# Q-network definition
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        return self.fc(x)

# Select action via ε-greedy policy
def select_action(state, epsilon, q_net, action_dim):
    if np.random.rand() < epsilon:
        return np.random.randint(action_dim)
    with torch.no_grad():
        state_tensor = torch.from_numpy(state).float().unsqueeze(0)
        q_values = q_net(state_tensor)
        return q_values.argmax().item()

# Training loop
def train_q_learning(env_name="CartPole-v1", episodes=500, gamma=0.99, lr=1e-3):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    q_net = QNetwork(state_dim, action_dim)
    optimizer = Adam(q_net.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.995

    for episode in range(episodes):
        state = env.reset()
        state = state[0] if isinstance(state, tuple) else state  # compatibility
        done = False
        total_reward = 0

        while not done:
            action = select_action(state, epsilon, q_net, action_dim)
            result = env.step(action)
            next_state, reward, terminated, truncated, *_ = result if len(result) == 5 else (*result, False)
            done = terminated or truncated

            # Compute target
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)

            with torch.no_grad():
                next_q = torch.max(q_net(next_state_tensor)).item()

            target_q = q_net(state_tensor).clone()
            target = reward + gamma * next_q * (1 - int(done))
            target_q[0][action] = target

            # Optimize
            pred_q = q_net(state_tensor)
            loss = loss_fn(pred_q, target_q)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            state = next_state
            total_reward += reward

        epsilon = max(epsilon * epsilon_decay, epsilon_min)
        print(f"Episode {episode+1}, Reward: {total_reward}, Epsilon: {epsilon:.3f}")

    env.close()

# Uncomment to train
# train_q_learning()