# Zygote gives wrong derivatives when using redegree function to extend the degree of TPS struc.
# The results prove Zygote has some problems. But it is difficult to locate the source of the error.
# Here are some tests of the redegree function

In [7]:
using Zygote
include("../src/JuTPSA.jl")
using .JuTPSA

# a simplified redegree function
function f(a::T) where T<:Number
    # create a variable = a + x1, the polymap is [a, 1] for [constant, x1]
    ctps = CTPS(a, 1, 1, 2)

    # extend the variable to second degree
    degree = ctps.degree + 1
    Max_TPS_Degree = 2
    TPS_Dim = 1
    degree = min(degree+1, Max_TPS_Degree)
    terms = binomial(TPS_Dim + degree, degree)

    # new_map = zeros(T, terms)
    # new_map[1:ctps.terms] = ctps.map  # mutating array
    new_map = [i <= length(ctps.map) ? ctps.map[i] : zero(T) for i in 1:terms] # extend the map, this approach avoids mutating the array

    # the output is a polymap with [a, 1, 0] for [constant, x1, x1^2]
    return new_map
end

a = 3.0
result = f(a)
grad = jacobian(f, a)
println("result of f(a): $result")
println("derivative with respect to a: $grad")

result of f(a): [3.0, 1.0, 0.0]
derivative with respect to a: ([2.0, 0.0, 0.0],)




In [11]:
function f(a::T) where T<:Number
    # directly create the map [a, 1] for [constant, x1], without using CTPS struct
    ctps = CTPS(a, 1, 1, 2)
    map = [a, 1] # the same as ctps.map
    println(ctps.map)
    println(map)
    # extend the variable to second degree
    degree = ctps.degree + 1
    Max_TPS_Degree = 2
    TPS_Dim = 1
    degree = min(degree+1, Max_TPS_Degree)
    terms = binomial(TPS_Dim + degree, degree)

    # new_map = [i <= length(ctps.map) ? ctps.map[i] : zero(T) for i in 1:terms] # extend the map, this approach avoids mutating the array
    new_map = [i <= length(ctps.map) ? map[i] : zero(T) for i in 1:terms]
    
    # the output is a polymap with [a, 1, 0] for [constant, x1, x1^2]
    return new_map
end

a = 3.0
result = f(a)
grad = jacobian(f, a)
println("result of f(a): $result")
println("derivative with respect to a: $grad")

[3.0, 1.0]
[3.0, 1.0]
[3.0, 1.0]
[3.0, 1.0]


result of f(a): [3.0, 1.0, 0.0]
derivative with respect to a: ([1.0, 0.0, 0.0],)


In [9]:
function f(a::T) where T<:Number
    # create a variable = a + x1, the polymap is [a, 1] for [constant, x1]
    ctps = CTPS(a, 1, 1, 2)

    # extend the variable to second degree
    degree = ctps.degree + 1
    Max_TPS_Degree = 2
    TPS_Dim = 1
    degree = min(degree+1, Max_TPS_Degree)
    terms = binomial(TPS_Dim + degree, degree)

    # # extend the map, using Zygote.Buffer that allows mutation of arrays
    new_map = zeros(T, terms)
    new_map_buffer = Zygote.Buffer(new_map)
    for i in 1:ctps.terms
        new_map_buffer[i] = ctps.map[i]
    end
    for i in ctps.terms+1:terms
        new_map_buffer[i] = zero(T)
    end
    new_map = copy(new_map_buffer)
    
    return new_map
end

a = 3.0
grad = jacobian(f, a)
println("result of f(a): $result")
println("derivative with respect to a: $grad")

result of f(a): [3.0, 1.0, 0.0]
derivative with respect to a: ([1.0, 0.0, 0.0],)


In [1]:
# Zygote has problems dealing with mutable struct

using Zygote

mutable struct S1
    map::Vector{Float64}
end

function S1(a)
    return S1([a, 1])
end

function f1(a)
    s1 = S1(a)
    terms = length(s1.map) + 1
    new_map = [i <= length(s1.map) ? s1.map[i] : 0.0 for i in 1:terms]
    return new_map
end

a = 3.0
grad = jacobian(f1, a)
# println("result of f(a): $result")
println("derivative with respect to a: $grad")

derivative with respect to a: ([2.0, 0.0, 0.0],)


In [1]:
include("../src/JuTPSA.jl")
using .JuTPSA
using Zygote
ctps1 = CTPS(1.0, 1, 3, 3)
ctps2 = CTPS(2.0, 2, 3, 3)
println("ctps1: $ctps1")
println("ctps2: $ctps2")

function f(k)
    ctps1 = CTPS(1.0, 1, 3, 3)
    ctps2 = CTPS(2.0, 2, 3, 3)
    ctps3 = k*ctps1 * ctps2
    return ctps3.map
end

grad = jacobian(f, 3.0)
println("derivative with respect to a: $grad")

ctps1: CTPS{Float64, 3, 3}(1, 4, [1.0, 1.0, 0.0, 0.0], Base.RefValue{PolyMap}(PolyMap(3, 3, [[0, 0, 0, 0], [1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1], [2, 2, 0, 0], [2, 1, 1, 0], [2, 1, 0, 1], [2, 0, 2, 0], [2, 0, 1, 1], [2, 0, 0, 2], [3, 3, 0, 0], [3, 2, 1, 0], [3, 2, 0, 1], [3, 1, 2, 0], [3, 1, 1, 1], [3, 1, 0, 2], [3, 0, 3, 0], [3, 0, 2, 1], [3, 0, 1, 2], [3, 0, 0, 3]])))
ctps2: CTPS{Float64, 3, 3}(1, 4, [2.0, 0.0, 1.0, 0.0], Base.RefValue{PolyMap}(PolyMap(3, 3, [[0, 0, 0, 0], [1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1], [2, 2, 0, 0], [2, 1, 1, 0], [2, 1, 0, 1], [2, 0, 2, 0], [2, 0, 1, 1], [2, 0, 0, 2], [3, 3, 0, 0], [3, 2, 1, 0], [3, 2, 0, 1], [3, 1, 2, 0], [3, 1, 1, 1], [3, 1, 0, 2], [3, 0, 3, 0], [3, 0, 2, 1], [3, 0, 1, 2], [3, 0, 0, 3]])))


derivative with respect to a: ([2.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],)


In [1]:
include("../src/JuTPSA.jl")
using .JuTPSA
x = CTPS(0, 1, 6, 2)
px = CTPS(0, 2, 6, 2)
y = CTPS(0, 3, 6, 2)
py = CTPS(0, 4, 6, 2)
dp = 0.001

delta = CTPS(dp, 5, 6, 2)
z = CTPS(0, 6, 6, 2)
print("x: $x\n, px: $px\n, y: $y\n, py: $py\n, delta: $delta\n, z: $z\n")

x: CTPS{Int64, 6, 2}(1, 7, [0, 1, 0, 0, 0, 0, 0], Base.RefValue{PolyMap}(PolyMap(6, 2, [[0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0], [1, 0, 0, 1, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, 0, 1, 0], [1, 0, 0, 0, 0, 0, 1], [2, 2, 0, 0, 0, 0, 0], [2, 1, 1, 0, 0, 0, 0], [2, 1, 0, 1, 0, 0, 0], [2, 1, 0, 0, 1, 0, 0], [2, 1, 0, 0, 0, 1, 0], [2, 1, 0, 0, 0, 0, 1], [2, 0, 2, 0, 0, 0, 0], [2, 0, 1, 1, 0, 0, 0], [2, 0, 1, 0, 1, 0, 0], [2, 0, 1, 0, 0, 1, 0], [2, 0, 1, 0, 0, 0, 1], [2, 0, 0, 2, 0, 0, 0], [2, 0, 0, 1, 1, 0, 0], [2, 0, 0, 1, 0, 1, 0], [2, 0, 0, 1, 0, 0, 1], [2, 0, 0, 0, 2, 0, 0], [2, 0, 0, 0, 1, 1, 0], [2, 0, 0, 0, 1, 0, 1], [2, 0, 0, 0, 0, 2, 0], [2, 0, 0, 0, 0, 1, 1], [2, 0, 0, 0, 0, 0, 2]])))
, px: CTPS{Int64, 6, 2}(1, 7, [0, 0, 1, 0, 0, 0, 0], Base.RefValue{PolyMap}(PolyMap(6, 2, [[0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0], [1, 0, 0, 1, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, 0, 1, 0], [1, 0, 0, 0, 0, 0, 1], [2, 2, 0, 0, 



In [2]:
D1 = Drift("D1", 1.0)
Q1 = Quad("Q1", 1.0, 2.0, 0)
seq = [D1, Q1]
track(seq, [x, px, y, py, delta, z])

MethodError: MethodError: no method matching track(::Vector{Main.JuTPSA.AbstractElement}, ::Vector{CTPS{T, 6, 2} where T})

Closest candidates are:
  track(::Any, !Matched::Array{CTPS{T, TPS_Dim, Max_TPS_Degree}, 1}) where {T, TPS_Dim, Max_TPS_Degree}
   @ Main.JuTPSA c:\Users\WAN\Desktop\JuTPSA\JuTPSA\src\tracking\TPSAtranfermap.jl:101
