This repository has been archived by the owner on Aug 16, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
/
ProbTable.lua
180 lines (168 loc) · 5.6 KB
/
ProbTable.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
local torch = require 'torch'
--- @module ProbTable
-- Implementation of probability table using Torch tensor
local ProbTable = torch.class('tl.ProbTable')
--- Constructor.
-- @arg {torch.tensor} P - probability Tensor, the `i`th dimension corresponds to the `i`th variable.
-- @arg {table[string]=} names - A table of names for the variables. By default theses will be assigned using indices.
--
-- Example:
--
-- @code {lua}
-- local t = ProbTable(torch.Tensor{{0.2, 0.8}, {0.4, 0.6}, {0.1, 0.9}}, {'a', 'b'})
-- t:query{a=1, b=2} -- 0.8
-- t:query{a=2} -- Tensor{0.4, 0.6}
function ProbTable:__init(P, names)
if not names then names = torch.range(1, P:nDimension()):totable() end
self.names = {}
self.name2index = {}
if type(names) == 'string' then
self.names = {names}
self.name2index = {}
self.name2index[names] = 1
else
for _, name in ipairs(names) do
table.insert(self.names, name)
self.name2index[name] = #self.names
end
end
self.P = P
end
--- @returns {int} number of variables in the table
function ProbTable:size()
return self.P:nDimension()
end
--- @returns {torch.Tensor} probabilities for the assignments in `dict`.
-- @arg {table[string=int]} dict - an assignment to consider
--
-- Example:
--
-- @code {lua}
-- local t = ProbTable(torch.Tensor{{0.2, 0.8}, {0.4, 0.6}, {0.1, 0.9}}, {'a', 'b'})
-- t:query{a=1, b=2}
-- t:query{a=2}
--
-- The first query is `0.8`. The second query is `Tensor{0.4, 0.6}`
function ProbTable:query(dict)
for name, value in pairs(dict) do
local i = assert(self.name2index[name], name .. ' is not a valid name')
assert(value > 0 and value <= self.P:size(i), value .. ' is out of range')
end
local ind = {}
for i, name in ipairs(self.names) do
table.insert(ind, dict[name] or {})
end
return self.P[ind]
end
--- @returns {ProbTable} a copy
function ProbTable:clone()
local names = tl.copy(self.names)
local P = self.P:clone()
return ProbTable.new(P, names)
end
--- @returns {string} string representation
function ProbTable:__tostring__()
local s = ''
local divider = ''
for i, name in ipairs(self.names) do
s = s .. name .. '\t'
divider = divider .. '-' .. '\t'
end
s = s .. 'P\n' .. divider .. '-\n'
local dims = self.P:size():totable()
for i, d in ipairs(dims) do
dims[i] = torch.range(1, d):totable()
end
for _, ind in ipairs(table.combinations(dims)) do
for _, i in ipairs(ind) do
s = s .. i .. '\t'
end
s = s .. self.P[ind] .. '\n'
end
return s
end
--- Returns a new table that is the product of two tables.
-- @arg {ProbTable} B - another table
-- @returns {ProbTable} product of this and another table
function ProbTable:mul(B)
-- allocate new P and name for the new product ProbTable
local P = self.P:clone()
local names = tl.copy(self.names)
local name2index = tl.copy(self.name2index)
-- the idea is that we will extend the new name order such that
-- the beginning names are in the exact same order as B.names.
-- this way B.P[ind] can be multiplied with P[ind] directly.
-- we also do this because repeatTensor repeats from the beginning dimensions.
local write = 1 -- This keep track of the index of the first non-B name
for i, name in ipairs(B.names) do
if name2index[name] then
-- This name is in both A and B, so we swap it to beginning
-- swap P
local old_i = name2index[name]
P = P:transpose(old_i, write)
-- swap name
local old_write_name = names[write]
names[write] = name
names[old_i] = old_write_name
-- swap name2index
name2index[old_write_name] = old_i
name2index[name] = write
else
-- Otherwise this name is in B only, we simply insert it into the correct spot
table.insert(names, write, name)
for i, name in ipairs(names) do name2index[name] = i end
local sizes = torch.ones(P:nDimension() + 1)
sizes[1] = B.P:size(i)
P = P:repeatTensor(table.unpack(sizes:totable())):transpose(1, write)
end
write = write + 1
end
local dims = B.P:size():totable()
for i, d in ipairs(dims) do dims[i] = torch.range(1, d):totable() end
for _, ind in ipairs(table.combinations(dims)) do
if type(P[ind]) == 'number' then
P[ind] = P[ind] * B.P[ind]
else
P[ind]:mul(B.P[ind])
end
end
return ProbTable.new(P, names)
end
--- Marginalizes this probability table in place.
-- @arg {string} name - the variable to marginalize
-- @returns {ProbTable} this probability table with the variable `name` marginalized out
function ProbTable:marginalize(name)
local dim = assert(self.name2index[name], tostring(name) .. ' is not a valid name')
self.P = self.P:sum(dim):squeeze(dim)
if type(self.P) == 'number' then self.P = torch.Tensor{self.P} end
self.name2index[name] = nil
for i = dim, #self.names do
self.names[i] = self.names[i+1]
if self.names[i+1] then
self.name2index[self.names[i+1]] = i
end
end
return self
end
--- Marginalizes this probability table in place to calculate a marginal.
-- @arg {string} name - the variable to calculate
-- @returns {ProbTable} this probability table marginalizing all variables except `name`
function ProbTable:marginal(name)
assert(self.name2index[name], 'Table does not contain variable with name '..name)
while self:size() > 1 do
for i = 1, self:size() do
if self.names[i] ~= name then
self:marginalize(self.names[i])
break
end
end
end
return self
end
--- Normalizes this table by dividing by the sum of all probabilities.
-- @returns {ProbTable} normalized table
function ProbTable:normalize()
self.P:div(self.P:sum())
return self
end
return ProbTable