-
Notifications
You must be signed in to change notification settings - Fork 25
/
varname.jl
136 lines (122 loc) · 4.29 KB
/
varname.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
```
struct VarName{sym}
indexing :: String
end
```
A variable identifier. Every variable has a symbol `sym` and `indices `indexing`.
The Julia variable in the model corresponding to `sym` can refer to a single value or
to a hierarchical array structure of univariate, multivariate or matrix variables. `indexing` stores the indices that can access the random variable from the Julia
variable.
Examples:
- `x[1] ~ Normal()` will generate a `VarName` with `sym == :x` and `indexing == "[1]"`.
- `x[:,1] ~ MvNormal(zeros(2))` will generate a `VarName` with `sym == :x` and
`indexing == "[Colon(),1]"`.
- `x[:,1][2] ~ Normal()` will generate a `VarName` with `sym == :x` and
`indexing == "[Colon(),1][2]"`.
"""
struct VarName{sym}
indexing::String
end
"""
@varname(var)
A macro that returns an instance of `VarName` given the symbol or expression of a Julia variable, e.g. `@varname x[1,2][1+5][45][3]` returns `VarName{:x}("[1,2][6][45][3]")`.
"""
macro varname(expr::Union{Expr, Symbol})
expr |> varname |> esc
end
function varname(expr)
ex = deepcopy(expr)
(ex isa Symbol) && return quote
DynamicPPL.VarName{$(QuoteNode(ex))}("")
end
(ex.head == :ref) || throw("VarName: Mis-formed variable name $(expr)!")
inds = :(())
while ex.head == :ref
if length(ex.args) >= 2
strs = map(x -> :($x === (:) ? "Colon()" : string($x)), ex.args[2:end])
pushfirst!(inds.args, :("[" * join($(Expr(:vect, strs...)), ",") * "]"))
end
ex = ex.args[1]
isa(ex, Symbol) && return quote
DynamicPPL.VarName{$(QuoteNode(ex))}(foldl(*, $inds, init = ""))
end
end
throw("VarName: Mis-formed variable name $(expr)!")
end
macro vsym(expr::Union{Expr, Symbol})
expr |> vsym
end
"""
vsym(expr::Union{Expr, Symbol})
Returns the variable symbol given the input variable expression `expr`. For example, if the input `expr = :(x[1])`, the output is `:x`.
"""
function vsym(expr::Union{Expr, Symbol})
ex = deepcopy(expr)
(ex isa Symbol) && return QuoteNode(ex)
(ex.head == :ref) || throw("VarName: Mis-formed variable name $(expr)!")
while ex.head == :ref
ex = ex.args[1]
isa(ex, Symbol) && return QuoteNode(ex)
end
throw("VarName: Mis-formed variable name $(expr)!")
end
"""
@vinds(expr)
Returns a tuple of tuples of the indices in `expr`. For example, `@vinds x[1,:][2]` returns
`((1, Colon()), (2,))`.
"""
macro vinds(expr::Union{Expr, Symbol})
expr |> vinds |> esc
end
function vinds(expr::Union{Expr, Symbol})
ex = deepcopy(expr)
inds = Expr(:tuple)
(ex isa Symbol) && return inds
(ex.head == :ref) || throw("VarName: Mis-formed variable name $(expr)!")
while ex.head == :ref
pushfirst!(inds.args, Expr(:tuple, ex.args[2:end]...))
ex = ex.args[1]
isa(ex, Symbol) && return inds
end
throw("VarName: Mis-formed variable name $(expr)!")
end
"""
split_var_str(var_str, inds_as = Vector)
This function splits a variable string, e.g. `"x[1:3,1:2][3,2]"` to the variable's symbol `"x"` and the indexing `"[1:3,1:2][3,2]"`. If `inds_as = String`, the indices are returned as a string, e.g. `"[1:3,1:2][3,2]"`. If `inds_as = Vector`, the indices are returned as a vector of vectors of strings, e.g. `[["1:3", "1:2"], ["3", "2"]]`.
"""
function split_var_str(var_str, inds_as = Vector)
ind = findfirst(c -> c == '[', var_str)
if inds_as === String
if ind === nothing
return var_str, ""
else
return var_str[1:ind-1], var_str[ind:end]
end
end
@assert inds_as === Vector
inds = Vector{String}[]
if ind === nothing
return var_str, inds
end
sym = var_str[1:ind-1]
ind = length(sym)
while ind < length(var_str)
ind += 1
@assert var_str[ind] == '['
push!(inds, String[])
while var_str[ind] != ']'
ind += 1
if var_str[ind] == '['
ind2 = findnext(c -> c == ']', var_str, ind)
push!(inds[end], strip(var_str[ind:ind2]))
ind = ind2+1
else
ind2 = findnext(c -> c == ',' || c == ']', var_str, ind)
push!(inds[end], strip(var_str[ind:ind2-1]))
ind = ind2
end
end
end
return sym, inds
end