In [2]:
## Imports
import time
import torch
import torch.nn as nn
import torch.optim as optim

from torchdyn.core import NeuralODE
from torchdyn.numerics import odeint, Euler, HyperEuler
import numpy as np
import os
import torch.pi as pi

from models.FNPEG import EOM_FNPEG_long

ModuleNotFoundError: No module named 'torchdyn'

In [None]:
# Loss function declaration
sf_star = torch.Tensor([7.952]).to(device)
loss_func = LossFunc(sf_star)

# Time span
t0, tf = 0, 2 # initial and final time for controlling the system
steps = 20 + 1 # so we have a time step of 0.1s
t_span = torch.linspace(t0, tf, steps).to(device)

lRef = 6.378135  #km
tRef = torch.sqrt(lRef/9.81)
Rmars = 3390.0
# Initial distribution
r0, theta0, phi0, v0, gamma0, psi0, s0 = Rmars + 130.0, 90*pi/180, 45*pi/180, 4000.0, -15*pi/180, 70*pi/180, 0.0
x0 = torch.cat([r, theta, phi, v, gamma, psi, s], -1) # limit of the state distribution (in rads and rads/second)
#init_dist = torch.distributions.Uniform(torch.Tensor([-x0, -x0]), torch.Tensor([x0, x0]))

#Problem parameters

#planet
planet = {}
planet['Omega'] = 1 / (24.6 * 3600) #rad/s

#vehicle
vehicle = {}
vehicle['B0'] = 155 #kg/m^2
vehicle['LD'] = 0.15

#Guidance
guid = {}
guid['filter'] = {}
guid['filter']['rho_L'] = 1.0
guid['filter']['rho_D'] = 1.0
guid['FNPEG'] = {}
guid['FNPEG']['bankProfile'] = 'linear'
muM = 4.282837e13
e0 = muM/r0 - 0.5 * v0**2
sigmaF = 1.0
eF = 11.8944 #km^2/s^2
sigma0 = 90*pi/180

In [None]:
## Test integrator
dyn = EOM_FNPEG_long(sigma0, e0, sigmaF, eF, lRef, tRef, planet, vehicle, guid).to('cuda')
trajectory = odeint(dyn, x0, t_span, solver='tsit5', atol=1e-7, rtol=1e-7)    

In [None]:


# Hyperparameters
lr = 3e-3
epochs = 500
bs = 1024
opt = torch.optim.Adam(u.parameters(), lr=lr)

# Training loop
t0 = time.time(); losses=[]
for e in range(epochs):
    x0 = init_dist.sample((bs,)).to(device)
    _, trajectory = odeint(dyn, x0, t_span, solver='tsit5', atol=1e-7, rtol=1e-7)    
    loss = loss_func(trajectory); losses.append(loss.detach().cpu().item())
    loss.backward(); opt.step(); opt.zero_grad()
    print('Loss {:.4f} , epoch {}'.format(loss.item(), e), end='\r')
timing = time.time() - t0; print('\nTraining time: {:.4f} s'.format(timing))

