In [227]:
import jax.numpy as jnp
import jax
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

# from aa598.hw2_helper import simulate_dynamics
from cbfax.dynamics import DynamicallyExtendedSimpleCar
import torch
from jax import grad
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import math

In [228]:
@jax.jit
def simulate_dynamics(dynamics, state, controls, dt):
    T = controls.shape[0]
    states = [state]
    for c in controls:
        state = dynamics.discrete_step(state, c, 0., dt)
        states.append(state)
    return jnp.stack(states)

In [229]:
radius = 1.
v_max = 2.0
acceleration_max = 1.0
acceleration_min = -1.0
steering_max = 0.3
steering_min = -0.3

In [230]:
@jax.jit
def evaluate_trajectory_cost(robot_states, robot_controls, human_states_samples, coeff=[0.1, 0.3, 0.5, 0.4, 3., 3.]):
    # lower is better
    steering = robot_controls[:,1]
    acceleration = robot_controls[:,0]
    # steering effort
    turning_effort = (steering**2).mean() 
    # acceleration effort
    acceleration_effort = (acceleration**2).mean() 
    # speed limit
    speed = jax.nn.relu(robot_states[:,-1].max() - v_max) + jax.nn.relu(-robot_states[:,-1].min()) 
    # progress to goal
    progress = robot_states[-1,2]**2 + (robot_states[:,1]**2).mean()
    # collision 
    collision = jax.nn.relu(-(jnp.linalg.norm((robot_states - human_states_samples)[:,:,:2], axis=-1).min(-1) - radius).mean())
    # control limits
    control_limits = jax.nn.relu(steering.max() - steering_max) + jax.nn.relu(steering_min - steering.min()) + jax.nn.relu(acceleration.max() - acceleration_max) + jax.nn.relu(acceleration_min - acceleration.min()) 

    return jnp.dot(jnp.array(coeff), jnp.array([turning_effort, acceleration_effort, speed, progress, collision, control_limits]))
    


In [231]:
planning_horizon = 25 # planning horizon to compute cost over
n_human_samples = 64 # number of human future trajectories to sample
n_robot_samples = 32 # number of robot trajectories to sample for MPPI
dt = 0.1 # timestep size
num_iterations = 20 # number of MPPI iteraciotns
num_time_steps = 50 # number of timesteps to simulate
human_control_prediction_noise_limit = 0.25
human_control_prediction_variance = 0.25
robot_control_noise_limit = 0.25
robot_control_noise_variance = 0.25


num_of_data = 24
# making datasets
robot_trajectory_dataset = []
human_trajectory_dataset = []
robot_controls_list_dataset = []
past_state_data_set = [] # decision context # include past sate of both robot and human
human_controls_list_dataset = []

for i in range(num_of_data):

    robot = DynamicallyExtendedSimpleCar() # robot dynamics
    human = DynamicallyExtendedSimpleCar() # human dynamics

    # initial states
    if i < num_of_data / 2:
        robot_state = jnp.array([-3.0 - 0.02 * i, 0 + 0.02*i, 0., 1.])
        human_state = jnp.array([-1., -2., jnp.pi/2., 1.])
    else:
        robot_state = jnp.array([1.0 + 0.02 * (i/2), 0 + 0.02*(i/2), jnp.pi, 1.])
        human_state = jnp.array([-1., -2., jnp.pi/2., 1.])

    # robot_state = jnp.array([-3.0, -0., np.random.uniform(-np.pi/4, np.pi/4), 1.])
    # human_state = jnp.array([-1., -2., jnp.pi/2., 1.])
    # human_state = jnp.array([-1., -2., np.random.uniform(-np.pi/4, np.pi/4), 1.])

    # robot_state = jnp.array([-3.0 + 0.02 * i, 0 + 0.02*i, 0., 1.])
    # human_state = jnp.array([-1., -2., jnp.pi/2., 1.])


    # nominal controls
    robot_nominal_controls = jnp.zeros([planning_horizon, robot.control_dim])
    # assume human wants to follow a constant velocity mode (i.e., zero control input)
    human_nominal_controls = jnp.zeros([planning_horizon, human.control_dim])

    # making lists of things for plotting later
    robot_trajectory = [robot_state]
    human_trajectory = [human_state]
    robot_controls_list = []
    human_controls_list = []
    human_samples = []
    robot_nominal_controls_list = [robot_nominal_controls]

    coeffs = [0.2, 0.1, 5., 10., 55., 5.]   # <----- try different values!

    for ti in range(num_time_steps):
        # very simple human prediction model -- just gaussian noise about a constant velocity model.
        dus = jnp.clip(jnp.array(np.random.randn(n_human_samples, planning_horizon, human.control_dim) * human_control_prediction_variance), -human_control_prediction_noise_limit, human_control_prediction_noise_limit)
        human_controls_samples = jnp.clip(human_nominal_controls + dus, min=jnp.array([acceleration_min, steering_min]), max=jnp.array([acceleration_max, steering_max]))
        human_states_samples = jax.vmap(simulate_dynamics, [None, None, 0, None])(human, human_state, human_controls_samples, dt)
        human_samples.append(human_states_samples)
        
        for t in range(num_iterations):
            temperature = 1 - (t / num_iterations)
            # sampling robot control trajectories
            dus = jnp.clip(jnp.array(np.random.randn(n_robot_samples, planning_horizon, robot.control_dim) * robot_control_noise_variance), -robot_control_noise_limit, robot_control_noise_limit)
            robot_controls_samples = jnp.clip(robot_nominal_controls + dus, min=jnp.array([acceleration_min, steering_min]), max=jnp.array([acceleration_max, steering_max]))
            # simulate robot trajectory for each control trajectory sample
            robot_states_samples = jax.vmap(simulate_dynamics, [None, None, 0, None])(robot, robot_state, robot_controls_samples, dt)
            # evaluate cost of each robot trajectory sample
            trajectory_costs = jax.vmap(evaluate_trajectory_cost, [0, 0, None, None])(robot_states_samples, robot_controls_samples, human_states_samples, coeffs)
            # weight for each trajectory sample
            weights = jax.nn.softmax(-trajectory_costs / temperature).reshape([-1, 1, 1])
            # compute new nominal control using weighted sum
            
            robot_nominal_controls = jnp.clip(robot_nominal_controls + (dus * weights).sum(0), min=jnp.array([acceleration_min, steering_min]), max=jnp.array([acceleration_max, steering_max]))
            
        # use final nominal control to step forward in time by one step
        robot_nominal_controls_list.append(robot_nominal_controls)
        robot_state = simulate_dynamics(robot, robot_state, robot_nominal_controls[:1], dt)[-1]
        human_state = simulate_dynamics(human, human_state, human_controls_samples[0][:1], dt)[-1]
        # collect the new state and controls for plotting purposes
        robot_trajectory.append(robot_state)
        human_trajectory.append(human_state)

        robot_controls_list.append(robot_nominal_controls[:1])
        human_controls_list.append(human_controls_samples[0][:1])

    # turn things into jnp.array
    robot_trajectory = jnp.stack(robot_trajectory)
    human_trajectory = jnp.stack(human_trajectory)
    human_samples = jnp.stack(human_samples)
    robot_controls_list = jnp.concatenate(robot_controls_list, 0)
    human_controls_list = jnp.concatenate(human_controls_list, 0)

    # convert the robot_traj into desire form [action, control]
    robot_trajectory = np.array(robot_trajectory)
    robot_controls_list = np.array(robot_controls_list)
    

    traj_mat = np.hstack((robot_trajectory[9:50], robot_controls_list[9:]))
    traj_mat = traj_mat.T
    print(traj_mat.shape)

    human_traj_mat = np.hstack((human_trajectory[9:50], human_controls_list[9:]))
    human_traj_mat = human_traj_mat.T

    decision_context = np.hstack((robot_trajectory[:9], human_trajectory[:9]))
    decision_context = decision_context.T

    robot_trajectory_dataset.append(traj_mat)
    past_state_data_set.append(decision_context)


    human_trajectory_dataset.append(human_traj_mat)
    robot_controls_list_dataset.append(robot_controls_list)
    human_controls_list_dataset.append(human_controls_list)

(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)
(6, 41)


In [232]:
# converting datasets to np array
robot_trajectory_dataset = np.array(robot_trajectory_dataset)
past_state_data_set = np.array(past_state_data_set)
human_trajectory_dataset = np.array(human_trajectory_dataset)
robot_controls_list_dataset = np.array(robot_controls_list_dataset)




In [233]:
# picking a history for simulation. 

random_index = np.random.randint(0, num_of_data)

ego_initial_nominal_control = human_controls_list_dataset[random_index][9]

actor_past_states = past_state_data_set[random_index][:4, :]

# used as prediction for calculating the guidance cost. 
human_future = human_trajectory_dataset[random_index]

print(human_future.shape)
ego_past_states = past_state_data_set[random_index][-4:]

print("actor_past_states: ", actor_past_states)
print("ego past state: ", ego_past_states.shape)

(6, 41)
actor_past_states:  [[1.21       1.1072052  0.9993344  0.8859301  0.767802   0.6432161
  0.50970197 0.37105933 0.22729968]
 [0.21       0.21093018 0.21430437 0.22039686 0.22625376 0.23075113
  0.2378351  0.2505655  0.26495788]
 [3.1415927  3.123495   3.0971506  3.0786908  3.105414   3.1056058
  3.0715632  3.0284922  3.0551295 ]
 [1.         1.0560086  1.1025245  1.1688643  1.1966707  1.2966707
  1.3774959  1.4072367  1.4824146 ]]
ego past state:  (4, 9)


In [234]:

# forming the dataset. 

class TrajectoryDataset(Dataset):
    def __init__(self, past_state_data, robot_trajectory_data, human_trajectory_data):
        self.past_state_data = [torch.tensor(s) for s in past_state_data]
        self.robot_trajectory = [torch.tensor(t) for t in robot_trajectory_data]
        self.human_trajectory = [torch.tensor(h) for h in human_trajectory_data]
    
    def __len__(self):
        return len(self.human_trajectory) 

    def __getitem__(self, idx):
        return {
            "robot_and_human_past_state": self.past_state_data[idx],
            "robot_trajectory": self.robot_trajectory[idx],
            "human_trajectory": self.human_trajectory[idx]
        }

# Example data

# Create dataset and dataloader
dataset = TrajectoryDataset(past_state_data_set, robot_trajectory_dataset, human_trajectory_dataset)
dataloader = DataLoader(dataset, batch_size=3, shuffle=True)


In [235]:
# From: https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL?usp=sharing#scrollTo=buW6BaNga-XH


import torch.nn.functional as F

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    print(t.dtype)
    out = vals.gather(-1, t)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

def forward_diffusion_sample(x_0, t):
    """ 
    Takes an image and a timestep as input and 
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t * x_0 \
    + sqrt_one_minus_alphas_cumprod_t * noise, noise


# Define beta schedule
T = 100
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

In [236]:
# adpoting from: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py to include decision context as input. 

class SinusoidalPositionEmbeddings(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings
    


# Unet for predicting noise
class DoubleConv(torch.nn.Module):

    def __init__(self, in_features, out_features, mid_features=None):
        super().__init__()
        if not mid_features:
            mid_features = out_features
        self.double_conv = torch.nn.Sequential(
            torch.nn.Conv1d(in_features, mid_features, kernel_size=3, padding=1, bias=False),
            torch.nn.BatchNorm1d(mid_features),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv1d(mid_features, out_features, kernel_size=3, padding=1, bias=False),
            torch.nn.BatchNorm1d(out_features),
            torch.nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(torch.nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_features, out_features):
        super().__init__()
        self.maxpool_conv = torch.nn.Sequential(
            torch.nn.MaxPool1d(2),
            DoubleConv(in_features, out_features)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
    

class Up(torch.nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_features, out_features, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of features
        if bilinear:
            self.up = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_features, out_features, in_features // 2)
        else:
            self.up = torch.nn.ConvTranspose1d(in_features, in_features // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_features, out_features)

    def forward(self, x1, x2):
        x1 = self.up(x1)
 
        diffY = x2.size()[1] - x1.size()[1]
        diffX = x2.size()[2] - x1.size()[2]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super(OutConv, self).__init__()
        self.conv = torch.nn.Conv1d(in_features, out_features, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(torch.nn.Module):
    def __init__(self, robot_feature, time_size, c_features, bilinear=False):
        # robot_feature: 6.  num of feature in robots trajectory matrix (size: 6 x 50)
        # time_size: 1.  time is the size of a time step, it is an integer
        # c_features: 4. num of feature in human trajectory matrix (size: 4 x 51)


        super(UNet, self).__init__()
        # self.n_channels = n_channels
        # self.n_classes = n_classes
        self.bilinear = bilinear

        # self.inc = (DoubleConv(robot_feature + time_size + c_features, 50))
        self.inc = (DoubleConv(robot_feature + c_features, 50))
        self.down1 = (Down(50, 100))
        self.down2 = (Down(100, 200))
        self.down3 = (Down(200, 400))
        factor = 2 if bilinear else 1
        self.down4 = (Down(400, 800 // factor))
        self.up1 = (Up(800, 400 // factor, bilinear))
        self.up2 = (Up(400, 200 // factor, bilinear))
        self.up3 = (Up(200, 100 // factor, bilinear))
        self.up4 = (Up(100, 50, bilinear))
        self.outc = (OutConv(50, robot_feature))

        self.time_mlp = torch.nn.Sequential(
                SinusoidalPositionEmbeddings(42),
                torch.nn.Linear(42, 41),
                torch.nn.ReLU()
        )
        # self.decision_context_embedded = torch.nn.Linear(9, 41)
        self.decision_context_embedded = torch.nn.Sequential(
            torch.nn.Linear(9, 41),
            torch.nn.ReLU()
        )

        self.dropout = torch.nn.Dropout(0.50)
        self.Relu = torch.nn.ReLU6()
        self.tanh = torch.nn.Tanh()

    def forward(self, x, t, c):

        # embedded time
        # t = torch.tensor([t], dtype=torch.float32).view(1, 1)
        # t_embedded = self.time_mlp(torch.tensor([t], dtype=torch.float32).view(1, 1))
        t_embedded = self.Relu(self.time_mlp(torch.tensor([t], dtype=torch.float32)))
        c_embedded = self.Relu(self.decision_context_embedded(c))
        t_embedded = t_embedded.expand(x.shape[0], 6, 41)
        # c_embedded = c_embedded.expand(x.shape[0], 6, 41)
        # t_embedded = t_embedded.expand(-1, 6, 1)


        print("t size: ", t_embedded.size())
        print("c size: ",c_embedded.size())
        x = x + t_embedded
        # x = torch.hstack((x, t_embedded, c_embedded))
        x = torch.hstack((x, c_embedded))
        # x = torch.cat((x, t_embedded, c_embedded))
        print("x after cat size is:", x.size())

        x = F.normalize(x)

        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.dropout(x)
        output = self.outc(x)
        print("output shape is:" , output.size())
        # output = self.tanh(output)
        return output





        



In [237]:
# Model initialization. 
model = UNet(6, 1, 8)

In [238]:

def get_loss(model, x_0, t, c):
    x_noisy, noise = forward_diffusion_sample(x_0, t)
    noise_pred = model(x_noisy, t, c)
    return F.l1_loss(noise, noise_pred)

In [239]:

# Forward sampling function. 
@torch.no_grad()
def sample_timestep(x, t, c):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    print(type(x))
    print(type(c))
    print(type(t))
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t, c) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        # As pointed out by Luis Pereira (see YouTube comment)
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

In [240]:

# training loop. 
learning_rate = 1e-5
num_epochs = 38
# train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.95)
criterion = torch.nn.MSELoss()

for epoch in range(num_epochs):
    for idx, batch in enumerate(dataloader):
      optimizer.zero_grad()

      t = torch.randint(0, T, (1,)).long()

      # t = np.random.randint(0, T)
      print(t)
      loss = get_loss(model, batch['robot_trajectory'], t,  batch['robot_and_human_past_state'])
      print(loss)
      loss.backward()
      optimizer.step()

tensor([90])
torch.int64
torch.int64
t size:  torch.Size([3, 6, 41])
c size:  torch.Size([3, 8, 41])
x after cat size is: torch.Size([3, 14, 41])
output shape is: torch.Size([3, 6, 41])
tensor(0.9396, grad_fn=<MeanBackward0>)
tensor([59])
torch.int64
torch.int64
t size:  torch.Size([3, 6, 41])
c size:  torch.Size([3, 8, 41])
x after cat size is: torch.Size([3, 14, 41])
output shape is: torch.Size([3, 6, 41])
tensor(0.9133, grad_fn=<MeanBackward0>)
tensor([16])
torch.int64
torch.int64
t size:  torch.Size([3, 6, 41])
c size:  torch.Size([3, 8, 41])
x after cat size is: torch.Size([3, 14, 41])
output shape is: torch.Size([3, 6, 41])
tensor(0.9024, grad_fn=<MeanBackward0>)
tensor([31])
torch.int64
torch.int64
t size:  torch.Size([3, 6, 41])
c size:  torch.Size([3, 8, 41])
x after cat size is: torch.Size([3, 14, 41])
output shape is: torch.Size([3, 6, 41])
tensor(0.9445, grad_fn=<MeanBackward0>)
tensor([8])
torch.int64
torch.int64
t size:  torch.Size([3, 6, 41])
c size:  torch.Size([3, 8, 4

In [241]:
def np_arr_jnp_arr(np_arr):
    jnp_arr = jnp.zeros(len(np_arr))
    for i in range(len(np_arr)):
        jnp_arr = jnp_arr.at[i].set(np_arr[i][0])
    return jnp_arr




In [None]:
# Constants
lambda_t = 1.0  # Hyperparameter for time-to-collision (TTC)
lambda_d = 1.0  # Hyperparameter for collision distance
alpha = 1.0  # Step size for gradient-based updates


def calculate_distance(traj1, traj2):
    return torch.norm(traj1[:2, :] - traj2[:2, :], dim=0)

def collision_cost(traj_adv, traj_ego):
    distances = calculate_distance(traj_adv, traj_ego)
    return -torch.sum(distances)








In [243]:
# Calculate the guidance function cost. 

def guidance_cost(traj_ego, traj_adv, init_state, steps, dt=0.1):
    # Extract controls
    a = traj_adv[4, :]  # Acceleration
    omega = traj_adv[5, :]  # Steering rate

    # Initialize trajectory
    predicted_traj = [init_state]
    state = init_state.clone()

    for t in range(steps):
        x, y, theta, v = state[0], state[1], state[2], state[3]
        
        # Update the state based on dynamics
        x_next = x + v * torch.cos(theta) * dt
        y_next = y + v * torch.sin(theta) * dt
        theta_next = theta + omega[t] * dt
        v_next = v + a[t] * dt
        
        # Create the next state vector
        next_state = torch.tensor([x_next, y_next, v_next, theta_next])
        
        # Append to the trajectory and update current state
        predicted_traj.append(next_state)
        state = next_state

    predicted_traj =  torch.stack(predicted_traj, dim=1)

    # calculate the cost
    Jcol = collision_cost(predicted_traj[:, :-1], traj_ego)
    return Jcol
    

    # return torch.stack(predicted_traj, dim=1)

In [None]:
# simulation
planning_horizon = 25 # planning horizon to compute cost over
n_ego_samples = 64 # number of human future trajectories to sample
n_robot_samples = 32 # number of robot trajectories to sample for MPPI
dt = 0.1 # timestep size
# K = 100 # number of diffusion steps
num_time_steps = 50 # number of timesteps to simulate
ego_control_prediction_noise_limit = 0.25
ego_control_prediction_variance = 0.25
robot_control_noise_limit = 0.25
robot_control_noise_variance = 0.25
guidance_step = 10 # num_of steps
guidance_weight = 20.

model.eval()

ego = DynamicallyExtendedSimpleCar() # robot dynamics
actor = DynamicallyExtendedSimpleCar() # human dynamics
# actor2 = DynamicallyExtendedSimpleCar()

# initial states
actor_state = actor_past_states[:, -1:]
ego_state = ego_past_states[:, -1:]

# convert initial states to jnp array

actor_state = np_arr_jnp_arr(actor_state)
ego_state = np_arr_jnp_arr(ego_state)

# nominal controls
ego_nominal_controls =  jnp.zeros([planning_horizon, ego.control_dim])

# initialize trajectory
actor_trajectory = []
actor2_trajectory = []
ego_trajectory = [] 

for i in range(9):
    # print(ego_past_states[:, i])
    past_states = actor_past_states[:, i]
    jrr = jnp.array([past_states[0], past_states[1], past_states[2], past_states[3]])
    actor_trajectory.append(jrr)



for i in range(9):
    past_states = ego_past_states[:, i]
    jrr = jnp.array([past_states[0], past_states[1], past_states[2], past_states[3]])
    ego_trajectory.append(jrr)




ego_controls_list = []
actor_controls_list = []
actor2_controls_list = []
ego_samples = []
actor_nominal_controls_list = []

coeffs = [0.2, 0.1, 5., 2., 15., 5.]   # <----- try different values!





In [None]:
actor_hist = actor_trajectory[-9:]
# actor2_hist = actor2_trajectory[-9:]
ego_hist = ego_trajectory[-9:]

actor_hist = np.array(actor_hist)
# actor2_hist = np.array(actor2_hist)
ego_hist = np.array(ego_hist)

actor_hist = actor_hist.T
# actor2_hist = actor2_hist.T
ego_hist = ego_hist.T
    # forming decision context
c = np.vstack((actor_hist, ego_hist))
c = np.float32(c)

# c2 = np.vstack((actor2_hist, ego_hist))
# c2 = np.float32(c2)


c = torch.from_numpy(c).unsqueeze(0)

t_K = torch.randn_like(torch.from_numpy(traj_mat))
for k in range(T - 1, 0, -1):
    k = torch.Tensor([k]).to(torch.int64)
    traj_hat = sample_timestep(t_K.unsqueeze(0), k, c)
    t_K = traj_hat.squeeze(0)

# do a guidance after the diffusion
t_K.requires_grad_()
cost = guidance_cost(t_K, torch.from_numpy(human_future), torch.from_numpy(actor_hist[:, -1].T), 41, 0.1)

# Compute gradients of the cost with respect to the adversarial trajectory
cost.backward()  # Backpropagation

# Update trajectory using gradient descent
with torch.no_grad():  # Disable gradient tracking for updates
    t_K -= alpha * t_K.grad * betas[-1]




for ti in range(num_time_steps):


    # very simple human prediction model -- just gaussian noise about a constant velocity model.
    dus = jnp.clip(jnp.array(np.random.randn(n_ego_samples, planning_horizon, ego.control_dim) * ego_control_prediction_variance), -ego_control_prediction_noise_limit, ego_control_prediction_noise_limit)
    ego_controls_samples = jnp.clip(ego_nominal_controls + dus, min=jnp.array([acceleration_min, steering_min]), max=jnp.array([acceleration_max, steering_max]))
    ego_states_samples = jax.vmap(simulate_dynamics, [None, None, 0, None])(ego, ego_state, ego_controls_samples, dt)
    ego_samples.append(ego_states_samples)
    
    if(ti < 41):
    # extract controls from the denoised traj
        actor_nominal_controls = t_K.T[ti][-2:]
        actor_nominal_controls = actor_nominal_controls.detach().numpy()
        # actor_nominal_controls = actor_nominal_controls[np.newaxis, :]
        # actor_nominal_controls = np.array([[actor_nominal_controls[0], actor_nominal_controls[1]]])
        actor_nominal_controls = np.array([[actor_nominal_controls[0], actor_nominal_controls[1]]])

    else:
        actor_nominal_controls = np.array([[0., 0.]])
        # actor2_nominal_controls = np.array([[0., 0.]])

    # actor_nominal_controls = actor_nominal_controls /1e19
    print("nominal control is:", actor_nominal_controls)
    # update both ego_state and actor state

    actor_nominal_controls_list.append(actor_nominal_controls)
    actor_state = simulate_dynamics(actor, actor_state, actor_nominal_controls, dt)[-1]
    actor_trajectory.append(actor_state)



    ego_state = simulate_dynamics(ego, ego_state, ego_controls_samples[0][:1], dt)[-1]
    ego_trajectory.append(ego_state)

    
    actor_controls_list.append(actor_nominal_controls[:1])
    ego_controls_list.append(ego_controls_samples[0][:1])

    # print(type(actor_trajectory))
    


actor_trajectory = jnp.stack(actor_trajectory) 
ego_trajectory = jnp.stack(ego_trajectory)
ego_samples = jnp.stack(ego_samples)
actor_controls_list = jnp.concatenate(actor_controls_list, 0)
ego_controls_list = jnp.concatenate(ego_controls_list, 0)







torch.int64
torch.int64
torch.int64
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
t size:  torch.Size([1, 6, 41])
c size:  torch.Size([1, 8, 41])
x after cat size is: torch.Size([1, 14, 41])
output shape is: torch.Size([1, 6, 41])
torch.int64
torch.int64
torch.int64
torch.int64
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
t size:  torch.Size([1, 6, 41])
c size:  torch.Size([1, 8, 41])
x after cat size is: torch.Size([1, 14, 41])
output shape is: torch.Size([1, 6, 41])
torch.int64
torch.int64
torch.int64
torch.int64
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
t size:  torch.Size([1, 6, 41])
c size:  torch.Size([1, 8, 41])
x after cat size is: torch.Size([1, 14, 41])
output shape is: torch.Size([1, 6, 41])
torch.int64
torch.int64
torch.int64
torch.int64
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
t size:  torch.Size([1, 6, 41])
c size:  torch.Size([1, 8, 41])
x after cat size is: torch.Size([

In [246]:
print(actor_controls_list)

[[-0.44380197  2.6355915 ]
 [ 0.44575864 -0.86775434]
 [ 7.400762   -0.39887813]
 [-3.1629848  -5.2904167 ]
 [-2.0438373  -3.311776  ]
 [-1.8574125  -0.57980657]
 [-0.10201012  0.04044207]
 [-0.25121862  1.8098193 ]
 [-0.60293555  0.3534741 ]
 [ 0.54684335  1.1396576 ]
 [ 2.2218707  -1.9021851 ]
 [-1.0470226   2.1873844 ]
 [ 2.8846326   1.7082258 ]
 [ 0.6935619  -2.8105428 ]
 [-0.03141943 -0.23449333]
 [ 0.15977652  3.0804803 ]
 [-1.6798009  -0.18413933]
 [ 1.9563813  -5.0715876 ]
 [-2.4823916  -3.670114  ]
 [ 2.742667    3.0534825 ]
 [ 3.885713   -4.02361   ]
 [-7.8850384  -1.7495539 ]
 [ 1.7982607  -1.9125365 ]
 [-1.1551977  -4.6857023 ]
 [ 5.352468   -2.5532942 ]
 [-2.1979795  -2.721492  ]
 [ 1.1990325   0.69400465]
 [-1.0307146  -1.9418298 ]
 [-2.2916489   0.93652576]
 [ 2.2285962  -3.588104  ]
 [-0.3012235  -0.08193616]
 [ 1.4099929   0.4263298 ]
 [-1.3897403  -3.2139916 ]
 [-0.11712819  1.0543668 ]
 [ 0.22934791 -4.8204155 ]
 [-1.9013462  -0.04779878]
 [ 2.8692763   0.87291604]
 

In [247]:
@interact(i=(0,num_time_steps-1))
def plot(i):
    fig, axs = plt.subplots(1,2, figsize=(18,8))
    ax = axs[0]
    actor_position = actor_trajectory[i, :2]
    # actor2_position = actor2_trajectory[i, :2]
    ego_position = ego_trajectory[i, :2]
    circle1 = plt.Circle(actor_position, radius / 2, color='C0', alpha=0.4)
    # circle1_ = plt.Circle(actor2_position, radius / 2, color='C2', alpha = 0.4)
    circle2 = plt.Circle(ego_position, radius / 2, color='C1', alpha=0.4)
    ax.add_patch(circle1)
    # ax.add_patch(circle1_)
    ax.add_patch(circle2)
    ax.plot(ego_samples[i,:,:,0].T, ego_samples[i,:,:,1].T, "o-", alpha=0.1, markersize=2, color='C1')
    ax.plot(actor_trajectory[:,0], actor_trajectory[:,1], "o-", markersize=3, color='C0')
    # ax.plot(actor2_trajectory[:,0], actor2_trajectory[:,1], "o-", markersize=3, color='C2')
    ax.plot(ego_trajectory[:,0], ego_trajectory[:,1], "o-", markersize=3, color='C1')
    ax.scatter(actor_trajectory[i:i+1,0], actor_trajectory[i:i+1,1], s=30, color='C0', label="Actor")
    # ax.scatter(actor2_trajectory[i:i+1,0], actor2_trajectory[i:i+1,1], s=30, color='C2', label="Actor2")
    ax.scatter(ego_trajectory[i:i+1,0], ego_trajectory[i:i+1,1], s=30, color='C1', label="Ego")
    ax.grid()
    ax.legend()
    ax.axis("equal")
    ax.set_xlim([-4,4])
    ax.set_ylim([-3, 6])

    ax.set_title("heading=%.2f velocity=%.2f" % (actor_trajectory[i,2], actor_trajectory[i,3]))
    
    ax = axs[1]
    plt.plot(actor_controls_list)
    plt.scatter([i], actor_controls_list[i:i+1, 0], label="Acceleration")
    plt.scatter([i], actor_controls_list[i:i+1, 1], label="Steering")
    ax.plot(actor_trajectory[:,-1], "o-", markersize=3, color='C0', label="Velocity")

    ax.legend()
    ax.grid()


interactive(children=(IntSlider(value=24, description='i', max=49), Output()), _dom_classes=('widget-interact'â€¦