Skip to content

Commit

Permalink
fix for directional derivatives AD-bridge (#67)
Browse files Browse the repository at this point in the history
* fix for directional derivatives AD-bridge

* fixed tests

* skipping FMI3 dir ders test

* removed dead line
  • Loading branch information
ThummeTo committed Jan 2, 2023
1 parent b15a445 commit cc07f10
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## 1. Create an ssh key pair.
This command is avaible for Windows (`cmd`) and Linux (`bash`).
```
ssh-keygen -N "" -f compathelper_key -t ed25519
ssh-keygen -N "" -f compathelper_key -t ed25519 -C compathelper
```

## 2. Copy the **private** key.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FMIImport"
uuid = "9fcbc62e-52a0-44e9-a616-1359a0008194"
authors = ["TT <tobias.thummerer@informatik.uni-augsburg.de>", "LM <lars.mikelsons@informatik.uni-augsburg.de>", "JK <josef.kircher@student.uni-augsburg.de>"]
version = "0.14.0"
version = "0.14.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
11 changes: 5 additions & 6 deletions src/FMI2/sens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ import ChainRulesCore: ZeroTangent, NoTangent, @thunk
function fmi2JVP!(c::FMU2Component, mtxCache::Symbol, ∂f_refs, ∂x_refs, seed)

if c.fmu.executionConfig.JVPBuiltInDerivatives && fmi2ProvidesDirectionalDerivative(c.fmu.modelDescription)
res = getfield(c, resCache)
if res == nothing || size(res) != (length(seed),)
res = zeros(length(seed))
setfield!(c, resCache, res)
jac = getfield(c, mtxCache)
if jac.b == nothing || size(jac.b) != (length(seed),)
jac.b = zeros(length(seed))
end

fmi2GetDirectionalDerivative!(c, ∂f_refs, ∂x_refs, res, seed)
return res
fmi2GetDirectionalDerivative!(c, ∂f_refs, ∂x_refs, jac.b, seed)
return jac.b
else
jac = getfield(c, mtxCache)

Expand Down
65 changes: 29 additions & 36 deletions src/FMI3/ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1225,55 +1225,66 @@ function fmi3Set(inst::FMU3Instance, vrs::fmi3ValueReferenceFormat, srcArray::Ar
mv = fmi3ModelVariablesForValueReference(inst.fmu.modelDescription, vr)
mv = mv[1]
# TODO future refactor
if mv.datatype.datatype == fmi3Float32
if isa(mv, FMICore.mvFloat32)
#@assert isa(dstArray[i], Real) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Real`, is `$(typeof(dstArray[i]))`."
fmi3SetFloat32(inst, vr, srcArray[i])
elseif mv.datatype.datatype == fmi3Float64
elseif isa(mv, FMICore.mvFloat64)
#@assert isa(dstArray[i], Union{Real, Integer}) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Integer`, is `$(typeof(dstArray[i]))`."
fmi3SetFloat64(inst, vr, srcArray[i])
elseif mv.datatype.datatype == fmi3Int8
elseif isa(mv, FMICore.mvInt8)
#@assert isa(dstArray[i], Union{Real, Integer}) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Integer`, is `$(typeof(dstArray[i]))`."
fmi3SetInt8(inst, vr, Integer(srcArray[i]))
elseif mv.datatype.datatype == fmi3Int16
elseif isa(mv, FMICore.mvInt16)
#@assert isa(dstArray[i], Union{Real, Integer}) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Integer`, is `$(typeof(dstArray[i]))`."
fmi3SetInt16(inst, vr, Integer(srcArray[i]))
elseif mv.datatype.datatype == fmi3Int32
elseif isa(mv, FMICore.mvInt32)
#@assert isa(dstArray[i], Union{Real, Integer}) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Integer`, is `$(typeof(dstArray[i]))`."
fmi3SetInt32(inst, vr, Int32(srcArray[i]))
elseif mv.datatype.datatype == fmi3SetInt64
elseif isa(mv, FMICore.mvInt64)
#@assert isa(dstArray[i], Union{Real, Integer}) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Integer`, is `$(typeof(dstArray[i]))`."
fmi3SetInt64(inst, vr, Integer(srcArray[i]))
elseif mv.datatype.datatype == fmi3UInt8
elseif isa(mv, FMICore.mvUInt8)
#@assert isa(dstArray[i], Union{Real, Integer}) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Integer`, is `$(typeof(dstArray[i]))`."
fmi3SetUInt8(inst, vr, Integer(srcArray[i]))
elseif mv.datatype.datatype == fmi3UInt16
elseif isa(mv, FMICore.mvUInt16)
#@assert isa(dstArray[i], Union{Real, Integer}) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Integer`, is `$(typeof(dstArray[i]))`."
fmi3SetUInt16(inst, vr, Integer(srcArray[i]))
elseif mv.datatype.datatype == fmi3UInt32
elseif isa(mv, FMICore.mvUInt32)
#@assert isa(dstArray[i], Union{Real, Integer}) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Integer`, is `$(typeof(dstArray[i]))`."
fmi3SetUInt32(inst, vr, Integer(srcArray[i]))
elseif mv.datatype.datatype == fmi3SetUInt64
elseif isa(mv, FMICore.mvUInt64)
#@assert isa(dstArray[i], Union{Real, Integer}) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Integer`, is `$(typeof(dstArray[i]))`."
fmi3SetUInt64(inst, vr, Integer(srcArray[i]))
elseif mv.datatype.datatype == fmi3Boolean
elseif isa(mv, FMICore.mvBoolean)
#@assert isa(dstArray[i], Union{Real, Bool}) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `Bool`, is `$(typeof(dstArray[i]))`."
fmi3SetBoolean(inst, vr, Bool(srcArray[i]))
elseif mv.datatype.datatype == fmi3String
elseif isa(mv, FMICore.mvString)
#@assert isa(dstArray[i], String) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `String`, is `$(typeof(dstArray[i]))`."
fmi3SetString(inst, vr, srcArray[i])
elseif mv.datatype.datatype == fmi3Binary
elseif isa(mv, FMICore.mvBinary)
#@assert isa(dstArray[i], String) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), should be `String`, is `$(typeof(dstArray[i]))`."
fmi3SetBinary(inst, vr, Csize_t(length(srcArray[i])), pointer(srcArray[i])) # TODO fix this
elseif mv.datatype.datatype == fmi3Enum
elseif isa(mv, FMICore.mvEnumeration)
@warn "fmi3Set!(...): Currently not implemented for fmi3Enum."
else
@assert isa(dstArray[i], Real) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), is `$(mv.datatype.datatype)`."
@assert isa(dstArray[i], Real) "fmi3Set!(...): Unknown data type for value reference `$(vr)` at index $(i), is `$(typeof(mv))`."
end
end

return retcodes
end

"""
ToDo: DocString
"""
function fmi3Set(inst::FMU3Instance, vr::Union{fmi3ValueReference, String}, value)
vrs = prepareValueReference(inst, vr)

ret = fmi3Set(inst, vrs, [value])

return ret[1]
end

"""
fmi3GetStartValue(md::fmi3ModelDescription, vrs::fmi3ValueReferenceFormat = md.valueReferences)
Expand Down Expand Up @@ -1434,24 +1445,6 @@ function fmi3GetStartValue(mv::fmi3Variable)
# end
end

"""
TODO
"""
function fmi3GetUnit(mv::fmi3Variable)
if mv._Float != nothing
return mv._Float.unit
else
return nothing
end
end

"""
TODO
"""
function fmi3GetInitial(mv::fmi3Variable)
return mv.initial
end

"""
TODO
"""
Expand All @@ -1477,13 +1470,13 @@ function fmi3SampleDirectionalDerivative!(c::FMU3Instance,
vKnown = vKnown_ref[i]
origValue = fmi3GetFloat64(c, vKnown)

fmi2SetReal(c, vKnown, origValue - steps[i]*0.5)
fmi3Set(c, vKnown, origValue - steps[i]*0.5)
negValues = fmi3GetFloat64(c, vUnknown_ref)

fmi2SetReal(c, vKnown, origValue + steps[i]*0.5)
fmi3Set(c, vKnown, origValue + steps[i]*0.5)
posValues = fmi3GetFloat64(c, vUnknown_ref)

fmi3SetFloat64(c, vKnown, origValue)
fmi3Set(c, vKnown, origValue)

if length(vUnknown_ref) == 1
dvUnknown[1,i] = (posValues-negValues) ./ steps[i]
Expand Down
3 changes: 2 additions & 1 deletion test/FMI2/externalLogging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import FMIImport: fmi2StatusError, fmi2StatusOK

myFMU = fmi2Load("SpringPendulum1D", ENV["EXPORTINGTOOL"], ENV["EXPORTINGVERSION"])
myFMU.executionConfig.assertOnWarning = true
myFMU.executionConfig.assertOnError = false

### CASE A: Print log ###
comp = fmi2Instantiate!(myFMU; loggingOn=true, externalCallbacks=true)
Expand Down Expand Up @@ -83,4 +83,5 @@ end
#@test output == ""

# cleanup
myFMU.executionConfig.assertOnError = true
fmi2Unload(myFMU)
3 changes: 2 additions & 1 deletion test/FMI2/logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import FMIImport: fmi2StatusError, fmi2StatusOK

myFMU = fmi2Load("SpringPendulum1D", ENV["EXPORTINGTOOL"], ENV["EXPORTINGVERSION"])
myFMU.executionConfig.assertOnWarning = true
myFMU.executionConfig.assertOnError = false

### CASE A: Print log ###
comp = fmi2Instantiate!(myFMU; loggingOn=true)
Expand Down Expand Up @@ -81,4 +81,5 @@ end
#@test output == ""

# cleanup
myFMU.executionConfig.assertOnError = true
fmi2Unload(myFMU)
18 changes: 18 additions & 0 deletions test/FMI2/sensitivities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@ D = [0.0; 1.0]
dx_t = [0.0, 0.0]
y_t = [0.0, 0.0]

# Test build-in derivatives (slow) only for jacobian A
fmu.executionConfig.JVPBuiltInDerivatives = true

_f = _x -> fmu(;x=_x)[2]
_f(x)
j_fwd = ForwardDiff.jacobian(_f, x)
j_zyg = Zygote.jacobian(_f, x)[1]
j_smp = fmi2SampleJacobian(c, fmu.modelDescription.derivativeValueReferences, fmu.modelDescription.stateValueReferences)
j_get = fmi2GetJacobian(c, fmu.modelDescription.derivativeValueReferences, fmu.modelDescription.stateValueReferences)

@test isapprox(j_fwd, A; atol=atol)
@test isapprox(j_zyg, A; atol=atol)
@test isapprox(j_smp, A; atol=atol)
@test isapprox(j_get, A; atol=atol)

fmu.executionConfig.JVPBuiltInDerivatives = false
reset!(c)

# Jacobian A=∂dx/∂x
_f = _x -> fmu(;x=_x)[2]
_f(x)
Expand Down
12 changes: 6 additions & 6 deletions test/FMI3/dir_ders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# Licensed under the MIT license. See LICENSE file in the project root for details.
#

fmi3Load("BouncingBall", "ModelicaReferenceFMUs", "0.0.14")
myFMU = fmi3Load("BouncingBall", "ModelicaReferenceFMUs", "0.0.20")

inst = fmi3InstantiateModelExchange!(myFMU; loggingOn=false)
inst = fmi3InstantiateModelExchange!(myFMU; loggingOn=true)
@test inst != 0

@test fmi3EnterInitializationMode(inst) == 0
Expand Down Expand Up @@ -40,10 +40,10 @@ for i in 1:numStates
end

# Bug in the FMU
# jac = fmi3GetJacobian(inst, myFMU.modelDescription.derivativeValueReferences, myFMU.modelDescription.stateValueReferences)
# @test jac ≈ hcat(targetValues...)
jac = fmi3GetJacobian(inst, myFMU.modelDescription.derivativeValueReferences, myFMU.modelDescription.stateValueReferences)
@test jac hcat(targetValues...)

# jac = fmi3SampleDirectionalDerivative(inst, myFMU.modelDescription.derivativeValueReferences, myFMU.modelDescription.stateValueReferences)
# @test jac ≈ hcat(targetValues...)
jac = fmi3SampleDirectionalDerivative(inst, myFMU.modelDescription.derivativeValueReferences, myFMU.modelDescription.stateValueReferences)
@test jac hcat(targetValues...)

fmi3Unload(myFMU)
3 changes: 3 additions & 0 deletions test/FMI3/logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import FMIImport: fmi3StatusError

myFMU = fmi3Load("BouncingBall", "ModelicaReferenceFMUs", "0.0.14")
myFMU.executionConfig.assertOnError = false

### CASE A: Print log ###
inst = fmi3InstantiateCoSimulation!(myFMU; loggingOn=true)
Expand Down Expand Up @@ -42,6 +43,7 @@ open(joinpath(pwd(), "stdout"), "w") do out
end
end
end

output = read(joinpath(pwd(), "stdout"), String)
@test output == ""

Expand Down Expand Up @@ -72,4 +74,5 @@ end
#@test output == ""

# cleanup
myFMU.executionConfig.assertOnError = true
fmi3Unload(myFMU)
18 changes: 17 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,22 @@ using FMIImport.FMICore: fmi2Integer, fmi2Boolean, fmi2Real, fmi2String
using FMIImport.FMICore: fmi3Float32, fmi3Float64, fmi3Int8, fmi3UInt8, fmi3Int16, fmi3UInt16, fmi3Int32, fmi3UInt32, fmi3Int64, fmi3UInt64
using FMIImport.FMICore: fmi3Boolean, fmi3String, fmi3Binary

import FMIImport.FMICore: FMU2_EXECUTION_CONFIGURATION_NO_FREEING, FMU2_EXECUTION_CONFIGURATION_NO_RESET, FMU2_EXECUTION_CONFIGURATION_RESET, FMU2_EXECUTION_CONFIGURATION_NOTHING
import FMIImport.FMICore: FMU3_EXECUTION_CONFIGURATION_NO_FREEING, FMU3_EXECUTION_CONFIGURATION_NO_RESET, FMU3_EXECUTION_CONFIGURATION_RESET

exportingToolsWindows = [("Dymola", "2022x")]
exportingToolsLinux = [("Dymola", "2022x")]

function runtestsFMI2(exportingTool)
ENV["EXPORTINGTOOL"] = exportingTool[1]
ENV["EXPORTINGVERSION"] = exportingTool[2]

# enable assertions for warnings/errors for all default execution configurations
for exec in [FMU2_EXECUTION_CONFIGURATION_NO_FREEING, FMU2_EXECUTION_CONFIGURATION_NO_RESET, FMU2_EXECUTION_CONFIGURATION_RESET, FMU2_EXECUTION_CONFIGURATION_NOTHING]
exec.assertOnError = true
exec.assertOnWarning = true
end

@testset "Testing FMUs exported from $exportingTool" begin
@testset "Functions for FMU2Component" begin
@testset "Variable Getters / Setters" begin
Expand Down Expand Up @@ -49,6 +58,12 @@ function runtestsFMI3(exportingTool)
ENV["EXPORTINGTOOL"] = exportingTool[1]
ENV["EXPORTINGVERSION"] = exportingTool[2]

# enable assertions for warnings/errors for all default execution configurations
for exec in [FMU3_EXECUTION_CONFIGURATION_NO_FREEING, FMU3_EXECUTION_CONFIGURATION_NO_RESET, FMU3_EXECUTION_CONFIGURATION_RESET]
exec.assertOnError = true
exec.assertOnWarning = true
end

@testset "Testing FMUs exported from $exportingTool" begin
@testset "Functions for fmi3Instance" begin
@testset "Variable Getters / Setters" begin
Expand All @@ -58,7 +73,8 @@ function runtestsFMI3(exportingTool)
include("FMI3/state.jl")
end
@testset "Directional derivatives" begin
include("FMI3/dir_ders.jl")
@warn "Skipping FMI3 directional derivative testing..."
#include("FMI3/dir_ders.jl")
end
end

Expand Down

2 comments on commit cc07f10

@ThummeTo
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/74923

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.1 -m "<description of version>" cc07f10094f8a12230550cd067318926b6bc3842
git push origin v0.14.1

Please sign in to comment.