## Train surrogate model

With the previous data, train a surrogate model (gaussian process or NN)

## Model training

We use data from a file or a function generation to train a perception surrogate $g(\Delta x)$.
In this case $\Delta x$ is defined as the relative position (distance) between a tree and the drone: $\Delta x = \| x - t_i \|$

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [None]:
nn_input_dim = 3

In [None]:
def gausspdf(x, mu, sigma):
  return (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)

# Plot for testing
x = np.linspace(0, 10, 100)
y = gausspdf(x, 4, 1.3) + 0.5
fig = plt.figure()
plt.plot(x, y)
fig.show()

In [None]:
import torch
import numpy as np
from scipy.stats import norm  # For Gaussian PDF
def fake_confidence(drone_pos, tree_pos, include_yaw=True, fov=20):
    """
    Simulates the confidence output of the NN.
    """
    direction = tree_pos - drone_pos[:2]
    distance = np.linalg.norm(direction)
    result = np.full_like(distance, 0.5)
    
    if include_yaw:
        theta = drone_pos[2]
        drone_forward = np.array([np.cos(theta), np.sin(theta)])
        n_direction = direction / distance
        ang_fov_tree = n_direction.T @ drone_forward  # Dot product (cos(angolo))

        fov_threshold = np.cos(np.deg2rad(fov))
    
    if distance < 0.001 or distance > 10:
        result = 0.5
    elif include_yaw and ang_fov_tree < fov_threshold:  # Tree oustide the fov
        result = 0.5
    else:
        result = norm.pdf(distance, loc=2.5, scale=1) + 0.5
        
    return result

def generate_fake_dataset(
    samples_xy, 
    samples_yaw, 
    x_dimension, 
    use_fov=True,
    x_low=-8,          # Lower bound for drone X position
    x_high=8,          # Upper bound for drone X position
    y_low=-8,          # Lower bound for drone Y position
    y_high=8,          # Upper bound for drone Y position
    tree_low=-8,       # Lower bound for tree position
    tree_high=8        # Upper bound for tree position
):
    synthetic_X = []
    synthetic_Y = []
    
    # Define ranges for drone positions with parameterized bounds
    x_range = np.linspace(x_low, x_high, samples_xy)
    y_range = np.linspace(y_low, y_high, samples_xy)
    yaw_range = np.linspace(-np.pi, np.pi, samples_yaw)

    for x in x_range:
        for y in y_range:
            if use_fov:
                for yaw in yaw_range:
                    tree_pos = np.random.uniform(low=tree_low, high=tree_high, size=(1, 2))
                    drone_pos = np.array([x, y, yaw])
                    value = fake_confidence(drone_pos, tree_pos.flatten(), use_fov)
                    synthetic_X.append(
                        np.concatenate((drone_pos[:x_dimension], tree_pos.flatten()))
                    )
                    synthetic_Y.append(value)
            else:
                tree_pos = np.random.uniform(low=tree_low, high=tree_high, size=(1, 2))
                drone_pos = np.array([x, y, -np.arctan2(y, x)])
                value = fake_confidence(drone_pos, tree_pos.flatten(), use_fov)
                synthetic_X.append(
                    np.concatenate((drone_pos[:x_dimension], tree_pos.flatten()))
                )
                synthetic_Y.append(value)
    
    return torch.Tensor(np.array(synthetic_X)), torch.Tensor(np.array(synthetic_Y))
use_fov = True

X, Y = generate_fake_dataset(
    samples_xy=40, 
    samples_yaw=10, 
    x_dimension=3,
    use_fov=True,
    x_low=-10,        # Drone X ranges from -10 to 10
    x_high=10,
    y_low=-10,         # Drone Y ranges from -5 to 5
    y_high=10,
    tree_low=0,     # Tree ranges from -15 to 15
    tree_high=0
)

print("Features shape:", X.shape)
print("Targets shape:", Y.shape)


In [None]:
hidden_size = 64
hidden_layers = 3

# Simple NN
class MultiLayerPerceptron(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()


        self.input_layer = torch.nn.Linear(input_dim if not input_dim== 3 else input_dim+1, hidden_size)

        hidden_layers = []
        for i in range(3):
            hidden_layers.append(torch.nn.Linear(hidden_size, hidden_size))

        self.hidden_layer = torch.nn.ModuleList(hidden_layers)
        self.out_layer = torch.nn.Linear(hidden_size, 1)

    def forward(self, x):

        if x.shape[-1] == 3:  # Check if the input has 3 dimensions
            # Replace the angle with its sin and cos values
            sin_cos = torch.cat([torch.sin(x[..., -1:]), torch.cos(x[..., -1:])], dim=-1)
            x = torch.cat([x[..., :-1], sin_cos], dim=-1)
            x = self.input_layer(x)
        for layer in self.hidden_layer:
            x = torch.tanh(layer(x))
        x = self.out_layer(x)
        return x


In [None]:
# Import necessary libraries
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import datetime


nn_input_dim = 3
train = False
batch_size = 1
lr = 1e-4
epochs = 4
validation_split = 0.2

# Directory to save models
model_dir = f"saved_models_{nn_input_dim}d"
os.makedirs(model_dir, exist_ok=True)

# Prepare dataset
X = X[:,:3]
dataset = TensorDataset(X, Y)

if   train:
    # Train-validation split
    num_train_samples = int((1 - validation_split) * len(dataset))
    num_val_samples = len(dataset) - num_train_samples
    train_dataset, val_dataset = random_split(dataset, [num_train_samples, num_val_samples])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    model =  MultiLayerPerceptron(input_dim=nn_input_dim)
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    print(model)

    # Set up TensorBoard logging
    log_dir = f"runs/fcn_{datetime.datetime.now().strftime('%Y)-%m-%d_%H-%M-%S')}"
    writer = SummaryWriter(log_dir)


    # Training loop with validation
    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()

        # Train step
        train_loss = 0
        for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            y_pred = model(x_batch)
            loss = criterion(y_pred, y_batch)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)

        # Validation step
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_idx, (x_batch, y_batch) in enumerate(val_loader):
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                y_pred = model(x_batch)
                loss = criterion(y_pred, y_batch)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch [{epoch+1}/{epochs}] | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        # Log to TensorBoard
        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        
        # Save the model if validation loss improves
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            model_path = os.path.join(model_dir, f"best_model_epoch_{epoch+1}.pth")
            torch.save(model.state_dict(), model_path)
            print(f"Model saved to {model_path}")

    writer.close()

In [None]:
import re
def get_latest_best_model(model_dir):
    # Find the latest best model file
    latest_model = max(
        (f for f in os.listdir(model_dir) if re.match(r"best_model_epoch_(\d+)\.pth", f)),
        key=lambda x: int(re.match(r"best_model_epoch_(\d+)\.pth", x).group(1))
    )
    return os.path.join(model_dir, latest_model)

In [None]:
# Create subplots
fig = make_subplots(
    rows=1, cols=2,  # 1 row, 2 columns
    specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],  # Both subplots are 3D scatter plots
    subplot_titles=("Ground Truth", "Predicted")  # Titles for each subplot
)

X, Y = generate_fake_dataset(
    samples_xy=40, 
    samples_yaw=10, 
    x_dimension=3,
    use_fov=True,
    x_low=-8,        # Drone X ranges from -10 to 10
    x_high=8,
    y_low=-8,         # Drone Y ranges from -5 to 5
    y_high=8,
    tree_low=0,     # Tree ranges from -15 to 15
    tree_high=0
)

fig.add_trace(go.Scatter3d(z=Y.flatten(), x=X[:,0], y=X[:,1], name='Ground Truth',marker=dict(
        color=Y.flatten(),  # Color based on z values
        colorscale='Viridis',     # Use Viridis color map
        colorbar=dict(title='Z Value'),  # Add color bar
        size=5,                   # Marker size
        opacity=0.8               # Marker opacity
    ),), row=1, col=1)  # Add to the first subplot


model = MultiLayerPerceptron(input_dim=nn_input_dim)
model.load_state_dict(torch.load(get_latest_best_model(model_dir)))
model.eval()

X = X[:,:3]
# Model prediction
with torch.no_grad():
    y_test = model(X).detach().numpy()

# Add predicted trace to the second subplot
fig.add_trace(go.Scatter3d(z=y_test.flatten(),x=X[:, 0],y=X[:, 1],name='Predicted',mode='markers',marker=dict(
        color=y_test.flatten(),  # Color based on z values
        colorscale='Viridis',     # Use Viridis color map
        colorbar=dict(title='Z Value'),  # Add color bar
        size=5,                   # Marker size
        opacity=0.8               # Marker opacity
    )
), row=1, col=2)  # Add to the second subplot

# Update layout
fig.update_layout(
    title_text="Ground Truth vs Predicted",  # Main title
    scene=dict(  # Update the first subplot's scene
        xaxis_title="x1",
        yaxis_title="x2",
        zaxis_title="Output"
    ),
    scene2=dict(  # Update the second subplot's scene
        xaxis_title="x1",
        yaxis_title="x2",
        zaxis_title="Output"
    ),
    showlegend=False  # Hide legend for clarity
)

fig.show()

## Define the working environment

The problem definition is as follows:

$x$ state of the robot defined by the position of the drone and its velocity [x, y ,z, vx, vy, vz]

$f(x, u)$ state transition function for the drone

$u$ acceleration commands to the drone [ax, ay, az]

$t \in T$ tree positions for the $T$ trees [[tx_1, ty_2], ... , [tx_T, ty_T]]

$\lambda$ belief for the trees maturity confidence [\lambda_1, ..., \lambda_T] (values from 0 to 1). For practicity, it can be seen as part of $x$

$z$ observation vector for the tree maturity confidence [z_1, ..., z_T] (values from 0 to 1)

$g(\Delta x)$ observation surrogate. It is applied to every tree.

$b(\lambda, z)$ bayesian update to the previous belief.

$H(\lambda)$ entropy function for the belief defined (for the case of binary distribution) as: $-\lambda \log{\lambda} - (1-\lambda) \log(1-\lambda)$.

$J(\lambda)$ the cost function of the MPC defined as: $\sum_{1, ..., n} \delta_1 H(\lambda_i) * \delta_2 \Delta x_i^2 + \delta_3 \|u\|$. They correspond to trying to reduce the entropy for each of the trees using $\Delta x_i^2$ to guide the planner when there is no observation, and reduce the control inputs.

The steps of the system are as follows:
1. Load learned $g()$ which works for one tree.
2. Initialize $x$ in a $x_0$ position, $\lambda$ with $0.5$ values for each tree, and $t$ as known.
3. Run the NMPC from $x$ for $N$ iterations. In each step:
  - Compute $\Delta x$ for each tree with the new drone $x$.
  - Get estimation from NN for each tree: $z = g(\Delta x)$.
  - Fuse the estimation in $\lambda$ for each tree: $\lambda_{k} = b(\lambda_{k-1}, z)$
4. Apply the solution from the MPC.
5. Get a real observation.
6. Integrate the real observation into $\lambda$.
7. Go back to step 3.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import casadi as ca
# Define the entropy function
def entropy(lambda_val):
    return (-lambda_val * ca.log10(lambda_val) - (1 - lambda_val) * ca.log10(1 - lambda_val))/ca.log10(2)

# Generate lambda values (avoid 0 and 1 to prevent log(0) errors)
x = np.linspace(0.01, 0.99, 100)

# Compute entropy values
entropy_vals = entropy(x)

# Plot the entropy function
plt.figure(figsize=(8, 6))
plt.plot(x, entropy_vals, label="Entropy", color="blue")
plt.title("Entropy Function")
plt.xlabel("Lambda (λ)")
plt.ylabel("Entropy")
plt.grid(True)
plt.legend()
plt.show()

In [None]:
import numpy as np
import casadi as ca
import l4casadi as l4c
import time

# Constants
T = 1.0
N = 10
dt = T / N
nx = 3  # Represents position (x, y, theta) and velocity (vx, vy, omega)

def generate_tree_positions(grid_size, spacing):
    """Generate tree positions in a grid."""
    x_positions = np.arange(0, grid_size[0]*spacing, spacing)
    y_positions = np.arange(0, grid_size[1]*spacing, spacing)
    xv, yv = np.meshgrid(x_positions, y_positions)
    tree_positions = np.vstack([xv.ravel(), yv.ravel()]).T
    return tree_positions + 0.1

def get_domain(tree_positions):
    """Return the domain (bounding box) of the tree positions."""
    x_min = np.min(tree_positions[:, 0])
    x_max = np.max(tree_positions[:, 0])
    y_min = np.min(tree_positions[:, 1])
    y_max = np.max(tree_positions[:, 1])
    return [x_min,y_min], [x_max, y_max]

def kin_model(T=1.0, N=20):
    # Correct kinematic model with state [x, y, theta, vx, vy, omega]
    x = ca.MX.sym('x', nx*2)
    u = ca.MX.sym('u', nx)  # Control: [acc_x, acc_y, angular_acc]
    xf = x + dt * ca.vertcat(x[nx:], u)
    return ca.Function('F', [x, u], [xf])

# Bayesian update function
def bayes(lambda_prev, z):
    prod = lambda_prev * z
    return prod / (prod + (1 - lambda_prev) * (1 - z))


# MPC optimization using CasADi
def mpc_opt(g_nn, trees, lb, ub, x0, lambda_vals, steps=10):
    opti = ca.Opti()
    F_ = kin_model(T=1.0, N=steps)

    P0 = opti.parameter(nx*2 + len(trees))
    X = opti.variable(nx*2, steps + 1)
    U = opti.variable(nx, steps)

    opti.subject_to(X[:, 0] == P0[:nx*2])
    lambda_k = P0[nx*2:]

    for i in range(steps):
        # State constraints
        opti.subject_to(opti.bounded(lb[0]-3, X[0, i+1], ub[0]+3))
        opti.subject_to(opti.bounded(lb[1]-3, X[1, i+1], ub[1]+3))
        opti.subject_to(opti.bounded(-10, X[3, i+1], 10))  # Velocity constraints
        opti.subject_to(opti.bounded(-10, X[4, i+1], 10))  # Velocity constraints
        opti.subject_to(opti.bounded(-3.14, X[5, i+1], 3.14))  # Angular velocity
        opti.subject_to(opti.bounded(-10, U[:2, i], 10))  # Acceleration constraints
        opti.subject_to(opti.bounded(-3.14, U[2, i], 3.14))  # Angular acceleration
        opti.subject_to(X[:, i+1] == F_(X[:, i], U[:, i]))

        # Calculate relative positions and distances
        relative_pos = ca.repmat(X[:2, i+1].T, trees.shape[0], 1) - trees
        squared_dist = ca.sum2(relative_pos**2)  # Correct distance calculation
        within_range = squared_dist <= 100

        # Neural network input and prediction
        heading = ca.repmat(X[2, i+1], trees.shape[0], 1)
        nn_input = ca.horzcat(relative_pos, heading)
        g_out = g_nn(nn_input)
        already_seen = (lambda_k[:,-1] < 0.9)
        z_k = ca.fmax(g_out, 0.5) #* already_seen + 0.5 * (1- already_seen)

        # Bayesian update
        bayes_update = bayes(lambda_k[:,-1], z_k)
        lambda_k = ca.horzcat(lambda_k, bayes_update)

    # Objective function: Minimize the entropy
    obj = 0
    for i in range(steps):
        # Compute here the relative position of the drone wrt to the tree
        obj += ca.sum1(entropy(lambda_k[:, i+1]) - entropy(lambda_k[:, i]))

    obj += 1e-5 * ca.sumsqr(U[:2,:]) + 1e-8 * ca.sumsqr(U[2,:])
    opti.minimize(obj)
    options = {"ipopt": {"tol":1e-5, "warm_start_init_point" : 'no', "hessian_approximation": "limited-memory", "print_level":0, "sb": "no", "mu_strategy": "monotone", "max_iter":500}} #reduced print level
    opti.solver('ipopt', options)

    # Initial step    
    opti.set_value(P0, ca.vertcat(x0, lambda_vals))

    sol = opti.solve()

    inputs = [P0, opti.x,opti.lam_g]
    outputs = [U[:,0], X, opti.x, opti.lam_g]
    mpc_step = opti.to_function('mpc_step',inputs,outputs)

    return mpc_step, ca.DM(sol.value(U[:,0])), ca.DM(sol.value(X)), ca.DM(sol.value(opti.x)), ca.DM(sol.value(opti.lam_g))

# Main loop
def main():

    # Step 2: Initialize x in a x_0 position, lambda$ with 0.5 values for each tree, and t as known.
    trees = generate_tree_positions ([5,5],4)
    lb, ub = get_domain(trees)
    lambda_k = ca.DM.ones(len(trees)) * 0.5
    mpc_horizon = N


    # Step 1: Load learned g(.) which works for one tree.
    model = MultiLayerPerceptron(input_dim=nn_input_dim) 
    model.load_state_dict(torch.load(get_latest_best_model(model_dir)))
    model.eval()
    g_nn = l4c.L4CasADi(model, generate_jac_jac=True, batched=True, device='cuda')

    # Initialize robot state
    x0 = ca.vertcat( ca.DM.ones(nx), ca.DM.zeros(nx))

    # Log
    all_trajectories = []
    lambda_history = []
    entropy_history = []
    durations = []

    # Step 2.5: Initialize mpc
    mpc_step, u, x_, x, lam = mpc_opt(g_nn, trees,lb, ub, x0, lambda_k,mpc_horizon)

    iteration = 0

    while iteration < 700:
        iteration += 1

        # Logging phase
        lambda_history.append(lambda_k.full().flatten().tolist())
        entropy_history.append(ca.sum1(entropy(lambda_k)).full().flatten().tolist()[0])
        all_trajectories.append( x_[:nx,:].full())

        # Step 4: Apply command to the drone (update its pose)
        x0 =  x_[:,1].full().flatten()

        # Step 5: Get a real observation (simulated)

        nn_inference = []
        z_k = ca.vertcat(*[fake_confidence(x0, tree, fov=25) for tree in trees])
        
        # Step 6: Integrate observation into lambda
        lambda_k = bayes(lambda_k,z_k)    

        # Step 3: Run MPC
        start_time = time.time()
        u, x_, x, lam = mpc_step(ca.vertcat(x0, lambda_k), x, lam)
        durations.append( time.time() - start_time)
        entropy_k = ca.sum1(entropy(lambda_k)).full().flatten()[0]

        print(f"Iteration {iteration}: x={x0}, Weights sum = {entropy_k}, {lambda_k}") #log the position and weight sum
        if entropy_k < 0.5:
            break
    
    return all_trajectories, entropy_history, lambda_history, durations, g_nn, trees, lb, ub

all_trajectories, entropy_history, lambda_history, durations, g_nn, trees, lb, ub = main()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

def plot_animated_trajectory_and_entropy_2d(all_trajectories, entropy_history, lambda_history, trees, lb, ub, computation_durations):
    print(np.array(all_trajectories).shape)
    
    # Extract trajectory data
    x_trajectory = np.array([traj[0] for traj in all_trajectories])
    y_trajectory = np.array([traj[1] for traj in all_trajectories])
    theta_trajectory = np.array([traj[2] for traj in all_trajectories])  # Drone orientation (yaw)
    all_trajectories = np.array(all_trajectories)
    lambda_history = np.array(lambda_history)
    
    # Compute entropy reduction for each step
    entropy_mpc_pred = []
    for k in range(all_trajectories.shape[0]):
        lambda_k = lambda_history[k]
        entropy_mpc_pred_k = [entropy_history[k]]  # Start with the initial entropy value
        for i in range(all_trajectories.shape[2]-1):
            relative_position_robot_trees = np.tile(all_trajectories[k,:2,i+1], (trees.shape[0], 1)) - trees
            distance_robot_trees = np.sqrt(np.sum(relative_position_robot_trees**2, axis=1))
            theta =  np.tile(all_trajectories[k,2,i+1], (trees.shape[0], 1))  # Drone yaw
            input_nn = ca.horzcat(relative_position_robot_trees, theta) # horzcat(np.tile(all_trajectories[k,:3,i+1], (trees.shape[0], 1)), trees)
            z_k = (distance_robot_trees > 10) * 0.5 + (distance_robot_trees <= 10) * ca.fmax(g_nn(input_nn), 0.5)
            lambda_k = bayes(lambda_k, z_k)
            reduction = ca.sum1(entropy(lambda_k)).full().flatten()[0]
            entropy_mpc_pred_k.append(reduction)
        entropy_mpc_pred.append(entropy_mpc_pred_k)
    
    entropy_mpc_pred = np.array(entropy_mpc_pred)

    # Compute the sum of entropies for all trees at each frame
    sum_entropy_history = entropy_history

    # Compute cumulative computation durations
    cumulative_durations = np.cumsum(computation_durations)
    # Create a subplot with 2 rows and 2 columns
    fig = make_subplots(
        rows=2, cols=2,
        column_widths=[0.7, 0.3],
        row_heights=[0.6, 0.4],
        specs=[
            [{"type": "scatter"}, {"type": "scatter"}],  # First row: 2D map and entropy plot
            [{"type": "scatter"}, {"type": "scatter"}]   # Second row: empty and computation durations plot
        ]
    )

    # Add the initial trajectory (2D scatter plot) to the first subplot
    fig.add_trace(
        go.Scatter(
            x=x_trajectory[0],
            y=y_trajectory[0],
            mode="lines+markers",
            name="MPC Future Trajectory",
            line=dict(color="red", width=4),
            marker=dict(size=5, color="blue")
        ),
        row=1, col=1
    )

    fig.add_trace(
        go.Scatter(
            x=[],
            y=[],
            mode="lines+markers",
            name="Drone Trajectory",
            line=dict(color="orange", width=4),
            marker=dict(size=5, color="orange")
        ),
        row=1, col=1
    )

    # Add trees as circles with color based on lambda to the first subplot
    for i in range(trees.shape[0]):
        fig.add_trace(
            go.Scatter(
                x=[trees[i, 0]],
                y=[trees[i, 1]],
                mode="markers",
                marker=dict(
                    size=10,
                    color="#FF0000",  # Initial color (red)
                    colorscale=[[0, "#FF0000"], [1, "#00FF00"]],  # Red to green
                    cmin=0,
                    cmax=1,
                    showscale=False
                ),
                name=f"Tree {i}"
            ),
            row=1, col=1
        )

    # Add the sum of entropies plot to the second subplot (top-right)
    fig.add_trace(
        go.Scatter(
            x=[],
            y=[],
            mode="lines+markers",
            name="Sum of Entropies (Past)",
            line=dict(color="blue", width=2),
            marker=dict(size=5, color="blue")
        ),
        row=1, col=2
    )

    fig.add_trace(
        go.Scatter(
            x=[],
            y=[],
            mode="lines+markers",
            name="Sum of Entropies (Future)",
            line=dict(color="purple", width=2, dash="dot"),
            marker=dict(size=5, color="purple")
        ),
        row=1, col=2
    )

    # Add the computation durations plot to the fourth subplot (bottom-right)
    fig.add_trace(
        go.Scatter(
            x=[],
            y=[],
            mode="lines+markers",
            name="Computation Durations",
            line=dict(color="green", width=2),
            marker=dict(size=5, color="green")
        ),
        row=2, col=2
    )

    # Create frames for animation
    frames = []
    for k in range(len(entropy_mpc_pred)):
        # Update tree colors based on lambda values
        tree_data = []
        for i in range(trees.shape[0]):
            tree_data.append(
                go.Scatter(
                    x=[trees[i, 0]],
                    y=[trees[i, 1]],
                    mode="markers",
                    marker=dict(
                        size=10,
                        color=[2*(lambda_history[k][i] - 0.5)],  # Color based on lambda value
                        colorscale=[[0, "#FF0000"], [1, "#00FF00"]],  # Red to green
                        cmin=0,
                        cmax=1,
                        showscale=False
                    ),
                    name=f"Tree {i}"
                )
            )

        # Update the sum of entropies plot
        sum_entropy_past = sum_entropy_history[:k+1]
        sum_entropy_future = entropy_mpc_pred[k]

        # Update the computation durations plot
        computation_durations_past = computation_durations[:k+1]

        # Add drone orientation as an arrow
        x_start = x_trajectory[k]  # Drone x position
        y_start = y_trajectory[k]  # Drone y position
        theta = theta_trajectory[k]  # Drone yaw angle
        x_end = x_start + 0.5 * np.cos(theta)  # Arrow end x
        y_end = y_start + 0.5 * np.sin(theta)  # Arrow end y

        list_of_actual_orientations = []
        for x0,y0,x1,y1 in zip(x_start, y_start, x_end, y_end):
            arrow = go.layout.Annotation(
                dict(
                    x=x1,  # Arrow end x
                    y=y1,  # Arrow end y
                    xref="x", yref="y",
                    text="",
                    showarrow=True,
                    axref="x", ayref="y",
                    ax=x0,  # Arrow start x
                    ay=y0,  # Arrow start y
                    arrowhead=3,  # Arrowhead size
                    arrowwidth=1.5,  # Arrow width
                    arrowcolor="red",  # Arrow color
                )
            )
            list_of_actual_orientations.append(arrow)


        # Add drone orientation as an arrow
        x_start = x_trajectory[:k+1,0]  # Drone x position
        y_start = y_trajectory[:k+1,0]  # Drone y position
        theta = theta_trajectory[:k+1,0]  # Drone yaw angle
        x_end = x_start + 0.5 * np.cos(theta)  # Arrow end x
        y_end = y_start + 0.5 * np.sin(theta)  # Arrow end y

        for x0,y0,x1,y1 in zip(x_start, y_start, x_end, y_end):
            arrow = go.layout.Annotation(
                dict(
                    x=x1,  # Arrow end x
                    y=y1,  # Arrow end y
                    xref="x", yref="y",
                    text="",
                    showarrow=True,
                    axref="x", ayref="y",
                    ax=x0,  # Arrow start x
                    ay=y0,  # Arrow start y
                    arrowhead=3,  # Arrowhead size
                    arrowwidth=1.5,  # Arrow width
                    arrowcolor="orange",  # Arrow color
                )
            )
            list_of_actual_orientations.append(arrow)
        frame = go.Frame(
            data=[
                go.Scatter(
                    x=x_trajectory[k],
                    y=y_trajectory[k],
                    mode="lines+markers",
                    line=dict(color="red", width=4),
                    marker=dict(size=5, color="blue")
                ),
                go.Scatter(
                    x=x_trajectory[:k+1,0],
                    y=y_trajectory[:k+1,0],
                    mode="lines+markers",
                    line=dict(color="orange", width=4),
                    marker=dict(size=5, color="orange")
                ),
                *tree_data,  # Add tree data for this frame
                go.Scatter(
                    x=np.arange(len(sum_entropy_past)),
                    y=sum_entropy_past,
                    mode="lines+markers",
                    line=dict(color="blue", width=2),
                    marker=dict(size=5, color="blue")
                ),
                go.Scatter(
                    x=np.arange(k, k+len(sum_entropy_future)),
                    y=sum_entropy_future,
                    mode="lines+markers",
                    line=dict(color="purple", width=2, dash="dot"),
                    marker=dict(size=5, color="purple")
                ),
                go.Scatter(
                    x=np.arange(len(computation_durations_past)),
                    y=computation_durations_past,
                    mode="lines+markers",
                    line=dict(color="green", width=2),
                    marker=dict(size=5, color="green")
                )
            ],
            name=f"Frame {k}",
            layout=dict(annotations=list_of_actual_orientations)  # Add the annotation and arrow to the frame
        )
        frames.append(frame)

    # Add frames to the figure
    fig.frames = frames

    # Update layout for the subplots
    fig.update_layout(
        title="Drone Trajectory, Sum of Entropies, and Computation Durations",
        xaxis=dict(title="X Position", range=[lb[0] - 3 , ub[0] + 3]), 
        yaxis=dict(title="Y Position", range=[lb[1] - 3 , ub[1] + 3]),
        xaxis2=dict(title="Time Step"),
        yaxis2=dict(title="Sum of Entropies"),
        xaxis3=dict(title="Time Step"),
        yaxis3=dict(title="Computation Duration (s)"),
        updatemenus=[
            dict(
                type="buttons",
                buttons=[
                    dict(
                        label="Play",
                        method="animate",
                        args=[None, {"frame": {"duration": 200, "redraw": True}, "fromcurrent": True}]
                    ),
                    dict(
                        label="Pause",
                        method="animate",
                        args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}]
                    )
                ],
                showactive=True,
                x=0.1,
                y=0
            )
        ],
        sliders=[{
            "active": 0,
            "yanchor": "top",
            "xanchor": "left",
            "currentvalue": {
                "font": {"size": 20},
                "prefix": "Frame:",
                "visible": True,
                "xanchor": "right"
            },
            "transition": {"duration": 50, "easing": "cubic-in-out"},
            "pad": {"b": 10, "t": 50},
            "len": 0.9,
            "x": 0.1,
            "y": 0,
            "steps": [
                {
                    "args": [[f.name], {"frame": {"duration": 50, "redraw": True}, "mode": "immediate"}],
                    "label": str(k),
                    "method": "animate",
                }
                for k, f in enumerate(fig.frames)
            ],
        }]
    )

    # Show the figure
    fig.show()
    fig.write_html('neural_mpc_results.html')
    return entropy_mpc_pred
# Call the function to plot the animation
entropy_mpc_pred = plot_animated_trajectory_and_entropy_2d(all_trajectories, entropy_history, lambda_history, trees, lb, ub, durations)
