In [1]:
using SymPy

In [2]:
function _sym(x::Expr)
    y = x
    y.args .= _sym.(y.args)
    return  y
end

_sym(x::Number) = Sym(x)
_sym(x) = x

macro sym(x)
    return esc(_sym(x))
end

@sym (macro with 1 method)

In [3]:
function array_to_code(name, x::Matrix)
    s = string(name) * " = [\n"
    
    for i in axes(x,1)
        s *= "["
        for j in axes(x,2)
            s *= " "
            s *= string(x[i,j])
            s *= " "
        end
        s *= "]\n"
    end
    
    s *= "]\n"
end

array_to_code (generic function with 1 method)

In [4]:
function get_glrk_nodes(s)
    if s == 1
        c = @sym Vector([1//2])
    elseif s == 2
        c = @sym [1/2-√3/6, 1/2+√3/6]
    elseif s == 3
        c = @sym [1/2-√15/10,  1/2,        1/2+√15/10 ]
    elseif s == 4
        c = @sym [
          1/2-√(3/7+2*√30/35)/2,
          1/2-√(3/7-2*√30/35)/2,
          1/2+√(3/7-2*√30/35)/2,
          1/2+√(3/7+2*√30/35)/2
        ]
    elseif s == 5
        c = @sym [
          1/2-√(5/9+2*√70/63)/2,
          1/2-√(5/9-2*√70/63)/2,
          1/2,
          1/2+√(5/9-2*√70/63)/2,
          1/2+√(5/9+2*√70/63)/2        
        ]
    else
        @error "GLRK nodes for " * string(s) * " stages not implemented."
    end

    return c
end

get_glrk_nodes (generic function with 1 method)

In [5]:
function get_lobatto_nodes(s)
    if s == 2
        c = @sym [0, 1]
    elseif s == 3
        c = @sym [0, 1//2, 1]
    elseif s == 4
        c = @sym [0,  (5-√5)/10,   (5+√5)/10,  1]
    elseif s == 5
        c = @sym [0,  1//2-√21/14,  1//2,  1//2+√21/14,  1]
    else
        @error "Lobatto nodes for " * string(s) * " stages not implemented."
    end

    return c
end

get_lobatto_nodes (generic function with 1 method)

In [6]:
function get_glrk_lobatto_tableau(s, σ=2)
    c = get_glrk_nodes(s)
    c̄ = get_lobatto_nodes(σ)
    a = [Sym("a_" * string(i) * string(j)) for i in 1:σ, j in 1:s]
    
    sol = Dict()

    for i in 1:σ
        eqs = []
        for k in 1:s
            eq = - c̄[i]^k / k
            for j in 1:s
                eq += a[i,j] * c[j]^(k-1)
            end
            push!(eqs, eq)
        end
        sol = merge(sol, solve(Sym.(eqs), a[i,:]))
    end
    
    A = [sol[a[i,j]] for i in axes(a,1), j in axes(a,2)]
end

get_glrk_lobatto_tableau (generic function with 2 methods)

In [7]:
get_glrk_lobatto_tableau(1)

2×1 Array{Sym,2}:
 0
 1

In [8]:
get_glrk_lobatto_tableau(2)

2×2 Array{Sym,2}:
   0    0
 1/2  1/2

In [9]:
get_glrk_lobatto_tableau(3)

2×3 Array{Sym,2}:
    0    0     0
 5/18  4/9  5/18

In [10]:
function get_glrk_lobatto_tableau_manual(s, σ=2)
    c = get_glrk_nodes(s)
    c̄ = get_lobatto_nodes(σ)
    a = [Sym("a_" * string(i) * string(j)) for i in 1:σ, j in 1:s]
    A = [Sym(0) for i in 1:σ, j in 1:s]
    
    sol = Dict()

    for i in 2:σ
        eqs = Sym[]
        for k in 1:s
            eq = - c̄[i]^k / k
            for j in 1:s
                eq += a[i,j] * c[j]^(k-1)
            end
            push!(eqs, eq)
        end

        for k in 1:s
            A[i,k] = solve(eqs[k], a[i,k])[1]
            for l in k+1:s
                eqs[l] = simplify(subs(eqs[l], a[i,k], A[i,k]))
            end
        end
        
        for k in s-1:-1:1
            for l in k+1:s
                A[i,k] = subs(A[i,k], a[i,l], A[i,l])
            end
        end
    end
    
    return simplify.(A)
end

get_glrk_lobatto_tableau_manual (generic function with 2 methods)

In [11]:
a2 = get_glrk_lobatto_tableau_manual(2)

2×2 Array{Sym,2}:
   0    0
 1/2  1/2

In [12]:
a3 = get_glrk_lobatto_tableau_manual(3)

2×3 Array{Sym,2}:
    0    0     0
 5/18  4/9  5/18

In [13]:
a4 = get_glrk_lobatto_tableau_manual(4)

2×4 Array{Sym,2}:
                  0  …                   0
 -sqrt(30)/72 + 1/4     -sqrt(30)/72 + 1/4

In [14]:
a4[2,2] == a4[2,3]

false

In [15]:
N(a4[2,2]) ≈ N(a4[2,3])

true

In [16]:
a4[2,2] = a4[2,3]

√30   1
─── + ─
 72   4

In [17]:
a4

2×4 Array{Sym,2}:
                  0                  0                  0                   0
 -sqrt(30)/72 + 1/4  sqrt(30)/72 + 1/4  sqrt(30)/72 + 1/4  -sqrt(30)/72 + 1/4

In [18]:
a5 = get_glrk_lobatto_tableau_manual(5)

2×5 Array{Sym,2}:
                                                                                                                 0  …                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    0
 (-5797*sqrt(70) - 19635*sqrt(6) + 3297*sqrt(105) + 34069)/(900*(-31*sqrt(70) - 105*sqrt(6) + 12*sqrt(105) + 124))     (-1309*sqrt(30*sqrt(70) + 525)/180 - 187*sqrt(14*sqrt(70) + 245)/20 + 189*sqrt(-20*sqrt(70) + 350)/200 + 21*sqrt(-84*sqrt(70) + 1470)/40 + 27*sqrt(-14*sqrt(70) + 245)/4 + 21*sqrt(-

In [19]:
a5[2,1] == a5[2,5]

false

In [20]:
N(a5[2,1]) ≈ N(a5[2,5])

true

In [21]:
a5[2,5] = a5[2,1]

-5797⋅√70 - 19635⋅√6 + 3297⋅√105 + 34069
────────────────────────────────────────
 900⋅(-31⋅√70 - 105⋅√6 + 12⋅√105 + 124) 

In [24]:
a5

2×5 Array{Sym,2}:
                                                                                                                 0  …                                                                                                                  0
 (-5797*sqrt(70) - 19635*sqrt(6) + 3297*sqrt(105) + 34069)/(900*(-31*sqrt(70) - 105*sqrt(6) + 12*sqrt(105) + 124))     (-5797*sqrt(70) - 19635*sqrt(6) + 3297*sqrt(105) + 34069)/(900*(-31*sqrt(70) - 105*sqrt(6) + 12*sqrt(105) + 124))

In [22]:
array_to_code(:x, a5)

"x = [\n[ 0  0  0  0  0 ]\n[ (-5797*sqrt(70) - 19635*sqrt(6) + 3297*sqrt(105) + 34069)/(900*(-31*sqrt(70) - 105*sqrt(6) + 12*sqrt(105) + 124))  3*(-155*sqrt(70) - 525*sqrt(6) + 21*sqrt(105) + 217)/(100*(-31*sqrt(70) - 105*sqrt(6) + 12*sqrt(105) + 124))  64/225  3*(-155*sqrt(70) - 525*sqrt(6) + 21*sqrt(105) + 217)/(100*(-31*sqrt(70) - 105*sqrt(6) + 12*sqrt(105) + 124))  (-5797*sqrt(70) - 19635*sqrt(6) + 3297*sqrt(105) + 34069)/(900*(-31*sqrt(70) - 105*sqrt(6) + 12*sqrt(105) + 124)) ]\n]\n"

In [23]:
array_to_code(:x, N.(a5))

"x = [\n[ 0  0  0  0  0 ]\n[ 0.1184634425280945437571320203599586813216300011062070077914139441108586442015225  0.2393143352496832340206457574178190964561477766715707699863638336669191335762568  64//225  0.2393143352496832340206457574178190964561477766715707699863638336669191335762568  0.1184634425280945437571320203599586813216300011062070077914139441108586442015225 ]\n]\n"