# Introduction

In [1]:
import Pkg
Pkg.activate("..")
import TOML
using MarkovModels
import MarkovModels: Semirings
using JLD2
using HDF5
using ProgressMeter

[32m[1m  Activating[22m[39m project at `/mnt/matylda3/ikocour/project/FAST-ASR/recipes/datasets/wsj`


In [2]:
config = TOML.parsefile("../conf/config.toml")

Dict{String, Any} with 7 entries:
  "graphs"     => Dict{String, Any}("units"=>"lang/units.txt", "sil_sym"=>"!SIL…
  "features"   => Dict{String, Any}("var_norm"=>false, "name"=>"mfcc_hires_16kH…
  "output"     => Dict{String, Any}("dir"=>"output")
  "model"      => Dict{String, Any}("hidden-dims"=>[384, 384, 384, 384, 384], "…
  "experiment" => Dict{String, Any}("dir"=>"./exp/nnet")
  "dataset"    => Dict{String, Any}("name"=>"wsj", "test"=>"test_eval92", "trai…
  "training"   => Dict{String, Any}("epochs"=>20, "update_freq"=>2, "optimizer"…

# Prepare refs

In [3]:
testdir = joinpath(config["dataset"]["dir"], config["dataset"]["name"], config["dataset"]["test"])
lexfile = joinpath("..", config["graphs"]["lexicon"])
textfile = joinpath(testdir, "text")

"/mnt/matylda3/ikocour/SCMR/WSJ_mix/data/wsj/test_eval92/text"

In [4]:
import Base.Iterators: flatten

In [5]:
function prepare_refs(lexfile, textfile)
    lexicon = open(lexfile) do f
        lexicon = Dict{String, Vector{String}}()
        for line in readlines(f)
            word, pron... = split(line)
            lexicon[word] = pron
        end
        lexicon
    end
    
    open(textfile) do f
        refs = Dict{String, Vector{String}}()
        for line in readlines(f)
            uttid, words... = split(line)
            phones = map(words) do w
                if w in keys(lexicon)
                    lexicon[w]
                else
                    w = replace(w, "," => "", ";" => "", ":" => "", "!" => "", "?" => "")
                    lexicon[w]
                end
            end
            refs[uttid] = flatten(phones) |> collect
        end
        refs
    end
end

prepare_refs (generic function with 1 method)

In [6]:
refs = prepare_refs(lexfile, textfile)

Dict{String, Vector{String}} with 333 entries:
  "440c040z" => ["DH", "IY", "T", "UW", "K", "AH", "M", "P", "AH", "N"  …  "L",…
  "443c040y" => ["B", "AH", "T", "AH", "T", "EH", "M", "P", "T", "S"  …  "F", "…
  "441c040y" => ["DH", "IY", "G", "EY", "T", "S", "AH", "N", "D", "K"  …  "AO",…
  "445c040i" => ["DH", "AH", "S", "IH", "N", "AO", "R", "D", "ER", "T"  …  "EH"…
  "442c040m" => ["IH", "T", "M", "EY", "B", "IY", "DH", "AH", "T", "AW"  …  "F"…
  "446c0409" => ["AH", "N", "AH", "DH", "ER", "S", "P", "OW", "K", "S"  …  "V",…
  "441c040g" => ["NSN", "Y", "EH", "S", "T", "ER", "D", "IY", "Z", "S"  …  "F",…
  "445c040z" => ["S", "OW", "AH", "N", "AW", "R", "L", "IY", "AA", "P"  …  "IY"…
  "444c040g" => ["NSN", "W", "AH", "T", "EH", "V", "ER", "DH", "IY", "K"  …  "T…
  "447c040i" => ["N", "AA", "K", "AH", "S", "OW", "N", "IY", "S", "AH"  …  "P",…
  "445c040o" => ["IH", "F", "DH", "IY", "S", "OW", "V", "IY", "EH", "T"  …  "R"…
  "445c040s" => ["DH", "IY", "K", "AH", "N", "S", "AH", "L", "

# Extract Hypothesis

In [7]:
graphsdir = joinpath("../", config["graphs"]["dir"], config["dataset"]["name"])
den_fsmfile = joinpath(graphsdir, "denominator_fsm.jld2")

outdir = joinpath(config["experiment"]["dir"], config["dataset"]["name"], config["output"]["dir"])
outfile = joinpath("..", outdir, "test.h5")

In [9]:
function decode(mfsm::MatrixFSM{SR}, in_lhs::AbstractArray) where SR<:TropicalSemiring
    μ = maxstateposteriors(mfsm, in_lhs)
    path = bestpath(mfsm, μ)
    mfsm.labels[path]
end

function decode_dataset(h5file::String, mfsm::MatrixFSM{SR}; kwargs...) where SR<:TropicalSemiring
    h5open(h5file, "r") do f
        decode_dataset(f, mfsm; kwargs...)
    end
end

function decode_dataset(output::HDF5.File, mfsm::MatrixFSM{SR}; acwt=1.0) where SR<:TropicalSemiring
    hyps = Dict{String, Vector{String}}()
    @showprogress 1 "Decoding dataset..." for k in keys(output)
        lhs = acwt * Array(output[k])
        hyp = decode(mfsm, lhs)
        hyps[k] = map(first, hyp)
    end
    return hyps
end

decode_dataset (generic function with 2 methods)

## Process hyps

In [10]:
import StatsBase: rle

process_hyp(hyp) = begin
    hyp = first(rle(hyp)) # remove repeated phones
    filter(x -> x != "SIL", hyp) # remove SIL
end

process_hyp (generic function with 1 method)

In [11]:
function prepare_hyps(h5file::String, graph::String; kwargs...)
    mfsm = jldopen(graph) do f 
        convert(MatrixFSM{TropicalSemiring{Float32}}, f["fsm"])
    end
    prepare_hyps(h5file, mfsm)
end

function prepare_hyps(h5file::String, mfsm::MatrixFSM; kwargs...)
    hyps = decode_dataset(h5file, mfsm; kwargs...)
    return Dict(uttid => process_hyp(hyp) for (uttid, hyp) in hyps)
end

hyps = prepare_hyps(outfile, den_fsmfile; acwt=1.0);

[32mDecoding dataset... 100%|████████████████████████████████| Time: 0:00:04[39m


# Score

In [12]:
using WordErrorRate

function score(refs::Dict{String, T}, hyps::Dict{String, T}) where T<:Vector{String}
    wer_stats = Dict()
    for uttid in keys(refs)
        ref = refs[uttid]
        hyp = hyps[uttid]
        wer_stats[uttid] = WER(ref, hyp)
    end
    return wer_stats
end

function compute_wer(wer_stats::Dict)
    S = foldl(values(wer_stats); init=0) do a,b a + b.nsub end
    I = foldl(values(wer_stats); init=0) do a,b a + b.nins end
    D = foldl(values(wer_stats); init=0) do a,b a + b.ndel end
    N = foldl(values(wer_stats); init=0) do a,b a + length(b.ref) end
    return round((S+I+D)/N * 100; digits=2), N, S, I, D
end

function dump_wer(io::IO, wer_stats::Dict)
    for uttid in sort(collect(keys(wer_stats)))
        write(io, uttid)
        write(io, "\n")
        write(io, pralign(String, wer_stats[uttid]))
        write(io, "--\n")
    end
end

wer_stats = score(refs, hyps)
wer_result, nref, nsub, nins, ndel = compute_wer(wer_stats)
println("%WER $wer_result [ $(nsub+nins+ndel) / $nref, $nins ins, $ndel del, $nsub sub ]")

%WER 16.05 [ 3792 / 23623, 402 ins, 807 del, 2583 sub ]


In [None]:
dump_wer(stdout, wer_stats)

# Experiments with LM weight

In [16]:
refs = prepare_refs(lexfile, textfile)

mfsm = jldopen(den_fsmfile) do f 
    convert(MatrixFSM{TropicalSemiring{Float32}}, f["fsm"])
end

let best_wer = Inf, best_wer_stats = nothing, best_lmwt = nothing
    for lmwt in 0.8:0.05:1.2
        hyps = prepare_hyps(outfile, mfsm; acwt=1/lmwt)
        wer_stats = score(refs, hyps)
        wer_result, nref, nsub, nins, ndel = compute_wer(wer_stats)
        println("$lmwt: %WER $wer_result [ $(nsub+nins+ndel) / $nref, $nins ins, $ndel del, $nsub sub ]")
        if best_wer > wer_result
            best_wer = wer_result
            best_wer_stats = wer_stats
            best_lmwt = lmwt
        end
    end
end

println("Best WER: LMWT = $best_lmwt")
wer_result, nref, nsub, nins, ndel = compute_wer(wer_stats)
println("%WER $wer_result [ $(nsub+nins+ndel) / $nref, $nins ins, $ndel del, $nsub sub ]")
dump_wer(stdout, wer_stats)

[32mDecoding dataset... 100%|████████████████████████████████| Time: 0:00:01[39m


0.8: %WER 16.17 [ 3821 / 23623, 494 ins, 703 del, 2624 sub ]


[32mDecoding dataset... 100%|████████████████████████████████| Time: 0:00:01[39m


0.85: %WER 16.17 [ 3820 / 23623, 471 ins, 731 del, 2618 sub ]


[32mDecoding dataset... 100%|████████████████████████████████| Time: 0:00:01[39m


0.9: %WER 16.05 [ 3792 / 23623, 449 ins, 746 del, 2597 sub ]


[32mDecoding dataset... 100%|████████████████████████████████| Time: 0:00:01[39m


0.95: %WER 16.06 [ 3795 / 23623, 422 ins, 782 del, 2591 sub ]


[32mDecoding dataset... 100%|████████████████████████████████| Time: 0:00:01[39m


1.0: %WER 16.05 [ 3792 / 23623, 402 ins, 807 del, 2583 sub ]


[32mDecoding dataset... 100%|████████████████████████████████| Time: 0:00:01[39m


1.05: %WER 16.03 [ 3787 / 23623, 379 ins, 824 del, 2584 sub ]


[32mDecoding dataset... 100%|████████████████████████████████| Time: 0:00:01[39m


1.1: %WER 16.04 [ 3789 / 23623, 365 ins, 844 del, 2580 sub ]


[32mDecoding dataset... 100%|████████████████████████████████| Time: 0:00:01[39m


1.15: %WER 16.12 [ 3808 / 23623, 348 ins, 881 del, 2579 sub ]


[32mDecoding dataset... 100%|████████████████████████████████| Time: 0:00:01[39m


1.2: %WER 16.17 [ 3821 / 23623, 331 ins, 913 del, 2577 sub ]
Best WER: LMWT = 1.05
%WER 16.17 [ 3821 / 23623, 331 ins, 913 del, 2577 sub ]
440c0401
Scores: (#C #S #D #I) 95 12 4 2
REF:  D R AE V ** OW L AO S T M AH N TH AH G R IY D IH N P R IH N S AH P AH L T UW S EH L IH T S IH N L AE N D W AO T ER T R AE N S P * ER T EY SH AH N S T IY V AH D AO R IH NG AH N D P AY P F AE B R IH K EY SH AH N B IH Z N IH S IH Z F R ER AH N AH N D IH S K L OW Z D S AH M 
HYP:  T R EY V AH V  L AE S T M AH N TH ER G R IY D IH N P R IH N S AH P AH L T IH S EH L IH T S IH N L AH N D W AO T ER T R AE N S P R AH T EY SH AH N S T IY V IH D AO R IH NG AH N * P AY K F AE B * ER K EY SH AH N B IH Z N IH S AH Z F R ** AH N AH N D IH S K L OW Z * S AH M 
Eval: S   S    I  S    S                S                                       S                       S                             I S                        S                    D      S        D S                             S        D                        