### Missing Data Experiment

The raw data for this experiemnt can be downloaded from https://ifcs.boku.ac.at/repository/data/tetragonula_bee/index.html, and is supplied in the `DATA/Bees/Tetragonula.csv` , the code load numpy files with the raw data and labels (no pre processing).

Only the vHDPMM and DPMM experiments are included, as GMM was done using sklearn in python.

Note - Here we demonstrate a single iteration, thus not mean or std, this is especialy true for the vHDPMM with missing data, as we took random partitions each time.

In [1]:
using NPZ
using DPMMSubClusters
using LinearAlgebra
using Clustering
using Random

#### Raw Data Loading

In [2]:
data = copy(npzread("DATA/Bees/bees_x.npy")')
labels = npzread("DATA/Bees/bees_y.npy");

In [3]:
size(data)

(13, 236)

In [4]:
DPMMPrior = DPMMSubClusters.niw_hyperparams(1,zeros(13),16,Matrix{Float64}(I, 13, 13)*0.1)
DPMMResults = DPMMSubClusters.fit(data,DPMMPrior,100.0,iters = 100, gt = labels, verbose = false);

│   caller = Distributions.Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:36
└ @ Distributions /home/dinari/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:36
│   caller = Distributions.Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:38
└ @ Distributions /home/dinari/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:38


#### DPMM NMI:

In [5]:
println(mutualinfo(DPMMResults[1], labels; normed =true))

0.6051868832576468


We will partition the data into 4 groups (at random), and create a version where all features exists, and a version where a random count of features are missing

In [6]:
perm = randperm(size(data,2))
mixed_data = data[:,perm]
mixed_labels = labels[perm]

interval = Int(length(mixed_labels)/4)

group_indices = Int.(collect(1:interval:length(mixed_labels)))
labels_dict = Dict()
full_data_dict = Dict()
missing_data_dict = Dict()
features_count = rand(1:7,4)
base_features = collect(1:6)
for (i,v) in enumerate(group_indices)
    relevant_data = mixed_data[:,v:v+interval-1]
    labels_dict[i] = mixed_labels[v:v+interval-1]
    full_data_dict[i] = relevant_data
    choosen_features = Int.(vcat(base_features,(randperm(7).+6)[1:features_count[i]]))
    missing_data_dict[i] = relevant_data[choosen_features,:]
end   

We will now run our model, initially adding process (as it must have atleast 1 worker process)

In [7]:
using Distributed
addprocs(4)

4-element Array{Int64,1}:
 2
 3
 4
 5

In [8]:
cur_dir = pwd()
@everywhere cd("../")
include("hdp_shared_features.jl")

results_stats (generic function with 1 method)

In [9]:
local_priors = [niw_hyperparams(1.0, zeros(i),i+3,Matrix{Float64}(I, i, i)*0.1) for i in features_count]
global_hyper_params = niw_hyperparams(1.0, zeros(6), 9, Matrix{Float64}(I, 6, 6)*0.1)
constant_local_prior = niw_hyperparams(1.0, zeros(7), 10, Matrix{Float64}(I, 7, 7)*0.1)

niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 10.0, [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1])

In [10]:
missing_data_results = vhdp_fit(missing_data_dict,6,100.0,100.0,100.0,global_hyper_params,local_priors,100)

│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:36
└ @ Distributions ~/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:36
│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:38
└ @ Distributions ~/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:38
│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:36
└ @ Distributions ~/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:36
│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:38
└ @ Distributions ~/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:38
│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:36
└ @ Distributions ~/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:36
│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:38
└ @ Distributions ~/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:38
│   caller = Dirichlet{Float64}(::Array{

Iteration: 1|| Global Counts: [6]|| iter time: 20.527177095413208
Iteration: 2|| Global Counts: [6]|| iter time: 0.051473140716552734
Iteration: 3|| Global Counts: [6]|| iter time: 0.011178970336914062
Iteration: 4|| Global Counts: [9]|| iter time: 0.5040509700775146
Iteration: 5|| Global Counts: [10, 10]|| iter time: 0.6376841068267822
Iteration: 6|| Global Counts: [5, 12]|| iter time: 0.3723890781402588
Iteration: 7|| Global Counts: [5, 11]|| iter time: 0.015729188919067383
Iteration: 8|| Global Counts: [5, 14]|| iter time: 0.01861715316772461
Iteration: 9|| Global Counts: [5, 12]|| iter time: 0.16479992866516113
Iteration: 10|| Global Counts: [5, 12, 5, 12]|| iter time: 0.10274100303649902
Iteration: 11|| Global Counts: [5, 11, 3, 8]|| iter time: 0.020714998245239258
Iteration: 12|| Global Counts: [5, 11, 3, 8]|| iter time: 0.02478480339050293
Iteration: 13|| Global Counts: [5, 11, 3, 8]|| iter time: 0.020206928253173828
Iteration: 14|| Global Counts: [5, 11, 3, 7]|| iter time: 0.01

Iteration: 91|| Global Counts: [1, 4, 1, 8, 4, 4, 2, 4, 3, 1, 5]|| iter time: 0.031244993209838867
Iteration: 92|| Global Counts: [1, 4, 1, 8, 4, 4, 2, 4, 3, 1, 5]|| iter time: 0.02950119972229004
Iteration: 93|| Global Counts: [1, 4, 1, 8, 4, 4, 2, 4, 3, 1, 5]|| iter time: 0.029187917709350586
Iteration: 94|| Global Counts: [1, 4, 1, 8, 4, 4, 2, 4, 3, 1, 5]|| iter time: 0.031241893768310547
Iteration: 95|| Global Counts: [1, 4, 1, 8, 4, 4, 2, 4, 3, 1, 5]|| iter time: 0.03681683540344238
Iteration: 96|| Global Counts: [1, 4, 1, 8, 4, 4, 2, 4, 3, 1, 5]|| iter time: 0.02806401252746582
Iteration: 97|| Global Counts: [1, 4, 1, 8, 4, 4, 2, 4, 3, 1, 5]|| iter time: 0.02947092056274414
Iteration: 98|| Global Counts: [1, 4, 1, 8, 4, 4, 2, 4, 3, 1, 5]|| iter time: 0.03415703773498535
Iteration: 99|| Global Counts: [1, 4, 1, 8, 4, 4, 2, 4, 3, 1, 5]|| iter time: 0.03452801704406738
Iteration: 100|| Global Counts: [1, 4, 1, 8, 4, 4, 2, 4, 3, 1, 5]|| iter time: 0.029803037643432617


(hdp_shared_features(model_hyper_params(niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 9.0, [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), niw_hyperparams[niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0, 0.0], 8.0, [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0], 7.0, [0.1 0.0 0.0 0.0; 0.0 0.1 0.0 0.0; 0.0 0.0 0.1 0.0; 0.0 0.0 0.0 0.1]), niw_hyperparams(1.0, [0.0], 4.0, [0.1]), niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 10.0, [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1])], 100.0, 100.0, 100.0, 1.0, 1.0, 11, 7), Dict{Any,Any}(4 => local_group(model_hyper_params(niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 9.0, [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 10.0, [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0

#### Missing Data NMI:

In [11]:
NMI = 0.0
for i=1:4
    group_labels = create_global_labels(missing_data_results[1].groups_dict[i])
    NMI += mutualinfo(group_labels,labels_dict[i])
end
println(NMI/4)

0.8427130563971763


In [12]:
full_data_results = vhdp_fit(full_data_dict,6,100.0,100.0,100.0,global_hyper_params,constant_local_prior,100)

Iteration: 1|| Global Counts: [8]|| iter time: 0.013607025146484375
Iteration: 2|| Global Counts: [8]|| iter time: 0.01246500015258789
Iteration: 3|| Global Counts: [11]|| iter time: 0.013159990310668945
Iteration: 4|| Global Counts: [14]|| iter time: 0.013863086700439453
Iteration: 5|| Global Counts: [16, 16]|| iter time: 0.01936507225036621
Iteration: 6|| Global Counts: [11, 18]|| iter time: 0.07648205757141113
Iteration: 7|| Global Counts: [10, 15]|| iter time: 0.02076888084411621
Iteration: 8|| Global Counts: [10, 15]|| iter time: 0.02622699737548828
Iteration: 9|| Global Counts: [11, 14]|| iter time: 0.02179408073425293
Iteration: 10|| Global Counts: [10, 14, 10]|| iter time: 0.020015954971313477
Iteration: 11|| Global Counts: [5, 17, 7]|| iter time: 0.024630069732666016
Iteration: 12|| Global Counts: [5, 17, 7]|| iter time: 0.02058100700378418
Iteration: 13|| Global Counts: [5, 17, 7]|| iter time: 0.02791905403137207
Iteration: 14|| Global Counts: [5, 17, 7]|| iter time: 0.027915

Iteration: 91|| Global Counts: [4, 2, 4, 4, 1, 3, 10, 1, 5, 4]|| iter time: 0.026715993881225586
Iteration: 92|| Global Counts: [4, 2, 4, 4, 1, 3, 10, 1, 5, 4]|| iter time: 0.034481048583984375
Iteration: 93|| Global Counts: [4, 2, 4, 4, 1, 3, 10, 1, 5, 4]|| iter time: 0.03551602363586426
Iteration: 94|| Global Counts: [4, 2, 4, 4, 1, 3, 10, 1, 5, 4]|| iter time: 0.028322935104370117
Iteration: 95|| Global Counts: [4, 2, 4, 4, 1, 3, 10, 1, 5, 4]|| iter time: 0.03683614730834961
Iteration: 96|| Global Counts: [4, 2, 4, 4, 1, 3, 10, 1, 5, 4]|| iter time: 0.028318166732788086
Iteration: 97|| Global Counts: [4, 2, 4, 4, 1, 3, 10, 1, 5, 4]|| iter time: 0.028203964233398438
Iteration: 98|| Global Counts: [4, 2, 4, 4, 1, 3, 10, 1, 5, 4]|| iter time: 0.03499293327331543
Iteration: 99|| Global Counts: [4, 2, 4, 4, 1, 3, 10, 1, 5, 4]|| iter time: 0.028522014617919922
Iteration: 100|| Global Counts: [4, 2, 4, 4, 1, 3, 10, 1, 5, 4]|| iter time: 0.02923583984375


(hdp_shared_features(model_hyper_params(niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 9.0, [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 10.0, [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), 100.0, 100.0, 100.0, 1.0, 1.0, 13, 7), Dict{Any,Any}(4 => local_group(model_hyper_params(niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 9.0, [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), niw_hyperparams(1.0, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 10.0, [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), 100.0, 100.0, 100.0, 1.0, 1.0, 13, 7), [16.416400909423828 18.518800735473633 … 20.6205997467041 17.31730079650879; 0.0 15.515299797058105 … 12.112099647521973 0.0; … ; 11.01099967956543 10.010199546813965 … 12.812800407409668 15.414600372314453; 18.11840057373047 16.616600036621094 … 15.715700

#### Full Data NMI

In [14]:
NMI = 0.0
for i=1:4
    group_labels = create_global_labels(full_data_results[1].groups_dict[i])
    NMI += mutualinfo(group_labels,labels_dict[i])
end
println(NMI/4)

0.8563517631792423
