-
Notifications
You must be signed in to change notification settings - Fork 17
/
DataContainers.jl
122 lines (96 loc) · 3.65 KB
/
DataContainers.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
module DataContainers
## Imports
import Base: size #must import to add a definition to size
##Exports
export DataContainer, PairedDataContainer
export size
export get_data, get_inputs, get_outputs
## Objects
"""
DataContainer{FT <: Real}
Container to store data samples as columns in an array.
"""
struct DataContainer{FT <: Real}
#stored data, each piece of data is a column [data dimension × number samples]
stored_data::AbstractMatrix{FT}
#constructor with 2D arrays
function DataContainer(stored_data::AbstractMatrix{FT}; data_are_columns = true) where {FT <: Real}
if data_are_columns
new{FT}(deepcopy(stored_data))
else
#Note: permutedims contains a deepcopy
new{FT}(permutedims(stored_data, (2, 1)))
end
end
end
"""
PairedDataContainer{FT <: Real}
Stores input - output pairs as data containers, there must be an equal number of inputs and outputs.
"""
struct PairedDataContainer{FT <: Real}
# container for inputs and ouputs, each Container holds an array
# size [data/parameter dimension × number samples]
inputs::DataContainer{FT}
outputs::DataContainer{FT}
#constructor with 2D Arrays
function PairedDataContainer(
inputs::AbstractMatrix{FT},
outputs::AbstractMatrix{FT};
data_are_columns = true,
) where {FT <: Real}
sample_dim = data_are_columns ? 2 : 1
if !(size(inputs, sample_dim) == size(outputs, sample_dim))
throw(
DimensionMismatch(
"There must be the same number of samples of both inputs and outputs. Got $(size(inputs, sample_dim)) input samples and $(size(outputs, sample_dim)) output samples.",
),
)
end
stored_inputs = DataContainer(inputs; data_are_columns = data_are_columns)
stored_outputs = DataContainer(outputs; data_are_columns = data_are_columns)
new{FT}(stored_inputs, stored_outputs)
end
#constructor with DataContainers
function PairedDataContainer(inputs::DataContainer, outputs::DataContainer)
if !(size(inputs, 2) == size(outputs, 2))
throw(
DimensionMismatch(
"There must be the same number of samples of both inputs and outputs. Got $(size(inputs, 2)) input samples and $(size(outputs, 2)) output samples.",
),
)
else
FT = eltype(get_data(inputs))
new{FT}(inputs, outputs)
end
end
end
## Functions
"""
size(dc::DataContainer, idx::IT) where {IT <: Integer}
Returns the size of the stored data. If `idx` provided, it returns the size along dimension `idx`.
"""
size(dc::DataContainer) = size(dc.stored_data)
size(dc::DataContainer, idx::IT) where {IT <: Integer} = size(dc.stored_data, idx)
size(pdc::PairedDataContainer) = size(pdc.inputs), size(pdc.outputs)
"""
size(pdc::PairedDataContainer, idx::IT) where {IT <: Integer}
Returns the sizes of the inputs and ouputs along dimension `idx` (if provided).
"""
size(pdc::PairedDataContainer, idx::IT) where {IT <: Integer} = size(pdc.inputs, idx), size(pdc.outputs, idx)
"""
get_data(pdc::PairedDataContainer)
Returns both input and output data stored in `pdc`.
"""
get_data(dc::DataContainer) = deepcopy(dc.stored_data)
get_data(pdc::PairedDataContainer) = get_inputs(pdc), get_outputs(pdc)
"""
get_inputs(pdc::PairedDataContainer)
Returns input data stored in `pdc`.
"""
get_inputs(pdc::PairedDataContainer) = get_data(pdc.inputs)
"""
get_outputs(pdc::PairedDataContainer)
Returns output data stored in `pdc`.
"""
get_outputs(pdc::PairedDataContainer) = get_data(pdc.outputs)
end # module