In [566]:
import numpy as np
from numpy import zeros, eye, array, exp, zeros_like, vstack, outer
from numpy.linalg import norm, solve, inv
from autograd import jacobian
import autograd.numpy as anp
from autograd.numpy.linalg import solve as asolve
from math import log10, floor

from scipy.stats import multivariate_normal as MVN

from Manifolds.GeneralizedEllipse import GeneralizedEllipse

- Density of a multivariate normal centered at $\mu=(0, 0)$ with covariance matrix $\Sigma$ is
$$
p(x) = (2\pi)^{-1}\det(\Sigma)^{-1/2}\exp\left(-\frac{1}{2}x^\top\Sigma^{-1} x\right)
$$
- Function is the log density $f:\mathbb{R}^n\to\mathbb{R}$
$$
f(x) = \log p(x) = -\log(2\pi) - \frac{1}{2}\log\det(\Sigma) - \frac{1}{2}x^\top\Sigma^{-1}x
$$
- Gradient 
$$
\nabla f(x) = -\Sigma^{-1}x
$$
- Jacobian function $J_f:\mathbb{R}^n\to\mathbb{R}^{1\times n}$
$$
J_f(x) = - x^\top \Sigma^{-1}
$$
- Nabla squared and Hessian matrix
$$
H_f(x) = \nabla^2 f(x) = -\Sigma^{-1}
$$

In [619]:
μ = zeros(2)
Σ = array([[3.0, 0.0], [0.0, 1.0]])
Σinv = inv(Σ)
π = MVN(μ, Σ)
f = π.logpdf
z0 = -4
ellipse = GeneralizedEllipse(μ, Σ, exp(z0))
grad_f = lambda x: - solve(Σ, x - μ)
hess_f = lambda x: - inv(Σ)
q = MVN(zeros(2), eye(2))

In [620]:
x0 = ellipse.sample()
v0 = q.rvs()

In [621]:
def thug_dynamic(x0, v0, T, B, α, grad_log_pi):
    positions = x0
    velocities = v0
    g = grad_log_pi(x0)              # Compute gradient at x0
    g = g / norm(g)                  # Normalize
    w0 = v0 - α * g * (g @ v0)       # Tilt velocity
    velocities = vstack((velocities, w0))
    w, x = w0, x0                    # Housekeeping
    δ = T / B                    # Compute step size
    for _ in range(B):
        x = x + δ*w/2           # Move to midpoint
        positions = vstack((positions, x))
        g = grad_log_pi(x)          # Compute gradient at midpoint
        ghat = g / norm(g)          # Normalize 
        w = w - 2*(w @ ghat) * ghat # Reflect velocity using midpoint gradient
        velocities = vstack((velocities, w))
        x = x + δ*w/2           # Move from midpoint to end-point
        positions = vstack((positions, x))
    # Unsqueeze the velocity
    g = grad_log_pi(x)
    g = g / norm(g)
    v = w + (α / (1 - α)) * g * (g @ w)
    velocities = vstack((velocities, v))
    return positions, velocities

In [622]:
T = 10.0
B = 10
δ = T / B

In [623]:
α = 0.5
γ = α*(2-α) / ((1-α)**2)

In [624]:
positions, velocities = thug_dynamic(x0, v0, T, B, α, grad_f)

In [625]:
assert len(positions)  == 2*B + 1
assert len(velocities) == 3 + B

In [626]:
N = lambda x: (outer(grad_f(x), grad_f(x))) / (grad_f(x) @ grad_f(x))

In [627]:
def interwave_arrays(a, b):
    c = np.empty((len(a) + len(b),), dtype='object')
    c[0::2] = a
    c[1::2] = b
    return c

In [628]:
xindeces = interwave_arrays(['{}'.format(i) for i in range(B+1)], ['{}o2'.format(i) for i in range(1, 2*B, 2)])

In [629]:
def construct_velocities_dictionary(velocities, B):
    vdict = {'w{}'.format(i): velocities[i+1] for i in range(B+1)}
    vdict['v0'] = velocities[0]
    vdict['vB'] = velocities[-1]
    return vdict

In [630]:
def construct_positions_dictionary(positions, xindeces):
    xdict = {'x{}'.format(xindeces[i]): x for (i, x) in enumerate(positions)}
    return xdict

In [631]:
xdict = construct_positions_dictionary(positions, xindeces)

In [632]:
vdict = construct_velocities_dictionary(velocities, B)

In [633]:
def get_x(b):
    assert type(b) == str or type(b) == int
    if type(b) == int:
        if b in [int(ind) for ind in xindeces if ind.isdigit()]:
            return positions[::2][b]
        else:
            raise ValueError("b is integer but not within range.")
    elif type(b) == str:
        if b in xindeces:
            return positions[np.where(xindeces == b)[0][0]]
        else:
            raise ValueError("b is string but not in xindeces.")
    else:
        raise ValueError('b must be either string or integer')
    
    
def get_w(b):
    assert (type(b) == str and b.isdigit()) or type(b) == int
    b = int(b)
    assert int(b) >= 0 and int(b) <= B
    return velocities[1:-1][int(b)]

In [634]:
# Grab all x's
x1o2 = get_x('1o2')
x1   = get_x(1)
x3o2 = get_x('3o2')
x2   = get_x(2)
x5o2 = get_x('5o2')
x3   = get_x(3)
x7o2 = get_x('7o2')
x4   = get_x(4)
x9o2 = get_x('9o2')
x5   = get_x(5)

# Grab all vs and ws
w0 = get_w(0)
w1 = get_w(1)
w2 = get_w(2)
w3 = get_w(3)
w4 = get_w(4)
w5 = get_w(5)
v5 = velocities[-1]
vB = v5

# Grab w with reference to B
wBm1 = get_w(B-1)
wB   = get_w(B)

# Grab position with reference to B
xB   = get_x(B)
xBm1 = get_x(B-1)

Δ = vB@vB - v0@v0

In [635]:
N0    = N(x0)
N_1o2 = N(get_x('1o2'))
N_9o2 = N(get_x('9o2'))
N1    = N(x1)
N5    = N(x5)
NB    = N(xB)

In [636]:
# Correct, gives v5@v5
# w5 @ (I + ((α*(2-α)) / ((1-α)**2)) * N(x5)) @ w5

In [637]:
# Correct, gives v5@v5 - v0@v0
#-α*(2-α)*(v0@N0@v0) + (α*(2-α)/((1-α)**2))*w4@(I - 2*N_9o2)@N5@(I - 2*N_9o2)@w4

In [638]:
def N_autograd(x):
    gx = asolve(Σ, x - μ)
    return anp.outer(gx, gx) / (gx @ gx)

In [639]:
Tfunc = lambda x: I - N(x)
J = lambda x: grad_f(x).reshape(1, -1)
Jplus = lambda x: J(x).T @ inv(J(x)@J(x).T)
DNa = jacobian(N_autograd)
DN = lambda x: sum_to_transpose(np.outer(Jplus(x), hess_f(x)).reshape(2, 2, 2) @ Tfunc(x))
DNexplicit = lambda x, w: ((Σinv@(outer(w, x) + outer(x, w))@Σinv)/(x@Σinv@Σinv@x)) - 2*((w@Σinv@Σinv@x)/((x@Σinv@Σinv@x)**2))*Σinv@outer(x, x)@Σinv
I = eye(2)

In [640]:
sum_to_transpose = lambda A: A + A.T

In [641]:
def Γ(b):
    xb = get_x(b)
    xbm1o2 = get_x('{}o2'.format(2*b-1))
    Nb = N(xb)
    Nbm1o2 = N(xbm1o2)
    return Nb - 2*Nb@Nbm1o2 -2*Nbm1o2@Nb + 4*Nbm1o2@Nb@Nbm1o2


def Γ̂(b):
    xbm1   = get_x(b-1)
    xbm1o2 = get_x('{}o2'.format(2*b-1))
    Nbm1 = N(xbm1)
    Nbm1o2 = N(xbm1o2)
    return Nbm1 + δ*DNa(xbm1)@Nbm1o2@get_w(b-1)

In [642]:
def Λ(b):
    wb = get_w(b)
    xb = get_x(b)
    xbp1o2 = get_x('{}o2'.format(2*b+1))
    return wb@DNa(xb)@N(xbp1o2)@wb@wb

In [643]:
# There are three elements Jplus(x0), hess_f(x0), and T(x0)

In [644]:
#v @ DNa(x0)@ v

In [645]:
#np.kron(Jplus(x0), hess_f(x0).reshape(2,1,2)) @ T(x0)

In [281]:
#np.outer(Jplus(x), hess_f(x)).shape

$$
N_{b-1/2} = N_{b-1} + \frac{\delta}{2} DN_{b-1}(\cdot, \cdot) w_{b-1} + \mathcal{O}(\delta^2)
$$

In [282]:
#norm(N_1o2 - (N0 + (δ/2)*DNa(x0) @ w0))

$$
N_{b-1/2}N_b N_{b-1/2} = N_{b-1} + \frac{\delta}{2} DN_{b-1}(\cdot, \cdot) w_{b-1} + \mathcal{O}(\delta^2)
$$

In [283]:
#norm(N_1o2@N1@N_1o2 - (N0 + (δ/2)*DNa(x0)@w0))

In [284]:
#norm(Γ(B) - Γ̂(B))

In [285]:
######### Compute delta

In [286]:
#Δ = v5@v5 - v0@v0

In [287]:
#Δcorrect = -α*(2-α)*v0@N0@v0 + ((α*(2-α))/((1-α)**2))*wBm1@Γ(B)@wBm1

In [288]:
#Δ - Δcorrect

In [289]:
#Δexpand = -α*(2-α)*v0@N0@v0 + ((α*(2-α))/((1-α)**2))*wBm1@Γ̂(B)@wBm1

In [290]:
#Δ - Δexpand

In [291]:
# check if 
#wBm1@DNa(xBm1)@wBm1

In [292]:
#wBm1 @ (DNa(xBm1)@NBm1o2@wBm1) @ wBm1

In [293]:
#wB@DNa(xB)@wB

In [294]:
#δ*α*(2-α)*np.sum([Λ(b) for b in range(B)])

In [295]:
#wB@wB + ((α*(2-α))/((1-α)**2))*(w0@N0@w0 + δ*np.sum([Λ(b) for b in range(B)]))

In [296]:
#min([norm(Γ(b) - Γ̂(b)) for b in range(1, B+1)]) < δ**2

At some point we get to
$$
\begin{align}
    \Delta 
    &= -\alpha(2-\alpha)v_0^\top N_0 v_0 + \frac{\alpha(2-\alpha)}{(1-\alpha)^2}w_{B}^\top N_B w_{B} \\
    &= -\alpha(2-\alpha)v_0^\top N_0 v_0 + \frac{\alpha(2-\alpha)}{(1-\alpha)^2}w_{B-1}^\top\Gamma_B w_{B-1}\\
    &= -\alpha(2-\alpha)v_0^\top N_0 v_0 + \gamma w_{B-1}^\top\Gamma_B w_{B-1}
\end{align}
$$
where we have defined $\gamma = \frac{\alpha(2-\alpha)}{(1-\alpha)^2}$ for simplicity.

In [297]:
#γ = α*(2-α) / ((1-α)**2)

In [298]:
# Indeed we can check this
#Δ - (-α*(2-α)*v0@N0@v0 + γ*wBm1@Γ(B)@wBm1)

At this point we would like to check our various approximations. For instance, it is still true that one can use the approximation to $\Gamma$

In [299]:
#abs(Δ - (-α*(2-α)*v0@N0@v0 + γ*wBm1@Γ̂(B)@wBm1))

This means that the approximation works. Our approximation is 
$$
\hat{\Gamma}_B = N_{B-1} + \delta DN_{B-1}(\cdot, \cdot)N_{B-1/2}w_{B-1} + \mathcal{O}(\delta^2)
$$

This fundamentally means that our approximate $\Delta$ is
$$
\hat{\Delta} = -\alpha(2-\alpha)v_0^\top N_0 v_0 + \gamma w_{B-1}^\top N_{B-1} w_{B-1} + \gamma \delta w_{B-1}^\top DN_{B-1}(\cdot, \cdot)N_{B-1/2}w_{B-1} w_{B-1} + \mathcal{O}(\delta^2)
$$

Now we can apply the same to $w_{B-1}^\top N_{B-1} w_{B-1}$ giving
$$
\begin{align}
    w_{B-1}^\top N_{B-1} w_{B-1}
    &= w_{B-2}^\top \Gamma_{B-1} w_{B-2} \\
    &= w_{B-2}^\top N_{B-2}w_{B-2} + \delta w_{B-2}^\top DN_{B-2}(\cdot, \cdot) N_{B-3/2} w_{B-2}w_{B-2} + \mathcal{O}(\delta^2)
\end{align}
$$

If one denotes
$$
\Lambda_b = w_b^\top DN_b(\cdot, \cdot) N_{b+1/2} w_{b} w_b
$$
then basically this tells us
$$
\begin{align}
    \hat{\Delta}
    &= -\alpha(2-\alpha)\|v_0^\perp\|^2
\end{align}
$$

Can we check that
$$
\gamma w_B^\top N_B w_B = \gamma w_0^\top N_0 w_0 + \delta\gamma \sum_{b=0}^{B-1} \Lambda_{b}
$$

In [300]:
#wB = get_w(B)

In [301]:
#NB = N(get_x(B))

In [302]:
#γ*wB@NB@wB

In [303]:
#γ*w0@N0@w0 + δ*γ*sum([Λ(b) for b in range(B)])

Therefore one should also have

In [304]:
#vB@vB - (w0@w0 + γ*wB@NB@wB)

In [305]:
#abs(vB@vB - (w0@w0 + γ*w0@N0@w0 + δ*γ*sum([Λ(b) for b in range(B)])))

And so also $\Delta$ should be close enough

In [306]:
#vB@vB - v0@v0

In [307]:
#-α*(2-α)*v0@N0@v0 + δ*γ*sum([Λ(b) for b in range(B)])

In [308]:
#Δ

In [309]:
#w0@w0- (v0@v0 - α*(2-α)*v0@N0@v0)

In [495]:
abs(Δ - δ*γ*sum([Λ(b) for b in range(B)]))

0.0013539139848017523

Test that
$$
N_{b+3/2}w_{b+1} = -N_{b+1/2}w_b + \mathcal{O}(\delta)
$$

In [311]:
#norm(N(get_x('3o2'))@get_w(1) + N(get_x('1o2'))@get_w(0))

Check whether
$$
w_{b+1}^\top J_{b+1}^+ = -w_b J_b^+ + \mathcal{O}(\delta)
$$

In [312]:
#abs((get_w(1)@Jplus(get_x(1)) + get_w(0)@Jplus(get_x(0)))[0])

Check that
$$
DN_{b+1}(\cdot, \cdot) = DN_b(\cdot, \cdot) + \mathcal{O}(\delta)
$$

In [313]:
#norm(DNa(get_x(1)) - DNa(get_x(0)))

In [314]:
#Λ(1) - Λ(0)

At this point, it might be possible to approximate the whole sum (up to order $\delta$) using $B\Lambda(0)$

$$
\sum_{b=0}^{B-1} \Lambda_b = B\Lambda_0
$$

In [315]:
#sum([Λ(b) for b in range(B)]) - B*Λ(0)

And therefore we can then approximate
$$
\Delta = \gamma T \Lambda_0 + \mathcal{O}(\delta^2)
$$

In [496]:
abs(Δ - γ*(B*δ)*Λ(0))

0.12255434385028963

In [370]:
w0@DNa(x0)@wB

array([-0.05436017,  0.43205513])

In [377]:
N(get_x(1))@DNa(get_x(1))@get_w(1)@N(get_x(1))

array([[ 0.05930491,  0.23332837],
       [-0.01507349, -0.05930491]])

In [380]:
DNexplicit(get_x(1), get_w(1)@N(get_x(1)))@N(get_x(1))

array([[-0.01056784, -0.04157795],
       [ 0.00268602,  0.01056784]])

In [383]:
N(get_x(1))@DNexplicit(get_x(1), get_w(1))@N(get_x(1))

array([[2.19492018e-18, 1.04367431e-17],
       [9.82086275e-18, 4.42679497e-17]])

In [373]:
DNexplicit(x0, wB) @ w0

array([-0.05436017,  0.43205513])

In [356]:
((A@(outer(wB, x0) + outer(x0, wB))@A)/(x0@A@A@x0)) - 2*((wB@A@A@x0)/((x0@A@A@x0)**2))*A@outer(x0, x0)@A

array([[ 0.13164051,  0.27961329],
       [ 0.27961329, -0.13164051]])

array([1.12356085, 1.67476483])

In [345]:
(np.outer(Jplus(x0), hess_f(x0)).reshape(2, 2, 2) @ Tfunc(x0)) + ((np.outer(Jplus(x0), hess_f(x0)).reshape(2, 2, 2) @ Tfunc(x0))).T

array([[[ 0.08074039,  0.17149802],
        [-0.05416686, -0.11505404]],

       [[ 0.17149802, -0.08074039],
        [-0.11505404,  0.05416686]]])

In [None]:
(np.outer(Jplus(x0), hess_f(x0)).reshape(2, 2, 2) @ Tfunc(x0))

In [329]:
N(x0)@np.outer(Jplus(x0), hess_f(x0)).reshape(2, 2, 2) @ Tfunc(x0) + N(x0)@(np.outer(Jplus(x0), hess_f(x0)).reshape(2, 2, 2) @ Tfunc(x0)).T

array([[[-0.00769079, -0.01633576],
        [-0.03439138, -0.0730496 ]],

       [[-0.01633576,  0.00769079],
        [-0.0730496 ,  0.03439138]]])

In [337]:
np.outer(Jplus(x0), hess_f(x0)).reshape(2, 2, 2) @ (N(x0)@ Tfunc(x0))

array([[[ 2.22220723e-19, -5.75374674e-19],
        [-2.82912313e-18, -6.08295129e-18]],

       [[ 9.93717744e-19, -2.57293747e-18],
        [-1.26511597e-17, -2.72014984e-17]]])

In [334]:
(np.outer(N(x0)@Jplus(x0), hess_f(x0)).reshape(2, 2, 2) @ Tfunc(x0)).T

array([[[ 0.0403702 ,  0.18052583],
        [-0.02708343, -0.12111059]],

       [[-0.00902781, -0.0403702 ],
        [ 0.00605655,  0.02708343]]])

In [386]:
N(get_x('1o2'))@N(get_x(1))@N(get_x('1o2'))

array([[0.05483872, 0.22764634],
       [0.22764634, 0.94500483]])

In [389]:
N(get_x(0)) + (δ/2)*N(get_x(0))@DNexplicit(get_x(0), get_w(0))

array([[0.05105042, 0.21220942],
       [0.22828521, 0.94894958]])

In [391]:
N(get_x(0)) @ DNexplicit(get_x(0), get_w(0)) @ N(get_x(0))

array([[ 1.08385034e-18,  8.59775728e-19],
       [-2.60216149e-19, -1.35934175e-17]])

In [395]:
N(get_x(0))@DNexplicit(get_x(0), N(get_x('1o2'))@get_w(0))@N(get_x(0))

array([[1.41702195e-18, 5.87227221e-18],
       [5.87205805e-18, 2.49577964e-17]])

In [397]:
DNexplicit(get_x(0), get_w(0))@N(get_x(0))

array([[ 0.06847487,  0.30620317],
       [-0.01531273, -0.06847487]])

In [403]:
N(get_x(0))@DNexplicit(get_x(0), get_w(0))

array([[ 0.06847487, -0.01531273],
       [ 0.30620317, -0.06847487]])

In [407]:
DNexplicit(get_x(0), get_w(0))@N(get_x(0)) + N(get_x(0))@DNexplicit(get_x(0), get_w(0))

array([[ 0.13694973,  0.29089044],
       [ 0.29089044, -0.13694973]])

In [406]:
DNexplicit(get_x(0), get_w(0))

array([[ 0.13694973,  0.29089044],
       [ 0.29089044, -0.13694973]])

# Checking everything

In [646]:
def check_if_at_least_same_order(number, reference=δ**2) -> int:
    base10_number = log10(abs(number))
    base10_reference      = log10(abs(reference))
    return abs(floor(base10_number)) >= abs(floor(base10_reference))

In [647]:
def Λ2(b):
    wb = get_w(b)
    xb = get_x(b)
    xbp1o2 = get_x('{}o2'.format(2*b+1))
    return -wb@DNexplicit(xb, (I - 3*N(xbp1o2))@wb)@wb

In [648]:
b     = 3
bm1o2 = '{}o2'.format(2*b - 1)

$$
    N_b = N_{b-1} + \delta DN_{b-1}(I - N_{b-1/2})w_{b-1}
$$

In [649]:
check_if_at_least_same_order(
    norm(
        N(get_x(b)) - (N(get_x(b-1)) + δ*DNexplicit(get_x(b-1), (I - N(get_x(bm1o2)))@get_w(b-1)))
    )
)

True

$$
N_{b-1/2} = N_{b-1} + \frac{\delta}{2} DN_{b-1} w_{b-1} + \mathcal{O}(\delta^2)
$$

In [650]:
check_if_at_least_same_order(
    norm(
        N(get_x(bm1o2)) - (N(get_x(b-1)) + δ*DNexplicit(get_x(b-1), get_w(b-1)))
    )
)

True

$$
N_b N_{b-1/2} = N_{b-1} + \frac{\delta}{2}N_{b-1}DN_{b-1} w_{b-1} + \delta DN_{b-1} T_{b-1/2} w_{b-1} N_{b-1} + \mathcal{O}(\delta^2)
$$

In [651]:
check_if_at_least_same_order(
    norm(
        N(get_x(b))@N(get_x(bm1o2)) - (N(get_x(b-1)) + (δ/2)*N(get_x(b-1))@DNexplicit(get_x(b-1), get_w(b-1)) + δ*DNexplicit(get_x(b-1), Tfunc(get_x(bm1o2))@get_w(b-1)))
    )
)

True

$$
N_{b-1/2} N_b = N_{b-1} + \delta N_{b-1} DN_{b-1} T_{b-1/2} w_{b-1} + \frac{\delta}{2} DN_{b-1} w_{b-1} N_{b-1} + \mathcal{O}(\delta^2)
$$

In [652]:
check_if_at_least_same_order(
    norm(
        N(get_x(bm1o2))@N(get_x(b)) - (N(get_x(b-1)) + δ*N(get_x(b-1))@DNexplicit(get_x(b-1), Tfunc(get_x(bm1o2))@get_w(b-1)) + (δ/2)*DNexplicit(get_x(b-1), get_w(b-1))@N(get_x(b-1)))
    )
)

True

$$
N_{b-1/2}N_b N_{b-1/2} = N_{b-1} + \delta DN_{b-1} w_{b-1} + \mathcal{O}(\delta^2)
$$

In [653]:
check_if_at_least_same_order(
    norm(
        N(get_x(bm1o2))@N(get_x(b))@N(get_x(bm1o2)) - (N(get_x(b-1)) + δ*DNexplicit(get_x(b-1), get_w(b-1)))
    )
)

True

$$
\Gamma_b = N_{b-1} - \delta DN_{b-1} (I - 3N_{b-1/2})w_{b-1} + \mathcal{O}(\delta^2)
$$

In [654]:
check_if_at_least_same_order(
    norm(
        Γ(b) - (N(get_x(b-1)) - δ*DNexplicit(get_x(b-1), (I - 3*N(get_x(bm1o2)))@get_w(b-1)))
    )
)

True

$$
\Delta = -\alpha(2-\alpha) v_0^\top N_0 v_0 + \gamma w_{B-1}^\top \left(N_{B-1} - \delta DN_{B-1}(I-3N_{B-1/2})w_{B-1} \right) w_{B-1} + \mathcal{O}(\delta^2)
$$

In [655]:
check_if_at_least_same_order(
    norm(
        Δ - (-α*(2-α)*v0@N0@v0 + γ*get_w(B-1)@N(get_x(B-1))@get_w(B-1) + γ*δ*Λ2(B-1))
    )
)

True

$$
\begin{align}
\Delta = &-\alpha(2-\alpha)\|v_0^\perp\|^2 \\
         &+ \gamma w_{B-2}^\top N_{B-2} w_{B-2} \\
         &- \gamma\delta w_{B-2}^\top DN_{B-2}(I - 3N_{B-3/2})w_{B-2}w_{B-2}\\
         &- \gamma\delta w_{B-1}^\top DN_{B-1}(I - 3N_{B-1/2})w_{B-1}w_{B-1} + \mathcal{O}(\delta^2)
\end{align}
$$

In [656]:
check_if_at_least_same_order(
    norm(
        Δ - (-α*(2-α)*v0@N0@v0 + γ*get_w(B-2)@N(get_x(B-2))@get_w(B-2) + γ*δ*Λ2(B-2) + γ*δ*Λ2(B-1))
    )
)

True

$$
\begin{align}
\Delta = &-\alpha(2-\alpha)\|v_0^\perp\|^2 \\
         &+ \gamma w_{B-3}^\top N_{B-3} w_{B-3} \\
         &- \gamma\delta w_{B-3}^\top DN_{B-3}(I - 3N_{B-5/2})w_{B-3}w_{B-3} \\
         &- \gamma\delta w_{B-2}^\top DN_{B-2}(I - 3N_{B-3/2})w_{B-2}w_{B-2}\\
         &- \gamma\delta w_{B-1}^\top DN_{B-1}(I - 3N_{B-1/2})w_{B-1}w_{B-1} + \mathcal{O}(\delta^2)
\end{align}
$$

In [657]:
check_if_at_least_same_order(
    norm(
        Δ - (-α*(2-α)*v0@N0@v0 + γ*get_w(B-3)@N(get_x(B-3))@get_w(B-3) + γ*δ*Λ2(B-3) + γ*δ*Λ2(B-2) + γ*δ*Λ2(B-1))
    )
)

True

$$
\begin{align}
\Delta = &-\alpha(2-\alpha)\|v_0^\perp\|^2 \\
         &+ \gamma w_{B-4}^\top N_{B-4} w_{B-4} \\
         &- \gamma\delta w_{B-4}^\top DN_{B-4}(I - 3N_{B-7/2})w_{B-4}w_{B-4} \\ 
         &- \gamma\delta w_{B-3}^\top DN_{B-3}(I - 3N_{B-5/2})w_{B-3}w_{B-3} \\
         &- \gamma\delta w_{B-2}^\top DN_{B-2}(I - 3N_{B-3/2})w_{B-2}w_{B-2}\\
         &- \gamma\delta w_{B-1}^\top DN_{B-1}(I - 3N_{B-1/2})w_{B-1}w_{B-1} + \mathcal{O}(\delta^2)
\end{align}
$$

In [658]:
check_if_at_least_same_order(
    norm(
        Δ - (-α*(2-α)*v0@N0@v0 + γ*get_w(B-4)@N(get_x(B-4))@get_w(B-4) + γ*δ*Λ2(B-4) + γ*δ*Λ2(B-3) + γ*δ*Λ2(B-2) + γ*δ*Λ2(B-1))
    )
)

True

$$
\Delta = -\gamma\delta \sum_{b=0}^{B-1}w_b^\top DN_{b}(I - 3N_{b+1/2})w_b w_b + \mathcal{O}(\delta^2)
$$

In [683]:
check_if_at_least_same_order(
    norm(
        Δ - (-γ*δ*sum([Λ2(b) for b in range(B)]))
    )
)

True

In [684]:
check_if_at_least_same_order(Δ)

True

$$
\Delta = 4\gamma\delta \sum_{b=0}^{B/2 - 1}w_{2b}^\top DN_{2b} N_{2b+1/2}w_{2b}w_{2b} + \mathcal{O}(\delta^2)
$$

In [687]:
if B % 2 == 0:
    print(check_if_at_least_same_order(
        norm(
            Δ - (-4*γ*δ*sum([get_w(b)@DNexplicit(get_x(b), N(get_x('{}o2'.format(2*b+1)))@get_w(b))@get_w(b) for b in range(0, B, 2)]))
        )
    ))

True


In [702]:
DNexplicit(get_x(0), N(get_x(0))@v0)

array([[ 0.03220269,  0.03554794],
       [ 0.03554794, -0.03220269]])

In [705]:
DNa(get_x(0))@N(get_x(0))@v0

array([[ 0.03220269,  0.03554794],
       [ 0.03554794, -0.03220269]])

In [701]:
DNexplicit(get_x(0), get_w(0))@get_w(0)

array([-0.0661564, -0.095169 ])

$$
w_{b-1}^\top\Gamma_b w_{b-1} = w_{b-1}^\top N_{b-1} w_{b-1} - \delta w_{b-1}^\top DN_{b-1}(I-3 N_{b-1/2})w_{b-1} w_{b-1} + \mathcal{O}(\delta^2)
$$

In [538]:
get_w(b-1)@DNexplicit(get_x(b-1), (I-3*N(get_x(bm1o2)))@get_w(b-1))@get_w(b-1)

-0.006126008799801175

In [539]:
check_if_at_least_same_order(
    norm(
        get_w(b-1)@Γ(b)@get_w(b-1) - (get_w(b-1)@N(get_x(b-1))@get_w(b-1) - δ*get_w(b-1)@DNexplicit(get_x(b-1), (I-3*N(get_x(bm1o2)))@get_w(b-1))@get_w(b-1))
    )
)

True

$$
\Delta = \gamma\delta\sum_{b=0}^{B-1}\Lambda_b + \mathcal{O}(\delta^2)
$$
where the new terms $\Lambda_b$ are given by
$$
\Lambda_b = - w_b^\top DN_b (I-3N_{b+1/2})w_b w_b
$$

In [517]:
Δ

0.007372304101414429

In [543]:
γ*δ*sum([-get_w(b)@DNexplicit(get_x(b), (I-3*N(get_x('{}o2'.format(2*b+1))))@get_w(b))@get_w(b) for b in range(B)])

0.01749910578667169

In [521]:
γ*δ*sum([Λ2(b) for b in range(B)])

0.01749910578667169

In [525]:
γ*get_w(B)@N(get_x(B))@get_w(B)

0.3193231808074642

In [532]:
γ*get_w(0)@N(get_x(0))@get_w(0) + γ*δ*sum([Λ2(b) for b in range(B)])

0.3294499824927218

In [530]:
γ*get_w(B)@N(get_x(B))@get_w(B) - γ*get_w(0)@N(get_x(0))@get_w(0)

0.0073723041014140955

In [533]:
γ*get_w(B)@N(get_x(B))@get_w(B) - (γ*get_w(0)@N(get_x(0))@get_w(0) + γ*δ*sum([Λ2(b) for b in range(B)]))

-0.010126801685257592