In [None]:
import sys

sys.path.append('..')

import numpy as np
import decision_transformer.manage as DT_manager
from optimization.ocp import ocp_cvx
from dynamics.orbit_dynamics import dynamics_roe_optimization, map_rtn_to_roe
from optimization.rpod_scenario import oe_0_ref, t_0, n_time_rpod, dock_param_maker
%matplotlib ipympl

In [None]:
# Simulation configuration
transformer_model_name = 'checkpoint_rtn_ctgrtg'
import_config = DT_manager.transformer_import_config(transformer_model_name)

state_representation = import_config['state_representation']
dataset_to_use = import_config['dataset_to_use']
mdp_constr = import_config['mdp_constr']
transformer_model_name = import_config['model_name']
timestep_norm = import_config['timestep_norm']

In [None]:
#datasets, dataloaders = DT_manager.get_train_val_test_data(state_representation, dataset_to_use, mdp_constr)
datasets, dataloaders = DT_manager.get_train_val_test_data(state_representation, dataset_to_use, mdp_constr, transformer_model_name, timestep_norm)
train_dataset, val_dataset, test_dataset = datasets
train_loader, eval_loader, test_loader = dataloaders
data_stats = datasets[0].data_stats

In [None]:
train_dataset.n_data, test_dataset.n_data, train_dataset.n_data + test_dataset.n_data

In [None]:
# Get the model and set it into eval mode
model = DT_manager.get_DT_model(transformer_model_name, train_loader, eval_loader)
model.eval();

In [None]:
# Use the model to predict the trajectory
test_sample = next(iter(test_loader))#test_loader.dataset.getix(7)#
DT_trajectory,_ = DT_manager.use_model_for_imitation_learning(model, test_loader, test_sample, import_config['state_representation'], use_dynamics=True, output_attentions=True)

In [None]:
# Precompute stm, cim, phi
states_i, actions_i, rtgs_i, ctgs_i, goal_i, timesteps_i, attention_mask_i, oe, dt, time_sec, horizons, ix = test_sample
hrz = horizons.item()
stm_hrz, cim_hrz, psi_hrz, oe_hrz, time_hrz, dt_hrz = dynamics_roe_optimization(oe_0_ref, t_0, hrz, n_time_rpod)
if state_representation == 'roe':
    state_roe_0 = np.array((states_i[0, 0, :] * data_stats['states_std'][0]) + data_stats['states_mean'][0])
elif state_representation == 'rtn':
    state_rtn_0 = np.array((states_i[0, 0, :] * data_stats['states_std'][0]) + data_stats['states_mean'][0])
    state_roe_0 = map_rtn_to_roe(state_rtn_0, np.array(oe[0, :, 0]))
dock_param, _= dock_param_maker(np.array((goal_i[0, 0, :] * data_stats['goal_std'][0]) + data_stats['goal_mean'][0]))

# Compute convex solution
states_cvx, action_cvx, feas_cvx = ocp_cvx(stm_hrz, cim_hrz, psi_hrz, state_roe_0, dock_param, n_time_rpod)
dyn_trajectory, _ = DT_manager.torch_model_inference_dyn(model, test_loader, test_sample, stm_hrz, cim_hrz, psi_hrz, state_representation, rtg_perc=1., ctg_perc=0., rtg=None)

In [None]:
# Print the index of the trajectory and plot
'''DT_trajectory['roe_ol'] = dyn_trajectory['roe_dyn']
DT_trajectory['rtn_ol'] = dyn_trajectory['rtn_dyn']
DT_trajectory['dv_ol'] = dyn_trajectory['dv_dyn']'''
print(test_sample[-1])
DT_manager.plot_DT_trajectory(DT_trajectory)