In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import random
from einops.layers.torch import Rearrange
from einops import rearrange

from typing import Any, Dict, Tuple, Optional
from game_mechanics import (
    ChooseMoveCheckpoint,
    ShooterEnv,
    checkpoint_model,
    choose_move_randomly,
    human_player,
    load_network,
    play_shooter,
    save_network,
)
from tqdm.notebook import tqdm

from functools import partial
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
from copy import deepcopy
from functools import partial

from utils import *
%load_ext autoreload
%autoreload 2

pygame 2.1.2 (SDL 2.0.16, Python 3.8.10)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [22]:
# Hristo - incorrect
def gae(rewards, values, successor_values, dones, gamma, lamda, correct_last_term):
    N = len(rewards)
    deltas = rewards + gamma * successor_values - values
    gamlam = gamma * lamda
    gamlam_geo_series = torch.as_tensor([gamlam**i for i in range(N)])*(1-gamlam)
    full_gamlam_matrix = torch.stack([torch.roll(gamlam_geo_series, shifts=n) for n in range(N)])
    full_gamlam_matrix = torch.triu(full_gamlam_matrix)

    done_indexes = torch.squeeze(dones.nonzero(), dim=1).tolist()
    for terminal_index in done_indexes:
        full_gamlam_matrix[: terminal_index + 1, terminal_index + 1:] = 0

    end_index = torch.arange(N)
    for start, end in zip([-1]+done_indexes[:-1], done_indexes):
        end_index[start+1:end+1] = end
    if correct_last_term:
        # make sure it sums to one:
        # (by making the term for the last value be 1 - sum(all other terms))
        full_gamlam_matrix[torch.arange(N), end_index] += 1 - full_gamlam_matrix.sum(axis=1)
    return full_gamlam_matrix @ deltas

def calculate_gae(
        rewards: torch.Tensor,
        values: torch.Tensor,
        successor_values: torch.Tensor,
        is_terminals: torch.Tensor,
        gamma: float,
        lamda: float,
):
    """
    Calculate the Generalized Advantage Estimator (GAE) for a batch of transitions.

    GAE = \sum_{t=0}^{T-1} (gamma * lamda)^t * (r_{t+1} + gamma * V_{t+1} - V_t)
    """
    N = len(rewards)

    # Gets the delta terms: the TD-errors
    delta_terms = rewards + gamma * successor_values - values

    gamlam = gamma * lamda

    gamlam_geo_series = torch.tensor([gamlam ** n for n in range(N)])

    # Shift the coefficients to the right for each successive row
    full_gamlam_matrix = torch.stack([torch.roll(gamlam_geo_series, shifts=n) for n in range(N)])

    # Sets everything except upper-triangular to 0
    gamlam_matrix = torch.triu(full_gamlam_matrix)

    # Zero out terms that are after an episode termination
    for terminal_index in torch.squeeze(is_terminals.nonzero(), dim=1):
        full_gamlam_matrix[: terminal_index + 1, terminal_index + 1:] = 0

    return torch.matmul(gamlam_matrix, delta_terms)

In [10]:
rewards = torch.ones(20)
values = torch.ones(20)*10
successor_values = torch.ones(20)*10
dones = torch.zeros(20)
gamma = 0.95
lamda = 0.8

In [25]:
hristo_no_correcting_last_term = gae(rewards, values, successor_values, dones, gamma, lamda, False)
hristo_corrected_last_term = gae(rewards, values, successor_values, dones, gamma, lamda, True)
tom = calculate_gae(rewards, values, successor_values, dones, gamma, lamda)

print('Hristo no correcting for the last term:')
print(hristo_no_correcting_last_term)
print()
print('Tom:')
print(tom)
print()
print('They are the same, just multiplied by a constant')
print(hristo_no_correcting_last_term / tom)
print()
print('However, correcting for the last term is very significant:')
print(hristo_corrected_last_term)


Hristo no correcting for the last term:
tensor([0.4979, 0.4973, 0.4964, 0.4953, 0.4938, 0.4918, 0.4893, 0.4859, 0.4814,
        0.4756, 0.4679, 0.4577, 0.4443, 0.4268, 0.4037, 0.3732, 0.3332, 0.2805,
        0.2112, 0.1200])

Tom:
tensor([2.0747, 2.0720, 2.0684, 2.0637, 2.0575, 2.0494, 2.0386, 2.0245, 2.0060,
        1.9815, 1.9494, 1.9071, 1.8515, 1.7782, 1.6819, 1.5551, 1.3883, 1.1688,
        0.8800, 0.5000])

They are just a constant apart:
tensor([0.2400, 0.2400, 0.2400, 0.2400, 0.2400, 0.2400, 0.2400, 0.2400, 0.2400,
        0.2400, 0.2400, 0.2400, 0.2400, 0.2400, 0.2400, 0.2400, 0.2400, 0.2400,
        0.2400, 0.2400])

However, correcting for the last term is very significant:
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000])
