/
parameters_matrix.jl
91 lines (70 loc) · 3.06 KB
/
parameters_matrix.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
nearest_orthogonal_matrix(X::AbstractMatrix{<:Union{Real,Complex}})
Project `X` onto the closest orthogonal matrix in Frobenius norm.
Originally used in varz: https://github.com/wesselb/varz/blob/master/varz/vars.py#L446
"""
@inline function nearest_orthogonal_matrix(X::AbstractMatrix{<:Union{Real,Complex}})
# Inlining necessary for type inference for some reason.
U, _, V = svd(X)
return U * V'
end
"""
orthogonal(X::AbstractMatrix{<:Real})
Produce a parameter whose `value` is constrained to be an orthogonal matrix. The argument `X` need not
be orthogonal.
This functionality projects `X` onto the nearest element subspace of orthogonal matrices (in
Frobenius norm) and is overparametrised as a consequence.
Originally used in varz: https://github.com/wesselb/varz/blob/master/varz/vars.py#L446
"""
orthogonal(X::AbstractMatrix{<:Real}) = Orthogonal(X)
struct Orthogonal{TX<:AbstractMatrix{<:Real}} <: AbstractParameter
X::TX
end
Base.:(==)(X::Orthogonal, Y::Orthogonal) = X.X == Y.X
value(X::Orthogonal) = nearest_orthogonal_matrix(X.X)
function flatten(::Type{T}, X::Orthogonal) where {T<:Real}
v, unflatten_to_Array = flatten(T, X.X)
unflatten_Orthogonal(v_new::Vector{T}) = Orthogonal(unflatten_to_Array(v_new))
return v, unflatten_Orthogonal
end
"""
positive_definite(X::AbstractMatrix{<:Real})
Produce a parameter whose `value` is constrained to be a positive-definite matrix. The argument `X` needs to
be a positive-definite matrix (see https://en.wikipedia.org/wiki/Definite_matrix).
The unconstrained parameter is a `LowerTriangular` matrix, stored as a vector.
"""
function positive_definite(X::AbstractMatrix{<:Real})
isposdef(X) || throw(ArgumentError("X is not positive-definite"))
return PositiveDefinite(tril_to_vec(cholesky(X).L))
end
struct PositiveDefinite{TL<:AbstractVector{<:Real}} <: AbstractParameter
L::TL
end
Base.:(==)(X::PositiveDefinite, Y::PositiveDefinite) = X.L == Y.L
A_At(X) = X * X'
value(X::PositiveDefinite) = A_At(vec_to_tril(X.L))
function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real}
v, unflatten_v = flatten(T, X.L)
unflatten_PositiveDefinite(v_new::Vector{T}) = PositiveDefinite(unflatten_v(v_new))
return v, unflatten_PositiveDefinite
end
# Convert a vector to lower-triangular matrix
function vec_to_tril(v::AbstractVector{T}) where {T}
n_vec = length(v)
n_tril = Int((sqrt(1 + 8 * n_vec) - 1) / 2) # Infer the size of the matrix from the vector
L = zeros(T, n_tril, n_tril)
L[tril!(trues(size(L)))] = v
return L
end
function ChainRulesCore.rrule(::typeof(vec_to_tril), v::AbstractVector{T}) where {T}
L = vec_to_tril(v)
pullback_vec_to_tril(Δ) = NoTangent(), tril_to_vec(unthunk(Δ))
return L, pullback_vec_to_tril
end
# Convert a lower-triangular matrix to a vector (without the zeros)
# Adapted from https://stackoverflow.com/questions/50651781/extract-lower-triangle-portion-of-a-matrix
function tril_to_vec(X::AbstractMatrix{T}) where {T}
n, m = size(X)
n == m || error("Matrix needs to be square")
return X[tril!(trues(size(X)))]
end