Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch fb #19

Merged
merged 7 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions docs/src/inference.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
# Inference

## Baum-Welch algoritm (forward-backward)

The Baum-Welch algorithm computes the probability to be in a state `i`
at time $n$:
```math
p(z_n = i | x_1, ..., x_N)
```
It is implemented by the [`αβrecursion`](@ref) function.
## Basic algorithms

```@docs
αβrecursion
pdfposteriors
```
4 changes: 2 additions & 2 deletions src/MarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ include("fsm.jl")
export FSM
export addstate!
export determinize
export link!
export addarc!
export minimize
export renormalize!
export setinit!
Expand All @@ -52,7 +52,7 @@ Inference algorithms.

include("algorithms.jl")

export stateposteriors
export pdfposteriors
export bestpath

end
152 changes: 81 additions & 71 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -1,119 +1,129 @@
# SPDX-License-Identifier: MIT

const Abstract3DTensor{T} = AbstractArray{T,3} where T

#======================================================================
Forward recursion
======================================================================#

function αrecursion!(α::AbstractMatrix{T}, π::AbstractVector{T},
Aᵀ::AbstractMatrix{T}, lhs::AbstractMatrix) where T
N = size(lhs, 2)
function αrecursion(π::AbstractVector{SF},
Tᵀ::AbstractMatrix{SF},
lhs::AbstractMatrix{SF}) where SF <: Semifield
S, N = length(π), size(lhs, 2)
α = similar(lhs, SF, S, N)
buffer = similar(lhs[:, 1])

@views elmul!(α[:,1], π, lhs[:,1])
@views for n in 2:N
matmul!(buffer, Aᵀ, α[:,n-1])
matmul!(buffer, Tᵀ, α[:,n-1])
elmul!(α[:,n], buffer, lhs[:,n])
end
α
end

function αrecursion!(α::Abstract3DTensor{T}, π::AbstractVector{T},
Aᵀ::AbstractMatrix{T}, lhs::Abstract3DTensor{T}) where T
N = size(lhs, 2)
buffer = similar(α[:,1,:])
@views elmul!(α[:,1,:], π, lhs[:,1,:])
@views for n in 2:N
matmul!(buffer, Aᵀ, α[:,n-1,:])
elmul!(α[:,n,:], buffer, lhs[:,n,:])
end
α
end

#======================================================================
Backward recursion
======================================================================#

function βrecursion!(β::AbstractMatrix{T}, ω::AbstractVector{T},
A::AbstractMatrix{T}, lhs::AbstractMatrix{T}) where T
N = size(lhs, 2)
function βrecursion(ω::AbstractVector{SF},
T::AbstractMatrix{SF},
lhs::AbstractMatrix{SF}) where SF <: Semifield
S, N = length(ω), size(lhs, 2)
β = similar(lhs, SF, S, N)
buffer = similar(lhs[:, 1])

β[:, end] = ω
@views for n in N-1:-1:1
elmul!(buffer, β[:,n+1], lhs[:,n+1])
matmul!(β[:,n], A, buffer)
end
β
end

function βrecursion!(β::Abstract3DTensor{T}, ω::AbstractVector{T},
A::AbstractMatrix{T}, lhs::Abstract3DTensor{T}) where T
N = size(lhs, 2)
buffer = fill!(similar(β[:,1,:]), one(T))
@views elmul!(β[:,end,:], ω, buffer)
@views for n in N-1:-1:1
elmul!(buffer, β[:,n+1,:], lhs[:,n+1,:])
matmul!(β[:,n,:], A, buffer)
matmul!(β[:,n], T, buffer)
end
β
end

#======================================================================
Generic forward-backward algorithm
Specialized algorithms
======================================================================#

function αβrecursion(π::AbstractVector{T}, ω::AbstractVector{T},
A::AbstractMatrix{T}, Aᵀ::AbstractMatrix,
lhs::AbstractMatrix{T}) where T
S, N = length(π), size(lhs, 2)
α = similar(lhs, T, S, N)
β = similar(lhs, T, S, N)

αrecursion!(α, π, Aᵀ, lhs)
βrecursion!(β, ω, A, lhs)
"""
pdfposteriors(cfsm, lhs)
pdfposteriors(union(cfsm1, cfsm2, ...), batch_lhs)

Calculate the conditional posterior of "assigning" the \$n\$th frame
to the \$i\$th pdf. The output is a tuple `γ, ttl` where `γ` is a matrix
(# pdf x # frames) and `ttl` is the total probability of the sequence.
This function can also be caused in "batch-mode" by providing a union
of compiled fsm and a 3D tensor containing the per-state, per-frame
and per-batch values.
"""
pdfposteriors

function pdfposteriors(in_cfsm::CompiledFSM,
in_lhs::AbstractMatrix{T}) where T <: Real
# Convert the FSM and the likelihood matrix to the Log-semifield.
SF = LogSemifield{T}
cfsm = convert(CompiledFSM{SF}, in_cfsm)
lhs = copyto!(similar(in_lhs, SF), in_lhs)

γ = similar(lhs, T, S, N)
elmul!(γ, α, β)
sums = sum(γ, dims = 1)
eldiv!(γ, γ, sums)
S = size(cfsm.C, 1)
N = size(lhs, 2)

γ, minimum(sums)
end
# Expand the likelihood matrix to get the per-state likelihoods.
state_lhs = matmul!(similar(lhs, S, N), cfsm.C, lhs)

function αβrecursion(π::AbstractVector{T}, ω::AbstractVector{T},
A::AbstractMatrix{T}, Aᵀ::AbstractMatrix,
lhs::Abstract3DTensor{T}) where T
S, N, B = size(lhs)
α = similar(lhs, T, S, N, B)
β = similar(lhs, T, S, N, B)
α = αrecursion(cfsm.π, cfsm.Tᵀ, state_lhs)
β = βrecursion(cfsm.ω, cfsm.T, state_lhs)
state_γ = elmul!(similar(state_lhs), α, β)

αrecursion!(α, π, Aᵀ, lhs)
βrecursion!(β, ω, A, lhs)
# Transform the per-state γs to per-likelihoods γs.
γ = matmul!(lhs, cfsm.Cᵀ, state_γ) # re-use `lhs` memory.

γ = similar(lhs, T, S, N, B)
elmul!(γ, α, β)
# Re-normalize the γs.
sums = sum(γ, dims = 1)
eldiv!(γ, γ, sums)

γ, dropdims(minimum(sums, dims = (1, 2)), dims = (1, 2))
# Convert the result to the Real-semiring
out = copyto!(similar(in_lhs), γ)

exp.(out), convert(T, minimum(sums))
end

#======================================================================
Specialized algorithms
======================================================================#
function pdfposteriors(in_ucfsm::UnionCompiledFSM,
in_lhs::AbstractArray{T,3}) where T <: Real

_convert(T, ttl::Number) = convert(T, ttl)
_convert(T, ttl::AbstractVector) = copyto!(similar(ttl, T), ttl)
# Reshape the 3D tensor to have a matrix of size BK x N where
# B is the number of elements in the batch.
in_lhs_matrix = vcat(eachslice(in_lhs, dims = 3)...)

function stateposteriors(in_cfsm, in_lhs::AbstractArray{T}) where T <: Real
# Convert the FSM and the likelihood matrix to the Log-semifield.
SF = LogSemifield{T}
cfsm = convert(CompiledFSM{SF}, in_cfsm)
lhs = copyto!(similar(in_lhs, SF), in_lhs)
γ, ttl = αβrecursion(cfsm.π, cfsm.ω, cfsm.A, cfsm.Aᵀ, lhs)
cfsm = convert(CompiledFSM{SF}, in_ucfsm.cfsm)
lhs = copyto!(similar(in_lhs_matrix, SF), in_lhs_matrix)

S = size(cfsm.C, 1)
K, N = size(in_lhs)

# Expand the likelihood matrix to get the per-state likelihoods.
state_lhs = matmul!(similar(lhs, S, N), cfsm.C, lhs)

α = αrecursion(cfsm.π, cfsm.Tᵀ, state_lhs)
β = βrecursion(cfsm.ω, cfsm.T, state_lhs)
state_γ = elmul!(similar(state_lhs), α, β)

# Transform the per-state γs to per-likelihoods γs.
γ = matmul!(lhs, cfsm.Cᵀ, state_γ) # re-use `lhs` memory.
γ = permutedims(reshape(γ, K, :, N), (1, 3, 2))

# Re-normalize each element of the batch.
sums = sum(γ, dims = 1)
eldiv!(γ, γ, sums)

# Convert the result to the Real-semiring
out = copyto!(similar(in_lhs), γ)
exp.(out), _convert(T, ttl)

ttl = dropdims(minimum(sums, dims = (1, 2)), dims = (1, 2))
exp.(out), copyto!(similar(ttl, T), ttl)
end

@deprecate stateposteriors(cfsm, lhs) pdfposteriors(cfsm, lhs)

function bestpath(in_cfsm, in_lhs::AbstractArray{T}) where T <: Real
SF = TropicalSemifield{T}
cfsm = convert(CompiledFSM{SF}, in_cfsm)
Expand Down
Loading