## vHDPMM as HDPMM

We generate data from a CRF prior over infinite mixture of Gaussians.

Results may vary a bit from paper, as the data generation is random.

In [1]:
using Distributed
addprocs(4)
@everywhere cd("../")
include("hdp_shared_features.jl")
include("gaussian_generator.jl")

generate_grouped_gaussian_from_hdp_group_counts (generic function with 1 method)

In [2]:
function generate_data(dim, groups_count,sample_count,var,α = 10, γ = 1)
    crf_prior = hdp_prior_crf_draws(sample_count,groups_count,α,γ)
    pts,labels = generate_grouped_gaussian_from_hdp_group_counts(crf_prior[2],dim,var)
    return pts, labels
end

function results_stats(pred_dict, gt_dict)
    avg_nmi = 0
    for i=1:length(pred_dict)
        nmi = mutualinfo(pred_dict[i],gt_dict[i])
        avg_nmi += nmi
    end
    return avg_nmi / length(pred_dict)
end


function run_and_compare(pts,labels,gdim,iters = 100)     
     gprior, lprior = create_default_priors(gdim,0,:niw)
     model = hdp_fit(pts,10,1,gprior,iters)
     model_results = get_model_global_pred(model[1])
     return results_stats(labels,model_results)
 end

run_and_compare (generic function with 2 methods)

### D:3, N:100

In [3]:
gdim = 3
pts,labels = generate_data(gdim,4,100,100.0)
run_and_compare(pts,labels,gdim)

│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:36
└ @ Distributions /home/dinari/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:36
│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:38
└ @ Distributions /home/dinari/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:38
│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:36
└ @ Distributions ~/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:36
│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:38
└ @ Distributions ~/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:38
│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:36
└ @ Distributions ~/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:36
│   caller = Dirichlet{Float64}(::Array{Float64,1}) at dirichlet.jl:38
└ @ Distributions ~/.julia/packages/Distributions/0Wogo/src/multivariate/dirichlet.jl:38
│   caller = Diric

Iteration: 1|| Global Counts: [4]|| iter time: 21.13436484336853
Iteration: 2|| Global Counts: [4]|| iter time: 0.0142669677734375
Iteration: 3|| Global Counts: [4]|| iter time: 0.011148929595947266
Iteration: 4|| Global Counts: [4]|| iter time: 0.011926889419555664
Iteration: 5|| Global Counts: [4]|| iter time: 0.19217491149902344
Iteration: 6|| Global Counts: [4]|| iter time: 0.013338088989257812
Iteration: 7|| Global Counts: [4]|| iter time: 0.011987924575805664
Iteration: 8|| Global Counts: [4]|| iter time: 0.011352062225341797
Iteration: 9|| Global Counts: [4]|| iter time: 0.010693073272705078
Iteration: 10|| Global Counts: [4]|| iter time: 0.011907100677490234
Iteration: 11|| Global Counts: [4]|| iter time: 0.011569023132324219
Iteration: 12|| Global Counts: [4]|| iter time: 0.011690139770507812
Iteration: 13|| Global Counts: [4]|| iter time: 0.015639781951904297
Iteration: 14|| Global Counts: [4]|| iter time: 0.015273094177246094
Iteration: 15|| Global Counts: [4]|| iter time: 0

Iteration: 95|| Global Counts: [7, 7, 9, 5, 7, 4, 13, 9, 7, 7, 1, 5, 14, 5, 3]|| iter time: 0.09143900871276855
Iteration: 96|| Global Counts: [10, 8, 11, 8, 8, 4, 15, 9, 10, 6, 1, 5, 9, 5, 3]|| iter time: 0.08422017097473145
Iteration: 97|| Global Counts: [8, 6, 10, 6, 9, 3, 14, 9, 7, 7, 1, 6, 6, 5, 3]|| iter time: 0.05897402763366699
Iteration: 98|| Global Counts: [8, 6, 7, 6, 10, 4, 14, 8, 8, 8, 1, 6, 6, 4, 3]|| iter time: 0.054937124252319336
Iteration: 99|| Global Counts: [8, 6, 5, 6, 11, 5, 16, 9, 9, 7, 1, 4, 8, 7, 2]|| iter time: 0.06593585014343262
Iteration: 100|| Global Counts: [9, 8, 6, 7, 11, 3, 19, 8, 11, 8, 1, 4, 10, 3, 3, 11]|| iter time: 0.06609892845153809


0.91795921383425

### D:8, N:5000

In [12]:
gdim = 8
pts,labels = generate_data(gdim,4,5000,15.0)
run_and_compare(pts,labels,gdim)

Iteration: 1|| Global Counts: [4]|| iter time: 0.01793694496154785
Iteration: 2|| Global Counts: [4]|| iter time: 0.01627492904663086
Iteration: 3|| Global Counts: [4]|| iter time: 0.016991853713989258
Iteration: 4|| Global Counts: [4]|| iter time: 0.02559494972229004
Iteration: 5|| Global Counts: [4, 4]|| iter time: 0.03721117973327637
Iteration: 6|| Global Counts: [4, 4]|| iter time: 0.028136014938354492
Iteration: 7|| Global Counts: [4, 4]|| iter time: 0.024547100067138672
Iteration: 8|| Global Counts: [4, 4]|| iter time: 0.03251194953918457
Iteration: 9|| Global Counts: [4, 4]|| iter time: 0.027008771896362305
Iteration: 10|| Global Counts: [4, 4, 4, 4]|| iter time: 0.030237913131713867
Iteration: 11|| Global Counts: [4, 4, 4, 4]|| iter time: 0.03408694267272949
Iteration: 12|| Global Counts: [4, 4, 4, 4]|| iter time: 0.04360699653625488
Iteration: 13|| Global Counts: [4, 4, 4, 4]|| iter time: 0.02822399139404297
Iteration: 14|| Global Counts: [4, 4, 4, 4]|| iter time: 0.0406658649

Iteration: 66|| Global Counts: [4, 5, 4, 4, 4, 3, 3, 3, 1, 3, 5, 2, 9, 2, 2, 1, 1, 5, 1, 8, 4, 1, 1, 4, 1, 1, 1, 1, 1, 15, 4, 3, 7, 1, 1, 3, 13, 1, 1, 1, 7, 1, 4, 5, 1, 3, 4, 3, 3, 2, 3, 4]|| iter time: 0.1390681266784668
Iteration: 67|| Global Counts: [4, 6, 4, 4, 5, 3, 3, 3, 1, 3, 5, 2, 8, 2, 2, 1, 1, 5, 1, 8, 4, 1, 1, 3, 1, 1, 1, 1, 1, 8, 7, 2, 7, 1, 1, 3, 14, 1, 1, 1, 7, 1, 3, 6, 1, 2, 4, 3, 3, 2, 3, 4, 2]|| iter time: 0.16808581352233887
Iteration: 68|| Global Counts: [4, 4, 4, 4, 5, 3, 3, 3, 1, 3, 7, 2, 7, 3, 2, 1, 1, 5, 1, 8, 4, 1, 1, 3, 1, 1, 1, 1, 1, 8, 5, 1, 8, 1, 1, 3, 12, 1, 1, 1, 6, 1, 3, 5, 1, 2, 4, 3, 3, 1, 3, 4, 1]|| iter time: 0.13661503791809082
Iteration: 69|| Global Counts: [2, 5, 4, 5, 4, 3, 3, 3, 1, 3, 8, 2, 7, 2, 2, 1, 1, 6, 1, 7, 4, 1, 1, 3, 1, 1, 1, 1, 1, 6, 5, 1, 9, 1, 1, 3, 14, 1, 1, 1, 8, 1, 2, 5, 1, 2, 4, 2, 3, 1, 3, 6, 1]|| iter time: 0.13504600524902344
Iteration: 70|| Global Counts: [4, 4, 5, 5, 3, 3, 3, 1, 3, 8, 2, 9, 2, 2, 1, 1, 5, 1, 10, 4, 1, 1, 4, 1

Iteration: 99|| Global Counts: [5, 4, 4, 5, 3, 3, 3, 1, 3, 6, 2, 8, 3, 2, 1, 1, 4, 1, 4, 4, 1, 1, 5, 1, 1, 1, 1, 1, 8, 2, 4, 8, 1, 1, 3, 1, 1, 1, 1, 9, 1, 3, 3, 1, 3, 4, 3, 3, 1, 2, 8, 2, 3, 2, 1, 4, 2, 1, 5, 4, 2, 3, 5, 2, 4, 5, 4, 6, 2, 2, 4, 1]|| iter time: 0.19680094718933105
Iteration: 100|| Global Counts: [6, 4, 4, 4, 3, 3, 3, 1, 3, 6, 2, 7, 3, 2, 1, 1, 4, 1, 4, 4, 1, 1, 5, 1, 1, 1, 1, 1, 5, 2, 3, 7, 1, 1, 3, 1, 1, 1, 1, 9, 1, 3, 2, 1, 2, 4, 2, 3, 1, 2, 8, 1, 4, 4, 1, 3, 2, 2, 8, 4, 1, 3, 5, 2, 5, 5, 4, 5, 1, 2, 5, 2, 4, 5]|| iter time: 0.18584704399108887
11.334589719772339


0.962523875682653

### D:15 N:50000

In [13]:
gdim = 15
pts,labels = generate_data(gdim,4,50000,5.0,10,0.5)
run_and_compare(pts,labels,gdim)

Iteration: 1|| Global Counts: [4]|| iter time: 0.22115683555603027
Iteration: 2|| Global Counts: [4]|| iter time: 0.21135401725769043
Iteration: 3|| Global Counts: [4]|| iter time: 0.25061893463134766
Iteration: 4|| Global Counts: [4]|| iter time: 0.1435379981994629
Iteration: 5|| Global Counts: [4, 4]|| iter time: 0.3407571315765381
Iteration: 6|| Global Counts: [4, 4]|| iter time: 0.20615506172180176
Iteration: 7|| Global Counts: [4, 4]|| iter time: 0.1717989444732666
Iteration: 8|| Global Counts: [4, 4]|| iter time: 0.17197489738464355
Iteration: 9|| Global Counts: [4, 4]|| iter time: 0.16083312034606934
Iteration: 10|| Global Counts: [4, 4, 4, 4]|| iter time: 0.16608214378356934
Iteration: 11|| Global Counts: [4, 4, 4, 4]|| iter time: 0.20890498161315918
Iteration: 12|| Global Counts: [4, 4, 4, 4]|| iter time: 0.22586393356323242
Iteration: 13|| Global Counts: [4, 4, 4, 4]|| iter time: 0.20958709716796875
Iteration: 14|| Global Counts: [4, 4, 4, 4]|| iter time: 0.21988511085510254


Iteration: 66|| Global Counts: [4, 4, 4, 5, 4, 2, 1, 5, 6, 2, 2, 3, 1, 3, 6, 1, 2, 1, 12, 1, 1, 2, 6, 3, 1, 2, 2, 1, 2, 3, 5, 2, 2, 4, 3, 1, 2, 1, 2, 2, 1, 4, 9, 4, 2, 5, 4, 1, 2, 2, 2, 3, 1, 1, 1, 1, 1, 4, 4, 4, 3, 1, 1]|| iter time: 1.6176300048828125
Iteration: 67|| Global Counts: [4, 4, 4, 5, 4, 2, 1, 5, 5, 2, 2, 3, 1, 3, 5, 1, 2, 1, 13, 1, 1, 2, 6, 3, 1, 2, 2, 1, 2, 3, 4, 2, 2, 4, 3, 1, 2, 1, 3, 2, 1, 5, 7, 4, 2, 5, 5, 1, 2, 1, 2, 4, 1, 1, 1, 1, 1, 5, 5, 3, 3, 1, 2]|| iter time: 1.558840036392212
Iteration: 68|| Global Counts: [4, 4, 4, 6, 4, 2, 1, 4, 7, 2, 2, 3, 1, 3, 6, 1, 2, 2, 10, 1, 1, 2, 5, 4, 1, 2, 2, 1, 2, 3, 4, 2, 2, 4, 3, 1, 2, 1, 2, 2, 1, 5, 8, 4, 2, 6, 7, 1, 2, 1, 2, 4, 1, 1, 1, 1, 1, 4, 4, 3, 3, 1, 3, 5]|| iter time: 1.6090800762176514
Iteration: 69|| Global Counts: [4, 4, 4, 6, 4, 2, 1, 5, 5, 2, 2, 3, 1, 3, 6, 1, 2, 2, 10, 1, 1, 2, 5, 3, 1, 2, 2, 1, 2, 3, 5, 2, 2, 4, 3, 1, 2, 1, 2, 2, 1, 4, 9, 3, 2, 6, 8, 1, 2, 1, 2, 3, 1, 1, 1, 1, 1, 4, 4, 3, 3, 1, 4, 3]|| iter time

Iteration: 95|| Global Counts: [4, 4, 5, 5, 4, 2, 1, 4, 3, 2, 2, 3, 1, 3, 4, 1, 2, 2, 9, 1, 1, 2, 5, 4, 1, 2, 2, 1, 3, 3, 2, 2, 2, 4, 3, 1, 2, 1, 1, 2, 1, 1, 6, 3, 2, 1, 9, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 3, 4, 7, 3, 3, 4, 9, 2, 4, 2, 4, 1, 2, 4, 1, 4, 1, 5, 11, 1, 1, 1, 1]|| iter time: 1.9269239902496338
Iteration: 96|| Global Counts: [4, 4, 4, 6, 4, 2, 1, 5, 4, 2, 2, 3, 1, 3, 5, 1, 2, 2, 8, 1, 1, 2, 4, 4, 1, 2, 2, 1, 3, 3, 3, 2, 2, 4, 3, 1, 2, 1, 1, 2, 1, 1, 7, 3, 2, 1, 9, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 5, 3, 7, 3, 4, 4, 9, 2, 3, 2, 4, 1, 3, 4, 1, 3, 1, 4, 14, 1, 1, 1, 1]|| iter time: 1.851741075515747
Iteration: 97|| Global Counts: [4, 4, 4, 6, 4, 2, 1, 5, 3, 2, 2, 3, 1, 3, 5, 1, 2, 2, 4, 1, 1, 2, 5, 3, 1, 2, 2, 1, 3, 3, 3, 2, 2, 4, 3, 1, 2, 1, 1, 2, 1, 1, 6, 3, 2, 1, 12, 1, 2, 1, 2, 1, 1, 1, 1, 1, 3, 3, 2, 7, 3, 4, 4, 8, 3, 5, 2, 5, 1, 3, 4, 1, 5, 1, 5, 14, 1, 1, 1, 1]|| iter time: 1.8545219898223877
Iteration: 98|| Global Counts: [4, 4, 4, 4, 4, 2, 1, 4, 4, 2, 2, 3, 1, 3, 4, 1, 2, 2,

0.9039166624488165