In [None]:
# default_exp recall

<IPython.core.display.Javascript object>

# recall

> Memories for (perfect) recall

In [None]:
# hide
from nbdev.showdoc import *

%load_ext nb_black
%matplotlib inline
%config InlineBackend.figure_format='retina'
%config IPCompleter.greedy=True

The nb_black extension is already loaded. To reload it, use:
  %reload_ext nb_black


<IPython.core.display.Javascript object>

In [None]:
# export
import sys
import numpy as np

<IPython.core.display.Javascript object>

Let's make a simple kind of memory. One the perfectly remembers what it sees, but that has only so many slots to store experience. Once the slots are full it starts overwriting the oldest memories first. I'll call this the `Recall` memory. 

- The `capacity` is the size of the memory.
- The `encode(*arg)` add memories (in order). 
- The `sample()` method samples `n` experiences

In [None]:
# export
class Recall(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def encode(self, *args):
        """Saves a memory tuple."""
        # Pad out
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        # Remember
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def __call__(self, *args):
        self.encode(*args)

    def sample(self, n):
        return random.sample(self.memory, n)

    def __len__(self):
        return len(self.memory)

<IPython.core.display.Javascript object>

Now let's make another memory. This one is nearly the same except we add a way to prioritize recall. The priority can be any number > 0. Larger numbers mean higher priority. 

In [None]:
# export
class PriorityRecall(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.priority = []
        self.probs = None
        self.position = 0
        # Prevents div by 0 problems
        self.eps = sys.float_info.min

    def encode(self, weight, *args):
        """Saves a weight and a memory tuple."""
        # Sanity
        weight = float(weight)
        if np.isclose(weight, 0.0):
            raise ValueError("w must be > 0")
        # Pad out
        if len(self.memory) < self.capacity:
            self.memory.append(None)
            self.priority.append(None)
        # Remember
        self.priority[self.position] = weight
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def __call__(self, weight, *args):
        self.encode(weight, *args)

    def sample(self, n):
        # Est probs from priority weights
        summed = sum(self.weight) + self.eps
        self.probs = [w / summed for w in self.priority]
        # Wieghted sample
        return np.random.choice(self.memory, size=n, p=self.probs).tolist()

    def __len__(self):
        return len(self.memory)

<IPython.core.display.Javascript object>