# SAC Evaluation: Drone Formation Control

This notebook evaluates the trained SAC agent on the 3D drone formation task.
It loads the best checkpoint, runs an evaluation rollout, and produces:
- numeric summary metrics,
- 3D trajectory plots,
- time–series plots,
- a heatmap of formation error,
- a GIF of the flight.


In [2]:
import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.env_util import make_vec_env

from envs.drone_formation_env import DroneFormationEnv

import pandas as pd
from IPython.display import Image, display

import plotly.graph_objects as go


# Paths (assuming notebook is in repo root)
BASE_TRAIN_DIR = "training"

MODELS_DIR = os.path.join(BASE_TRAIN_DIR, "models")
LOGS_DIR = os.path.join(BASE_TRAIN_DIR, "logs", "tensorboard")
VECNORM_PATH = os.path.join(MODELS_DIR, "vecnormalize.pkl")
BEST_MODEL_PATH = os.path.join(MODELS_DIR, "best", "best_model.zip")

EVAL_DIR = "eval_outputs"
os.makedirs(EVAL_DIR, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

KeyboardInterrupt: 

In [None]:
def unpack_obs(raw):
    p1 = raw[0:3]
    v1 = raw[3:6]
    ang1 = raw[6:9]
    omega1 = raw[9:12]
    p2 = raw[12:15]
    v2 = raw[15:18]
    ang2 = raw[18:21]
    omega2 = raw[21:24]
    t = raw[24:27]
    off = raw[27:30]
    return p1, v1, ang1, omega1, p2, v2, ang2, omega2, t, off


def drone_mesh(px, py, pz, ang, size=0.3):
    roll, pitch, yaw = ang
    cr, sr = np.cos(roll), np.sin(roll)
    cp, sp = np.cos(pitch), np.sin(pitch)
    cy, sy = np.cos(yaw), np.sin(yaw)
    R = np.array([
        [cy * cp, cy * sp * sr - sy * cr, cy * sp * cr + sy * sr],
        [sy * cp, sy * sp * sr + cy * cr, sy * sp * cr - cy * sr],
        [-sp,     cp * sr,                cp * cr]
    ])
    arm1 = np.array([[size, 0, 0], [-size, 0, 0]]).T
    arm2 = np.array([[0, size, 0], [0, -size, 0]]).T
    arm1 = (R @ arm1).T + np.array([px, py, pz])
    arm2 = (R @ arm2).T + np.array([px, py, pz])
    return arm1, arm2


In [None]:
def load_training_metrics(log_dir=LOGS_DIR):
    if not os.path.exists(log_dir):
        return None

    xs_r, ys_r = [], []
    xs_c, ys_c = [], []
    xs_e, ys_e = [], []

    for root, dirs, files in os.walk(log_dir):
        for f in files:
            if not f.endswith(".csv"):
                continue
            path = os.path.join(root, f)
            data = np.genfromtxt(path, delimiter=",", skip_header=1)
            if data.ndim == 1:  # слишком короткий файл
                continue
            step = data[:, 1]
            val = data[:, 2]

            if "rollout_ep_rew_mean" in f or "rollout/ep_rew_mean" in f:
                xs_r.append(step); ys_r.append(val)
            if "train_critic_loss" in f or "train/critic_loss" in f:
                xs_c.append(step); ys_c.append(val)
            if "train_ent_coef" in f or "train/ent_coef" in f:
                xs_e.append(step); ys_e.append(val)

    return (xs_r, ys_r), (xs_c, ys_c), (xs_e, ys_e)


In [None]:
# Create evaluation environment (no GUI)
env_raw = make_vec_env(
    DroneFormationEnv,
    n_envs=1,
    env_kwargs=dict(gui=False, episode_len=600, use_wind=True, wind_std=0.3),
)

vecnorm = VecNormalize.load(VECNORM_PATH, env_raw)
vecnorm.training = False
vecnorm.norm_reward = False

model = SAC.load(BEST_MODEL_PATH, env=vecnorm, device=device)

print("Device:", device)


In [None]:
def run_episode(env, model):
    """Runs a single evaluation episode and returns all recorded arrays."""
    obs = env.reset()

    p1s, p2s, ts = [], [], []
    v1s, v2s, a1s, a2s = [], [], [], []
    ang1s, ang2s = [], []
    rs, ds, fs, tls, ints = [], [], [], [], []

    done = False

    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, r, dones, infos = env.step(action)
        done = bool(dones[0])
        info = infos[0]

        # get original (unnormalized) observation
        raw = env.get_original_obs()[0]
        p1, v1, ang1, _, p2, v2, ang2, _, target, _ = unpack_obs(raw)

        p1s.append(p1)
        p2s.append(p2)
        ts.append(target)
        v1s.append(v1)
        v2s.append(v2)
        ang1s.append(ang1)
        ang2s.append(ang2)

        thr1 = 0.5 * (action[0, 0:4] + 1.0) * 12.0
        thr2 = 0.5 * (action[0, 4:8] + 1.0) * 12.0
        a1s.append(thr1)
        a2s.append(thr2)

        rs.append(float(r[0]))
        ds.append(info["dist_target"])
        fs.append(info["form_error"])
        tls.append(info["tilt"])
        ints.append(info["inter_drone_dist"])

    # to numpy
    p1s = np.array(p1s)
    p2s = np.array(p2s)
    ts = np.array(ts)
    v1s = np.array(v1s)
    v2s = np.array(v2s)
    ang1s = np.array(ang1s)
    ang2s = np.array(ang2s)
    a1s = np.array(a1s)
    a2s = np.array(a2s)
    rs = np.array(rs)
    ds = np.array(ds)
    fs = np.array(fs)
    tls = np.array(tls)
    ints = np.array(ints)

    return dict(
        p1s=p1s, p2s=p2s, ts=ts,
        v1s=v1s, v2s=v2s,
        ang1s=ang1s, ang2s=ang2s,
        a1s=a1s, a2s=a2s,
        rs=rs, ds=ds, fs=fs, tls=tls, ints=ints,
    )


In [None]:
episode = run_episode(vecnorm, model)

rs = episode["rs"]
ds = episode["ds"]
fs = episode["fs"]
tls = episode["tls"]
ints = episode["ints"]

report = {
    "total_reward": float(rs.sum()),
    "avg_reward": float(rs.mean()),
    "avg_dist_target": float(ds.mean()),
    "avg_form_error": float(fs.mean()),
    "avg_tilt": float(tls.mean()),
    "min_inter_drone_dist": float(ints.min()),
    "steps": int(len(rs)),
}

# JSON
with open(os.path.join(EVAL_DIR, "report.json"), "w") as f:
    json.dump(report, f, indent=4)

# table view
pd.DataFrame([report])


In [None]:
p1s = episode["p1s"]
p2s = episode["p2s"]
ts  = episode["ts"]

fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection="3d")

ax.plot(p1s[:, 0], p1s[:, 1], p1s[:, 2], label="Drone 1")
ax.plot(p2s[:, 0], p2s[:, 1], p2s[:, 2], label="Drone 2")
ax.plot(ts[:, 0],  ts[:, 1],  ts[:, 2],  "r--", label="Target")

ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.legend()

plt.tight_layout()
plt.savefig(os.path.join(EVAL_DIR, "traj.png"), dpi=300)
plt.show()


In [None]:
t = np.arange(len(rs))
v1s = episode["v1s"]
v2s = episode["v2s"]

fig, axs = plt.subplots(4, 1, figsize=(12, 9), sharex=True)

axs[0].plot(t, rs)
axs[0].set_ylabel("Reward")

axs[1].plot(t, ds, label="dist_target")
axs[1].plot(t, fs, label="form_error")
axs[1].legend()

axs[2].plot(t, np.linalg.norm(v1s, axis=1), label="|v1|")
axs[2].plot(t, np.linalg.norm(v2s, axis=1), label="|v2|")
axs[2].legend()

axs[3].plot(t, tls)
axs[3].set_ylabel("Tilt")
axs[3].set_xlabel("Step")

plt.tight_layout()
plt.savefig(os.path.join(EVAL_DIR, "time.png"), dpi=300)
plt.show()


In [None]:
fig = plt.figure(figsize=(10, 2))
plt.imshow(fs.reshape(1, -1), aspect="auto", cmap="magma")
plt.colorbar(label="Form error")
plt.yticks([])
plt.xlabel("Step")

plt.tight_layout()
plt.savefig(os.path.join(EVAL_DIR, "hm_form.png"), dpi=300)
plt.show()


In [None]:
p1s = episode["p1s"]
p2s = episode["p2s"]
ang1s = episode["ang1s"]
ang2s = episode["ang2s"]

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d")

line1x, = ax.plot([], [], [], "b-", linewidth=3)
line1y, = ax.plot([], [], [], "b-", linewidth=3)
line2x, = ax.plot([], [], [], "g-", linewidth=3)
line2y, = ax.plot([], [], [], "g-", linewidth=3)

xmin = min(p1s[:, 0].min(), p2s[:, 0].min()) - 1
xmax = max(p1s[:, 0].max(), p2s[:, 0].max()) + 1
ymin = min(p1s[:, 1].min(), p2s[:, 1].min()) - 1
ymax = max(p1s[:, 1].max(), p2s[:, 1].max()) + 1
zmin = 0
zmax = max(p1s[:, 2].max(), p2s[:, 2].max(), 0) + 1

ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
ax.set_zlim(zmin, zmax)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")

def upd(i):
    arm1a, arm1b = drone_mesh(p1s[i, 0], p1s[i, 1], p1s[i, 2], ang1s[i])
    arm2a, arm2b = drone_mesh(p2s[i, 0], p2s[i, 1], p2s[i, 2], ang2s[i])

    line1x.set_data([arm1a[0, 0], arm1a[1, 0]], [arm1a[0, 1], arm1a[1, 1]])
    line1x.set_3d_properties([arm1a[0, 2], arm1a[1, 2]])
    line1y.set_data([arm1b[0, 0], arm1b[1, 0]], [arm1b[0, 1], arm1b[1, 1]])
    line1y.set_3d_properties([arm1b[0, 2], arm1b[1, 2]])

    line2x.set_data([arm2a[0, 0], arm2a[1, 0]], [arm2a[0, 1], arm2a[1, 1]])
    line2x.set_3d_properties([arm2a[0, 2], arm2a[1, 2]])
    line2y.set_data([arm2b[0, 0], arm2b[1, 0]], [arm2b[0, 1], arm2b[1, 1]])
    line2y.set_3d_properties([arm2b[0, 2], arm2b[1, 2]])

    return line1x, line1y, line2x, line2y

ani = animation.FuncAnimation(fig, upd, frames=len(p1s), interval=50, blit=False)
gif_path = os.path.join(EVAL_DIR, "traj.gif")
ani.save(gif_path, writer="pillow", fps=15)
plt.close(fig)

display(Image(filename=gif_path))


In [None]:
# Interactive 3D trajectory with Plotly

p1s = episode["p1s"]
p2s = episode["p2s"]
ts  = episode["ts"]

fig = go.Figure()

fig.add_trace(go.Scatter3d(
    x=p1s[:, 0],
    y=p1s[:, 1],
    z=p1s[:, 2],
    mode="lines",
    name="Drone 1",
))

fig.add_trace(go.Scatter3d(
    x=p2s[:, 0],
    y=p2s[:, 1],
    z=p2s[:, 2],
    mode="lines",
    name="Drone 2",
))

fig.add_trace(go.Scatter3d(
    x=ts[:, 0],
    y=ts[:, 1],
    z=ts[:, 2],
    mode="lines",
    name="Target",
    line=dict(dash="dash")
))

fig.update_layout(
    scene=dict(
        xaxis_title="X",
        yaxis_title="Y",
        zaxis_title="Z",
        aspectmode="data",
    ),
    title="Drone Formation – 3D Trajectory (Interactive)",
)


fig.show()


html_path = os.path.join(EVAL_DIR, "traj_interactive.html")
fig.write_html(html_path)
print("Saved interactive plot to:", html_path)
