-
Notifications
You must be signed in to change notification settings - Fork 0
/
GroupedTransform.jl
220 lines (192 loc) · 7.88 KB
/
GroupedTransform.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
using SparseArrays
@doc raw"""
GroupedTransform
A struct to describe a GroupedTransformation
# Fields
* `system::String` - choice of `"exp"` or `"cos"` or `"chui1"` or `"chui2"` or `"chui3"` or `"chui4"`
* `setting::Vector{NamedTuple{(:u, :mode, :bandwidths),Tuple{Vector{Int},Module,Vector{Int}}}}` - vector of the dimensions, mode, and bandwidths for each term/group, see also [`get_setting(system::String,d::Int,ds::Int,N::Vector{Int})::Vector{NamedTuple{(:u, :mode, :bandwidths),Tuple{Vector{Int},Module,Vector{Int}}}}`](@ref) and [`get_setting(system::String,U::Vector{Vector{Int}},N::Vector{Int})::Vector{NamedTuple{(:u, :mode, :bandwidths),Tuple{Vector{Int},Module,Vector{Int}}}}`](@ref)
* `X::Array{Float64}` - array of nodes
* `transforms::Vector{Tuple{Int64,Int64}}` - holds the low-dimensional sub transformations
# Constructor
GroupedTransform( system, setting, X )
# Additional Constructor
GroupedTransform( system, d, ds, N::Vector{Int}, X )
GroupedTransform( system, U, N, X )
"""
struct GroupedTransform
system::String
setting::Vector{
NamedTuple{(:u, :mode, :bandwidths),Tuple{Vector{Int},Module,Vector{Int}}},
}
X::Array{Float64}
transforms::Vector{Tuple{Int64,Int64}}
function GroupedTransform(
system::String,
setting::Vector{
NamedTuple{(:u, :mode, :bandwidths),Tuple{Vector{Int},Module,Vector{Int}}},
},
X::Array{Float64}
)
if !haskey(systems, system)
error("System not found.")
end
if (system == "exp" || system =="chui1" || system =="chui2"||system =="chui3"||system =="chui4")
if (minimum(X) < -0.5) || (maximum(X) >= 0.5)
error("Nodes must be between -0.5 and 0.5.")
end
elseif system == "cos"
if (minimum(X) < 0) || (maximum(X) > 0.5)
error("Nodes must be between 0 and 0.5.")
end
end
transforms = Vector{Tuple{Int64,Int64}}(undef, length(setting))
f = Vector{Tuple{Int64,Future}}(undef, length(setting))
w = (nworkers() == 1) ? 1 : 2
for (idx, s) in enumerate(setting)
if system =="chui1"
f[idx] = (w, remotecall(s[:mode].get_transform, w, s[:bandwidths], X[s[:u], :], 1 ))
elseif system =="chui2"
f[idx] = (w, remotecall(s[:mode].get_transform, w, s[:bandwidths], X[s[:u], :], 2))
elseif system =="chui3"
f[idx] = (w, remotecall(s[:mode].get_transform, w, s[:bandwidths], X[s[:u], :], 3))
elseif system =="chui4"
f[idx] = (w, remotecall(s[:mode].get_transform, w, s[:bandwidths], X[s[:u], :], 4))
else
f[idx] = (w, remotecall(s[:mode].get_transform, w, s[:bandwidths], X[s[:u], :]))
end
if nworkers() != 1
w = (w == nworkers()) ? 2 : (w + 1)
end
end
for (idx, s) in enumerate(setting)
transforms[idx] = (f[idx][1], fetch(f[idx][2]))
end
new(system, setting, X, transforms)
end
end
function GroupedTransform(
system::String,
d::Int,
ds::Int,
N::Vector{Int},
X::Array{Float64},
#m::Int64 = 1,
)
s = get_setting(system, d, ds, N)
return GroupedTransform(system, s, X)
end
function GroupedTransform(
system::String,
U::Vector{Vector{Int}},
N::Vector{Int},
X::Array{Float64},
#m::Int64 = 1,
)
s = get_setting(system, U, N)
return GroupedTransform(system, s, X)
end
@doc raw"""
*( F::GroupedTransform, fhat::GroupedCoefficients )::Vector{<:Number}
Overloads the `*` notation in order to achieve `f = F*fhat`.
"""
function Base.:*(F::GroupedTransform, fhat::GroupedCoefficients)::Vector{<:Number}
if F.setting != fhat.setting
error("The GroupedTransform and the GroupedCoefficients have different settings")
end
f = Vector{Future}(undef, length(F.transforms))
for i = 1:length(F.transforms)
f[i] =
@spawnat F.transforms[i][1] (F.setting[i][:mode].trafos[F.transforms[i][2]]) *
(fhat[F.setting[i][:u]])
end
return sum(i -> fetch(f[i]), 1:length(F.transforms))
end
@doc raw"""
*( F::GroupedTransform, f::Vector{<:Number} )::GroupedCoefficients
Overloads the * notation in order to achieve the adjoint transform `f = F*f`.
"""
function Base.:*(F::GroupedTransform, f::Vector{<:Number})::GroupedCoefficients
fh = Vector{Future}(undef, length(F.transforms))
for i = 1:length(F.transforms)
fh[i] =
@spawnat F.transforms[i][1] (F.setting[i][:mode].trafos[F.transforms[i][2]])' *
f
end
fhat = GroupedCoefficients(F.setting)
for i = 1:length(F.transforms)
fhat[F.setting[i][:u]] = fetch(fh[i])
end
return fhat
end
@doc raw"""
adjoint( F::GroupedTransform )::GroupedTransform
Overloads the `F'` notation and gives back the same GroupdTransform. GroupedTransform decides by the input if it is the normal trafo or the adjoint so this is only for convinience.
"""
function Base.:adjoint(F::GroupedTransform)::GroupedTransform
return F
end
@doc raw"""
F::GroupedTransform[u::Vector{Int}]::LinearMap{<:Number} or SparseArray
This function overloads getindex of GroupedTransform such that you can do `F[[1,3]]` to obtain the transform of the corresponding ANOVA term defined by `u`.
"""
function Base.:getindex(F::GroupedTransform, u::Vector{Int})#::LinearMap{<:Number}
idx = findfirst(s -> s[:u] == u, F.setting)
if isnothing(idx)
error("This term is not contained")
else
if F.system == "cos"
function trafo_cos(fhat::Vector{Float64})::Vector{Float64}
return remotecall_fetch(
F.setting[idx][:mode].trafos[F.transforms[idx][2]],
F.transforms[idx][1],
fhat,
)
end
function adjoint_cos(f::Vector{Float64})::Vector{Float64}
return remotecall_fetch(
F.setting[idx][:mode].trafos[F.transforms[idx][2]]',
F.transforms[idx][1],
f,
)
end
N = prod(F.setting[idx][:bandwidths] .- 1)
M = size(F.X, 2)
return LinearMap{Float64}(trafo_cos, adjoint_cos, M, N)
elseif F.system == "exp"
function trafo_exp(fhat::Vector{ComplexF64})::Vector{ComplexF64}
return remotecall_fetch(
F.setting[idx][:mode].trafos[F.transforms[idx][2]],
F.transforms[idx][1],
fhat,
)
end
function adjoint_exp(f::Vector{ComplexF64})::Vector{ComplexF64}
return remotecall_fetch(
F.setting[idx][:mode].trafos[F.transforms[idx][2]]',
F.transforms[idx][1],
f,
)
end
N = prod(F.setting[idx][:bandwidths] .- 1)
M = size(F.X, 2)
return LinearMap{ComplexF64}(trafo_exp, adjoint_exp, M, N)
elseif F.system == "chui1" || F.system == "chui2" || F.system == "chui3"||F.system == "chui4"
#S = SparseMatrixCSC{Float64, Int}
S = @spawnat F.transforms[idx][1] (F.setting[idx][:mode].trafos[F.transforms[idx][2]])
return SparseMatrixCSC{Float64, Int}(fetch(S))
end
end
end
@doc raw"""
get_matrix( F::GroupedTransform )::Matrix{<:Number}
This function returns the actual matrix of the transformation. This is not available for the wavelet basis
"""
function get_matrix(F::GroupedTransform)::Matrix{<:Number}
s1 = F.setting[1]
F_direct = s1[:mode].get_matrix(s1[:bandwidths], F.X[s1[:u], :])
for (idx, s) in enumerate(F.setting)
idx == 1 && continue
F_direct = hcat(F_direct, s[:mode].get_matrix(s[:bandwidths], F.X[s[:u], :]))
end
return F_direct
end