### TODO

- [ ] Make it work for FDS first
- [ ] Extend to other variants
- [ ] Extend even more
- [ ] Docstrings

## Module and data

In [1]:
# Load the module for class initialization functions

using ExpectationMaximization
import ExpectationMaximization: initialize_class_assignments, m_step

# Make sure that this notebook is run from the main directory

endswith(pwd(), "/notebooks") && cd("..")
PROJECT_PATH = pwd()

include("$PROJECT_PATH/src/load_datasets.jl")

dataset = load_rte();

# [questions x one-hot classes]
class_assignments = initialize_class_assignments(FDS(), dataset.crowd_counts);

# [questionx x annotators x classes]
counts = dataset.crowd_counts;

## Model and utils

In [43]:
using Distributions
# using Documenter
using FillArrays
using Turing

@model function em_fds(counts, param)
    # Dims
    n_questions, n_annotators, n_classes = size(counts)
    # Parameters ( class_marginals ≈ prior )
    class_marginals, error_rates = param
    # Class assignments (initialized by `initialize_class_assignments`)
    class_assignments = Categorical(w)
    # Model counts with params (eq 1 and 2)

    for q in 1:n_questions
        for c in 1:n_classes
            class_assignments[q, c] ~ 
        end
    end
end

"""
class_marginals - [questions] (eq. 4)
    (Number of questions having an answer as \$c\$) / (Total number of questions)
error_rates - [questions x choices x choices] (probability distribution over true choices (dim 3) given assigned choices (dim 2)) (eq. 3)
    \$ P(c_a | Y_q = c) = \frac{|T^{(c)}_{c_a}|}{|S^{(c)}_a|} \$
"""
function init_param(counts)::Tuple{AbstractArray{<:Real, 2}, AbstractArray{<:Real, 3}}
    class_marginals, error_rates = m_step(FDS(), counts, initialize_class_assignments(FDS(), counts))
    param = class_marginals, error_rates
    return param
end

# function get_latent(chain; keep_pct=.5, use_every=10)
#     chain.value.data. #...?
# end

init_param

In [44]:
size(error_rates)

(164, 2, 2)

In [47]:
error_rates[1, :, :]

2×2 Matrix{Float64}:
 0.958333  0.0416667
 0.1875    0.8125

In [26]:
class_marginals, error_rates = init_param(counts)
error_rates

164×2×2 Array{Float64, 3}:
[:, :, 1] =
 0.958333  0.1875
 0.794872  0.0913978
 0.75      0.125
 0.84      0.0384615
 0.75      0.125
 0.73262   0.490798
 0.923077  0.145833
 0.778547  0.729084
 0.840376  0.78877
 0.78607   0.748603
 0.822222  0.0571429
 0.636364  0.0212766
 0.444444  0.0
 ⋮         
 1.0       0.125
 0.916667  0.0
 0.833333  0.0
 1.0       0.625
 0.916667  0.0
 1.0       0.0714286
 1.0       0.142857
 0.727273  0.0
 0.727273  0.0
 0.952381  0.0526316
 1.0       0.1
 0.9       0.1

[:, :, 2] =
 0.0416667  0.8125
 0.205128   0.908602
 0.25       0.875
 0.16       0.961538
 0.25       0.875
 0.26738    0.509202
 0.0769231  0.854167
 0.221453   0.270916
 0.159624   0.21123
 0.21393    0.251397
 0.177778   0.942857
 0.363636   0.978723
 0.555556   1.0
 ⋮          
 0.0        0.875
 0.0833333  1.0
 0.166667   1.0
 0.0        0.375
 0.0833333  1.0
 0.0        0.928571
 0.0        0.857143
 0.272727   1.0
 0.272727   1.0
 0.047619   0.947368
 0.0        0.9
 0.1        0.9

In [None]:
using Optim
using ProgressMeter

# (Hyper)params

Random.seed!(42)

num_particles = 100
num_samples = 100
num_iterations = 4
sampler = PG(num_particles)

class_assignments = initialize_class_assignments(FDS, dataset.crowd_counts)
trace = (chn=[], opt=[])

@showprogress for i in 1:num_iterations
    # E-step
    chn = sample(em_fds(x, class_assignments), sampler, num_samples; progress=false)
    push!(trace.chn, chn)

    # M-step
    opt = optimize(make_obj(get_latent(chn)), param)
    push!(trace.opt, opt)

    param = opt.minimizer
    
    println("Iteration $i")
    display(param)
    println()
end