In [2]:
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from Pong import PongGame

pygame 2.6.1 (SDL 2.28.4, Python 3.10.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


# Linear Approximation

Linear function approximation is a method used in reinforcement learning to estimate the Q-values when the state-action space is large or continuous. Instead of maintaining a large Q-table, we approximate the Q-function using a linear model.

## Key Concepts

- **Model Structure:**  
  The Q-function is approximated by a simple linear layer without any hidden layers or nonlinear activation functions. The model takes the state representation as input and outputs Q-values for each possible action.

- **Input and Output:**  
  - Input: Feature vector representing the current state (e.g., paddle position, ball position and direction).  
  - Output: Vector of Q-values corresponding to each possible action.

- **Advantages:**  
  - Computationally efficient and fast to train.  
  - Simpler and less prone to overfitting compared to deep networks.  
  - Suitable for problems where the relationship between state and Q-values is approximately linear.

- **Limitations:**  
  - Cannot capture complex, nonlinear relationships in the data.  
  - Performance may degrade on more complicated environments where nonlinear function approximators (e.g., deep neural networks) excel.

## How it Works in Practice

- The linear model predicts Q-values using:  
  $$Q(s, a) \approx \mathbf{w}_a^\top \mathbf{x}_s + b_a$$
  where $\mathbf{x}_s$ is the feature vector of state $s$, and $\mathbf{w}_a, b_a$ are the weights and bias for action $a$.

- During training, we minimize the mean squared error between predicted Q-values and target Q-values derived from the Bellman equation:  
  $$\text{Loss} = \left(Q(s, a) - \left(r + \gamma \max_{a'} Q(s', a')\right)\right)^2$$

- The parameters $\mathbf{w}_a$ and $b_a$ are updated using gradient descent methods such as Adam optimizer.

## Summary

Using a linear function approximator for Q-learning provides a simple yet effective way to generalize across states without storing an explicit Q-table. It is best suited for environments where state-action values can be reasonably approximated by linear functions.


In [7]:
LEARNING_RATE = 0.001
DISCOUNT_FACTOR = 0.95
EPSILON = 0.1
EPISODES = 2500
MAX_ACTIONS_PER_EPISODE = 1000

NUM_PADDLE_POS = 30
NUM_BALL_X = 50
NUM_BALL_Y = 40

ACTIONS = [-1, 0, 1]

WIDTH, HEIGHT = 800, 600
PADDLE_WIDTH, PADDLE_HEIGHT = 10, 100

In [8]:
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.linear = nn.Linear(state_dim, action_dim)

    def forward(self, x):
        return self.linear(x)

state_dim = 5  # paddle_y, ball_x, ball_y, ball_dx, ball_dy
action_dim = len(ACTIONS)
q_network = QNetwork(state_dim, action_dim)
optimizer = optim.Adam(q_network.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

In [9]:
def discretize(value, max_value, bins):
    return min(bins - 1, max(0, int(value / max_value * bins)))

In [10]:
def state_to_features(state):
    paddle_y, ball_x, ball_y, ball_dx, ball_dy = state
    return torch.tensor([
        paddle_y / NUM_PADDLE_POS,
        ball_x / NUM_BALL_X,
        ball_y / NUM_BALL_Y,
        ball_dx,
        ball_dy
    ], dtype=torch.float32)

In [11]:
def choose_action(state, epsilon):
    if random.uniform(0, 1) < epsilon:
        return random.choice(ACTIONS)
    else:
        features = state_to_features(state)
        q_values = q_network(features)
        return ACTIONS[torch.argmax(q_values).item()]

In [12]:
def calculate_reward(state, action):
    paddle_y, ball_x, ball_y, ball_dx, ball_dy = state

    paddle_center = paddle_y * (HEIGHT / NUM_PADDLE_POS) + PADDLE_HEIGHT / 2
    ball_actual_y = ball_y * (HEIGHT / NUM_BALL_Y)

    distance = abs(paddle_center - ball_actual_y)
    reward = -distance / (HEIGHT / 2)

    if (ball_actual_y > paddle_center and action == 1) or (ball_actual_y < paddle_center and action == -1):
        reward += 0.5

    if ball_dx == 1 and ball_x == NUM_BALL_X - 1:
        if distance <= PADDLE_HEIGHT / 2:
            reward += 5
        else:
            reward -= 100
    return reward

In [13]:
def update_q_network(state, action, reward, next_state, done):
    features = state_to_features(state)
    next_features = state_to_features(next_state)

    q_values = q_network(features)
    next_q_values = q_network(next_features)

    action_index = ACTIONS.index(action)
    target = reward
    if not done:
        target += DISCOUNT_FACTOR * torch.max(next_q_values).item()

    target_q_values = q_values.clone().detach()
    target_q_values[action_index] = target

    loss = loss_fn(q_values, target_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [14]:
for episode in range(EPISODES):
    paddle_y = random.randint(0, NUM_PADDLE_POS - 1)
    ball_x = random.randint(0, NUM_BALL_X - 1)
    ball_y = random.randint(0, NUM_BALL_Y - 1)
    ball_dx = random.choice([-1, 1])
    ball_dy = random.choice([-1, 1])

    done = False
    action_count = 0

    while not done and action_count < MAX_ACTIONS_PER_EPISODE:
        action_count += 1
        state = (paddle_y, ball_x, ball_y, ball_dx, ball_dy)

        action = choose_action(state, EPSILON)
        paddle_y = max(0, min(NUM_PADDLE_POS - 1, paddle_y + action))

        ball_x += ball_dx
        ball_y += ball_dy

        if ball_y <= 0 or ball_y >= NUM_BALL_Y - 1:
            ball_dy = -ball_dy
        if ball_x <= 0:
            ball_dx = -ball_dx
        if ball_x >= NUM_BALL_X - 1:
            ball_dx = -ball_dx

        next_state = (paddle_y, ball_x, ball_y, ball_dx, ball_dy)

        reward = calculate_reward(state, action)
        done = ball_x == 0 or ball_x == NUM_BALL_X - 1

        update_q_network(state, action, reward, next_state, done)

    if (episode + 1) % 100 == 0:
        print(f"Episode {episode + 1}/{EPISODES}")

Episode 100/2500
Episode 200/2500
Episode 300/2500
Episode 400/2500
Episode 500/2500
Episode 600/2500
Episode 700/2500
Episode 800/2500
Episode 900/2500
Episode 1000/2500
Episode 1100/2500
Episode 1200/2500
Episode 1300/2500
Episode 1400/2500
Episode 1500/2500
Episode 1600/2500
Episode 1700/2500
Episode 1800/2500
Episode 1900/2500
Episode 2000/2500
Episode 2100/2500
Episode 2200/2500
Episode 2300/2500
Episode 2400/2500
Episode 2500/2500


In [15]:
def generate_all_possible_states():
    states = []
    for paddle_y in range(NUM_PADDLE_POS):
        for ball_x in range(NUM_BALL_X):
            for ball_y in range(NUM_BALL_Y):
                for ball_dx in [-1, 1]:
                    for ball_dy in [-1, 1]:
                        states.append((paddle_y, ball_x, ball_y, ball_dx, ball_dy))
    return states

In [16]:
states = generate_all_possible_states()

policy = {
    str(state): ACTIONS[
        torch.argmax(q_network(state_to_features(state)).detach()).item()
    ]
    for state in states
}

with open("pong_la.json", "w") as f:
    json.dump(policy, f)

In [5]:
policy_file = "pong_la.json"
game = PongGame(policy_file)
game.run()