## vHDPMM as HDPMM

We generate data from a CRF prior over infinite mixture of Gaussians.

Results may vary a bit from paper, as the data generation is random.

In [4]:
using Distributed
using Clustering
addprocs(4)
@everywhere using VersatileHDPMixtureModels

In [2]:
function generate_data(dim, groups_count,sample_count,var,α = 10, γ = 1)
    crf_prior = hdp_prior_crf_draws(sample_count,groups_count,α,γ)
    pts,labels = generate_grouped_gaussian_from_hdp_group_counts(crf_prior[2],dim,var)
    return pts, labels
end

function results_stats(pred_dict, gt_dict)
    avg_nmi = 0
    for i=1:length(pred_dict)
        nmi = mutualinfo(pred_dict[i],gt_dict[i])
        avg_nmi += nmi
    end
    return avg_nmi / length(pred_dict)
end


function run_and_compare(pts,labels,gdim,iters = 100)     
     gprior, lprior = create_default_priors(gdim,0,:niw)
     model = hdp_fit(pts,10,1,gprior,iters)
     model_results = get_model_global_pred(model[1])
     return results_stats(labels,model_results)
 end

run_and_compare (generic function with 2 methods)

### D:3, N:100

In [5]:
gdim = 3
pts,labels = generate_data(gdim,4,100,100.0)
run_and_compare(pts,labels,gdim)

Iteration: 1|| Global Counts: [4]|| iter time: 0.05887603759765625
Iteration: 2|| Global Counts: [4]|| iter time: 0.011872053146362305
Iteration: 3|| Global Counts: [4]|| iter time: 0.010411977767944336
Iteration: 4|| Global Counts: [4]|| iter time: 0.013432025909423828
Iteration: 5|| Global Counts: [4, 4]|| iter time: 0.006949186325073242
Iteration: 6|| Global Counts: [4, 4]|| iter time: 0.008085966110229492
Iteration: 7|| Global Counts: [4, 4]|| iter time: 0.008291006088256836
Iteration: 8|| Global Counts: [4, 4]|| iter time: 0.00836491584777832
Iteration: 9|| Global Counts: [4, 4]|| iter time: 0.009179115295410156
Iteration: 10|| Global Counts: [4, 4, 4]|| iter time: 0.009495973587036133
Iteration: 11|| Global Counts: [4, 4, 5]|| iter time: 0.010766029357910156
Iteration: 12|| Global Counts: [6, 4, 5]|| iter time: 0.015597105026245117
Iteration: 13|| Global Counts: [6, 4, 6]|| iter time: 0.013074159622192383
Iteration: 14|| Global Counts: [6, 4, 5]|| iter time: 0.013733148574829102


0.8627687962673293

### D:8, N:5000

In [6]:
gdim = 8
pts,labels = generate_data(gdim,4,5000,15.0)
run_and_compare(pts,labels,gdim)

Iteration: 1|| Global Counts: [4]|| iter time: 0.05582594871520996
Iteration: 2|| Global Counts: [4]|| iter time: 0.01994800567626953
Iteration: 3|| Global Counts: [4]|| iter time: 0.024904966354370117
Iteration: 4|| Global Counts: [4]|| iter time: 0.02165985107421875
Iteration: 5|| Global Counts: [4, 4]|| iter time: 0.023563146591186523
Iteration: 6|| Global Counts: [4, 4]|| iter time: 0.03844785690307617
Iteration: 7|| Global Counts: [4, 4]|| iter time: 0.03132796287536621
Iteration: 8|| Global Counts: [4, 4]|| iter time: 0.029201984405517578
Iteration: 9|| Global Counts: [4, 4]|| iter time: 0.02880096435546875
Iteration: 10|| Global Counts: [4, 4, 4, 4]|| iter time: 0.037361860275268555
Iteration: 11|| Global Counts: [4, 4, 4, 4]|| iter time: 0.03480410575866699
Iteration: 12|| Global Counts: [4, 4, 4, 4]|| iter time: 0.0396578311920166
Iteration: 13|| Global Counts: [4, 4, 4, 4]|| iter time: 0.03880500793457031
Iteration: 14|| Global Counts: [4, 4, 4, 4]|| iter time: 0.035665988922

0.923011870101779

### D:15 N:50000

In [7]:
gdim = 15
pts,labels = generate_data(gdim,4,50000,5.0,10,0.5)
run_and_compare(pts,labels,gdim)

Iteration: 1|| Global Counts: [4]|| iter time: 0.25215816497802734
Iteration: 2|| Global Counts: [4]|| iter time: 0.3607611656188965
Iteration: 3|| Global Counts: [4]|| iter time: 0.19218683242797852
Iteration: 4|| Global Counts: [4]|| iter time: 0.2593369483947754
Iteration: 5|| Global Counts: [4, 4]|| iter time: 0.19620800018310547
Iteration: 6|| Global Counts: [4, 4]|| iter time: 0.244581937789917
Iteration: 7|| Global Counts: [4, 4]|| iter time: 0.3289511203765869
Iteration: 8|| Global Counts: [4, 4]|| iter time: 0.2682380676269531
Iteration: 9|| Global Counts: [4, 4]|| iter time: 0.2240128517150879
Iteration: 10|| Global Counts: [4, 4, 4, 4]|| iter time: 0.2401740550994873
Iteration: 11|| Global Counts: [4, 4, 4, 4]|| iter time: 0.471937894821167
Iteration: 12|| Global Counts: [4, 4, 4, 4]|| iter time: 0.4046669006347656
Iteration: 13|| Global Counts: [4, 4, 4, 4]|| iter time: 0.3314208984375
Iteration: 14|| Global Counts: [4, 4, 4, 4]|| iter time: 0.5108201503753662
Iteration: 15

0.9999339048960051