In [None]:
import ray

ray.init(num_cpus=12, num_gpus=1, memory=1024 * 1024 * 1024 * 5, object_store_memory=1024 * 1024 * 1024 * 40)

In [None]:
from ray.tune.registry import register_env
import gym

def env_creator(env_config):
    return gym.make("satellite_gym:SatelliteEnv-v1")

# env = SatelliteEnv(df, sat_id=sat_id)
register_env("SatelliteEnv-v1", lambda x: env_creator)

In [None]:
import ray.rllib.agents.impala as impala
from ray.tune.logger import pretty_print
from pathlib import Path

def on_train_result(info):
    result = info["result"]
    if result["episode_reward_mean"] > 45:
        phase = 2
    elif result["episode_reward_mean"] > 22:
        phase = 1
    else:
        phase = 0
    trainer = info["trainer"]
    trainer.workers.foreach_worker(
        lambda ev: ev.foreach_env(
            lambda env: env.set_phase(phase)))
    
    
config = impala.DEFAULT_CONFIG.copy()
config['model']['use_lstm'] = True
config["model"]["vf_share_layers"] = True
config["num_workers"] = 10
config["num_gpus"] = 1
config["seed"] = 0
config["eager"] = False
config["lr"] = 1e-05
config["num_envs_per_worker"] = 10
# config["num_gpus_per_worker"] = .1
config["sample_batch_size"] = 4000
config["train_batch_size"] = 80000
config["callbacks"] = { "on_train_result": on_train_result }

trainer = impala.ImpalaTrainer(config=config, env="satellite_gym:SatelliteEnv-v1")



In [None]:
trainer.restore("/home/golemxiv/ray_results/IMPALA-04gmoje1lu/checkpoint_201/checkpoint-201")

In [None]:
trainer.config['lr'] = 1e-06

In [None]:


for i in range(401):
    # Perform one iteration of training the policy with PPO
    result = trainer.train()
    print(pretty_print(result))
    
    if i % 50 == 0:
        checkpoint = trainer.save()
        print("checkpoint saved at", checkpoint)

In [None]:
trainer.get_policy().export_model('trained_model_v1')


In [None]:
df.head()

In [None]:
sat_id

In [None]:
state = env.reset()

In [None]:
import numpy as np
# obs = np.zeros((256), dtype=np.float32)
obs = np.squeeze(env.df.head(256).values)

In [None]:
from satellite_gym.envs.satellite_env.satellite_env import TRAIN_COLUMNS, TEST_COLUMNS
trainer.get_policy().compute_single_action(obs=np.squeeze(env.df[TRAIN_COLUMNS].head(1).values), state=trainer.get_policy().get_initial_state())

In [None]:
true_value = env.df[['Vx', 'Vy', 'Vz']].values

policy = trainer.get_policy()
state = policy.get_initial_state()
predicted_value = np.empty((1, 3,))
for i in range(len(true_value)):
#     if len(predicted_value) == 1:
#         predicted_value = np.array([env.df[['Vx_sim', 'Vy_sim', 'Vz_sim']].loc[i].values])
    val = policy.compute_single_action(np.squeeze(env.df[TRAIN_COLUMNS].loc[i].values), state=state)
    state = val[1]
    predicted_value = np.append(predicted_value, [val[0][3:]], axis=0)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d, Axes3D #<-- Note the capitalization! 


fig = plt.figure()
ax = Axes3D(fig) #<-- Note the difference from your original code...
ax.scatter(xs=true_value[:,:1], ys=true_value[:,1:2], zs=true_value[:,2:3], marker='o')
ax.scatter(xs=predicted_value[:,:1], ys=predicted_value[:,1:2], zs=predicted_value[:,2:3], marker='^')
ax.view_init(elev=10., azim=20)
plt.show()
# for ii in range(0,360,1):
#         ax.view_init(elev=10., azim=ii)
#         fig.savefig("movie/movie%d.png" % ii)


#ax + by + cz + d = 0

In [None]:
fig = go.Figure(data=[go.Scatter3d(x=true_value[:,:1], y=true_value[:,1:2], z=true_value[:,2:3], mode='markers')])
fig.write_image("figure.png")