In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import torch
from feed_forward_net import FeedForwardNet
import crocoddyl
import example_robot_data
from crocoddyl.utils.pendulum import CostModelDoublePendulum, ActuationModelDoublePendulum
from datagen import Datagen
from ddp_solver import solve_problem
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
data = Datagen.grid_data(size=30,limits=[-1., 1.] )

In [None]:
data.shape

In [5]:
positions, values = Datagen.values(init_conditions=data, horizon=200)

KeyboardInterrupt: 

In [None]:
net = torch.load('net.pth')

In [None]:
position_tensor = torch.Tensor(positions)

val_p = net(position_tensor).detach().numpy()


In [None]:
val_t = torch.tensor(values)

In [None]:
error = val_p - values
print(f" Mean Error: {np.mean(error)}")


In [None]:
pos = position_tensor.detach().numpy()

In [None]:
plt.clf()

font = {'family': 'serif',
        'color':  'darkred',
        'weight': 'normal',
        'size': 16,
        }

# Make the figure:

fig, axs = plt.subplots(2, figsize=(4, 6), sharex=True, sharey =True)
fig.subplots_adjust(left=0.02, bottom=0.2, right=0.8, top= 0.8, wspace=0.45)
fig.suptitle(f'Comparisions of Value Functions. Double Pendulum', fontsize = 15)
plt.yticks(np.arange(-2, 2.1, 1))
plt.xticks(np.arange(-2, 2.1, 1))


im0 = axs[0].scatter(x = positions[:,0], y = positions[:,1], c = values)
axs[0].set_title("Crocoddyl", fontdict=font)

im1 = axs[1].scatter(x = pos[:,0], y = pos[:,1], c = val_p)
axs[1].set_title("Neural net", fontdict=font)

fig.colorbar(im1, ax=axs[1]).set_label(" Predicted Value ", fontdict=font)
fig.colorbar(im0, ax=axs[0]).set_label(" Value ", fontdict=font)


In [None]:
def validate(net):
    
    init_conds = Datagen.random_starting_conditions(100,  angle1_lim=[-1, 1], angle2_lim=[-1, -1])
    xt, yt = Datagen.values(init_conditions=init_conds, horizon=200)
    xt, yt = torch.Tensor(xt), torch.Tensor(yt)
    yp = net(xt)
    e = yp - yt

    print(f" Mean Error: {torch.mean(e)}")
    
validate(net)