In [1]:
import torch
import torch.nn as nn
from torch.autograd import grad 

import pytorch_lightning as pl
from torchdyn.numerics.odeint import odeint_hybrid
from torchdyn.numerics.solvers import DormandPrince45
import torchdyn.numerics.sensitivity
import attr

## lietorch:
import sys; sys.path.append('../')
import lie_torch as lie
import math

from functorch import vmap

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

<h1> Only General Potential Shaping: gradients immediately on algebra, two potential terms on SE(3) </h1>

Class Definition: Mixed Dynamics on SE3 with NNs for potential and damping injection

In [2]:
from models import se3_Dynamics_2,SE3_Dynamics, SE3_Quadratic, AugmentedSE3, PosDefSym, PosDefTriv, PosDefSym_Small


Sensitivity / Adjoint Method:

This only needs the final condition of the forward dynamics and adjoint dynamics, then it computes both state and adjoint state from there

In [3]:
from sensitivity import _gather_odefunc_hybrid_adjoint_light

Learners: Training step, etc.

In [4]:
from learners import EnergyShapingLearner

Event definition: Chart Switching on SE(3)

In [5]:
# Adapted from pytorch-implicit/exampels/network/simulate_tcp.ipynb, and Paper: Neural Hybrid Automata: Learning Dynamics withMultiple Modes and Stochastic Transitions (M.Poli, 2021)

@attr.s
class EventCallback(nn.Module):
    def __attrs_post_init__(self):
        super().__init__()

    def check_event(self, t, x):
        raise NotImplementedError

    def jump_map(self, t, x):
        raise NotImplementedError
    
    def batch_jump(self, t, x, ev):
        raise NotImplementedError

@attr.s
class ChartSwitch(EventCallback):       
    def check_event(self, t, xi): 
        # works for collection of states
        qi, P, i, rem = xi[...,:6], xi[...,6:12], xi[...,12], xi[...,13:]
        wi = qi[...,:3]
        ev = (torch.sqrt(lie.dot(wi,wi)) > math.pi*3/4).bool() 
        return ev

    def jump_map(self, t, xi):
        #xi = torch.squeeze(xi)
        qi, P, i, rem = xi[...,:6], xi[...,6:12], xi[...,12], xi[...,13:]
        H = lie.unchart(qi, i)
        j = torch.unsqueeze(lie.bestChart(H),0)
        qj = lie.chart_trans(qi, i, j)
        return torch.cat((qj, P,j,rem), -1) #torch.unsqueeze(,0)
    
    def batch_jump(self, t, xi, ev):
        xi[ev,:] = vmap(self.jump_map)(t[ev],xi[ev,:])
        return xi
        
    def jump_map_forced(self, t, xi, j):
        # jump to chart j for all
        #xi = torch.squeeze(xi)
        qi, P, i, rem = xi[...,:6], xi[...,6:12], xi[...,12], xi[...,13:]
        qj = qi*0;
        for k in range(qi.size(0)):
            qj[k,:] = lie.chart_trans(qi[k,:], i[k], j[k])
        return torch.cat((qj, P,torch.unsqueeze(j,1),rem), -1) #torch.unsqueeze(,0)
    
    def batch_jump_forced(self, t, xi, j):
        xi = vmap(self.jump_map_forced)(t,xi,j)
        return xi
        
    
@attr.s
class ChartSwitchAugmented(EventCallback):
    des_props = None
    # Expects x of type: z[:6] = qi, z[6:12] = P, z[12] = i, z[13:25] = λi, z[25:] = μ. This is used for the system augmented with co-state-dynamics for adjoint gradient method
    def check_event(self, t, z): 
        xi, i, λi, rem = self.to_input(z)
        w = xi[...,:3]
        ev = (torch.sqrt(lie.dot(w,w)) < -1).bool()  
        return ev

    def batch_jump_to_Id(self, λiT, xT):
        #only transitions λiT to corresponding version λT at identity of SE(3)
        λT = vmap(lie.dchart_trans_mix_Co_to_Id)(λiT,xT[...,:12],xT[...,-1])
        return λT
    
    def jump_map(self, t, z):
        xi, i, λi, rem  = z[...,:12], z[...,12], z[...,13:25], z[...,25:]
        qi = xi[...,:6]
        H = lie.unchart(qi, i)
        j = lie.bestChart(H)
        xj = lie.chart_trans_mix(xi, i, j)
        λj = lie.chart_trans_mix_Co(xi, λi, i, j)
        return torch.cat((xj, torch.unsqueeze(j,-1), λj, rem), -1)
    
    def batch_jump(self, t, z, ev):
        xi, i, λii, rem = self.to_input(z)
        z = torch.cat((xi,torch.unsqueeze(i,-1),λii),-1)
        z[ev,:] = vmap(self.jump_map)(t[:xi.shape[0]][ev],z[ev,:])
        xi, i, λii = z[...,:12], z[...,12], z[...,13:26] 
        return self.to_output(xi, i, λii, rem)
    
    def to_input(self, z):
        if (self.des_props!=None):
            numels,shapes = tuple(self.des_props)
            xii_nel, λi_nel = tuple(numels)
            xii_shp, λi_shp = tuple(shapes)
            xii, λii, rem = z[:xii_nel], z[xii_nel:xii_nel+λi_nel], z[xii_nel+λi_nel:]
            xii, λii = xii.reshape(xii_shp), λii.reshape(λi_shp)
            xi, i = xii[...,:12], torch.unsqueeze(xii[...,12],-1)
            return xi, i, λii, rem
        else:
            xi, i, λi, rem  = z[...,:12], z[...,12], z[...,13:25], z[...,25:]
            return xi, i, λi, rem
    
    def to_output(self, xj, j, λj, rem):
        if (self.des_props!= None):
            xjj = torch.cat((xj,torch.unsqueeze(j,-1)),-1)
            z = torch.cat((xjj.flatten(),λj.flatten(),rem))
        else:
            z = torch.cat((xj, torch.unsqueeze(j,-1), λj, rem), -1)
        return z

Parameters of Dynamics, Definition of Loss-Function

In [6]:
# Adapted from latent-energy-shaping-main/notebooks/optimal_energy_shaping.ipynb

I = torch.diag(torch.tensor((0.01,0.01,0.01,1,1,1))).to(device) ; # Inertia Tensor

from models import IntegralLoss


Definition of NNs for potential and damping injection

In [7]:
nh = 32

V1 = nn.Sequential(nn.Linear(12, nh), nn.Softplus(), nn.Linear(nh, nh), nn.Tanh(), nn.Linear(nh, 1)).to(device)
V2 = nn.Sequential(nn.Linear(3, nh), nn.Softplus(), nn.Linear(nh, nh), nn.Tanh(), nn.Linear(nh, 1)).to(device)

### Likewise for Damping injection:
nf = 6
B = nn.Sequential(nn.Linear(18, nh), nn.Softplus(), nn.Linear(nh, nh), nn.Tanh(), nn.Linear(nh, nf),PosDefTriv()).to(device)

### Initialize Parameters: 

for p in V1.parameters(): torch.nn.init.normal_(p, mean=0.0, std=0.01)#torch.nn.init.zeros_(p)
for p in V2.parameters(): torch.nn.init.normal_(p, mean=0.0, std=0.01)
for p in B.parameters(): torch.nn.init.normal_(p, mean=0.0, std=0.01) #torch.nn.init.zeros_(p)


Initialization of Chart-Switches, Prior- \& Target Distribution, and Dynamics

In [8]:
### Prior and Target Distribution

from utils import prior_dist_SE3, target_dist_SE3_HP,  target_cost_SE3, multinormal_target_dist

th_max = torch.tensor(math.pi).to(device); d_max = torch.tensor(0).to(device); pw_max = torch.tensor(0.0).to(device); pv_max = torch.tensor(1.0).to(device); ch_min = torch.tensor(0).to(device); ch_max = torch.tensor(0).to(device); 
prior = prior_dist_SE3(th_max,d_max,pw_max,pv_max,ch_min,ch_max,device)

H_target = torch.eye(4).to(device); sigma_th = torch.tensor(4).to(device); sigma_d = torch.tensor(10).to(device); sigma_pw = torch.tensor(5).to(device); sigma_p = torch.tensor(1).to(device);
                                    # sigma_th = torch.tensor(0.4).to(device); sigma_d = torch.tensor(0.4).to(device); sigma_pw = torch.tensor(1e-1).to(device); sigma_p = torch.tensor(1e-1).to(device);
target = target_cost_SE3(H_target,sigma_th,sigma_d,sigma_pw,sigma_p,device)

integral_loss_scale = torch.tensor(0.01).to(device)

### Callback:
callbacks = [ChartSwitch()]
jspan = 10 # maximum number of chart switches per iteration (if this many happen, something is wrong anyhow)

callbacks_adjoint = [ChartSwitchAugmented()]
jspan_adjoint = 10

### Initialize Dynamics 
#(I,B,V1,V2) = torch.load('IBV_fShaping_07_07_11:57.pt') #('IBV_fShaping_19_09_16:25.pt')
I = I.to(device); B = B.to(device); V1 = V1.to(device); V2 = V2.to(device)
f = se3_Dynamics_2(I,B,V1,V2,target).to(device) # (I,B,V) = torch.load('IBV_fShaping.pt')

## Augmented Dynamics with integral loss
aug_f = AugmentedSE3(f, IntegralLoss(f),target).to(device) 
aug_f.l.scale = integral_loss_scale

t_span = torch.linspace(0, 3, 30).to(device)  


Training Loop: Optimal Potential Shaping

In [9]:
#f.parameters()
#model_parameters = filter(lambda p: p.requires_grad, f.parameters())
#params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters])
#print(params)
#model_parameters = filter(lambda p: p.requires_grad, f.parameters())

#[print(p) for p in model_parameters]

In [10]:
from HybridODE import NeuralODE_Hybrid
from pytorch_lightning.loggers import WandbLogger
import datetime

today = datetime. datetime. now()
date_time = today. strftime("%d/%m_%H:%M")

solver = 'dopri5'
atol, rtol, atol_adjoint, rtol_adjoint = 1e-3,1e-4,1e-3,1e-4
dt_min, dt_min_adjoint = 0, 0

model = NeuralODE_Hybrid(f, jspan, callbacks, jspan_adjoint, callbacks_adjoint, solver, atol, rtol, dt_min, atol_adjoint, rtol_adjoint, dt_min_adjoint, IntegralLoss(f), sensitivity = 'hybrid_adjoint_full').to(device) 
aug_model = NeuralODE_Hybrid(aug_f, jspan, callbacks, jspan_adjoint, callbacks_adjoint, solver, atol, rtol, dt_min, atol_adjoint, rtol_adjoint, dt_min_adjoint, sensitivity = 'hybrid_adjoint_full').to(device) 

learn = EnergyShapingLearner(model, t_span, prior, target, aug_model).to(device) 
learn.lr = 1e-3
learn.batch_size = 2048

logger = WandbLogger(project='potential-shaping-SE3', name='NN_'+date_time)

trainer = pl.Trainer(max_epochs=150, logger=logger, gpus = torch.cuda.device_count())#
trainer.fit(learn)
#torch.save((I,B,V),'IBV_fShaping_'+date_time+'.pt')

[34m[1mwandb[0m: Currently logged in as: [33mypwotte[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | NeuralODE_Hybrid | 4.6 K 
1 | aug_model | NeuralODE_Hybrid | 4.6 K 
-----------------------------------------------
4.6 K     Trainable params
0         Non-trainable params
4.6 K     Total params
0.018     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [11]:
date_time = today.strftime("%d_%m_%H:%M")
torch.save((I,B,V1,V2),'IBV_fShaping_'+date_time+'.pt')

#%debug


In [12]:
%debug

ERROR:root:No traceback has been produced, nothing to debug.
