In [1]:
using SymPy

In [33]:
function _sym(x::Expr)
    y = x
    y.args .= _sym.(y.args)
    return  y
end

_sym(x::Number) = Sym(x)
_sym(x) = x

macro sym(x)
    return esc(_sym(x))
end

@sym (macro with 1 method)

In [16]:
struct LagrangeBasis{T,N}
    x::Vector{T}

    denom::Vector{T}
    diffs::Matrix{T}

    function LagrangeBasis{T,N}(x) where {T,N}
        @assert length(x) == N

        local p::T

        denom = zeros(T,N)
        diffs = zeros(T,N,N)

        for i in eachindex(x)
            p = 1
            for j in eachindex(x)
                diffs[i,j] = x[i] - x[j]
                if i ≠ j
                    p *= diffs[i,j]
                end
            end
            denom[i] = 1/p
        end

        new(x, denom, diffs)
    end
end

function LagrangeBasis(x::Vector{T}) where {T}
    LagrangeBasis{T, length(x)}(x)
end

nbasis(b::LagrangeBasis{T,N}) where {T,N} = N
nnodes(b::LagrangeBasis{T,N}) where {T,N} = N
degree(b::LagrangeBasis{T,N}) where {T,N} = N-1

function eval_basis(b::LagrangeBasis{T,N}, j::Int, x::T) where {T,N}
    local y::T = 1

    for i in 1:nnodes(b)
        i ≠ j ? y *= (x - b.x[i]) : nothing
    end
    y * b.denom[j]
end

function deriv_basis(b::LagrangeBasis{T,N}, j::Int, x::T) where {T,N}
    local y::T = 0
    local z::T

    for l in 1:nnodes(b)
        if l ≠ j
            z = 1 / b.diffs[j,l]
            for i in 1:nnodes(b)
                i ≠ j && i ≠ l ? z *= (x - b.x[i]) / b.diffs[j,i] : nothing
            end
            y += z
        end
    end
    y
end

deriv_basis(b::LagrangeBasis, j::Int, i::Int) = deriv_basis(b, j, b.x[i])

deriv_basis (generic function with 2 methods)

In [40]:
function get_lobatto_nodes(s)
    if s == 2
        c = @sym [0, 1]
    elseif s == 3
        c = @sym [0, 1//2, 1]
    elseif s == 4
        c = @sym [0,  (5-√5)/10,   (5+√5)/10,  1]
    elseif s == 5
        c = @sym [0,  1//2-√21/14,  1//2,  1//2+√21/14,  1]
    else
        @error "Lobatto nodes for " * string(s) * " stages not implemented."
    end

    return c
end

get_lobatto_nodes (generic function with 1 method)

In [35]:
function get_deriv_matrix(s)
    q = (get_lobatto_nodes(s))
    l = LagrangeBasis(q)
    simplify.([deriv_basis(l, j, i) for i in 1:s, j in 1:s])
end

get_deriv_matrix (generic function with 1 method)

In [36]:
get_deriv_matrix(2)

2×2 Array{Sym,2}:
 -1  1
 -1  1

In [37]:
get_deriv_matrix(3)

3×3 Array{Sym,2}:
 -3   4  -1
 -1   0   1
  1  -4   3

In [38]:
get_deriv_matrix(4)

4×4 Array{Sym,2}:
               -6   5/2 + 5*sqrt(5)/2  -5*sqrt(5)/2 + 5/2                 1
 -sqrt(5)/2 - 1/2                   0             √5  -sqrt(5)/2 + 1/2
 -1/2 + sqrt(5)/2            -√5                   0   1/2 + sqrt(5)/2
               -1  -5/2 + 5*sqrt(5)/2  -5*sqrt(5)/2 - 5/2                 6

In [41]:
get_deriv_matrix(5)

5×5 Array{Sym,2}:
                  -10   7*sqrt(21)/6 + 49/6  …                    -1
 -3/2 - 3*sqrt(21)/14                     0     -3*sqrt(21)/14 + 3/2
                  3/4        -7*sqrt(21)/12                     -3/4
 -3/2 + 3*sqrt(21)/14            sqrt(21)/3      3*sqrt(21)/14 + 3/2
                    1  -49/6 + 7*sqrt(21)/6                       10

In [104]:
function get_velocities(s)
    a = get_deriv_matrix(s)
    q = [symbols("Q$i", real=true) for i in 1:s]
    v = simplify.(a * q)
end

get_velocities (generic function with 1 method)

In [59]:
c = symbols("c")

c

In [54]:
v2 = get_velocities(2)

2-element Array{Sym,1}:
 -Q₁ + Q₂
 -Q₁ + Q₂

In [55]:
d2 = @sym [+1, -1]
v2' * d2

0

In [56]:
v3 = get_velocities(3)

3-element Array{Sym,1}:
 -3⋅Q₁ + 4⋅Q₂ - Q₃
          -Q₁ + Q₃
  Q₁ - 4⋅Q₂ + 3⋅Q₃

In [60]:
d3 = @sym [+1, c, +1]
solve(v3' * d3, c)

1-element Array{Sym,1}:
 -2

In [48]:
v4 = get_velocities(4)

4-element Array{Sym,1}:
  -6*Q1 + Q2*(5/2 + 5*sqrt(5)/2) + Q3*(-5*sqrt(5)/2 + 5/2) + Q4
     Q1*(-sqrt(5)/2 - 1/2) + sqrt(5)*Q3 + Q4*(-sqrt(5)/2 + 1/2)
      Q1*(-1/2 + sqrt(5)/2) - sqrt(5)*Q2 + Q4*(1/2 + sqrt(5)/2)
 -Q1 + Q2*(-5/2 + 5*sqrt(5)/2) + Q3*(-5*sqrt(5)/2 - 5/2) + 6*Q4

In [61]:
d4 = @sym [+1, -c, +c, -1]
solve(v4' * d4, c)

1-element Array{Sym,1}:
 √5

In [121]:
v5 = get_velocities(5)

5-element Array{Sym,1}:
                -10*Q1 + 7*Q2*(sqrt(21) + 7)/6 - 16*Q3/3 + 7*Q4*(-sqrt(21) + 7)/6 - Q5
 -3*Q1*(sqrt(21) + 7)/14 + 16*sqrt(21)*Q3/21 - sqrt(21)*Q4/3 + 3*Q5*(-sqrt(21) + 7)/14
                                 3*Q1/4 - 7*sqrt(21)*Q2/12 + 7*sqrt(21)*Q4/12 - 3*Q5/4
 -3*Q1*(-sqrt(21) + 7)/14 + sqrt(21)*Q2/3 - 16*sqrt(21)*Q3/21 + 3*Q5*(sqrt(21) + 7)/14
                 Q1 - 7*Q2*(-sqrt(21) + 7)/6 + 16*Q3/3 - 7*Q4*(sqrt(21) + 7)/6 + 10*Q5

In [157]:
s = 5
Q5 = [symbols("Q$i", real=true) for i in 1:s]
V5 = [symbols("V$i", real=true) for i in 1:s]
v5 = simplify.(get_deriv_matrix(s) * Q5)
equs = v5 .- V5

5-element Array{Sym,1}:
                -10*Q1 + 7*Q2*(sqrt(21) + 7)/6 - 16*Q3/3 + 7*Q4*(-sqrt(21) + 7)/6 - Q5 - V1
 -3*Q1*(sqrt(21) + 7)/14 + 16*sqrt(21)*Q3/21 - sqrt(21)*Q4/3 + 3*Q5*(-sqrt(21) + 7)/14 - V2
                                 3*Q1/4 - 7*sqrt(21)*Q2/12 + 7*sqrt(21)*Q4/12 - 3*Q5/4 - V3
 -3*Q1*(-sqrt(21) + 7)/14 + sqrt(21)*Q2/3 - 16*sqrt(21)*Q3/21 + 3*Q5*(sqrt(21) + 7)/14 - V4
                 Q1 - 7*Q2*(-sqrt(21) + 7)/6 + 16*Q3/3 - 7*Q4*(sqrt(21) + 7)/6 + 10*Q5 - V5

In [158]:
Q51 = solve(equs[1], Q5[1])[1]

7⋅√21⋅Q₂   49⋅Q₂   8⋅Q₃   7⋅√21⋅Q₄   49⋅Q₄   Q₅   V₁
──────── + ───── - ──── - ──────── + ───── - ── - ──
   60        60     15       60        60    10   10

In [159]:
equs = [simplify(subs(equs[i], Q5[1], Q51)) for i in 1:s]

5-element Array{Sym,1}:
                                                                                                                                                     0
 -7*Q2/4 - 7*sqrt(21)*Q2/20 + 4*Q3/5 + 92*sqrt(21)*Q3/105 - sqrt(21)*Q4/3 - 7*Q4/10 - 27*sqrt(21)*Q5/140 + 33*Q5/20 + 3*sqrt(21)*V1/140 + 3*V1/20 - V2
                                                   -119*sqrt(21)*Q2/240 + 49*Q2/80 - 2*Q3/5 + 49*Q4/80 + 119*sqrt(21)*Q4/240 - 33*Q5/40 - 3*V1/40 - V3
 -7*Q2/10 + sqrt(21)*Q2/3 - 92*sqrt(21)*Q3/105 + 4*Q3/5 - 7*Q4/4 + 7*sqrt(21)*Q4/20 + 27*sqrt(21)*Q5/140 + 33*Q5/20 - 3*sqrt(21)*V1/140 + 3*V1/20 - V4
                                                      -147*Q2/20 + 77*sqrt(21)*Q2/60 + 24*Q3/5 - 147*Q4/20 - 77*sqrt(21)*Q4/60 + 99*Q5/10 - V1/10 - V5

In [160]:
Q55 = solve(equs[5], Q5[5])[1]

  7⋅√21⋅Q₂   49⋅Q₂   16⋅Q₃   7⋅√21⋅Q₄   49⋅Q₄   V₁   10⋅V₅
- ──────── + ───── - ───── + ──────── + ───── + ── + ─────
     54        66      33       54        66    99     99 

In [161]:
equs = [simplify(subs(equs[i], Q5[5], Q55)) for i in 1:s]

5-element Array{Sym,1}:
                                                                                                                     0
 -70*sqrt(21)*Q2/99 + 32*sqrt(21)*Q3/33 - 26*sqrt(21)*Q4/99 + 3*sqrt(21)*V1/154 + V1/6 - V2 - 3*sqrt(21)*V5/154 + V5/6
                                                             -7*sqrt(21)*Q2/18 + 7*sqrt(21)*Q4/18 - V1/12 - V3 - V5/12
  26*sqrt(21)*Q2/99 - 32*sqrt(21)*Q3/33 + 70*sqrt(21)*Q4/99 - 3*sqrt(21)*V1/154 + V1/6 - V4 + 3*sqrt(21)*V5/154 + V5/6
                                                                                                                     0

In [162]:
Q52 = solve(equs[2], Q5[2])[1]

48⋅Q₃   13⋅Q₄   27⋅V₁   11⋅√21⋅V₁   33⋅√21⋅V₂   27⋅V₅   11⋅√21⋅V₅
───── - ───── + ───── + ───────── - ───────── - ───── + ─────────
  35      35     980       980         490       980       980   

In [163]:
equs = [simplify(subs(equs[i], Q5[2], Q52)) for i in 1:s]

5-element Array{Sym,1}:
                                                                                                                    0
                                                                                                                    0
     -8*sqrt(21)*Q3/15 + 8*sqrt(21)*Q4/15 - 7*V1/40 - 3*sqrt(21)*V1/280 + 11*V2/20 - V3 - 7*V5/40 + 3*sqrt(21)*V5/280
 -64*sqrt(21)*Q3/105 + 64*sqrt(21)*Q4/105 - 3*sqrt(21)*V1/245 + 8*V1/35 - 13*V2/35 - V4 + 3*sqrt(21)*V5/245 + 8*V5/35
                                                                                                                    0

In [164]:
Q54 = solve(equs[4], Q5[4])[1]

     √21⋅V₁   9⋅V₁   13⋅√21⋅V₂   5⋅√21⋅V₄   √21⋅V₅   9⋅V₅
Q₃ - ────── + ──── + ───────── + ──────── - ────── - ────
       56     448       448         64        56     448 

In [165]:
equs = [simplify(subs(equs[i], Q5[4], Q54)) for i in 1:s]

5-element Array{Sym,1}:
                                       0
                                       0
 -3*V1/8 + 7*V2/8 - V3 + 7*V4/8 - 3*V5/8
                                       0
                                       0

In [166]:
equs[3]*8

-3⋅V₁ + 7⋅V₂ - 8⋅V₃ + 7⋅V₄ - 3⋅V₅

In [167]:
d5 = @sym [+3, c, +8, c, +3]
solve(v5' * d5, c)

1-element Array{Sym,1}:
 -7