# Comparison of Reward Functions

This shows how the different reward functions behave (especially when crossing a checkpoint).

This can be used as a tool to come up with a good reward function.

In [None]:
# See how the reward function changes after hitting a checkpoint.
# The goal here is to find a reward that will encourage the agent to
# go through the check, already pointing at the next check.

from pod.ai.ai_utils import gen_pods
from pod.constants import Constants
import math
import numpy as np

from pod.board import PodBoard
from pod.util import PodState
from pod.controller import SimpleController
from pod.ai.rewards import speed_reward, diff_reward, dist_reward, ang_reward, check_reward, make_reward
from pod.drawer import Drawer
from pod.player import Player
from vec2 import Vec2, UNIT

TURNS = 20

board = PodBoard.grid(rows=1, cols=3, spacing=4000)

# Generate some starting points
pods = []
labels = []
for ang in np.arange(0, math.pi + 0.00001, math.pi / 4):
    check_to_pos = UNIT.rotate(ang) * (2 * Constants.check_radius())
    vel = UNIT.rotate(ang + math.pi) * (Constants.max_vel() * 0.1)
    pods.append(PodState(
        pos=board.checkpoints[0] + check_to_pos,
        vel=vel,
        angle=ang + math.pi
    ))
    labels.append("%.1f°" % (ang * 180/math.pi))

# For each starting point, create a Player
players = [Player(SimpleController(board), pod) for pod in pods]
    
drawer = Drawer(
    board,
    players=players,
    labels=labels
)

Show the initial state of the board and players.

In [None]:
drawer.draw_frame(pods)

Show the players playing through a few frames.

In [None]:
drawer.animate(max_frames=TURNS, fps=3)

And now, the interesting part: compare the different reward functions.

In [None]:
r_func = make_reward([
    (0.2, speed_reward),
    (0.2, diff_reward),
    (0.2, dist_reward),
    (0.02, ang_reward),
    (0.2, check_reward)
])

player_idx = -1 # Which player to use for comparison
drawer.compare_rewards([
    ('custom', r_func),
    ('speed', speed_reward),
    ('diff', diff_reward),
    ('dist', dist_reward),
    ('ang', ang_reward),
], [players[player_idx]], [labels[player_idx]], max_frames=TURNS)