# Experiment 2: Two stage suborbital launcher


In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
from dataclasses import dataclass
from math import radians
import random
import time

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange
import ipywidgets as widgets

from cw.context import time_it
from cw.simulation import Simulation, StatesBase, AB3Integrator, BatchLogger, Logging

from topone.sim_post_processing import sim_post_processing
from topone.dynamics_1 import Dynamics1, Stage
from topone.pid_agent import PIDAgent

from agent import Agent
from linear_softmax_agent import LinearSoftmaxAgent, State
from environment import Environment
from ideal_agent import IdealAgent

## Simulation configuration

In [3]:
@dataclass
class States(StatesBase):
    t: float = 0
    command_engine_on: bool = False
    command_drop_stage: bool = False
    gii: np.ndarray = np.zeros(2)
    xii: np.ndarray = np.zeros(2)
    vii: np.ndarray = np.zeros(2)
    aii: np.ndarray = np.zeros(2)
    tci: np.ndarray = np.eye(2)
    vic: np.ndarray = np.zeros(2)
    fii_thrust: np.ndarray = np.zeros(2)
    theta: float = 0.
    theta_dot: float = 0.
    mass: float = 0.
    mass_dot: float = 0.
    h: float = 0.
    engine_on: bool = False
    stage_state: int = 0
    stage_idx: int = 0
    gamma_i: float = 0.
    gamma_e: float = 0.
    latitude: float = 0.
        
    reward: float = 0.
    score: float = 0.
    done: bool = False
    delta_v: float = 0.

    def get_y_dot(self):
        y = np.empty(7)
        y[:2] = self.vii
        y[2:4] = self.aii
        y[4] = self.theta_dot
        y[5] = self.mass_dot
        y[6] = self.reward
        return y

    def get_y(self):
        y = np.empty(7)
        y[:2] = self.xii
        y[2:4] = self.vii
        y[4] = self.theta
        y[5] = self.mass
        y[6] = self.score
        return y

    def set_t_y(self, t, y):
        self.t = t
        self.xii = y[:2]
        self.vii = y[2:4]
        self.theta = y[4]
        self.mass = y[5]
        self.score = y[6]

In [4]:
# agent = Agent(
#     epsilon=0.01,
#     alpha=0.9,
#     gamma=0.9,
#     path="./set_3"
# )

agent = IdealAgent()

In [5]:
agent = LinearSoftmaxAgent(
    alpha=.01,
    gamma=.99,
    path="./set_ls_9",
    load_last=True
)

In [6]:
simulation = Simulation(
    states_class=States,
    integrator=AB3Integrator(
        h=0.1,
        rk4=False,
        fd_max_order=1),
    modules=[
        Dynamics1(
            surface_diameter=1737.4e3,
            mu=4.9048695e12,
            stages=(
                Stage(
                    dry_mass=2,
                    propellant_mass=0.08,
                    specific_impulse=100,
                    thrust=12*1.7),
                Stage(
                    dry_mass=1,
                    propellant_mass=0.01,
                    specific_impulse=150,
                    thrust=1.1*1.7),
            ),
            initial_altitude=1.,
            initial_theta_e=radians(90),
            initial_latitude=radians(90),
        ),
        Environment(),
        agent
    ],
    logging=Logging(),
    initial_state_values=None,
)
batch_logger = BatchLogger()
batch_logger.initialize(simulation)
simulation.stash_states()

In [7]:
def post_processing(result):
    sim_post_processing(result)

## Batch run

In [8]:
def run_batch(n_episodes, backup_period=30, timeout=60):
    batch_logger.reset_batch()

    # Backup original logger and swap with faster logger
    original_logger = simulation.logging
    simulation.logging = batch_logger

    last_backup_time = time.time()

    start_time = time.time()

    out = widgets.Output(layout={})
    display(out)
    
    try:
        for i in trange(n_episodes):
            simulation.restore_states()
            simulation.run(1000)

            if i == 0:
                with out:
                    agent.display_greedy_policy()
                out.clear_output(wait=True)

            if time.time() - last_backup_time >= backup_period:
                agent.save()
                last_backup_time = time.time()

                with out:
                    agent.display_greedy_policy()
                out.clear_output(wait=True)

            if time.time() - start_time >= timeout:
                break

    except KeyboardInterrupt:
        print("Batch cancelled")
    finally:
        agent.save()
        simulation.logging = original_logger
        batch_results = batch_logger.finish_batch()
        out.clear_output()
        with out:
            agent.display_greedy_policy()
        return batch_results

In [9]:
n_per_min = 8400
batch_results = run_batch(int(1e7), backup_period=5, timeout=60 * 60 * 10)
display(batch_results)

Output()

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10000000.0), HTML(value='')))




In [11]:
agent.states

[]

In [17]:
agent.display_greedy_policy()

Stage 0
  UNFIRED: 0 [0.33388714 0.33245358 0.33365928] 
  FIRING: 2 [0.33370065 0.32999068 0.33630867]
  FIRED: 1 [0.33192865 0.33415273 0.33391862]
Stage 1
  UNFIRED: 1 [0.33200877 0.33445959 0.33353164] 
  FIRING: 2 [0.33077645 0.33159956 0.337624  ]
  FIRED: 1 [0.3313694  0.33441207 0.33421854]


In [16]:
agent.clean(False)

## Single simulation

In [11]:
simulation.restore_states()
with time_it("simulation run"):
    result = simulation.run(10)
post_processing(result)
result

1
env 1
0 False
1
env 1
0 False
1
env 1
0 False
1
env 1
0 False
1
env 1
0 False
1
env 1
0 False
1
env 1
0 False
1
env 1
0 False
1
env 1
0 False
1
env 1
7.611627328877092 True
simulation run: 0.012357464001979679 [s]


In [11]:
# print(agent.rewards)
# result.h



<generator object IdealAgent.step at 0x7f6598f486d0>

In [9]:
# What the agent thinks it's the right policy.
# print(max(result.h).item())
agent.display_greedy_policy()

Stage 0
  UNFIRED: 1 
  FIRING: 1
  FIRED: 2
Stage 1
  UNFIRED: 1 
  FIRING: 1
  FIRED: 2


In [25]:
# Right policy
print(max(result.h).item())
agent.display_greedy_policy()

663.7453234132845
Stage 0
  UNFIRED: 1 
  FIRING: 1
  FIRED: 2
Stage 1
  UNFIRED: 1 
  FIRING: 1
  FIRED: 0


In [33]:
plt.figure()
result.vii.plot.line(x="t", label="command_engine_on")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7fad06df5fa0>,
 <matplotlib.lines.Line2D at 0x7fad06df5d00>]

In [31]:
plt.figure()
result.h.plot.line(x="t", label="h")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7fad06ea2ee0>]

In [38]:
plt.figure()
result.vii.plot.line(x="t", label="h")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7f0ea6d6e2e0>,
 <matplotlib.lines.Line2D at 0x7f0ea6d95100>]

In [32]:
plt.figure()
result.command_engine_on.plot.line(x="t", label="command_engine_on")
result.command_drop_stage.plot.line(x="t", label="command_drop_stage")
plt.legend()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7f788d2261c0>

In [33]:
plt.figure()
(result.mass_dot * 1e3).plot.line(x="t", label="mass_dot")
(result.stage_idx + 1).plot(label="stage_idx")
(result.stage_state).plot(label="stage_state")
plt.legend()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7f7898d4ca30>

In [19]:
plt.figure()
(result.mass).plot.line(x="t", label="mass")
(result.fii_thrust).plot(label="fii_thrust")
(result.stage_idx).plot(label="stage_idx")
plt.legend()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

NameError: name 'result' is not defined

In [10]:
agent.get_backup_indices()[-1]

1032

In [14]:
pis = []
for i in agent.get_backup_indices():
    agent.load(i)
    pis.append(agent.pi(State(1, 1)))

plt.figure()
plt.plot(pis)
plt.legend(["Engine off", "Engine on", "Drop stage"])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7f0757c0bb20>

In [14]:
import inspect

In [15]:
def foo():
    print("1")
    value = yield "ok"
    print("2", value)

inspect.isgeneratorfunction(agent.step)

True

In [53]:
gen = foo()

In [54]:
next(gen)

1


'ok'

In [52]:
x = 0

In [60]:
try:
    gen.send(x)
except StopIteration:
    pass

x += 1
gen = foo()
next(gen)

2 5
1


'ok'

In [61]:
gen.send(x)

2 6


StopIteration: 