-
-
Notifications
You must be signed in to change notification settings - Fork 617
Description
Here is a file that reproduces the problem. This code is copied from the TextAnalysis package and slightly altered for Flux 10. The version in GitHub works with Flux 9
`#=
This code is copied from CRF of TextAnalysis
The current version in GitHub works with Flux 0.9
https://github.com/JuliaText/TextAnalysis.jl/tree/master/src/CRF
=#
using Flux
log_sum_exp(z) = log_sum_exp(z, maximum(z, dims = 1))
log_sum_exp(z, m) = log.(sum(exp.(z .- m), dims = 1)) .+ m
mutable struct CRF{S}
W::S # Transition Scores
n::Int # Num Labels
end
function CRF(n::Integer)
W = rand(Float32, n + 2, n + 2)
W[:, n + 1] .= -10000
W[n + 2, :] .= -10000
return CRF(W, n)
end
Flux.@functor CRF (W,)
preds_first(c::CRF, y) = c.W[c.n + 1, Flux.onecold(y, 1:length(y))]
preds_last(c::CRF, y) = c.W[Flux.onecold(y, 1:length(y)), c.n + 2]
preds_single(c::CRF, y, y_prev) = c.W[Flux.onecold(y_prev, 1:length(y_prev)), Flux.onecold(y, 1:length(y))]
function forward_score(c::CRF, x, init_α)
forward_var = log_sum_exp((c.W .+ transpose(x[1])) .+ init_α)
for i in 2:length(x)
forward_var = log_sum_exp((c.W .+ transpose(x[i])) .+ transpose(forward_var))
end
fs = log_sum_exp(c.W[:, c.n + 2] + transpose(forward_var))
return fs[1]
end
function score_sequence(c::CRF, x, label_seq)
score = preds_first(c, label_seq[1]) + Flux.onecold(label_seq[1], x[1])
for i in 2:length(label_seq)
score += preds_single(c, label_seq[i], label_seq[i-1]) +
Flux.onecold(label_seq[i], x[i])
end
return score + preds_last(c, label_seq[end])
end
crf_loss(c::CRF, x, label_seq, init_α) = forward_score(c, x, init_α) - score_sequence(c, x, label_seq)
label_count = 10
seq_length = 5
crf = CRF(label_count-2)
init_α = fill(-10000.0,label_count)
init_α[label_count-1] = 0.0
label_seq = [Flux.onehot(i,1:label_count) for i in 1:seq_length]
x = [rand(label_count) for _ in 1:seq_length]
print("crf_loss=$(crf_loss(crf,x,label_seq,init_α))")
print("gradient(crf_loss)=$(gradient(() -> crf_loss(crf,x,label_seq,init_α)))")
`