Skip to content

Issue with CRF loss function #1087

@opus111

Description

@opus111

Here is a file that reproduces the problem. This code is copied from the TextAnalysis package and slightly altered for Flux 10. The version in GitHub works with Flux 9

`#=
This code is copied from CRF of TextAnalysis

The current version in GitHub works with Flux 0.9

https://github.com/JuliaText/TextAnalysis.jl/tree/master/src/CRF
=#

using Flux

log_sum_exp(z) = log_sum_exp(z, maximum(z, dims = 1))
log_sum_exp(z, m) = log.(sum(exp.(z .- m), dims = 1)) .+ m

mutable struct CRF{S}
W::S # Transition Scores
n::Int # Num Labels
end

function CRF(n::Integer)
W = rand(Float32, n + 2, n + 2)
W[:, n + 1] .= -10000
W[n + 2, :] .= -10000
return CRF(W, n)
end

Flux.@functor CRF (W,)

preds_first(c::CRF, y) = c.W[c.n + 1, Flux.onecold(y, 1:length(y))]
preds_last(c::CRF, y) = c.W[Flux.onecold(y, 1:length(y)), c.n + 2]
preds_single(c::CRF, y, y_prev) = c.W[Flux.onecold(y_prev, 1:length(y_prev)), Flux.onecold(y, 1:length(y))]

function forward_score(c::CRF, x, init_α)
forward_var = log_sum_exp((c.W .+ transpose(x[1])) .+ init_α)
for i in 2:length(x)
forward_var = log_sum_exp((c.W .+ transpose(x[i])) .+ transpose(forward_var))
end
fs = log_sum_exp(c.W[:, c.n + 2] + transpose(forward_var))
return fs[1]
end

function score_sequence(c::CRF, x, label_seq)
score = preds_first(c, label_seq[1]) + Flux.onecold(label_seq[1], x[1])
for i in 2:length(label_seq)
score += preds_single(c, label_seq[i], label_seq[i-1]) +
Flux.onecold(label_seq[i], x[i])
end
return score + preds_last(c, label_seq[end])
end

crf_loss(c::CRF, x, label_seq, init_α) = forward_score(c, x, init_α) - score_sequence(c, x, label_seq)

label_count = 10
seq_length = 5
crf = CRF(label_count-2)
init_α = fill(-10000.0,label_count)
init_α[label_count-1] = 0.0
label_seq = [Flux.onehot(i,1:label_count) for i in 1:seq_length]
x = [rand(label_count) for _ in 1:seq_length]
print("crf_loss=$(crf_loss(crf,x,label_seq,init_α))")
print("gradient(crf_loss)=$(gradient(() -> crf_loss(crf,x,label_seq,init_α)))")
`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions