**Notes on my problems encountered when trying to implement EM-FDS with Turing's `@model`**

1. For some reason not very clear to me, `S` and `T` take on zeros in corresponding entries that result in `NaN`s in `error_rates = T ./ S`. This is why the following loops was necessary.

```julia
    # ...
    error_rates = T ./ S
    for a in 1:A, o in 1:O
        if hasnans(error_rates[a, o, :])
            error_rates[a, o, :] = normalized_class_marginals
        end
    end
    # ...
```

2. This implementation takes very long to execute, about 2 hours even with relatively small parameters `num_particles = 20`, `num_samples = 20`, `num_iterations = 10`. In comparison, [em-gmm](./em-gmm.ipynb) is quite fast, (although still much slower either than the "traditional" `@model` version in [the tutorial](https://turing.ml/dev/tutorials/01-gaussian-mixture-model/) or the EM implementation without `@model`).

3. After just 6 iterations with `num_particles = 20`, `num_samples = 20`, the parameters (after passing them through softmax) take on absurd values. One class completely dominates the other (and the effect is even stronger for `num_particles = 50`, `num_samples = 50`)

```julia
trace = load("./trace_2022-09-23T12:54:18.jld")["trace"]
params = map(trace.opt) do o softmax(o.minimizer) end
```

```sh
10-element Vector{Vector{Float64}}:
 [0.5137429126151726, 0.48625708738482754]
 [0.5196796341353839, 0.4803203658646161]
 [0.5551475166649195, 0.44485248333508054]
 [0.6297826008835576, 0.37021739911644236]
 [0.888102591461752, 0.1118974085382481]
 [0.9989157111987694, 0.0010842888012305854]
 [0.9989157111987694, 0.0010842888012305854]
 [0.9989157111987694, 0.0010842888012305854]
 [0.9989157111987694, 0.0010842888012305854]
 [0.9989157111987694, 0.0010842888012305854]
```

4. It is possible that equations from [the paper](http://sentic.net/wisdom2018sinha.pdf) (primarily (3) and (4) from page 3) and more generally probabilistic assumptions of the algorithm were not translated correctly.

5. If you would like to investigate how model behaves on this dataset with a given choice of sampler parameters before running it, you can load the traces of optimization and chain from `./notebooks/em-gmm/traces`:
   - `trace_20p_20s.jld`
     - `num_particles = 20; num_samples = 20`
   - `trace_50p_50s.jld`
     - `num_particles = 50; num_samples = 50`

# EM-FDS

In [1]:
using DawidSkeneAlgorithms
import DawidSkeneAlgorithms: initialize_class_assignments, m_step # For `init_param`

# Load dataset

dataset = load_rte()
counts = dataset.x # [question x annotators x classes]
size(counts)

(800, 164, 2)

In [2]:
using LinearAlgebra: normalize

"""
Initialize parameters.
Reuses `m_step` and `initialize_class_assignments` defined for the manual implemention of the algorithm.
"""
function init_param(counts)
    class_marginals, error_rates = m_step(FDS(), counts, initialize_class_assignments(FDS(), counts))
    return normalize(class_marginals[:], 1)
end

third(x) = x[3]

"""
Element-wise multiplication in order to remove annotators that didn't answer the question.
"""
function onehot_choices_to_categorical(x::Array{<:Real, 3})::Matrix{Int}
    third.(argmax(x; dims=3)[:, :, 1]) .* maximum(x; dims=3)[:, :, 1]
end

Q, A, O = size(counts) # questions, answers, options
choices = onehot_choices_to_categorical(counts)
class_marginals = init_param(counts)

;

In [3]:
using Distributions
using LogExpFunctions
using IterTools
using Turing


isnanbool(x) = isnan(x) === true
hasnans(x::AbstractArray) = isnanbool.(x) |> any
allnans(x::AbstractArray) = isnanbool.(x) |> all

@model function em_fds(
    choices::Matrix{Int}, # [Q x A]
    unnormalized_class_marginals::Vector{<:Real} # [O] (probability vector)
)
    # Get dimensions
    Q, A = size(choices) # questions, answers
    O = length(unnormalized_class_marginals) # options

    # eq (4)
    normalized_class_marginals = softmax(unnormalized_class_marginals)
    p_Y = Categorical(normalized_class_marginals)
    
    # sample answer sheet
    Ys ~ filldist(p_Y, Q)

    # eq (3)

    # [A x O]
    S = map(product(1:A, 1:O)) do (a, o)
        # Sₐ⁽ᶜ⁾ = {i | Yᵢ = c ∧ a has answered quetsion i}
        length([q for q in 1:Q 
                  if Ys[q] == o
                  && choices[q, a] ≠ 0])
    end

    # [A x O x O]
    T = map(product(1:A, 1:O, 1:O)) do (a, o_true, o_answered)
        # T_{cₐ}⁽ᶜ⁾ = {i | Yᵢ = c ∧ a has answered cₐ on question i}
        length([q for q in 1:Q
                if Ys[q] == o_true 
                && choices[q, a] == o_answered])
    end

    #=
    When `NaN`s occur in `error_rates`, they are arranged along the third dimension,
    which is sampled as a probability vector.
    The minimum risk way of dealing with this is to replace such cases with
    priors/`normalized_class_marginals`.
    =#
    error_rates = T ./ S
    for a in 1:A, o in 1:O
        if hasnans(error_rates[a, o, :])
            error_rates[a, o, :] = normalized_class_marginals
        end
    end
    

    #TODO: try moving it into the loop, maybe avoids lines or add an if inside line 62 
    p_c_given_Y = [hasnans(error_rates[a, o, :]) ? nothing : Categorical(error_rates[a, o, :]) for a in 1:A, o in 1:O]
    for q in 1:Q, a in 1:A
        choices[q, a] == 0 && continue
        choices[q, a] ~ p_c_given_Y[a, Ys[q]]
    end
end


em_fds (generic function with 2 methods)

In [4]:
function get_latent(chn; keep_pct=.5, use_every=1)
    # [latent_dim, num_samples]
    chn.value.data[floor(Int, keep_pct * end) : use_every : end, 1:800] |> Matrix{Int} |> transpose
end

function make_obj(latent)
    function obj(unnormalized_class_marginals)
        normalized_class_marginals = normalize(unnormalized_class_marginals, 1)
        return -mean([
            logjoint(em_fds(choices, normalized_class_marginals), (; Ys=Ys))
            for Ys in eachcol(latent)
            ])
    end
end

make_obj (generic function with 1 method)

In [5]:
using Dates
using JLD
using Random
using Optim
using ProgressMeter

Random.seed!(42)

num_particles = 50
num_samples = 50
num_iterations = 4
sampler = PG(num_particles)

class_marginals = init_param(counts)
trace = (chn=[], opt=[])

@showprogress for i in 1:num_iterations
    # E-step
    chn = sample(em_fds(choices, class_marginals), sampler, num_samples; progress=false)
    push!(trace.chn, chn)
    
    # M-step
    opt = optimize(make_obj(get_latent(chn)), class_marginals)
    push!(trace.opt, opt)
    class_marginals = opt.minimizer

    # Report
    @show i
    @show opt.minimum
    @show chn.logevidence
    flush(stdout)
    display(class_marginals)
end

savename = "./em-gmm-traces/trace_$(num_particles)p_$(num_samples)s.jld"
save(savename, "trace", trace)
println("Saved as $savename")

i = 1
opt.minimum = 5375.146428809849
chn.logevidence = -4828.24796068154


2-element Vector{Float64}:
 0.7021132195158861
 0.4927106311963871

i = 2
opt.minimum = 5373.538804173876
chn.logevidence = -4829.785954218567


2-element Vector{Float64}:
 0.8807285577246271
 0.5212577017128099

[32mProgress:  50%|████████████████████▌                    |  ETA: 0:58:21[39m

i = 3
opt.minimum = 5374.868222671934
chn.logevidence = -4837.555133876674


2-element Vector{Float64}:
 1.3798581257964482
 0.6552944316104014

[32mProgress:  75%|██████████████████████████████▊          |  ETA: 0:29:07[39m

i = 4
opt.minimum = 5313.567909955982
chn.logevidence = -4834.967479359389


2-element Vector{Float64}:
 2.9746660901293502
 0.1441533556854021

[32mProgress: 100%|█████████████████████████████████████████| Time: 1:56:19[39m


Saved as ./em-gmm-traces/trace_50p_50s.jld
