<a href="https://colab.research.google.com/github/JacobFV/Nodal-RL/blob/master/notebooks/Nodal_RL_core.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import numpy as np
import tensorflow as tf

In [17]:
tf.version.VERSION, tf.version.COMPILER_VERSION, tf.version.GRAPH_DEF_VERSION

('2.2.0', '5.4.0', 175)

In [3]:
class Node:
    """
    Generic Node base class
    """
    
    def __init__(self,
                 p = 1,
                 default_parent_biasing_params = {
                     "w_otarget": 1.0,
                     "o_target_lambda": 1.0
                 }):
        """
        initializese abstract node class
        
        :param p: period for timesteps. number
        :param default_parent_biasing_params:
                default biasing_params for parents
                Dict<str,number>
                
                properties:
                    - otarget weight: "w_otarget" -> number
                    - recency otarget weight decay factor: "o_target_lambda" -> number
        
        :return: returns Node
        """
        
        self.p = p
        self.default_parent_biasing_params = default_parent_biasing_params
        
        #initialize stateful variables
        self.t = 0
        
        """
        data for each parent node
        Dict<Node, Dict<str, number>>

        usage: [Node]["property"] = value

        properties:
            - otarget weight: 'w_otarget' -> number
            - recency otarget weight decay factor: 'o_target_lambda' -> number
        """
        self.parent_biasing_params = {}
        
        """
        actual observations used from parents

        NOTE: the parent may have more observations
            in its `s` records that are not used by
            `self` because they are at a finer timescale 

        dict<Node, dict<number,target observation>>
        """
        self.parent_o = {}

        """
        biasing target observations delivered to parents

        dict<Node, dict<number,target observation>>
        """
        self.parent_otarget = {}

        """
        top down biasing target observations obeyed from children

        NOTE: the children may have more target observations
            in their `parent_otarget` records that are not used by
            `self` because they are at a finer timescale 

        dict<Node, dict<number,target observation>>
        """
        self.child_otarget = {}
        
        """
        record of states for each timestep
        that `update` is called.

        Some states are referenced by children

        Dict<number, state>
        """
        self.s = {}
    
    def __call__(self, *parents):
        """
        register `parents` as its own parent and reach into
        `parents` to register as child. Since this method simply
        adds parents, it can be called multiple times with subsets
        of parents or once with all parents to achieve equivalent
        effects. Applies `self.default_parent_biasing_params` to
        all parents. To manually specify `parent_biasing_params`,
        pass in `parents` as dict<Node, parent_biasing_params>
        
        :param *parents: list of parent nodes to connect
                or dict<Node, parent_biasing_params>
        
        :return: returns nothing
        """

        #convert parents to dict<Node, parent_biasing_params>
        #regardless of how presented in args
        if(isinstance(parents, list)):
            parents = {
                parent, self.default_parent_biasing_params
                for parent in parents}
            
        for parent, parent_biasing_params in parents.items():
            self._add_parent(parent, parent_biasing_params)
            parent._add_child(self)
    
    def _add_parent(self, parent, parent_biasing_params):
        self.parent_biasing_params[parent] = parent_biasing_params
        self.parent_o[parent] = {}
        self.parent_otarget[parent] = {}
    
    def _add_child(self, child):
        self.child_otarget[child] = {}

    def update(self):
        """
        update node's internal variables at time intervals `self.p`
        This function should be overriden 

        :return: returns nothing
        """
        raise NotImplemented()

    def closest_record(self, dictionary, time):
        """
        Identify highest indexed record in `dictionary` 
        that is less than or equal to `time`
        
        NOTE: call from Node containing dictionary.
        because later vesions of this algorithm may use
        `self.p` to jump ahead intelligently
        
        :param dictionary: dict<number, obj> to query
        :param time: time to query for closest entries to
        
        :return: returns closest matching (time, record) record
        """
        
        for tau in dictionary.keys().reverse():
            if tau <= time:
                return tau, dictionary[tau]
        
        #errors
        if not bool(dictionary):
            raise Exception("`dictionary` empty")
        raise Exception("error searching for closest record to time:{time} in `dictionary`")
            
    def child_obs_target_weighted_mean(self, other_terms=[]):
        """
        computes weighted mean of child nodes'
        target observations for `self` and `other_terms`
        
        The weighted mean observation target is
        used by actuators and information nodes
        
        In computing `child_obs_target_weighted_mean`,
        child node target observations are selected
        by recency and added to `self.child_otarget`
        for later training
        
        :param other_terms: (opt.) [(weight, value)] tuple list
                of additional terms to incorperate into the
                weighted mean calculation

        :return: returns top down biasing target state  
        """
        
        weights = [weight for weight, value in other_terms]
        values = [value for weight, value in other_terms]
        
        # for each child
        for child, biasing_params in self.parent_biasing_params.items():
            # find child target observation for time tau
            # that is closest to `self.t`
            tau, otarget = child.closest_record(child.s, self.t + self.p)
            values.append(otarget)
            
            # compute weight in mean
            weights.append(
                biasing_params["w_otarget"]
                * CER(otarget)
                * np.exp(
                    -biasing_params["o_target_lambda"]
                    *(self.t + self.p - tau)
                )
            )
            
            # save otarget for later training 
            self.child_otarget[child][tau] = otarget
        
        # compute weighted sum
        weighted_sum = sum([
            weight * value
            for weight, value
            in zip(weights, values)
        ])
        
        # return weighted mean
        return weighted_sum / sum(weights)

    def save_information_episode(self, path):
        """
        saves data to `path` for training later
        
        :param path: savepath
        
        :return: returns nothing
        """
        #TODO maybe use pickle or shelve modules in addition to specific tf formats
        raise NotImplemented()
        
    def load_information_episode(self, path):
        """
        loads data from `path` for training
        
        :param path: savepath
        
        :return: returns nothing
        """
        #TODO maybe use pickle or shelve modules in addition to specific tf formats
        raise NotImplemented()

    def reset_information_episode(self):
        """
        resets all inference stateful variables
        including observations and internal states

        This means clearing:
            `self.child_otarget`,
            `self.parent otarget`,
            `self.parent_o`,
            and `self.s` dictionaries
            
        :return: returns nothing
        """
        self.child_otarget.clear()
        self.parent_otarget.clear()
        self.s.clear()
        self.parent_o.clear()

In [4]:
class Sensory_Node(Node):

    def update(self):
        #TODO
        raise NotImplemented()

In [5]:
class Actuator_Node(Node):

    def update():
        #TODO
        raise NotImplemented()

In [9]:
class Information_Node(Node):

    
    def __init__(self, 
                 abstractor,
                 predictor,
                 policy,
                 params = {
                    "pred-weight":None,
                    "alpha-d1":None,
                    "alpha-d2":None,
                    "alpha-d3":None
                 },   
                 ):
        """
        initialize an `Information_Node`
        """
        
        self.params = params
        self.abstractor
        self.predictor
        self.policy
        
        return

    def update(self):
        #TODO
        raise NotImplemented()

    def train(self,
              lr=0.001,
              verbose=True,
              beta_piD=1,
              beta_piC=1,
              beta_piI=1,
              beta_prD=1,
              beta_prC=1,
              ):
        """
        trains `Information_Node`

        :param lr: learning rate for optimizer
        :param verbose: (True/False) log information verbosely
        :param beta_piD: training hyperparameter. See 'Putting It All Togethor'
        :param beta_piC: training hyperparameter. See 'Putting It All Togethor'
        :param beta_piI: training hyperparameter. See 'Putting It All Togethor'
        :param beta_prD: training hyperparameter. See 'Putting It All Togethor'
        :param beta_prC: training hyperparameter. See 'Putting It All Togethor'
        
        :return: returns nothing
        """
        #TODO
        raise NotImplemented()

In [None]:
class Network:
    
    def __init__(self, nodes):
        """
        initializes the network
        
        :param nodes: all nodes that should be managed by the network. 
        
        :return: returns initialized network
        """
        self.nodes = nodes
        self.time = 0
    
    def run(self, duration=-1, timestep=1):
        """
        loops until time == duration
        
        :param duration: time to run for. -1 means forever
        :param timestep: time increments used during `step`
        
        :return: returns nothing
        """
        while self.time <= duration:
            self.step(timestep)
    
    def step(self, timestep):
        """
        increase timestep and check if any nodes should be updated
        
        :param timestep: positive time increment used
        
        :return: returns nothing
        """
        self.time += timestep
        #TODO
        raise NotImplemented()
    
    def save_episode(self, path):
        """
        save data to `path` for training
        
        :param path: savepath
        
        :return: returns nothing
        """
        #TODO
        raise NotImplemented()
    
    def load_episode(self, path):
        """
        loads data from `path` for training
        
        :param path: savepath
        
        :return: returns nothing
        """
        #TODO
        raise NotImplemented()
    
    def reset_episode(self):
        """
        resets all nodes' inference stateful variables
        including observations and internal states
            
        :return: returns nothing
        """
        #TODO
        raise NotImplemented()