-
Notifications
You must be signed in to change notification settings - Fork 331
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #225 from Huizerd/reward
Implementation of reward prediction error
- Loading branch information
Showing
5 changed files
with
118 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Union | ||
|
||
import torch | ||
|
||
|
||
class AbstractReward(ABC): | ||
# language=rst | ||
""" | ||
Abstract base class for reward computation. | ||
""" | ||
|
||
@abstractmethod | ||
def compute(self, **kwargs) -> None: | ||
# language=rst | ||
""" | ||
Computes/modifies reward. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def update(self, **kwargs) -> None: | ||
# language=rst | ||
""" | ||
Updates internal variables needed to modify reward. Usually called once per episode. | ||
""" | ||
pass | ||
|
||
|
||
class MovingAvgRPE(AbstractReward): | ||
# language=rst | ||
""" | ||
Calculates reward prediction error (RPE) based on an exponential moving average (EMA) of past rewards. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
# language=rst | ||
""" | ||
Constructor for EMA reward prediction error. | ||
""" | ||
self.reward_predict = torch.tensor(0.0) # Predicted reward (per step). | ||
self.reward_predict_episode = torch.tensor(0.0) # Predicted reward per episode. | ||
self.rewards_predict_episode = [] # List of predicted rewards per episode (used for plotting). | ||
|
||
def compute(self, **kwargs) -> torch.Tensor: | ||
# language=rst | ||
""" | ||
Computes the reward prediction error using EMA. | ||
Keyword arguments: | ||
:param Union[float, torch.Tensor] reward: Current reward. | ||
:return: Reward prediction error. | ||
""" | ||
# Get keyword arguments. | ||
reward = kwargs['reward'] | ||
|
||
return reward - self.reward_predict | ||
|
||
def update(self, **kwargs) -> None: | ||
# language=rst | ||
""" | ||
Updates the EMAs. Called once per episode. | ||
Keyword arguments: | ||
:param Union[float, torch.Tensor] accumulated_reward: Reward accumulated over one episode. | ||
:param int steps: Steps in that episode. | ||
:param float ema_window: Width of the averaging window. | ||
""" | ||
# Get keyword arguments. | ||
accumulated_reward = kwargs['accumulated_reward'] | ||
steps = torch.tensor(kwargs['steps']).float() | ||
ema_window = torch.tensor(kwargs.get('ema_window', 10.0)) | ||
|
||
# Compute average reward per step. | ||
reward = accumulated_reward / steps | ||
|
||
# Update EMAs. | ||
self.reward_predict = (1 - 1 / ema_window) * self.reward_predict + 1 / ema_window * reward | ||
self.reward_predict_episode = (1 - 1 / ema_window) * self.reward_predict_episode + \ | ||
1 / ema_window * accumulated_reward | ||
self.rewards_predict_episode.append(self.reward_predict_episode.item()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters