# 001 Tests on how to extend the stellarflow system to a `tf-agents` compatible RL-Env.

Inspiration: https://towardsdatascience.com/creating-a-custom-environment-for-tensorflow-agent-tic-tac-toe-example-b66902f73059

In [9]:
## Imports
import sys
sys.path.append("./source/")

import stellarflow as stf
import numpy as np
import tensorflow as tf

## Settings
AU, ED = stf.System._AU, stf.System._ED

## Checking if GPU is used
if len(tf.config.list_physical_devices('GPU')) == 0:
    print("NO GPU FOUND!")

In [10]:
from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts

In [11]:
## TODO: Plan: Create a py_environment based on the stellarfow.system and wrapping it in a TFPyEnvironment

In [14]:
np.concatenate([np.array([1, 2]), np.array([2, 3])], axis=0)

array([1, 2, 2, 3])

In [None]:
class stfaEnv(py_environment.PyEnvironment):
    ## TODO: Implement mass / propellant consumtion
    def __init__(self, 
            mass: float,
            initial_location: np.ndarray, 
            initial_velocity: np.ndarray, 
            stfSystem: stf.System):
        ## Providing for tfa-py_env
        self._action_spec = array_spec.BoundedArraySpec(
            shape=(3,), dtype=np.float32, name="boost", )
        self._observation_spec = array_spec.ArraySpec(
            shape=(3,), dtype=np.float32, name="location")
        self._state = initial_location
        self._initial_location = initial_location
        self._episode_ended = False

        ## Additional Information for spacecraft
        self._m = mass
        self._x = initial_location
        self._y = initial_velocity
        self._q = np.concatenate([initial_location, initial_velocity], axis=0)

        ## Binding with stf.System:
        temp = np.concatenate([stfSystem._Q.numpy(), np.expand_dims(self._q, axis=0)], axis=0)
        stfSystem._Q = tf.Variable(temp, dtype=tf.float32)
        stfSystem._Q_hist = [stfSystem._Q]
        stfSystem.masses = np.concatenate([stfSystem.masses, mass], axis=0)
        stfSystem._M = stfSystem.__reshape_masses(stfSystem.masses)

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def _reset(self):
        self._state = self._initial_location
        self._episode_ended = False
        return ts.restart(np.array([self._state], dtype=np.int32))

    def _step(self, action: np.ndarray):

        if self._episode_ended:
        # The last action ended the episode. Ignore the current action and start
        # a new episode.
            return self.reset()

        # Make sure episodes don't go on forever.
        if action == 1:
            self._episode_ended = True  
        elif action == 0:
            new_card = np.random.randint(1, 11)
            self._state += new_card
        else:
            raise ValueError('`action` should be 0 or 1.')

        if self._episode_ended or self._state >= 21:
            reward = self._state - 21 if self._state <= 21 else -21
            return ts.termination(np.array([self._state], dtype=np.int32), reward)
        else:
            return ts.transition(
                np.array([self._state], dtype=np.int32), reward=0.0, discount=1.0)
