-
Notifications
You must be signed in to change notification settings - Fork 9
/
Augmentation.jl
62 lines (51 loc) · 1.78 KB
/
Augmentation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
A state augmentation used by explorers.
Internally, hijacks the recorders machinery to
store it in a Replica.
"""
mutable struct Augmentation{T}
"""
The payload, initially nothing until
[`get_buffer()`](@ref) is called.
"""
contents::Union{T,Nothing}
"""
When it is volatile, i.e. can be
reconstructed on the fly and is only
stored for efficiency purpose, it is
not worth serialialinzing it
"""
serialize::Bool
end
buffers() = Augmentation(Dict{Symbol, Vector{Float64}}(), false)
"""
$SIGNATURES
Return a Vector of length dim. Allocating only the first
time, after that the buffer is recycled and stored in the
Replica's recorders.
"""
function get_buffer(augmentation, key::Symbol, dim::Int)::Vector{Float64}
dict = augmentation.contents
if !haskey(dict, key)
dict[key] = zeros(dim)
end
return dict[key]
end
Augmentation{T}() where {T} = Augmentation{T}(nothing, false)
Base.merge(a1::Augmentation{T}, a2::Augmentation{T}) where {T} =
Augmentation{T}(nothing, false)
# In this case we do not want to lose the augmentation at the end of the round
Base.empty!(a::Augmentation) = nothing
function Serialization.serialize(s::AbstractSerializer, instance::Augmentation{T}) where {T}
Serialization.writetag(s.io, Serialization.OBJECT_TAG)
Serialization.serialize(s, Augmentation{T})
Serialization.serialize(s, instance.serialize)
if instance.serialize
Serialization.serialize(s, instance.contents)
end
end
function Serialization.deserialize(s::AbstractSerializer, type::Type{Augmentation{T}}) where {T}
serialize_field = Serialization.deserialize(s)
contents = serialize_field ? Serialization.deserialize(s) : nothing
return Augmentation{T}(contents, serialize_field)
end