In [1]:
import numpy as np
import crocoddyl
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.utils import shuffle
torch.set_default_dtype(torch.double)

if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [2]:
#https://discuss.pytorch.org/t/advice-on-implementing-input-and-output-data-scaling/64369/2

In [3]:
#.... HYPERPARAMS



BATCHSIZE     = 32
lr            = 1e-3
DECAY         = 0.09

DEVICE        = 'cpu'

In [4]:
class PolicyNetwork(nn.Module):

        def __init__(self, 
                     input_dims:int  = 3,
                     output_dims:int = 60,
                     fc1_dims:int    = 256,
                     fc2_dims:int    = 256,                     
                     activation      = nn.ELU(),
                     device:str      = 'cpu'
                     ):


                super(PolicyNetwork, self).__init__()
                """Instantiate an untrained neural network with the given params

                Args
                ........
                        
                        1: input_dims   = dimensions of the state space of the robot. 3 for unicycle
                        2: output_dims  = dimensions of the next state
                        3: fc1_dims     = number of units in the first fully connected layer. Default 100
                        4: fc2_dims     = number of units in the second fully connected layer. Default 100
                        5: fc3_dims     = number of units in the third fully connected layer. Default 1
                        6: activation   = activation for the layers, default ReLU.
                        7: device       = device for computations. Generally CPU

                """

                self.input_dims    = input_dims
                self.output_dims   = output_dims
                self.fc1_dims      = fc1_dims
                self.fc2_dims      = fc2_dims
                self.activation1    = activation
                self.activation2 = nn.ReLU()


                #........... Structure
                self.fc1 = nn.Linear(self.input_dims, self.fc1_dims)
                self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
                self.fc3 = nn.Linear(self.fc2_dims, self.output_dims)

                #........... Weight Initialization protocol
                nn.init.xavier_uniform_(self.fc1.weight)
                nn.init.xavier_uniform_(self.fc2.weight)
                nn.init.xavier_uniform_(self.fc3.weight)

                
                #........... Bias Initialization protocol
                nn.init.constant_(self.fc1.bias, 0.003)
                nn.init.constant_(self.fc2.bias, 0.003)
                nn.init.constant_(self.fc3.bias, 0.003)
                
                # Send the neural net to device
                self.device = torch.device(device)
                self.to(self.device)


                
        def forward(self, state):

                next_state = self.activation1(self.fc1(state))
                next_state = self.activation1(self.fc2(next_state))
                next_state = self.fc3(next_state)
                
                return next_state

        def guess_xs(self, state, horizon=20):
                """
                Given a starting state, predict the state trajectory for the entire length of the horizon.
                The predicted trajectory should be of length horion +1
                """
                xs = []
                xs.append(state)

                for _ in range(horizon):
                        state = self(state)
                        xs.append(state)

                return torch.stack(xs).cpu().detach().numpy().reshape(horizon+1,3)


In [5]:
net = PolicyNetwork()

In [6]:
def griddedData(n_points:int = 1500,
                xy_limits:list = [-1.9,1.9],
                theta_limits:list = [-np.pi/2, np.pi/2]
                ):
    size = int(np.sqrt(n_points)) + 1
    min_x, max_x = [*xy_limits]
    xrange = np.linspace(min_x,max_x,size, endpoint=True)
    trange = np.linspace(*theta_limits, size, endpoint=True)
    points = np.array([ [x1,x2, x3] for x1 in xrange for x2 in xrange for x3 in trange])
    np.round_(points, decimals=6)
    np.random.shuffle(points)
    points = points[0:n_points, : ]
    return points

In [7]:
states = griddedData(500, xy_limits=[-1.9,1.9],theta_limits=[-np.pi/2, np.pi/2])
x = []
y = []
for x0 in tqdm(states):
    model = crocoddyl.ActionModelUnicycle()
    model.costweights = np.array([1., 1.]).T
    problem = crocoddyl.ShootingProblem(x0.T, [model]*20, model)
    ddp = crocoddyl.SolverDDP(problem)
    ddp.solve([], [] , 1000)
    
    xs = np.array(ddp.xs[1:]).flatten()
    """
    xs_us = np.hstack((np.array(ddp.xs[1:]), np.array(ddp.us)))
    np.round_(xs, 6)
    np.round_(xs_us, 6)
    assert len(xs) == len(xs_us) + 1
    for xs in xs[:-1]:
        x.append(xs)
        
    for xsus in xs_us:
        y.append(xsus)
    """
    x.append(x0)
    y.append(xs)
    
x = np.array(x) 

y = np.array(y).squeeze()


100%|██████████| 500/500 [00:00<00:00, 1421.23it/s]


In [8]:
x, y = shuffle(x, y, random_state=0)

xtest = torch.Tensor(x[0:1, :])
ytest = y[0:1, :]

xtrain = torch.Tensor(x[1:,:])
ytrain = torch.Tensor(y[1:,:])

dataset = torch.utils.data.TensorDataset(xtrain, ytrain)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = BATCHSIZE, shuffle=True)

#......  CRITERIA
criterion1 = torch.nn.MSELoss(reduction='sum')
criterion2 = torch.nn.L1Loss(reduction='mean')

opt = torch.optim.Adam(net.parameters(), lr = lr, betas=[0.5, 0.9], weight_decay=DECAY)
#opt = torch.optim.ASGD(net.parameters(), lr = lr, weight_decay=DECAY)
#opt = torch.optim.LBFGS(net.parameters(), history_size=10, max_iter=4)

net.to(DEVICE)
xtest.to(DEVICE)
for epoch in range(1000):
    for data, target in dataloader:
        net.train()
        opt.zero_grad()

        data        = data.to(DEVICE)
        target      = target.to(DEVICE)
        output      = net(data)
        
        
            
           
        
        loss        = criterion1(output, target) #+ 0.01*criterion2(output, target)
        loss.backward()
              
        opt.step()
        

    
    # Validation
    acc = 0
    acc2 = 0
    pred = net(xtest).detach().numpy()
    mae = np.mean(np.abs(pred - ytest))
    
    for i in range(len(xtest)):
        xtest[i].resize_(1, 3)
        pred = net(xtest[i]).detach().numpy()
        
        if np.mean(np.abs(pred - ytest[i])) < 0.01:
            acc += 1
        elif np.mean(np.abs(pred - ytest[i])) < 0.1:
            acc2 += 1
    print(f"EPOCH : {epoch} || MAE : {mae} || Within 0.01: {acc}/{len(xtest)} || Within 0.1 : {acc2}/{len(xtest)}")
    

EPOCH : 0 || MAE : 0.08741720879124067 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1 || MAE : 0.05853710176892415 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2 || MAE : 0.06538808201609016 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 3 || MAE : 0.04540931510330743 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 4 || MAE : 0.057146131626472634 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 5 || MAE : 0.05956207781839915 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 6 || MAE : 0.0663823975222021 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 7 || MAE : 0.05007852728251117 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 8 || MAE : 0.07214080869698876 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 9 || MAE : 0.08626231573304653 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 10 || MAE : 0.061905286694134774 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 11 || MAE : 0.0724912508926764 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 12 || MAE : 0.044906198145165235 || Within

EPOCH : 104 || MAE : 0.05001899598219807 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 105 || MAE : 0.04778333884130614 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 106 || MAE : 0.030576176960757213 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 107 || MAE : 0.03143755857978818 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 108 || MAE : 0.045657757883913956 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 109 || MAE : 0.04605613900995728 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 110 || MAE : 0.041497589761449814 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 111 || MAE : 0.0484670052669354 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 112 || MAE : 0.040108102254708115 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 113 || MAE : 0.030200007049621527 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 114 || MAE : 0.05100608679015513 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 115 || MAE : 0.03285659325846589 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 116 || MAE : 0.0

EPOCH : 206 || MAE : 0.03203804159693112 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 207 || MAE : 0.03776694159394753 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 208 || MAE : 0.028906858709207103 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 209 || MAE : 0.03177486147152317 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 210 || MAE : 0.03334664808265891 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 211 || MAE : 0.03150704950133147 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 212 || MAE : 0.03702214943311621 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 213 || MAE : 0.028886487745623943 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 214 || MAE : 0.05305862796529844 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 215 || MAE : 0.03998803101489861 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 216 || MAE : 0.020608820671595553 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 217 || MAE : 0.016108594705266523 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 218 || MAE : 0.0

EPOCH : 309 || MAE : 0.024205577895158658 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 310 || MAE : 0.015949390310722546 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 311 || MAE : 0.031230758423370333 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 312 || MAE : 0.03593690572654112 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 313 || MAE : 0.028964721634098425 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 314 || MAE : 0.02399217165270032 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 315 || MAE : 0.046011661560268334 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 316 || MAE : 0.025626750497036832 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 317 || MAE : 0.015249153746034978 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 318 || MAE : 0.0233756624647621 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 319 || MAE : 0.022393198050931932 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 320 || MAE : 0.033042303107114625 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 321 || MAE :

EPOCH : 412 || MAE : 0.025648478092696218 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 413 || MAE : 0.030414425073301176 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 414 || MAE : 0.016582080479562753 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 415 || MAE : 0.028739216694486362 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 416 || MAE : 0.023301715894938124 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 417 || MAE : 0.023483778569104427 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 418 || MAE : 0.020840761167006908 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 419 || MAE : 0.021720443991849716 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 420 || MAE : 0.025364165830681025 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 421 || MAE : 0.0291410120828468 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 422 || MAE : 0.023570225762701354 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 423 || MAE : 0.03229320719615241 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 424 || MAE 

EPOCH : 518 || MAE : 0.021666959884594408 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 519 || MAE : 0.017752938981049328 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 520 || MAE : 0.02324303732247673 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 521 || MAE : 0.02297619563835349 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 522 || MAE : 0.023349438873342857 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 523 || MAE : 0.02217696435729597 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 524 || MAE : 0.013501200684880577 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 525 || MAE : 0.018547318276174892 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 526 || MAE : 0.025255338360043473 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 527 || MAE : 0.01970233834763822 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 528 || MAE : 0.028458489495061575 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 529 || MAE : 0.02106917858453059 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 530 || MAE : 

EPOCH : 625 || MAE : 0.021287853653055137 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 626 || MAE : 0.022610064432452056 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 627 || MAE : 0.018450004813225417 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 628 || MAE : 0.024348889266024902 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 629 || MAE : 0.012445404375518434 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 630 || MAE : 0.029736423152646702 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 631 || MAE : 0.022005095388512978 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 632 || MAE : 0.0356664693163132 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 633 || MAE : 0.03205238874526413 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 634 || MAE : 0.021945448660204793 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 635 || MAE : 0.025202389764393544 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 636 || MAE : 0.03266756769834894 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 637 || MAE :

EPOCH : 732 || MAE : 0.02445689602343169 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 733 || MAE : 0.028284486752855532 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 734 || MAE : 0.016514524182385023 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 735 || MAE : 0.023936186909882994 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 736 || MAE : 0.023936343729866597 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 737 || MAE : 0.014834695974305303 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 738 || MAE : 0.027921432131788118 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 739 || MAE : 0.021880400738888448 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 740 || MAE : 0.02598611566267921 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 741 || MAE : 0.028217597414594398 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 742 || MAE : 0.015718240828986138 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 743 || MAE : 0.01857795417631673 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 744 || MAE 

EPOCH : 839 || MAE : 0.027622927220274627 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 840 || MAE : 0.011959906823150018 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 841 || MAE : 0.01709347863394401 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 842 || MAE : 0.02357987756137783 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 843 || MAE : 0.012828800495741133 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 844 || MAE : 0.016781265737458258 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 845 || MAE : 0.02013842319793366 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 846 || MAE : 0.010828597957688616 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 847 || MAE : 0.02330716562644604 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 848 || MAE : 0.02090095057245418 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 849 || MAE : 0.018253445874236513 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 850 || MAE : 0.025090090186791744 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 851 || MAE : 

EPOCH : 942 || MAE : 0.02278567673222018 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 943 || MAE : 0.020267246890223682 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 944 || MAE : 0.01718813077247243 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 945 || MAE : 0.022702711922323926 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 946 || MAE : 0.019084942401307863 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 947 || MAE : 0.021849122043181218 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 948 || MAE : 0.02561902562843738 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 949 || MAE : 0.01671741584245748 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 950 || MAE : 0.018075169526636778 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 951 || MAE : 0.017975523026416376 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 952 || MAE : 0.019284376357346047 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 953 || MAE : 0.017327517333281976 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 954 || MAE :

EPOCH : 1042 || MAE : 0.02090331018893676 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1043 || MAE : 0.012494242855642426 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1044 || MAE : 0.027180187390746343 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1045 || MAE : 0.016919676717768974 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1046 || MAE : 0.012716839686652283 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1047 || MAE : 0.02020119733194368 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1048 || MAE : 0.02378255734886057 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1049 || MAE : 0.024411397846976643 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1050 || MAE : 0.019306055815460185 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1051 || MAE : 0.017032512675975955 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1052 || MAE : 0.023954997605340922 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1053 || MAE : 0.021673333312703824 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH :

EPOCH : 1146 || MAE : 0.022081010424252447 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1147 || MAE : 0.015842871135081354 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1148 || MAE : 0.030782765017708182 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1149 || MAE : 0.029776014221781198 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1150 || MAE : 0.014675683956700766 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1151 || MAE : 0.02071267283614031 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1152 || MAE : 0.016755297618916736 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1153 || MAE : 0.020753750138246704 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1154 || MAE : 0.011460743382524171 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1155 || MAE : 0.014223060654094745 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1156 || MAE : 0.01750883916271077 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1157 || MAE : 0.0174364917364024 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 

EPOCH : 1252 || MAE : 0.025211803002241396 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1253 || MAE : 0.016527349818266538 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1254 || MAE : 0.021226151035637342 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1255 || MAE : 0.013833351994753591 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1256 || MAE : 0.027533386031810204 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1257 || MAE : 0.01986909212099013 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1258 || MAE : 0.017250839070899682 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1259 || MAE : 0.026416130672990968 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1260 || MAE : 0.019779548412140403 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1261 || MAE : 0.013794854731847623 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1262 || MAE : 0.019918474457520288 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1263 || MAE : 0.026955922294484548 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH

EPOCH : 1355 || MAE : 0.024612531158945496 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1356 || MAE : 0.015466873590878899 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1357 || MAE : 0.016151112522111554 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1358 || MAE : 0.01442680978406921 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1359 || MAE : 0.020187459365337876 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1360 || MAE : 0.017959085641822976 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1361 || MAE : 0.024409237794664657 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1362 || MAE : 0.012668945131171224 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1363 || MAE : 0.0216380269829844 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1364 || MAE : 0.023968699556111243 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1365 || MAE : 0.022265908962006423 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1366 || MAE : 0.022685313686926542 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH :

EPOCH : 1457 || MAE : 0.0200468645850858 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1458 || MAE : 0.01779139712707877 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1459 || MAE : 0.014082693724817467 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1460 || MAE : 0.018034937953407756 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1461 || MAE : 0.018157852700747594 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1462 || MAE : 0.009193107762612196 || Within 0.01: 1/1 || Within 0.1 : 0/1
EPOCH : 1463 || MAE : 0.01719226419017745 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1464 || MAE : 0.015075118328631171 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1465 || MAE : 0.021958614879453452 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1466 || MAE : 0.016116003240696335 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1467 || MAE : 0.02113578817823376 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1468 || MAE : 0.025579087026363287 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1

EPOCH : 1562 || MAE : 0.017092962405768304 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1563 || MAE : 0.020279259382361175 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1564 || MAE : 0.023166992539350957 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1565 || MAE : 0.021870516372444415 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1566 || MAE : 0.025380366572340236 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1567 || MAE : 0.01805869989014265 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1568 || MAE : 0.021159825445929356 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1569 || MAE : 0.025363510710861938 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1570 || MAE : 0.00902151669317788 || Within 0.01: 1/1 || Within 0.1 : 0/1
EPOCH : 1571 || MAE : 0.018111009782537827 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1572 || MAE : 0.021619322985410486 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1573 || MAE : 0.024617003398138396 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH 

EPOCH : 1664 || MAE : 0.019895365400969568 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1665 || MAE : 0.02147255750527948 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1666 || MAE : 0.01739891480364107 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1667 || MAE : 0.016151726479374455 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1668 || MAE : 0.023757220652938036 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1669 || MAE : 0.012205059815902978 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1670 || MAE : 0.02002758737772592 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1671 || MAE : 0.015370160708848215 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1672 || MAE : 0.01614431179576654 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1673 || MAE : 0.01569902114106678 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1674 || MAE : 0.020224578448805246 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1675 || MAE : 0.023922038569799176 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1

EPOCH : 1765 || MAE : 0.0108077835693003 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1766 || MAE : 0.019233150055549392 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1767 || MAE : 0.014686980301193892 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1768 || MAE : 0.013298985073242998 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1769 || MAE : 0.024296501784759384 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1770 || MAE : 0.03033708889181663 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1771 || MAE : 0.02322335409946776 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1772 || MAE : 0.021985368887886412 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1773 || MAE : 0.02274704327148091 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1774 || MAE : 0.013654636871478621 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1775 || MAE : 0.019716695484105193 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1776 || MAE : 0.019111908513949876 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1

EPOCH : 1868 || MAE : 0.01599362912524088 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1869 || MAE : 0.021320811027872106 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1870 || MAE : 0.02064619686128823 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1871 || MAE : 0.013479515594310099 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1872 || MAE : 0.020757575199343462 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1873 || MAE : 0.019377964918626565 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1874 || MAE : 0.01894642184216098 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1875 || MAE : 0.013983015644303494 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1876 || MAE : 0.017580451607612298 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1877 || MAE : 0.01947486090317361 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1878 || MAE : 0.018163913127988367 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1879 || MAE : 0.02192161868707188 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1

EPOCH : 1970 || MAE : 0.017603896888558727 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1971 || MAE : 0.017183862553854668 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1972 || MAE : 0.015531101405944029 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1973 || MAE : 0.01969488120093619 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1974 || MAE : 0.022902710238229474 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1975 || MAE : 0.014271639999848448 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1976 || MAE : 0.020355019066014904 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1977 || MAE : 0.013096700826176649 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1978 || MAE : 0.014172821284519134 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1979 || MAE : 0.014679634701870527 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1980 || MAE : 0.022174517084362737 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 1981 || MAE : 0.017736156504518708 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH

EPOCH : 2076 || MAE : 0.01392093154206845 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2077 || MAE : 0.022068366108068904 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2078 || MAE : 0.01317992376545422 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2079 || MAE : 0.015035372060291562 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2080 || MAE : 0.02879184925811706 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2081 || MAE : 0.028757900778269552 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2082 || MAE : 0.01718064534506895 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2083 || MAE : 0.017652263919932333 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2084 || MAE : 0.020472208980108302 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2085 || MAE : 0.016323585621796075 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2086 || MAE : 0.016951830176047678 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2087 || MAE : 0.0186895977156359 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 20

EPOCH : 2176 || MAE : 0.009754947785048125 || Within 0.01: 1/1 || Within 0.1 : 0/1
EPOCH : 2177 || MAE : 0.030499140049655975 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2178 || MAE : 0.020345618997533475 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2179 || MAE : 0.012153260561855119 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2180 || MAE : 0.016152808290023565 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2181 || MAE : 0.010680065617840288 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2182 || MAE : 0.016022254319172365 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2183 || MAE : 0.017279041058504616 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2184 || MAE : 0.014935167911189167 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2185 || MAE : 0.014539284663486335 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2186 || MAE : 0.012375710715217775 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH : 2187 || MAE : 0.01511999510149173 || Within 0.01: 0/1 || Within 0.1 : 1/1
EPOCH

KeyboardInterrupt: 

In [9]:
def guess(x):
    x = torch.Tensor(x.reshape(1, 3))
    xs, us = [], []
    xs.append(x)
    for _ in range(30):
        
        policy = net(x).detach().reshape(1, 5)
        
        xs.append(policy[:,0:3])
        us.append(policy[:,3:])
        x = policy[:,0:3]
    return torch.stack(xs).squeeze().numpy(), torch.stack(us).squeeze().numpy()

In [11]:
x  = np.array([np.random.uniform(-1.5,1.5), np.random.uniform(-1.5,1.5), 0])
#np.round_(x, 4)


    
#_xs, _us = guess(x)
_xs = []
_xs.append(x)
p_xs = net(torch.Tensor(x.reshape(1, 3))).detach().numpy().reshape(20, 3)
for _ in p_xs:
    _xs.append(_/1000)

_xs = np.array(_xs).reshape(31, 3)


model               = crocoddyl.ActionModelUnicycle()
model.costWeights   = np.array([1.,1.]).T
problem             = crocoddyl.ShootingProblem(x.T, [model]*20, model)
ddp                 = crocoddyl.SolverDDP(problem)
log                 = crocoddyl.CallbackLogger()
ddp.setCallbacks([log])
ddp.solve([], [], 1000)
stops = log.stops[1:]
xs = np.array(ddp.xs)

plt.clf()
fig, axs = plt.subplots(1, figsize=(8,6))
axs.plot(xs[:,0], xs[:,1], c = 'blue', label= " Crocoddyl")

axs.plot(_xs[:,0], _xs[:,1], c = 'grey', label= " Guess")
axs.set_xlim([-1.5,1.5 ])
axs.set_ylim([-1.5,1.5 ])

ValueError: cannot reshape array of size 63 into shape (31,3)