-
Notifications
You must be signed in to change notification settings - Fork 5
/
InferenceOnTest_NPDSLDA.m~
95 lines (78 loc) · 3.14 KB
/
InferenceOnTest_NPDSLDA.m~
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
function [testmodel, perfval, probassign] = InferenceOnTest_NPDSLDA(trmodel, data, MAXESTEPITER)
%% inference on test data
probassign = [];
testmodel = inittestmodel_NPDSLDA(trmodel, data);
count1 = 0;
maxvalue = -100000000000000000000;
disp('inference on test data starts');
while(count1<MAXESTEPITER)
disp('count from E-step in only inference');
count2 = 0;
while(count2<10)
%%%%%%%%%% small-phi update %%%%%%%%%%%%%%%%
[arg1, arg2] = prepargsforsmallphi(testmodel);
testmodel.smallphi = update_smallphi(testmodel, data, arg1, arg2, count2);
%disp('small_phi update done');
% value = likelihood_NPDSLDA(testmodel, data)
% if (compareval(value, maxvalue))
% maxvalue = value;
% else
% error('Incorrect after small_phi');
% end
%%%%%%%%%% zeta update %%%%%%%%%%%%%%%%
[arg1, arg2, arg3] = prepargsforzeta(testmodel);
[testmodel.zeta, testmodel.ss_lambda] = update_zeta(testmodel, data, arg1, arg2, arg3, count2);
testmodel.sumzeta = sum_zeta(testmodel, data);
%disp('zeta update done');
% value = likelihood_NPDSLDA(testmodel, data)
% if (compareval(value, maxvalue))
% maxvalue = value;
% else
% error('Incorrect after zeta');
% end
%%%%%%%%%% a update %%%%%%%%%%%%%%%%
testmodel = update_a(testmodel, data);
%disp('a update done');
% if(count1==0 && count2==0)
% else
% value = likelihood_NPDSLDA(testmodel, data)
% if (compareval(value, maxvalue))
% maxvalue = value;
% else
% error('Incorrect after a');
% end
% end
%%%%%%%%%% b update %%%%%%%%%%%%%%%%
testmodel = update_b(testmodel);
%disp('b update done');
% if(count1==0 && count2==0)
% else
% value = likelihood_NPDSLDA(testmodel, data)
% if (compareval(value, maxvalue))
% maxvalue = value;
% else
% error('Incorrect after b');
% end
% end
count2 = count2+1;
end
testmodel = update_mun(testmodel, data);
%disp('mu_n update done');
% value = likelihood_NPDSLDA(testmodel, data)
% if (compareval(value, maxvalue))
% maxvalue = value;
% else
% error('Incorrect after mu_n');
% end
count1 = count1+1;
end
[accval, confmat, probassign, binacc] = cal_accuracy_NPDSLDA(testmodel, data);
perfval.acc = accval;
[I,J,K] = unique(data.classlabels,'first');
Jprime = [J(2:end); length(data.classlabels)];
clscount = (Jprime-J);
perfval.wacc = 100*(perfval.acc')*clscount/sum(clscount);
perfval.confmat = confmat;
perfval.multiclassacc = 100*sum(diag(confmat))/sum(sum(confmat));
perfval.binacc = binacc;
end