# Import modules

In [3]:
import numpy as np
import torch
import random
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments
import pandas as pd
from datasets import Dataset, DatasetDict

# Collator definition

A collator is a helper class for the training loop. It is responsible for batching together the individual samples and preparing them for the model. The collator is called by the DataLoader for each batch.

In [4]:
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 20 #subsets of the episode we use for training
    state_dim: int = 20  # size of state space
    act_dim: int = 6  # size of action space
    max_ep_len: int = 1000 # max episode length in the dataset
    scale: float = 1000.0  # normalization of rewards/returns
    state_mean: np.array = None  # to store state means
    state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset) -> None:
        self.act_dim = len(dataset[0]["actions"][0])
        self.state_dim = len(dataset[0]["observations"][0])
        self.dataset = dataset
        # calculate dataset stats for normalization of states
        states = []
        traj_lens = []
        for obs in dataset["observations"]:
            states.extend(obs)
            traj_lens.append(len(obs))
        self.n_traj = len(traj_lens)
        states = np.vstack(states)
        self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
        
        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / sum(traj_lens)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
        
        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            si = random.randint(0, len(feature["rewards"]) - 1)

            # get sequences from dataset
            s.append(np.array(feature["observations"][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature["actions"][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature["rewards"][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature["dones"][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature["rewards"][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - self.state_mean) / self.state_std
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }

# Define trainable transformer model

In [5]:
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)
        # add the DT loss
        action_preds = output[1]
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]
        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        
        loss = torch.mean((action_preds - action_targets) ** 2)

        return {"loss": loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)

# Import dataset

Dataset defined as dictionary with train, validation and test sets. Each set is a list of dictionaries with the following keys:
- observations: [dim_obs x episode_length] numpy array
- actions: [dim_act x episode_length] numpy array
- rewards: [episode_length] numpy array
- dones: [episode_length] numpy array

In [6]:
train_dataset = pd.read_parquet('decision_transformer_satellites_rendezvous-train.parquet')
train_dataset = Dataset.from_pandas(train_dataset)
dataset = DatasetDict({"train":train_dataset})

In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['observations', 'actions', 'rewards', 'dones'],
        num_rows: 42
    })
})

In [8]:
collator = DecisionTransformerGymDataCollator(dataset["train"])

config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)
model = TrainableDT(config)

# Train model

In [9]:
training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    num_train_epochs=120,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collator,
)

trainer.train()
model.save_pretrained("output/")

  0%|          | 0/120 [00:00<?, ?it/s]

{'train_runtime': 19.2511, 'train_samples_per_second': 261.803, 'train_steps_per_second': 6.233, 'train_loss': 0.010625876983006795, 'epoch': 120.0}


# Develop environment for prediction loop

We use gym library and create a custom environment for the prediction loop. The environment is defined as a class with the following methods:
- reset: resets the environment and returns the initial observation
- step: takes an action and returns the next observation, reward, done flag and info dictionary
- render: renders the environment



In [10]:
# we should define limits similar to what is done for matlab
from dataclasses import dataclass
# ----- PARAMETERS FOR DYNAMIC SIMULATION ----- %

@dataclass
class parameters:         
    LC : float = 1
    LT : float = 3
    J_C : np.ndarray = np.eye(3)
    J_T : np.ndarray = np.eye(3)
    m_C : float = 1
    OM : float = 0.005
    pberth : np.ndarray = np.array([5, 0, 0])
    kP_tr : np.ndarray = 0.1*np.eye(3)
    kD_tr : np.ndarray = 1*np.eye(3)
    kP_rot : np.ndarray = 1*np.eye(3)
    u_lim : np.ndarray = np.ones((6,))
    r2 : float = 3**2 # keep out zone
    timestep : float = 0.3



# ----- META-PARAMETERS: OPTIONS FOR SOLVERS AND REWARD DEFINITION ----- %
@dataclass
class options:
    K_action : np.ndarray = np.eye(6)
    R_success : float = 5
    R_collision : float = -10
    R_timeout : float = -5

    pos_low_lim : np.ndarray = np.array([-15, -15, -15])
    pos_high_lim : np.ndarray = np.array([15, 15, 15])
    vel_low_lim : np.ndarray = np.array([-0.1, -0.1, -0.1])
    vel_high_lim : np.ndarray = np.array([0.1, 0.1, 0.1])
    ang_low_lim : np.ndarray = np.array([-0.1, -0.1, -0.1])
    ang_high_lim : np.ndarray = np.array([0.1, 0.1, 0.1])
    quat_low_lim : np.ndarray = np.array([-1, -1, -1, -1])
    quat_high_lim : np.ndarray = np.array([1, 1, 1, 1])

In [11]:
from gym import Env, spaces
import numpy as np
from scipy.integrate import solve_ivp
import quat
import ode_model
# quaternions
# other

def random_state():
    pos = np.random.uniform(size=(3,), low=options.pos_low_lim, high=options.pos_high_lim)
    vel = np.random.uniform(size=(3,), low=options.vel_low_lim, high=options.vel_high_lim)
    quat_chaser = np.random.uniform(size=(4,), low=options.quat_low_lim, high=options.quat_high_lim)
    quat_chaser = quat_chaser / np.linalg.norm(quat_chaser)
    ang_vel_chaser = np.random.uniform(size=(3,), low=options.ang_low_lim, high=options.ang_high_lim)
    quat_target = np.random.uniform(size=(4,), low=options.quat_low_lim, high=options.quat_high_lim)
    quat_target = quat_target / np.linalg.norm(quat_target)
    ang_vel_target = np.random.uniform(size=(3,), low=options.ang_low_lim, high=options.ang_high_lim)
    return np.concatenate((pos, vel, quat_chaser, ang_vel_chaser, quat_target, ang_vel_target))

def check_success(state):
    p_LC_L = state[0:3]
    v_LC_L = state[3:6]
    q_LC = state[6:10]/np.linalg.norm(state[6:10])
    w_IC_C = state[10:13]
    q_LT = state[13:17]/np.linalg.norm(state[13:17])
    w_IT_T = state[17:20]

    OM = parameters.OM
    OM_IL_L = np.array([0, 0, OM])

    p_LC_L_check = quat.rotate(parameters.pberth, q_LT)
    R_LC = quat.quat2rotm(q_LC)
    w_LC_L = R_LC @ w_IC_C - OM_IL_L # ang. velocity of line of sight chaser-target
    v_LC_L_check = np.cross(w_LC_L, p_LC_L_check) # chaser must have this velocity to keep up with rotation
    err = np.linalg.norm(p_LC_L - p_LC_L_check) + np.linalg.norm(v_LC_L - v_LC_L_check) + np.linalg.norm(q_LC - q_LT) + np.linalg.norm(w_IC_C - w_IT_T)
    
    tol = 1e-6
    if err < tol:
        return True
    else:
        return False



def compute_reward(obs, action):
    reward = parameters.timestep*np.linalg.norm(options.K_action @ action)

    if np.linalg.norm(obs[0:3]) < np.sqrt(parameters.r2):
        reward += options.R_collision

    if check_success(obs):
        reward += options.R_success

    return reward


class SpacecraftRendezvous(Env):
    def __init__(self):
        super(SpacecraftRendezvous, self).__init__()
        
        # Define observation space
        self.observation_shape = (20,)
        self.observation_space = spaces.Box(low = np.full(self.observation_shape, -np.inf), 
                                            high = np.full(self.observation_shape, np.inf),
                                            dtype = np.float64)
    
        
        # Define an action space 
        self.action_shape = (6,)
        self.action_space = spaces.Box(low = -parameters.u_lim, 
                                            high = parameters.u_lim,
                                            dtype = np.float64)
        
        self.timestep = 0.3
        self.current_state = random_state()

    def reset(self):
        
        self.current_state = random_state()
        while np.linalg.norm(self.current_state[0:3]) < np.sqrt(parameters.r2):
            self.current_state = random_state()

        observation = self.current_state

        return observation
    
    def step(self, action): # TODO : use a certain dt instead of hardcoded from definition
        # propagate the dynamics
        dt = parameters.timestep
        o = self.current_state
        sol = solve_ivp(lambda t, state : ode_model.dynamics(t, state, parameters, action), [0, dt], o, atol=1e-6, rtol=1e-6)
        observation = sol.y[:,-1]
        self.current_state = observation


        # check for impact and if reached objective
        if np.linalg.norm(self.current_state[0:3]) < np.sqrt(parameters.r2):
            done = True
        elif check_success(observation):
            done = True
        else:
            done = False

        # compute the reward
        reward = compute_reward(observation, action)

        return observation, reward, done

# Define variables for prediction loop

In [12]:
# build the environment
model = model.to("cpu")
env = SpacecraftRendezvous()
max_ep_len = 1000
device = "cpu"
scale = 100.0  # normalization for rewards/returns
TARGET_RETURN = 95 / scale  # evaluation is conditioned on a return of 12000, scaled accordingly

state_mean = collator.state_mean.astype(np.float32)
state_std = collator.state_std.astype(np.float32)
print(state_mean)

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
# Create the decision transformer model

state_mean = torch.from_numpy(state_mean).to(device=device)
state_std = torch.from_numpy(state_std).to(device=device)

[ 3.1382589e+00 -1.3246049e+00  4.6059570e-01  4.3193098e-02
  2.7685380e-02 -1.8177604e-02  2.4311738e-02  1.7281880e-03
 -1.6144600e-02  8.6094147e-01 -2.1195339e-02 -6.4675897e-02
 -9.2611000e-02 -1.0331314e-02  1.3281289e-03 -4.3708045e-02
  4.7704020e-01 -5.1562369e-02  2.7764983e-02  6.1281215e-02]


Function that compute predicted action from auto-regressive data

In [13]:
# Function that gets an action from the model using autoregressive prediction with a window of the previous 20 timesteps.
def get_action(model, states, actions, rewards, returns_to_go, timesteps):
    # This implementation does not condition on past rewards

    states = states.reshape(1, -1, model.config.state_dim)
    actions = actions.reshape(1, -1, model.config.act_dim)
    returns_to_go = returns_to_go.reshape(1, -1, 1)
    timesteps = timesteps.reshape(1, -1)

    states = states[:, -model.config.max_length :]
    actions = actions[:, -model.config.max_length :]
    returns_to_go = returns_to_go[:, -model.config.max_length :]
    timesteps = timesteps[:, -model.config.max_length :]
    padding = model.config.max_length - states.shape[1]
    # pad all tokens to sequence length
    attention_mask = torch.cat([torch.zeros(padding), torch.ones(states.shape[1])])
    attention_mask = attention_mask.to(dtype=torch.long).reshape(1, -1)
    states = torch.cat([torch.zeros((1, padding, model.config.state_dim)), states], dim=1).float()
    actions = torch.cat([torch.zeros((1, padding, model.config.act_dim)), actions], dim=1).float()
    returns_to_go = torch.cat([torch.zeros((1, padding, 1)), returns_to_go], dim=1).float()
    timesteps = torch.cat([torch.zeros((1, padding), dtype=torch.long), timesteps], dim=1)

    state_preds, action_preds, return_preds = model.original_forward(
        states=states,
        actions=actions,
        rewards=rewards,
        returns_to_go=returns_to_go,
        timesteps=timesteps,
        attention_mask=attention_mask,
        return_dict=False,
    )

    return action_preds[0, -1]


def saturate(action):
    lim = 1
    for i in range(len(action)):
        if (action[i] > lim):
            action[i] = lim
        elif (action[i] < -lim):
            action[i] = -lim

    return action

# Prediction loop

In [14]:
# Interact with the environment and create a video
episode_return, episode_length = 0, 0
state = env.reset()
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
rewards = torch.zeros(0, device=device, dtype=torch.float32)

timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
for t in range(max_ep_len):
    actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
    rewards = torch.cat([rewards, torch.zeros(1, device=device)])

    action = get_action(
        model,
        (states - state_mean) / state_std,
        actions,
        rewards,
        target_return,
        timesteps,
    )
    actions[-1] = action
    action = action.detach().cpu().numpy()
    action = action.reshape((-1,))
    # action limitation
    action = saturate(action)
    state, reward, done = env.step(action)

    cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
    states = torch.cat([states, cur_state], dim=0)
    rewards[-1] = reward

    pred_return = target_return[0, -1] - (reward / scale)
    target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
    timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

    episode_return += reward
    episode_length += 1

    if done:
        break

Export results to matlab for visualization

In [13]:
from scipy.io import savemat
mdic = {"x2": states.detach().numpy()}
savemat(r"..\optimal-control\result_vis.mat", mdic)