In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import pickle
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
from matplotlib import animation
from matplotlib.collections import LineCollection
import matplotlib.patches as mpatches
import time
import os

In [None]:
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_dtype(torch.float32)
print("Running on:", device)

In [None]:
def wind_process(T, theta, mu, wind_sigma, n, tau):
    num_to_sim = int(T / tau)
    winds = torch.zeros(n, num_to_sim + 1, device = device)
    winds[:, 0] = 0.5 * torch.rand(n , device = device) - 0.25 
    for step in range(1, num_to_sim + 1):
        dW = torch.randn(n , device = device)
        winds[:, step] = winds[:, step - 1] +(mu - winds[:, step - 1]) * theta * tau + wind_sigma * torch.sqrt(tau) * dW
    # only return winds for the integer time-steps
    final_wind = winds[:, 1:][:, ::int(1 / tau)]
    # add initial wind to the front
    winds = torch.cat((winds[:, 0].view(n, 1), final_wind), dim = 1)
    return winds

In [None]:
def running_cost(x, y, A, M):
    return( 1 - 1/(1 + torch.exp(A * (1 - x**2 - y**2)))) * M

def terminal_cost(x, y, L_x):
    return torch.norm(x - L_x, dim = 1, keepdim = True)**2 + torch.norm(y, dim = 1, keepdim = True)**2

def regularising(th1, th2, p, r):
    full_th = torch.cat((th1, th2))
    nrm = torch.norm(full_th)
    return nrm**p + 1e-9 * torch.exp(nrm)

def regularising_t(theta1, theta2, p, beta):
    theta_vmap = torch.vmap(regularising, in_dims=(0, 1, None, None))
    ells_by_j = theta_vmap(theta1, theta2, p, theta1.shape[0])
    return torch.sum(ells_by_j) / (2 * beta**2)

def gen_ref_path(reference_control, p0, winds, vs, T, n):
    ref_path = torch.zeros(n, T+1, 2, device = device) # use a 3D tensor to store path information
    ref_path[:, 0, :] = p0 
    for t in range(T): 
        heading = torch.cat([torch.cos(reference_control), torch.sin(reference_control)], dim = 1)
        wind_vec = torch.cat([torch.zeros(n, device = device).view(n, 1), winds[:, t].view(n, 1)], dim = 1)
        ref_path[:, t+1, :] = ref_path[:, t] + vs * heading + wind_vec
    return ref_path

In [None]:
class NeuralNet_erm(nn.Module):
    def __init__(self, input_dim, width, output_dim):
        super(NeuralNet_erm, self).__init__()
        self.hidden_layer = nn.Linear(input_dim, width)
        self.hidden_layer_2 = nn.Linear(width, width)
        self.hidden_layer_3 = nn.Linear(width, width)
        self.hidden_layer_4 = nn.Linear(width, width)
        self.sigmoid = nn.Tanh()
        self.output_layer = nn.Linear(width, output_dim)
        
    def forward(self, x):
        activations_1 = self.sigmoid(self.hidden_layer(x))
        activations_2 = self.sigmoid(self.hidden_layer_2(activations_1))
        activations_3 = self.sigmoid(self.hidden_layer_3(activations_2))
        activations_4 = self.sigmoid(self.hidden_layer_4(activations_3))
        unscaled = self.output_layer(activations_4)
        return unscaled

class NeuralNet_entropy(nn.Module):
    def __init__(self, input_dim, width, output_dim):
        super(NeuralNet_entropy, self).__init__()
        self.hidden_layer = nn.Linear(input_dim, width)
        self.hidden_layer.bias.data.zero_()
        self.hidden_layer.bias.requires_grad = False
        self.sigmoid = nn.Tanh()
        self.output_layer = nn.Linear(width, output_dim)
        self.output_layer.bias.data.zero_()
        self.output_layer.bias.requires_grad = False
        self.width = width
    def forward(self, x):
        activations = self.sigmoid(self.hidden_layer(x))
        unscaled = self.output_layer(activations)
        return unscaled/self.width

In [None]:
with open("./erm_vs_entropy/2025-09-25_sim1_of_3.pt", "rb") as f:
    data = torch.load(f, map_location = device)

In [None]:
data["erm"]

In [None]:
data["entropy_regularised"]

In [None]:
# collect hyperparameters
n = data["erm"]["n"]
T = data["erm"]["T"]
tau = data["erm"]["tau"]
L_x = data["erm"]["L_x"]
theta = data["erm"]["theta"]
mu = data["erm"]["mu"]
wind_sigma = data["erm"]["wind_sigma"]
A = data["erm"]["A"]
M = data["erm"]["M"]
erm_width = data["erm"]["network_width"]
entropy_width = data["entropy_regularised"]["network_width"]
beta = data["entropy_regularised"]["beta"]
sigma = data["entropy_regularised"]["sigma"]
vs = data["erm"]["vs"]
training_winds = data["erm"]["training_data"]

In [None]:
# generate reference trajectories
ref_ctrl = torch.zeros(n, 1, device = device)
initial_points = torch.zeros(n, 2, device = device) - torch.tensor([20, 0], device = device)
ref_path = gen_ref_path(ref_ctrl, initial_points, training_winds, vs, T, n)

for r in range(n):
    plt.plot(ref_path[r, :, 0].cpu(), ref_path[r, :, 1].cpu(), color = "gray", linestyle = "--")

In [None]:
# generate the realised paths
# erm 
# data["erm"]["erm_model_parameters"][0].keys()
erm_params = [params for params in data["erm"]["erm_model_parameters"]]
erm_models = [NeuralNet_erm(3, erm_width, 1).to(device) for _ in range(T)]
for model, state in zip(erm_models, erm_params):
    model.load_state_dict(state)
    model.to(device)
# entropy-regularised
entropy_params = [params for params in data["entropy_regularised"]["entropy_model_parameters"]]
reg_models = [NeuralNet_entropy(4, entropy_width, 1).to(device) for _ in range(T)]
for model, state in zip(reg_models, entropy_params):
    model.load_state_dict(state)
    model.to(device)

train_paths_erm = [initial_points]
train_paths_reg = [initial_points]

# first for erm
for t in range(T):
    angle = erm_models[t](torch.cat([train_paths_erm[-1]/torch.tensor([20, 10], device = device),
                                     training_winds[:, t].view(n, 1)], dim = 1))
    heading = torch.cat([torch.cos(angle), torch.sin(angle)], dim = 1)
    wind_vec = torch.cat([torch.zeros(n, 1, device = device), training_winds[:, t].view(n, 1)], dim = 1)
    new_p = train_paths_erm[-1] + vs* heading + wind_vec
    train_paths_erm.append(new_p)
train_paths_erm = torch.stack(train_paths_erm, dim = 1)

# then for entropy-regularised
for t in range(T):
    angle = reg_models[t](torch.cat([train_paths_reg[-1]/torch.tensor([20, 10], device = device),
                                     training_winds[:, t].view(n, 1),
                                     torch.ones(n, 1, device = device)], dim = 1))
    heading = torch.cat([torch.cos(angle), torch.sin(angle)], dim = 1)
    wind_vec = torch.cat([torch.zeros(n, 1, device = device), training_winds[:, t].view(n ,1)], dim = 1)
    new_p = train_paths_reg[-1] + vs * heading + wind_vec
    train_paths_reg.append(new_p)
train_paths_reg = torch.stack(train_paths_reg, dim = 1)

for pth in range(n):
    plt.plot(train_paths_erm[pth, :, 0].detach().cpu(), train_paths_erm[pth, :, 1].detach().cpu(), color = "blue")
    plt.plot(train_paths_reg[pth, :, 0].detach().cpu(), train_paths_reg[pth, :, 1].detach().cpu(), color = "red")

In [None]:
# generate test winds
test_size = 1000
test_winds = wind_process(T, theta, mu, wind_sigma, test_size, tau)

# generate test trajectories
test_paths_erm = [torch.zeros(test_size, 2, device = device) - torch.tensor([20, 0], device = device)]
test_paths_reg = [torch.zeros(test_size, 2, device = device) - torch.tensor([20, 0], device = device)]

# erm
for t in range(T):
    angle = erm_models[t](torch.cat([test_paths_erm[-1]/torch.tensor([20, 10], device = device),
                                     test_winds[:, t].view(test_size, 1)], dim = 1))
    heading = torch.cat([torch.cos(angle), torch.sin(angle)], dim = 1)
    wind_vec = torch.cat([torch.zeros(test_size, 1, device = device), test_winds[:, t].view(test_size, 1)], dim =1)
    new_p = test_paths_erm[-1] + vs * heading + wind_vec
    test_paths_erm.append(new_p)
test_paths_erm = torch.stack(test_paths_erm, dim = 1)

# entropy-regularised
for t in range(T):
    angle = reg_models[t](torch.cat([test_paths_reg[-1]/torch.tensor([20, 10], device = device),
                                     test_winds[:, t].view(test_size, 1),
                                     torch.ones(test_size, 1)], dim = 1))
    heading = torch.cat([torch.cos(angle), torch.sin(angle)], dim = 1)
    wind_vec = torch.cat([torch.zeros(test_size, 1, device = device), test_winds[:, t].view(test_size, 1)], dim =1)
    new_p = test_paths_reg[-1] + vs * heading + wind_vec
    test_paths_reg.append(new_p)
test_paths_reg = torch.stack(test_paths_reg, dim = 1)

for pth in range(test_size):
    plt.plot(test_paths_erm[pth, :, 0].detach().cpu(), test_paths_erm[pth, :, 1].detach().cpu(), color = "blue", linewidth = 0.5)
    plt.plot(test_paths_reg[pth, :, 0].detach().cpu(), test_paths_reg[pth, :, 1].detach().cpu(), color = "red", linewidth = 0.5)

In [None]:
# plotting parameters
plt.rcParams.update({
    "axes.titlesize": 30,
    "axes.labelsize": 30,
    "xtick.labelsize": 25,
    "ytick.labelsize": 25,
    "legend.fontsize": 25,
    "font.size": 25
})

In [None]:
# in-sample of both methods
red = mpatches.Patch(color="red", label="ERM")
green = mpatches.Patch(color="green", label="Entropy-Regularised")
fig, ax = plt.subplots(figsize=(19.2, 10.8), dpi=400,
                       constrained_layout = True)
grid_res = 300  # Resolution of the grid
x_vals = torch.linspace(-6, 6, grid_res)
y_vals = torch.linspace(-6, 6, grid_res)
X, Y = torch.meshgrid(x_vals, y_vals, indexing='xy')
Z = running_cost(X, Y, A=2, M=10)*100
Z[Z < 1.5] = float('nan')
ax.contourf(X.numpy(), Y.numpy(), Z.numpy(), levels=10, cmap='Greys', alpha=0.8)

for pth in range(n):
    plt.plot(train_paths_erm[pth, :, 0].detach().cpu(), train_paths_erm[pth, :, 1].detach().cpu(),
             color = "red", linewidth = 2)
    plt.plot(train_paths_reg[pth, :, 0].detach().cpu(), train_paths_reg[pth, :, 1].detach().cpu(),
             color = "green", linewidth = 2)

ax.set_title(f"In-Sample Trajectories, n = {n}")
targ = ax.scatter(x=20, y=0, label="Target", color="black", zorder=1000)
ax.legend(handles=[red, green, targ])
plt.show()
plt.close()

In [None]:
# compare out-of-sample trajectories
red = mpatches.Patch(color="red", label="ERM")
green = mpatches.Patch(color="green", label="Entropy-Regularised")
fig, ax = plt.subplots(figsize=(19.2, 10.8), dpi=400,
                       constrained_layout = True)
grid_res = 300  # Resolution of the grid
x_vals = torch.linspace(-6, 6, grid_res)
y_vals = torch.linspace(-6, 6, grid_res)
X, Y = torch.meshgrid(x_vals, y_vals, indexing='xy')
Z = running_cost(X, Y, A=2, M=10)*100
Z[Z < 1.5] = float('nan')
ax.contourf(X.numpy(), Y.numpy(), Z.numpy(), levels=10, cmap='Greys', alpha=0.8)

for pth in range(test_size):
    plt.plot(test_paths_erm[pth, :, 0].detach().cpu(), test_paths_erm[pth, :, 1].detach().cpu(),
             color = "red", linewidth = 0.8)
    plt.plot(test_paths_reg[pth, :, 0].detach().cpu(), test_paths_reg[pth, :, 1].detach().cpu(),
             color = "green", linewidth = 0.8)

ax.set_title(f"Out-of-Sample Trajectories, test size of {test_size}")
targ = ax.scatter(x=20, y=0, label="Target", color="black", zorder=1000)
ax.legend(handles=[red, green, targ])
plt.show()
plt.close()

In [None]:
# erm in-sample vs. out-of-sample
red = mpatches.Patch(color="red", label="Out-of-Sample")
green = mpatches.Patch(color="green", label="In-Sample")
fig, ax = plt.subplots(figsize=(19.2, 10.8), dpi=400,
                       constrained_layout = True)
grid_res = 300  # Resolution of the grid
x_vals = torch.linspace(-6, 6, grid_res)
y_vals = torch.linspace(-6, 6, grid_res)
X, Y = torch.meshgrid(x_vals, y_vals, indexing='xy')
Z = running_cost(X, Y, A=2, M=10)*100
Z[Z < 1.5] = float('nan')
ax.contourf(X.numpy(), Y.numpy(), Z.numpy(), levels=10, cmap='Greys', alpha=0.8)

for pth in range(test_size):
    plt.plot(test_paths_erm[pth, :, 0].detach().cpu(), test_paths_erm[pth, :, 1].detach().cpu(),
             color = "red", linewidth = 0.8)
for pth in range(n):
    plt.plot(train_paths_erm[pth, :, 0].detach().cpu(), train_paths_erm[pth, :, 1].detach().cpu(),
             color = "green", linewidth = 2)

ax.set_title(f"ERM Trajectories, n = {n}, test size of {test_size}")
targ = ax.scatter(x=20, y=0, label="Target", color="black", zorder=1000)
ax.legend(handles=[red, green, targ])
plt.show()
plt.close()

In [None]:
# entropy-regularised in-sample vs. out-of-sample
red = mpatches.Patch(color="red", label="Out-of-Sample")
green = mpatches.Patch(color="green", label="In-Sample")
fig, ax = plt.subplots(figsize=(19.2, 10.8), dpi=400,
                       constrained_layout = True)
grid_res = 300  # Resolution of the grid
x_vals = torch.linspace(-6, 6, grid_res)
y_vals = torch.linspace(-6, 6, grid_res)
X, Y = torch.meshgrid(x_vals, y_vals, indexing='xy')
Z = running_cost(X, Y, A=2, M=10)*100
Z[Z < 1.5] = float('nan')
ax.contourf(X.numpy(), Y.numpy(), Z.numpy(), levels=10, cmap='Greys', alpha=0.8)

for pth in range(test_size):
    plt.plot(test_paths_reg[pth, :, 0].detach().cpu(), test_paths_reg[pth, :, 1].detach().cpu(),
             color = "red", linewidth = 0.8)
for pth in range(n):
    plt.plot(train_paths_reg[pth, :, 0].detach().cpu(), train_paths_reg[pth, :, 1].detach().cpu(),
             color = "green", linewidth = 2)

ax.set_title(f"Regularised Trajectories, n = {n}, test size of {test_size}, regularisation strength = {sigma**2/(2*beta**2):.3f}")
targ = ax.scatter(x=20, y=0, label="Target", color="black", zorder=1000)
ax.legend(handles=[red, green, targ])
plt.show()
plt.close()

In [None]:
training_winds.shape

In [None]:
# generate the animations
# erm

ref_control = torch.zeros(n, 1, device=device)
ref_path = [initial_points.to(device)] 

for pos in range(T):
    wind_y = training_winds[:, pos].view(n, 1).to(device)
    wind_vector = torch.cat([torch.zeros_like(wind_y), wind_y], dim=1)
    heading = torch.cat([torch.cos(ref_control), torch.sin(ref_control)], dim=1)
    velocity = vs * heading +  wind_vector
    new_p = ref_path[-1] + velocity
    ref_path.append(new_p)

# Backward rollout
anim_paths = {}
wind_values = {}

anim_paths[f"{T}"] = ref_path
for t in range(T - 1, -1, -1): # T- 1, T - 2, ..., 0
    curr_path = ref_path[:t + 1]
    # curr_wind = [training_data[:, :t + 1]]
    curr_wind = [training_winds[:, i].view(n, 1) for i in range(t + 1)]

    curr_p = curr_path[-1]
    for rem in range(0, T - t):
        wind_val = training_winds[:, t + rem].view(n, 1).to(device)
        wind_vector = torch.cat([torch.zeros_like(wind_val), wind_val], dim=1)
        input_tensor = torch.cat((curr_p/torch.tensor([20, 10], device = device), wind_val), dim=1)

        curr_control = erm_models[t + rem](input_tensor)
        heading = torch.cat([torch.cos(curr_control), torch.sin(curr_control)], dim=1)
        velocity = vs * heading + wind_vector
        curr_p = curr_p + velocity

        curr_path.append(curr_p)
        curr_wind.append(wind_val)

    anim_paths[f"{t}"] = curr_path
    wind_values[f"{t}"] = curr_wind

# Prepare data for animation
plt.rcParams['animation.embed_limit'] = 100  # in MB
frame_data = []
for key in sorted(anim_paths.keys(), key=lambda k: int(k)):
    step_list = anim_paths[key]
    wind_list = wind_values.get(key, [])
    positions = torch.stack(step_list, dim=0).detach().cpu()  # (steps, N, 2)
    wind_vals = torch.stack(wind_list, dim=0).squeeze(-1).detach().cpu() if wind_list else None  # (steps, N)
    frame_data.append((int(key), positions, wind_vals))
frame_data.reverse()

# Setup plot
L_x, L_y = 20, 6
num_paths_to_show = n
interval = 100


plt.rcParams.update({
    "axes.titlesize": 30,
    "axes.labelsize": 30,
    "xtick.labelsize": 25,
    "ytick.labelsize": 25,
    "legend.fontsize": 25,
    "font.size": 25
})
fig, ax = plt.subplots(figsize=(19.2, 10.8), constrained_layout=True)
ax.set_xlim(-L_x, L_x)
ax.set_ylim(-L_y, L_y)
ax.plot(20, 0, 'ko', markersize=20, zorder=1000)

# Obstacle contour
grid_res = 300
x_vals = torch.linspace(-6, 6, grid_res)
y_vals = torch.linspace(-6, 6, grid_res)
X, Y = torch.meshgrid(x_vals, y_vals, indexing='xy')
Z = running_cost(X, Y, A=2, M=10) * 100
Z[Z < 1.5] = float('nan')
ax.contourf(X.numpy(), Y.numpy(), Z.numpy(), levels=10, cmap='Greys', alpha=0.8)

# Plot elements
grey_lines = [ax.plot([], [], linestyle='--', color='grey', lw=1)[0] for _ in range(num_paths_to_show)]
collections = []
# Plotting parameters

norm = plt.Normalize(training_winds.min().item(), training_winds.max().item())
cmap = plt.cm.plasma
sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label("Wind Value")

# --- Animation functions ---
def init():
    for line in grey_lines:
        line.set_data([], [])
    for coll in collections:
        coll.remove()
    collections.clear()
    return grey_lines

def update(frame_idx):
    for coll in collections:
        coll.remove()
    collections.clear()

    epoch_idx, data, winds = frame_data[frame_idx]
    ax.set_title(f"ERM, Backward Inductive Step = {epoch_idx}")

    for i in range(num_paths_to_show):
        if i >= data.shape[1]:
            continue
        x = data[:, i, 0].numpy()
        y = data[:, i, 1].numpy()

        # Grey past
        grey_lines[i].set_data(x[:epoch_idx + 1], y[:epoch_idx + 1])

        # Color future
        if winds is not None and epoch_idx < data.shape[0] - 1:
            segments = np.stack([x[epoch_idx:], y[epoch_idx:]], axis=-1)
            points = segments[:-1]
            segs = np.stack([points, segments[1:]], axis=1)
            colors = cmap(norm(winds[epoch_idx + 1:, i].numpy()))

            lc = LineCollection(segs, colors=colors, linewidths=2)
            ax.add_collection(lc)
            collections.append(lc)

    return grey_lines + collections

ani = animation.FuncAnimation(
    fig, update, frames=len(frame_data),
    init_func=init, blit=False, interval=interval
)

HTML(ani.to_jshtml())

In [None]:
# entropy-regularised
ref_control = torch.zeros(n, 1, device=device)
ref_path = [initial_points.to(device)] 

for pos in range(T):
    wind_y = training_winds[:, pos].view(n, 1).to(device)
    wind_vector = torch.cat([torch.zeros_like(wind_y), wind_y], dim=1)
    heading = torch.cat([torch.cos(ref_control), torch.sin(ref_control)], dim=1)
    velocity = vs * heading +  wind_vector
    new_p = ref_path[-1] + velocity
    ref_path.append(new_p)

# Backward rollout
anim_paths = {}
wind_values = {}

anim_paths[f"{T}"] = ref_path
for t in range(T - 1, -1, -1):
    curr_path = ref_path[:t + 1]
    # curr_wind = [training_data[:, :t + 1]]
    curr_wind = [training_winds[:, i].view(n, 1) for i in range(t + 1)]

    curr_p = curr_path[-1]
    for rem in range(0, T - t):
        wind_val = training_winds[:, t + rem].view(n, 1).to(device)
        wind_vector = torch.cat([torch.zeros_like(wind_val), wind_val], dim=1)
        input_tensor = torch.cat((curr_p/torch.tensor([20, 10], device = device), 
                                  wind_val, torch.ones(n, 1, device = device)), dim=1)

        curr_control = reg_models[t + rem](input_tensor)
        heading = torch.cat([torch.cos(curr_control), torch.sin(curr_control)], dim=1)
        velocity = vs * heading + wind_vector
        curr_p = curr_p + velocity

        curr_path.append(curr_p)
        curr_wind.append(wind_val)

    anim_paths[f"{t}"] = curr_path
    wind_values[f"{t}"] = curr_wind

# Prepare data for animation
plt.rcParams['animation.embed_limit'] = 100  # in MB
frame_data = []
for key in sorted(anim_paths.keys(), key=lambda k: int(k)):
    step_list = anim_paths[key]
    wind_list = wind_values.get(key, [])
    positions = torch.stack(step_list, dim=0).detach().cpu()  # (steps, N, 2)
    wind_vals = torch.stack(wind_list, dim=0).squeeze(-1).detach().cpu() if wind_list else None  # (steps, N)
    frame_data.append((int(key), positions, wind_vals))
frame_data.reverse()

# Setup plot
L_x, L_y = 20, 6
num_paths_to_show = n
interval = 100


plt.rcParams.update({
    "axes.titlesize": 30,
    "axes.labelsize": 30,
    "xtick.labelsize": 25,
    "ytick.labelsize": 25,
    "legend.fontsize": 25,
    "font.size": 25
})
fig, ax = plt.subplots(figsize=(19.2, 10.8), constrained_layout=True)
ax.set_xlim(-L_x, L_x)
ax.set_ylim(-L_y, L_y)
ax.plot(20, 0, 'ko', markersize=20, zorder=1000)

# Obstacle contour
grid_res = 300
x_vals = torch.linspace(-6, 6, grid_res)
y_vals = torch.linspace(-6, 6, grid_res)
X, Y = torch.meshgrid(x_vals, y_vals, indexing='xy')
Z = running_cost(X, Y, A=2, M=10) * 100
Z[Z < 1.5] = float('nan')
ax.contourf(X.numpy(), Y.numpy(), Z.numpy(), levels=10, cmap='Greys', alpha=0.8)

# Plot elements
grey_lines = [ax.plot([], [], linestyle='--', color='grey', lw=1)[0] for _ in range(num_paths_to_show)]
collections = []
# Plotting parameters

norm = plt.Normalize(training_winds.min().item(), training_winds.max().item())
cmap = plt.cm.plasma
sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label("Wind Value")

# --- Animation functions ---
def init():
    for line in grey_lines:
        line.set_data([], [])
    for coll in collections:
        coll.remove()
    collections.clear()
    return grey_lines

def update(frame_idx):
    for coll in collections:
        coll.remove()
    collections.clear()

    epoch_idx, data, winds = frame_data[frame_idx]
    ax.set_title(f"Regularised, Backward Inductive Step = {epoch_idx}")

    for i in range(num_paths_to_show):
        if i >= data.shape[1]:
            continue
        x = data[:, i, 0].numpy()
        y = data[:, i, 1].numpy()

        # Grey past
        grey_lines[i].set_data(x[:epoch_idx + 1], y[:epoch_idx + 1])

        # Color future
        if winds is not None and epoch_idx < data.shape[0] - 1:
            segments = np.stack([x[epoch_idx:], y[epoch_idx:]], axis=-1)
            points = segments[:-1]
            segs = np.stack([points, segments[1:]], axis=1)
            colors = cmap(norm(winds[epoch_idx + 1:, i].numpy()))

            lc = LineCollection(segs, colors=colors, linewidths=2)
            ax.add_collection(lc)
            collections.append(lc)

    return grey_lines + collections

ani = animation.FuncAnimation(
    fig, update, frames=len(frame_data),
    init_func=init, blit=False, interval=interval
)

HTML(ani.to_jshtml())

In [None]:
terminal_costs_erm_train = terminal_cost(train_paths_erm[:, -1, 0].view(n ,1),
                                         train_paths_erm[:, -1, 1].view(n, 1), L_x)
terminal_costs_erm_test = terminal_cost(test_paths_erm[:, -1, 0].view(test_size ,1),
                                         test_paths_erm[:, -1, 1].view(test_size, 1), L_x)
terminal_costs_reg_train = terminal_cost(train_paths_reg[:, -1, 0].view(n ,1),
                                         train_paths_reg[:, -1, 1].view(n, 1), L_x)
terminal_costs_reg_test = terminal_cost(test_paths_reg[:, -1, 0].view(test_size ,1),
                                        test_paths_reg[:, -1, 1].view(test_size, 1), L_x)

# convert to numpy, then take logs
terminal_costs_erm_train = np.log(np.array(terminal_costs_erm_train[:,0].detach().cpu()))
terminal_costs_erm_test = np.log(np.array(terminal_costs_erm_test[:,0].detach().cpu()))
terminal_costs_reg_train = np.log(np.array(terminal_costs_reg_train[:,0].detach().cpu()))
terminal_costs_reg_test = np.log(np.array(terminal_costs_reg_test[:,0].detach().cpu()))

In [None]:
fig, ax = plt.subplots(figsize = (19.2, 10.8), constrained_layout = True)
ax.hist(terminal_costs_reg_test, bins = 20, density = True, alpha = 0.6, color = "blue", label = "Out-of-Sample")
ax.set_xlim(-5, 5)
ax.set_ylim(0, 0.5)
ax.hist(terminal_costs_reg_train, bins = 10, density = True, alpha = 0.6, color = "orange", label = "In-Sample")
tick_location_1 = np.mean(terminal_costs_reg_test)
ax.axvline(x=tick_location_1, color='blue', linewidth=2, linestyle = "-")
tick_location_2 = np.mean(terminal_costs_reg_train)
ax.axvline(x=tick_location_2, color='orange', linewidth=2)
ax.set_xlabel("Log of Squared Distance")
ax.set_title(f"Regularised Terminal Costs, n = {n}, Test Size = {test_size}, Regularisation Strength = {sigma**2/(2*beta**2):.3f}")
ax.legend()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize = (19.2, 10.8), constrained_layout = True)
ax.hist(terminal_costs_erm_test, bins = 20, density = True, alpha = 0.6, color = "blue", label = "Out-of-Sample")
ax.set_xlim(-5, 5)
ax.set_ylim(0, 0.5)
ax.hist(terminal_costs_erm_train, bins = 10, density = True, alpha = 0.6, color = "orange", label = "In-Sample")
tick_location_1 = np.mean(terminal_costs_erm_test)
ax.axvline(x=tick_location_1, color='blue', linewidth=2, linestyle = "-")
tick_location_2 = np.mean(terminal_costs_erm_train)
ax.axvline(x=tick_location_2, color='orange', linewidth=2)
ax.set_xlabel("Log of Squared Distance")
ax.set_title(f"ERM Terminal Costs, n = {n}, Test Size = {test_size}")
ax.legend()
plt.show()

In [None]:
# plt.hist(torch.sum(running_cost(test_paths_erm[:, :, 0], test_paths_erm[:, :, 1], A, M), dim = 1).detach().cpu().numpy() + terminal_cost(test_paths_erm[:, -1, 0].view(test_size, 1),
#                                              test_paths_erm[:, -1, 1].view(test_size, 1), L_x).detach().cpu().numpy()[:, 0])

plt.hist(np.log(torch.sum(running_cost(test_paths_erm[:, :, 0], test_paths_erm[:, :, 1], A, M), dim = 1).detach().cpu().numpy() + terminal_cost(test_paths_erm[:, -1, 0].view(test_size, 1),
                                             test_paths_erm[:, -1, 1].view(test_size, 1), L_x).detach().cpu().numpy()[:, 0]))

In [None]:
# let's do the same, but with the obstacle cost included too
full_cost_train_erm = np.log(torch.sum(running_cost(train_paths_erm[:, :, 0], train_paths_erm[:, :, 1], A, M), dim = 1).detach().cpu().numpy()
                             + terminal_cost(train_paths_erm[:, -1, 0].view(n, 1),
                                             train_paths_erm[:, -1, 1].view(n, 1), L_x).detach().cpu().numpy()[:, 0])
full_cost_test_erm = np.log(torch.sum(running_cost(test_paths_erm[:, :, 0], test_paths_erm[:, :, 1], A, M), dim = 1).detach().cpu().numpy()
                             + terminal_cost(test_paths_erm[:, -1, 0].view(test_size, 1),
                                             test_paths_erm[:, -1, 1].view(test_size, 1), L_x).detach().cpu().numpy()[:, 0])
full_cost_train_reg = np.log(torch.sum(running_cost(train_paths_reg[:, :, 0], train_paths_reg[:, :, 1], A, M), dim = 1).detach().cpu().numpy()
                             + terminal_cost(train_paths_reg[:, -1, 0].view(n, 1),
                                             train_paths_reg[:, -1, 1].view(n, 1), L_x).detach().cpu().numpy()[:, 0])
full_cost_test_reg = np.log(torch.sum(running_cost(test_paths_reg[:, :, 0], test_paths_reg[:, :, 1], A, M), dim = 1).detach().cpu().numpy()
                             + terminal_cost(test_paths_reg[:, -1, 0].view(test_size, 1),
                                             test_paths_reg[:, -1, 1].view(test_size, 1), L_x).detach().cpu().numpy()[:, 0])


In [None]:
fig, ax = plt.subplots(figsize = (19.2, 10.8), constrained_layout = True)
ax.hist(full_cost_test_erm, bins = 20, density = True, alpha = 0.6, color = "blue", label = "Out-of-Sample")
ax.set_xlim(-5, 5)
ax.set_ylim(0, 0.5)
ax.hist(full_cost_train_erm, bins = 10, density = True, alpha = 0.6, color = "orange", label = "In-Sample")
tick_location_1 = np.mean(full_cost_test_erm)
ax.axvline(x=tick_location_1, color='blue', linewidth=2, linestyle = "-")
tick_location_2 = np.mean(full_cost_train_erm)
ax.axvline(x=tick_location_2, color='orange', linewidth=2)
ax.set_xlabel("Log of Total Cost")
ax.set_title(f"ERM Costs, n = {n}, Test Size = {test_size}")
ax.legend()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize = (19.2, 10.8), constrained_layout = True)
ax.hist(full_cost_test_reg, bins = 20, density = True, alpha = 0.6, color = "blue", label = "Out-of-Sample")
ax.set_xlim(-5, 5)
ax.set_ylim(0, 0.5)
ax.hist(full_cost_train_reg, bins = 10, density = True, alpha = 0.6, color = "orange", label = "In-Sample")
tick_location_1 = np.mean(full_cost_test_reg)
ax.axvline(x=tick_location_1, color='blue', linewidth=2, linestyle = "-")
tick_location_2 = np.mean(full_cost_train_reg)
ax.axvline(x=tick_location_2, color='orange', linewidth=2)
ax.set_xlabel("Log of Total Cost")
ax.set_title(f"Entropy-Regularised Costs, n = {n}, Test Size = {test_size}, Regularisation Strength = {sigma**2/(2*beta**2):.3f}")
ax.legend()
plt.show()