In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
from dataclasses import dataclass
from math import radians
import time

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange
import ipywidgets as widgets

from cw.context import time_it
from cw.simulation import Simulation, StatesBase, AB3Integrator, BatchLogger, Logging
from cw.filters import smooth_signal

from topone.sim_post_processing import sim_post_processing
from topone.dynamics_1 import Dynamics1, Stage

from environment import Environment
from actor_critic_agent import ActorCriticAgent

In [3]:
@dataclass
class States(StatesBase):
    t: float = 0
    command_engine_on: bool = False
    command_drop_stage: bool = False
    gii: np.ndarray = np.zeros(2)
    xii: np.ndarray = np.zeros(2)
    vii: np.ndarray = np.zeros(2)
    aii: np.ndarray = np.zeros(2)
    tci: np.ndarray = np.eye(2)
    vic: np.ndarray = np.zeros(2)
    fii_thrust: np.ndarray = np.zeros(2)
    theta: float = 0.
    theta_dot: float = 0.
    mass: float = 0.
    mass_dot: float = 0.
    h: float = 0.
    engine_on: bool = False
    stage_state: int = 0
    stage_idx: int = 0
    gamma_i: float = 0.
    gamma_e: float = 0.
    latitude: float = 0.
        
    reward: float = 0.
    score: float = 0.
    done: bool = False
    
    delta_v: float = 0.

    def get_y_dot(self):
        y = np.empty(7)
        y[:2] = self.vii
        y[2:4] = self.aii
        y[4] = self.theta_dot
        y[5] = self.mass_dot
        y[6] = self.reward
        return y

    def get_y(self):
        y = np.empty(7)
        y[:2] = self.xii
        y[2:4] = self.vii
        y[4] = self.theta
        y[5] = self.mass
        y[6] = self.score
        return y

    def set_t_y(self, t, y):
        self.t = t
        self.xii = y[:2]
        self.vii = y[2:4]
        self.theta = y[4]
        self.mass = y[5]
        self.score = y[6]

In [6]:
environment = Environment(target_time_step=0.01)
agent = ActorCriticAgent(
    path="./set_1",
    alpha=0.01,
    gamma=0.99,
    environment=environment,
)

simulation = Simulation(
    states_class=States,
    integrator=AB3Integrator(
        h=0.01,
        rk4=False,
        fd_max_order=1),
    modules=[
        Dynamics1(
            surface_diameter=1737.4e3,
            mu=4.9048695e12,
            stages=(
                Stage(
                    dry_mass=1,
                    propellant_mass=0.02,
                    specific_impulse=100,
                    thrust=2*1.7),
            ),
            initial_altitude=1000,
            initial_theta_e=radians(90),
            initial_latitude=radians(90),
        ),
        environment
    ],
    logging=Logging(),
    initial_state_values=None,
)

In [7]:
environment.start_simulation_thread(1000)

In [8]:
agent.train(1000, 1)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [9]:
agent.save()

In [10]:
environment.last_results

<IPython.core.display.JSON object>

In [16]:
r = agent.run_episode_greedy(1000)

In [17]:
r.h.plot()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7f1f2c02f040>]