# Блокнот для генерации деревьев траектрий, чтобы потом из них можно было сделать хорошую БД для обучения

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from missile_gym import MissileGym
from gymtree import GymTree, Node
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

%matplotlib widget
# %config InlineBackend.figure_format = 'svg'

### Генерируем стандартный сценарий

In [3]:
MissileGym.scenario_names

{'Chengdu_1',
 'Chengdu_2',
 'Cy-57_1',
 'RaptorF_1',
 'RaptorF_2',
 'Refale_1',
 'Refale_2',
 'fail_1',
 'fail_2',
 'fail_3',
 'sc_simple_1',
 'sc_simple_2',
 'sc_simple_3',
 'standart',
 'success_1',
 'success_2',
 'success_3',
 'success_4',
 'success_5'}

In [4]:
gym = MissileGym.make('fail_3')
tree = GymTree(gym)
tree.reset()

In [5]:
# выращиваем дерево
def f(x, d0=900):
    t = (x-d0/3)/abs(d0-d0/3)
    return t
for i in tqdm(range(5)):
    node = tree.root # np.random.choice(tree.get_not_full_nodes()) 
    d0 = node.get_distance_to_trg()
    while node:
        d_curr = node.get_distance_to_trg()
        p = f(d_curr, d0)
        node = node.walk(0)
for i in tqdm(range(5)):
    node = tree.get_perspective_node2()
    d0 = node.get_distance_to_trg()
    while node:
        d_curr = node.get_distance_to_trg()
        p = f(d_curr, d0)
        node = node.walk(p)

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




In [6]:
tree.plot()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [7]:
tree.plot_scatter()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
fig, ax = plt.subplots(figsize=(10,7))
tree.plot_scatter(ax=ax)
ax.grid()
ax.axis('equal')
plt.savefig('saves/sc_simple_3.png', format='png', dpi=300);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
# сохраняем
tree.save_to_file('saves/sc_simple_3.bin')

## А теперь посчитаем их все.... параллельно

In [1]:
from dask.distributed import Client, LocalCluster

In [2]:
client = Client(LocalCluster())
client

0,1
Client  Scheduler: tcp://127.0.0.1:51381  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 4  Cores: 8  Memory: 17.11 GB


In [3]:
def grow_tree(scenario_name):
    try:
        import numpy as np
        from missile_gym import MissileGym
        from gymtree import GymTree, Node
        import os
        
        def f(x, d0=900):
            t = (x-d0/3)/abs(d0-d0/3)
            return t

        gym = MissileGym.make(scenario_name)
        tree = GymTree(gym)
        tree.reset()
        
        file_name = f'saves/{scenario_name}.bin'
        if os.path.isfile(file_name):
            tree.load_from_file(file_name)
            
        for i in range(1):
            node = tree.root # 
            d0 = node.get_distance_to_trg()
            while node:
                d_curr = node.get_distance_to_trg()
                node = node.walk(0)
       
        for i in range(177):
            node = tree.root # np.random.choice(tree.get_not_full_nodes()) #
            d0 = node.get_distance_to_trg()
            while node:
                d_curr = node.get_distance_to_trg()
                p = f(d_curr, d0)
                node = node.walk(p)
        for i in range(1077):
            node = tree.get_perspective_node2()
            d0 = node.get_distance_to_trg()
            while node:
                d_curr = node.get_distance_to_trg()
                p = f(d_curr, d0)
                node = node.walk(p)
        
        tree.save_to_file(file_name)
        
        import matplotlib.pyplot as plt
        
        fig, ax = plt.subplots(figsize=(10,7))
        tree.plot_scatter(ax=ax)
        ax.grid()
        ax.axis('equal')
        plt.savefig(f'saves/{scenario_name}.png', format='png', dpi=300)
        return True
    except Exception as e:
        return e

In [5]:
from missile_gym import MissileGym

futs = client.map(grow_tree, MissileGym.scenario_names)

In [18]:
futs

[<Future: finished, type: builtins.bool, key: grow_tree-e99f1e7b44b9b001d13f79886b6b785c>,
 <Future: finished, type: builtins.bool, key: grow_tree-ecbd52b9d054a834e5b6ddecd61d7aa9>,
 <Future: finished, type: builtins.bool, key: grow_tree-089f55f23d7d586c4d78c196bf9967b4>,
 <Future: finished, type: builtins.bool, key: grow_tree-e72f44557e9c2c9bae44b6d4e372a870>,
 <Future: finished, type: builtins.bool, key: grow_tree-e44b98fdba7556c08ae42555883095b3>,
 <Future: finished, type: builtins.bool, key: grow_tree-106c5e357c07defb732b7d46cb3fc42e>,
 <Future: finished, type: builtins.bool, key: grow_tree-00c1508db758f99419f962dff9afce94>,
 <Future: finished, type: builtins.bool, key: grow_tree-69675a7547f143db2c1ac2be73d3a489>,
 <Future: pending, key: grow_tree-18894876883b08aa31bc818648ef0126>,
 <Future: pending, key: grow_tree-105e8233cf5c37d0461038664a0669e4>,
 <Future: finished, type: builtins.bool, key: grow_tree-048c20535994d85f6f8cf11968025657>,
 <Future: pending, key: grow_tree-e6eb82b3d

In [32]:
futs[6].result()

"'a' cannot be empty unless no samples are taken"