In [None]:
import numpy as np
import dill
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.collections import LineCollection
import matplotlib.patches as patches
import auxiliaries as aux

In [None]:
class EmptyAgent:
    def __init__(self):
        pass

def load_data(path):
    import dill
    with open(path, 'rb') as f:
        data = dill.load(f)
        
    print(f"Data loaded from {path}.")
    print(f"Number of agents: {data['MAS_parameters']['num_agents']}")
    print(f"MAS type: {data['MAS_parameters']['MAS_type']}")
    
    agents = []
    for agentid in data['agents']:
        agent = EmptyAgent()
        agent.id = agentid
        agent.cl_x = data['agents'][agentid]['cl_x']
        agent.cl_u = data['agents'][agentid]['cl_u']
        agents.append(agent)
        
    return data, agents

def extract_keys(d, parent_key=""):
    keys = set()
    
    if isinstance(d, dict):
        for key, value in d.items():
            full_key = f"{parent_key}.{key}" if parent_key else key
            keys.add(full_key)
            keys.update(extract_keys(value, full_key))
    
    return keys

In [None]:
"""Select colours for plotting."""
colours = [
    "#0072B2",  # blue
    "#D55E00",  # orange
    "#009E73",  # green
    "#CC79A7",  # magenta
    "#56B4E9",  # light blue
    "#E69F00",  # yellow-orange
    "#B22222",  # red
    "#6A3D9A",  # purple
    "#117733",  # teal green
    "#88CCEE",  # cyan
    "#DDCC77",  # muted yellow-orange
]

In [None]:
path = "./data/satellite_constellation_data.dill"     # Specify the path to your data file.
animate = False                                       # Set to True if you want to generate an animation (requires ffmpeg).

data, agents = load_data(path)
for key in data['sim_pars']:
    print(f"{key}:", data['sim_pars'][key])

In [None]:
"""Extract and transform data."""
max_sim_time = data['sim_data']['max_sim_time']

if type(data['sim_data']['cooperative_cost']) is list:
    data['sim_data']['cooperative_cost'] = np.vstack(data['sim_data']['cooperative_cost']).flatten()
    data['sim_data']['tracking_cost'] = np.vstack(data['sim_data']['tracking_cost']).flatten()
    data['sim_data']['change_cost'] = np.vstack(data['sim_data']['change_cost']).flatten()
    data['sim_data']['J'] = np.vstack(data['sim_data']['J']).flatten()
    
for agent in agents:
    if type(agent.cl_x) == list:
        agent.cl_x = np.hstack(agent.cl_x)
        agent.cl_u = np.hstack(agent.cl_u)
    agent.r0 = data['MAS_parameters']['r0']

sf = data['MAS_parameters']['scaling_factor']
r0 = data['MAS_parameters']['r0']

In [None]:
"""Plot the value function."""
# Plot from t1 to t2.
t1 = 0
t2 = max_sim_time+1

# Select a feasible start time (the end time is controlled below).
t1 = min(t1, max_sim_time+1)

# Draw the evolution in state space:
fig_V, ax_V = plt.subplots(figsize=(10, 6), num='state evolution')

stop_time = data['sim_data']['cooperative_cost'][t1:t2].shape[0]
ax_V.plot(range(t1, min(t2, stop_time)), data['sim_data']['cooperative_cost'][t1:t2], label='cooperative', color=colours[0])
ax_V.plot(range(t1, min(t2, stop_time)), data['sim_data']['tracking_cost'][t1:t2], label='tracking', color=colours[1])
ax_V.plot(range(max(t1,1), min(t2, stop_time)), data['sim_data']['change_cost'][max(t1,1):t2], label='change', color=colours[2])
ax_V.plot(range(t1, min(t2, stop_time)), data['sim_data']['J'][t1:t2], '--', label='J', color=colours[3])
    
ax_V.set_xlabel('time steps')
ax_V.set_title(f'Value function over time')
ax_V.grid(True)
ax_V.legend()

# Set the y-axis to logarithmic scale.
#ax_V.set_yscale('log')

plt.show()

print(f'Value function difference between the first and last time step: {data["sim_data"]["J"][-1] - data["sim_data"]["J"][0]}')
print(f'Value function at start: {data["sim_data"]["J"][0]}')
print(f'Value function at stop:  {data["sim_data"]["J"][-1]}')
print(f'Cooperation cost at stop: {data["sim_data"]["cooperative_cost"][-1]}')
print(f'Tracking cost at stop: {data["sim_data"]["tracking_cost"][-1]}')

In [None]:
"""Plot the closed-loop state evolution."""
# Plot from t1 to t2.
t1 = 0
t2 = max_sim_time+1
step = 1

# Select a feasible start time (the end time is controlled automatically).
t1 = min(t1, max_sim_time+1)

# Draw the evolution in state space:
fig_cl, ax_cl = plt.subplots(figsize=(9, 9), num='state evolution')

for i, agent in enumerate(agents):
    cl_x = np.zeros(agent.cl_x.shape)

    cl_x[0, t1 : t2+1:step] = (agent.cl_x[0, t1 : t2+1:step] + r0) * np.cos(agent.cl_x[1, t1 : t2+1:step])
    cl_x[1, t1 : t2+1:step] = (agent.cl_x[0, t1 : t2+1:step] + r0) * np.sin(agent.cl_x[1, t1 : t2+1:step])

    ax_cl.plot(cl_x[0, t1 : t2+1:step], cl_x[1,t1 : t2+1:step], color=colours[i], label=f'{agent.id}_x', linewidth=0.4)
    # Mark the initial state with a larger circle.
    ax_cl.plot(cl_x[0,t1], cl_x[1,t1], color=colours[i], marker='o', markersize=6)
    # Mark the final state with a cross.
    ax_cl.plot(cl_x[0,-1], cl_x[1,-1], color=colours[i], marker='x', markersize=6)

r_max = (data['MAS_parameters']['r_max'] + r0)
ax_cl.set_xlim(-r_max, r_max)
ax_cl.set_ylim(-r_max, r_max)
# Plot Earth
circle = patches.Circle((0, 0), 6371e3, facecolor='#4169E1', fill=True, edgecolor='#4169E1', linewidth=2) 
ax_cl.add_patch(circle)

# ax_cl.grid()
ax_cl.legend()

plt.show()

In [None]:
"""All states and inputs"""
# Define time range
t1 = 0
t2_state = max_sim_time + 1
t2_input = max_sim_time

# Number of states and inputs
num_states = 4
num_inputs = 2

# Agents to plot
agents2plot = agents[:]

# Select feasible start time
t1 = min(t1, max_sim_time + 1)

# Create a figure with multiple subplots.
fig, axes = plt.subplots(3, 2, figsize=(12, 12), num='State & Input Evolution')

## --- Plot all states ---
for idx_state in range(num_states):
    ax = axes[idx_state // 2, idx_state % 2]  # Get subplot position
    if idx_state == 1:
        title_state = f'Closed-loop state $x_{idx_state+1}$ (relative to Satellite 3 in degrees)'
    else:
        title_state = f'Closed-loop state $x_{idx_state+1}$'

    for i, agent in enumerate(agents):
        if agent not in agents2plot:
            continue
        tf = min(t2_state, agent.cl_x.shape[1] - 1)
        if idx_state == 1:
            cl_x = np.degrees(agent.cl_x[:,t1:tf+1] - agents[2].cl_x[:,t1:tf+1])
            ax.plot(range(t1, tf+1), cl_x[idx_state, t1:tf+1], 
                    color=colours[i], label=f'{agent.id}_x{idx_state+1}', linewidth=2)
            ax.plot(range(t1, tf+1), [data['cooperative_task']['theta_des']]*len(range(t1, tf+1)), color='lightgray', linewidth=1.5, linestyle='--')
            ax.plot(range(t1, tf+1), [-data['cooperative_task']['theta_des']]*len(range(t1, tf+1)), color='lightgray', linewidth=1.5, linestyle='--')
            ax.plot(range(t1, tf+1), [2*data['cooperative_task']['theta_des']]*len(range(t1, tf+1)), color='lightgray', linewidth=1.5, linestyle='--')
            ax.plot(range(t1, tf+1), [-2*data['cooperative_task']['theta_des']]*len(range(t1, tf+1)), color='lightgray', linewidth=1.5, linestyle='--')
        elif idx_state == 0:
            ax.plot(range(t1, tf+1), agent.cl_x[idx_state, t1:tf+1]/sf*1e-3, color=colours[i], label=f'{agent.id}_x{idx_state+1}', linewidth=2)
        else:
            ax.plot(range(t1, tf+1), agent.cl_x[idx_state, t1:tf+1], color=colours[i], label=f'{agent.id}_x{idx_state+1}', linewidth=2)
            
        # # Write radii and angular positions to files.
        # if idx_state == 0:
        #     trajectory_table = "t r\n"
        #     for j in range(t1, tf, step):
        #         trajectory_table += f"{j} {agent.cl_x[idx_state, j]/sf*1e-3 + r0*1e-3}\n"

        #     # Write to file.
        #     with open(f"./plotdata/sat_{agent.id}_r.tex", "w") as f:
        #         f.write(trajectory_table)
        # elif idx_state == 1:
        #     trajectory_table = "t dtheta\n"
        #     for j in range(t1, tf, step):
        #         trajectory_table += f"{j} {cl_x[idx_state, j]}\n"

        #     # Write to file.
        #     with open(f"./plotdata/sat_{agent.id}_dtheta.tex", "w") as f:
        #         f.write(trajectory_table)
                
    if np.linalg.norm(ax.get_ylim()) < 1e-8:
        ax.set_ylim(-0.1, 0.1)
        
    ax.grid()
    ax.legend()
    ax.set_title(title_state)
    ax.set_xlabel('time steps')
    
    if idx_state == 0:
        ax.set_ylabel(f'$r$ in km')
    elif idx_state == 1:
        ax.set_ylabel(f'$\\Delta \\vartheta$ in degrees')

## --- Plot all inputs ---
for idx_input in range(num_inputs):
    ax = axes[-1, idx_input]  # Get subplot position (last row)
    title_input = f'Closed-loop input $u_{idx_input+1}$'

    for i, agent in enumerate(agents):
        if agent not in agents2plot:
            continue
        tf = min(t2_input, agent.cl_u.shape[1] - 1)
        ax.plot(range(t1, tf+1), agent.cl_u[idx_input, t1:tf+1], 
                color=colours[i], label=f'{agent.id}_u{idx_input+1}', markersize=0, linewidth=2, marker='o')

    if np.linalg.norm(ax.get_ylim()) < 1e-8:
        ax.set_ylim(-0.1, 0.1)
        
    ax.grid()
    ax.legend()
    ax.set_title(title_input)
    ax.set_xlabel('time steps')
    ax.set_ylabel(f'$u_{idx_input+1}$')

# Adjust layout and show plot
plt.tight_layout()
plt.show()


In [None]:
"""Animate"""

dark_rgb_colours = [
    (0.00, 0.45, 0.70),
    (0.83, 0.37, 0.00),
    (0.00, 0.62, 0.45),
    (0.94, 0.89, 0.26),
    (0.80, 0.47, 0.65),
    (0.34, 0.71, 0.91),
    (0.90, 0.56, 0.00),
    (0.60, 0.60, 0.60),
    (0.70, 0.13, 0.13),
    (0.42, 0.24, 0.60),
    (0.07, 0.45, 0.20), 
    (0.53, 0.80, 0.93),
    (0.87, 0.80, 0.47),
]

if animate:
    # Define the figure and axis
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Define orbital scaling
    sf = data['MAS_parameters']['scaling_factor']
    r_max = data['MAS_parameters']['r_max'] + r0
    
    ax.set_xlim(-r_max, r_max)
    ax.set_ylim(-r_max, r_max)
    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')
    ax.grid()
    
    # Store the maximum time index
    t1 = 0
    t2 = max_sim_time + 1
    t2 = min(t2, agents[0].cl_x.shape[1] - 1)  # Ensure t2 is within bounds of the first agent's cl_x
    step = 1
    history_length = 10  # Number of previous steps to fade out
    
    # Plot Earth
    earth = patches.Circle((0, 0), 6371e3, facecolor='#4169E1', edgecolor='#4169E1', linewidth=2)
    ax.add_patch(earth)
    
    # Plot elements to be updated
    satellite_plots = []  # Stores the scatter objects
    trail_collections = []  # Stores the faded trail collections
    
    # Initialise the plot elements
    for i, satellite in enumerate(agents):
        # Current position marker
        satellite_plot, = ax.plot([], [], color=dark_rgb_colours[i], marker='o', markersize=6, linestyle='None')
        satellite_plots.append(satellite_plot)

        # Create a LineCollection for the fading trail with the correct colour
        trail_collection = LineCollection([], linewidth=1.5, colors=[dark_rgb_colours[i]])
        trail_collections.append(trail_collection)
        ax.add_collection(trail_collection)

    # Function to update the animation frame
    def update(frame):
        for i, satellite in enumerate(agents):
            if frame >= satellite.cl_x.shape[1]:
                satellite_plots[i].set_data([], [])  # Clear the satellite if the frame is out of bounds
                trail_collections[i].set_segments([])  # Clear the trail if the frame is out of bounds
                continue
            # Convert polar to Cartesian coordinates
            current_r = satellite.cl_x[0, frame]
            current_theta = satellite.cl_x[1, frame]
            current_x = np.array([(current_r + r0) * np.cos(current_theta)])
            current_y = np.array([(current_r + r0) * np.sin(current_theta)])

            past_start = max(0, frame - history_length)
            trail_r = satellite.cl_x[0, past_start:frame+1]
            trail_theta = satellite.cl_x[1, past_start:frame+1]
            trail_x = (trail_r + r0) * np.cos(trail_theta)
            trail_y = (trail_r + r0) * np.sin(trail_theta)

            # Update the satellite's position
            satellite_plots[i].set_data(current_x, current_y)

            # Create faded segments only if there are enough points
            if len(trail_x) > 1:
                segments = [((trail_x[j], trail_y[j]), (trail_x[j+1], trail_y[j+1])) for j in range(len(trail_x)-1)]
                
                # Generate alpha values for fading effect
                alpha_values = np.linspace(0.1, 1, len(segments))  # Fading from transparent (0.1) to opaque (1)

                # Convert satellite colour to RGBA and apply fading alpha
                faded_colors = [(dark_rgb_colours[i][0], dark_rgb_colours[i][1], dark_rgb_colours[i][2], alpha) for alpha in alpha_values]

                # Update the LineCollection with new segments and colors
                trail_collections[i].set_segments(segments)
                trail_collections[i].set_color(faded_colors)  # Set individual segment colors
            else:
                trail_collections[i].set_segments([])  # Clear trail if no valid segments

        return satellite_plots + trail_collections

    # Create animation (disable blitting for compatibility)
    interval = 80  # milliseconds between frames
    ani = animation.FuncAnimation(fig, update, frames=range(t1, t2, step), interval=interval, blit=False)
    
    ani.save("./data/satellite_orbit.mp4", writer="ffmpeg", dpi=300)