In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
from dataclasses import dataclass
from math import radians
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import trange
import ipywidgets as widgets

from cw.context import time_it
from cw.simulation import Simulation, StatesBase, AB3Integrator, BatchLogger, Logging
from cw.filters import smooth_signal

from topone.sim_post_processing import sim_post_processing
from topone.dynamics_1 import Dynamics1, Stage

from environment import Environment
from actor_critic_agent import ActorCriticAgent

In [3]:
@dataclass
class States(StatesBase):
    t: float = 0
    command_engine_on: bool = False
    command_drop_stage: bool = False
    gii: np.ndarray = np.zeros(2)
    xii: np.ndarray = np.zeros(2)
    vii: np.ndarray = np.zeros(2)
    aii: np.ndarray = np.zeros(2)
    tci: np.ndarray = np.eye(2)
    vic: np.ndarray = np.zeros(2)
    fii_thrust: np.ndarray = np.zeros(2)
    theta: float = 0.
    theta_dot: float = 0.
    mass: float = 0.
    mass_dot: float = 0.
    h: float = 0.
    engine_on: bool = False
    stage_state: int = 0
    stage_idx: int = 0
    gamma_i: float = 0.
    gamma_e: float = 0.
    latitude: float = 0.
        
    reward: float = 0.
    score: float = 0.
    done: bool = False
    
    delta_v: float = 0.

    def get_y_dot(self):
        y = np.empty(7)
        y[:2] = self.vii
        y[2:4] = self.aii
        y[4] = self.theta_dot
        y[5] = self.mass_dot
        y[6] = self.reward
        return y

    def get_y(self):
        y = np.empty(7)
        y[:2] = self.xii
        y[2:4] = self.vii
        y[4] = self.theta
        y[5] = self.mass
        y[6] = self.score
        return y

    def set_t_y(self, t, y):
        self.t = t
        self.xii = y[:2]
        self.vii = y[2:4]
        self.theta = y[4]
        self.mass = y[5]
        self.score = y[6]

In [4]:
environment = Environment(target_time_step=0.01)
agent = ActorCriticAgent(
    path="./set_2",
    alpha=0.01,
    gamma=0.99,
    environment=environment,
)

simulation = Simulation(
    states_class=States,
    integrator=AB3Integrator(
        h=0.01,
        rk4=False,
        fd_max_order=1),
    modules=[
        Dynamics1(
            surface_diameter=1737.4e3,
            mu=4.9048695e12,
            stages=(
                Stage(
                    dry_mass=1,
                    propellant_mass=0.02,
                    specific_impulse=100,
                    thrust=2*1.7),
            ),
            initial_altitude=1,
            initial_theta_e=radians(90),
            initial_latitude=radians(90),
        ),
        environment
    ],
    logging=Logging(),
    initial_state_values=None,
)

In [5]:
environment.start_simulation_thread(1000)

In [14]:
agent.train(1000, 10000, 5)
training_history = pd.DataFrame.from_records(agent.training_history, columns=agent.training_history[0]._fields)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10000.0), HTML(value='')))




In [12]:
agent.display_greedy_policy()

UNFIRED: 1 [[0.24263829 0.7573617 ]]
FIRING: 1 [[0.25990498 0.740095  ]]
FIRED: 1 [[0.16294645 0.8370536 ]]


In [40]:
agent.clean(False)

In [19]:

plt.figure()
# th.reward_sum.plot()
plt.plot(smooth_signal(training_history.reward_sum, wn=0.11))
# plt.plot(smooth_signal(np.array(agent.training_history).flatten(), wn=0.1))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7f56546bdf40>]

In [9]:
agent.save()

In [39]:
agent.display_greedy_policy()

(1, 32)
UNFIRED: 28 [[ 3.0149581  -0.52868605  0.8434439  -0.18095444 -0.42848393 -1.0057224
   0.92954445 -1.72662    -2.9032679  -3.8678105  -1.0684673  -0.13156109
  -2.1698103   0.30838996  2.799077    0.8303299  -0.6751958   0.84378713
  -1.2124715   2.4412723   0.79393595 -0.4389961   0.87997603  0.9085807
   3.0855093  -0.31790808 -1.3364947  -0.5592322   6.1183968  -2.9916596
  -1.5339713  -1.7977095 ]]
FIRING: 28 [[ 2.9956753  -0.52811426  0.8395922  -0.17942986 -0.42583463 -0.99852836
   0.9228872  -1.7168169  -2.8830466  -3.8396018  -1.0590945  -0.13242353
  -2.1523142   0.30522648  2.7776392   0.8196669  -0.6719858   0.8382596
  -1.2027322   2.4274576   0.78856075 -0.43645698  0.87171596  0.90215755
   3.0627096  -0.31558916 -1.3279064  -0.55552465  6.0760236  -2.971034
  -1.5230126  -1.7860043 ]]
FIRED: 28 [[ 3.0698047  -0.5303124   0.85439914 -0.1852908  -0.43601942 -1.0261844
   0.9484798  -1.754503   -2.9607835  -3.9480453  -1.0951263  -0.12910801
  -2.2195745   0.31738

In [10]:
environment.last_results

<IPython.core.display.JSON object>

In [10]:
r = agent.run_episode_greedy(1000)
r

In [11]:
plt.figure()
r.command_engine_on.plot()
r.engine_on.plot()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7efba403fd90>]