-
Notifications
You must be signed in to change notification settings - Fork 24
/
diagnostics.jl
54 lines (42 loc) · 2.1 KB
/
diagnostics.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
apply_diagnostic_check(check, stream)
This function applies a `check` to the `stream`. Does nothing if `check` is of type `Nothing`.
"""
function apply_diagnostic_check end
"""
ObjectiveDiagnosticCheckNaNs
If enabled checks that both variable and factor bound score functions in the objective computation do not return `NaN`s.
Throws an error if finds `NaN`.
"""
struct ObjectiveDiagnosticCheckNaNs end
check_isnan(something) = isnan(something)
check_isnan(counting::CountingReal) = check_isnan(BayesBase.value(counting))
function apply_diagnostic_check(::ObjectiveDiagnosticCheckNaNs, something, stream)
error_fn = (_) -> """
Failed to compute the final objective value. The result is `NaN`.
Use `free_energy_diagnostics` keyword argument in the `inference` function to suppress this error.
"""
return stream |> error_if(check_isnan, error_fn)
end
"""
ObjectiveDiagnosticCheckInfs
If enabled checks that both variable and factor bound score functions in the objective computation do not return `Inf`s.
Throws an error if finds `Inf`.
"""
struct ObjectiveDiagnosticCheckInfs end
check_isinf(something) = isinf(something)
check_isinf(counting::CountingReal) = check_isinf(BayesBase.value(counting))
function apply_diagnostic_check(::ObjectiveDiagnosticCheckInfs, something, stream)
error_fn = (_) -> """
Failed to compute the final objective value. The result is `Inf`.
Use `free_energy_diagnostics` keyword argument in the `inference` function to suppress this error.
"""
return stream |> error_if(check_isinf, error_fn)
end
apply_diagnostic_check(::Nothing, something, stream) = stream
apply_diagnostic_check(checks::Tuple, something, stream) = foldl((folded, check) -> apply_diagnostic_check(check, something, folded), checks; init = stream)
"""
const DefaultObjectiveDiagnosticChecks = (ObjectiveDiagnosticCheckNaNs(), ObjectiveDiagnosticCheckInfs())
A constant that defines the default objective diagnostic checks.
"""
const DefaultObjectiveDiagnosticChecks = (ObjectiveDiagnosticCheckNaNs(), ObjectiveDiagnosticCheckInfs())