-
Notifications
You must be signed in to change notification settings - Fork 10
/
regularizer.jl
52 lines (40 loc) · 1.51 KB
/
regularizer.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# ------------------------------------------
# regularizers
using ACE1
#export regularizer_params
"""
`regularizer_params(; type = "laplacian", kwargs...)` : returns a dictionary containing the
complete set of parameters required to construct one of the solvers.
All parameters are passed as keyword argument and the kind of
parameters required depend on "type".
## LSQR Parameters (default)
* `type = "laplacian"`
* `rlap_scal = 3.0`
"""
function regularizer_params(; type = "laplacian", kwargs...)
@assert haskey(regularizers, type)
return regularizers[type][2](; kwargs...)
end
"""
`laplacian_regularizer_params(; kwargs...)` : returns a dictionary containing the
complete set of parameters required to construct a laplacian regularizer.
All parameters are passed as keyword argument.
### Parameters
* `rlap_scal = 3.0`
"""
function laplacian_regularizer_params(; rlap_scal = 3.0)
# TODO: check that value is reasonable
return Dict(
"type" => "laplacian",
"rlap_scal" => rlap_scal)
end
function laplacian_regularizer(basis::JuLIP.MLIPs.IPSuperBasis; rlap_scal)
return Diagonal(vcat(ACE1.scaling.(basis.BB, rlap_scal)...))
end
function generate_regularizer(basis::JuLIP.MLIPs.IPSuperBasis, params::Dict)
regularizer = regularizers[params["type"]][1]
kwargs = Dict([Symbol(key) => val for (key, val) in params]...)
delete!(kwargs, :type)
return regularizer(basis; kwargs...)
end
regularizers = Dict("laplacian" => (laplacian_regularizer, laplacian_regularizer_params))