# Блокнот для генерации деревьев траектрий, чтобы потом из них можно было сделать хорошую БД для обучения

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from missile_gym import MissileGym
from gymtree import GymTree, Node
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

%matplotlib widget
# %config InlineBackend.figure_format = 'svg'

### Генерируем стандартный сценарий

In [3]:
gym = MissileGym.make('sc_simple_3')
tree = GymTree(gym)
tree.reset()

In [4]:
# выращиваем дерево
def f(x, d0=900):
    t = (x-d0/3)/abs(d0-d0/3)
    return t
for i in tqdm(range(13)):
    node = tree.root # np.random.choice(tree.get_not_full_nodes()) 
    d0 = node.get_distance_to_trg()
    while node:
        d_curr = node.get_distance_to_trg()
        p = f(d_curr, d0)
        node = node.walk(p)
for i in tqdm(range(17)):
    node = tree.get_perspective_node2()
    d0 = node.get_distance_to_trg()
    while node:
        d_curr = node.get_distance_to_trg()
        p = f(d_curr, d0)
        node = node.walk(p)

HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))




In [5]:
tree.plot()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [6]:

tree.plot_scatter()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [7]:
fig, ax = plt.subplots(figsize=(10,7))
tree.plot_scatter(ax=ax)
ax.grid()
ax.axis('equal')
plt.savefig('saves/sc_simple_3.png', format='png', dpi=300);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
# сохраняем
tree.save_to_file('saves/sc_simple_3.bin')

## А теперь посчитаем их все.... параллельно

In [9]:
from dask.distributed import Client, LocalCluster

In [10]:
client = Client(LocalCluster())
client

0,1
Client  Scheduler: tcp://127.0.0.1:55349  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 4  Cores: 8  Memory: 17.11 GB


In [18]:
def grow_tree(scenario_name):
    try:
        import numpy as np
        from missile_gym import MissileGym
        from gymtree import GymTree, Node

        gym = MissileGym.make(scenario_name)
        tree = GymTree(gym)
        tree.reset()

        def f(x, d0=900):
            t = (x-d0/3)/abs(d0-d0/3)
            return t
        for i in range(77):
            node = tree.root # np.random.choice(tree.get_not_full_nodes()) 
            d0 = node.get_distance_to_trg()
            while node:
                d_curr = node.get_distance_to_trg()
                p = f(d_curr, d0)
                node = node.walk(p)
        for i in range(777):
            node = tree.get_perspective_node2()
            d0 = node.get_distance_to_trg()
            while node:
                d_curr = node.get_distance_to_trg()
                p = f(d_curr, d0)
                node = node.walk(p)
        
        tree.save_to_file(f'saves/{scenario_name}.bin')
        
        import matplotlib.pyplot as plt
        
        fig, ax = plt.subplots(figsize=(10,7))
        tree.plot_scatter(ax=ax)
        ax.grid()
        ax.axis('equal')
        plt.savefig(f'saves/{scenario_name}.png', format='png', dpi=300)
        return True
    except Exception as e:
        return str(e)

In [19]:
futs = client.map(grow_tree, MissileGym.scenario_names)


In [22]:
from dask.distributed import as_completed

for future, result in tqdm(as_completed(futs, with_results=True)):
    print(result)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

True
True
True
True

