Casadi is great at building equation graphs, but it is horrible at simplification. Trying to check if a system is group affine is one example of where it is lacking.

In [6]:
import casadi as ca
from casadi.tools.graph import graph
import matplotlib.image as mpimg
from io import BytesIO
import matplotlib.pyplot as plt
import os
from IPython.display import Image


def draw_graph(x):
    g = graph.dotgraph(x)
    #g.set('dpi', 300)
    png = g.create('dot', 'png')
    return Image(png)

In [7]:
expr = ca.SX

def g(name):
    R = expr.sym('R' + name, 3, 3)
    v = expr.sym('v' + name, 3, 1)
    x = expr.sym('x' + name, 3, 1)
    X = ca.diagcat(R, 1, 1)
    X[:3, 3] = v
    X[:3, 4] = x
    return X, R, v, x

def skew(x):
    M = expr(3, 3)
    M[0, 1] = -x[2]
    M[0, 2] = x[1]
    M[1, 0] = x[2]
    M[1, 2] = -x[0]
    M[2, 0] = -x[1]
    M[2, 1] = x[0]
    return M

def f(X, u):
    g = expr(3, 1)
    g[2] = u[6]
    Omega = skew(u[0:3])
    a = u[3:6]
    R = X[0:3, 0:3]
    v = X[:3, 3]
    x = X[:3, 4]
    dX = expr(5, 5)
    dX[0:3, 0:3] = ca.mtimes(R, Omega)
    dX[:3, 3] = ca.mtimes(R, a) + g
    dX[:3, 4] = v
    return dX

In [8]:
g_a, R_a, v_a, x_a = g('a')
g_b, R_b, v_b, x_b = g('b')
g_I = expr.eye(5)
omega = expr.sym('omega', 3)
a = expr.sym('a', 3)
g = expr.sym('g')
u = ca.vertcat(omega, a, g)
res = ca.mtimes(f(g_a, u), g_b) + ca.mtimes(g_a, f(g_b, u)) - ca.mtimes([g_a, f(g_I, u), g_b]) - f(ca.mtimes(g_a, g_b), u)

Casadi cannot determine symbollically that the equation graphs are equivalent

In [9]:
x = ca.mtimes(f(g_a, u), g_b)  + ca.mtimes(g_a, f(g_b, u)) - ca.mtimes([g_a, f(g_I, u), g_b])
y = f(ca.mtimes(g_a, g_b), u)
ca.is_equal(ca.simplify(x), ca.simplify(y), 100)

False

We can however check numerically.

In [10]:
f_X = ca.Function('x', [R_a, v_a, x_a, R_b, v_b, x_b, u], [x])
f_Y = ca.Function('y', [R_a, v_a, x_a, R_b, v_b, x_b, u], [y])

import numpy as np
from numpy.random import randn
R_av = randn(3, 3)
v_av = randn(3, 1)
x_av = randn(3, 1)
R_bv = randn(3, 3)
v_bv = randn(3, 1)
x_bv = randn(3, 1)
uv = randn(7, 1)
group_affine = bool(ca.norm_fro(f_X(R_av, v_av, x_av, R_bv, v_bv, x_bv, uv) - f_Y(R_av, v_av, x_av, R_bv, v_bv, x_bv, uv)) < 1e-5)
group_affine

True