Skip to content

Commit

Permalink
Added option to run a mixture of distribution (indenpendent observati…
Browse files Browse the repository at this point in the history
…ons)
  • Loading branch information
vidaurre committed Oct 7, 2019
1 parent 90ef635 commit f7c8e9f
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 79 deletions.
26 changes: 14 additions & 12 deletions eval/GammaEntropy.m
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
function Entr = GammaEntropy(Gamma,Xi,T,order)
% Entropy of the state time courses
Entr = 0; K = size(Gamma,2);
for tr=1:length(T)
for tr = 1:length(T)
t = sum(T(1:tr-1)) - (tr-1)*order + 1;
Gamma_nz = Gamma(t,:);
Gamma_nz(Gamma_nz==0) = realmin;
if any(isinf(log(Gamma_nz(:)))), Gamma_nz(Gamma_nz==0) = eps; end
Entr = Entr - sum(Gamma_nz.*log(Gamma_nz));
t = (sum(T(1:tr-1)) - (tr-1)*(order+1) + 1) : ((sum(T(1:tr)) - tr*(order+1)));
Xi_nz = Xi(t,:,:);
Xi_nz(Xi_nz==0) = realmin;
if any(isinf(log(Xi_nz(:)))), Xi_nz(Xi_nz==0) = eps; end
Psi=zeros(size(Xi_nz)); % P(S_t|S_t-1)
for k = 1:K
sXi = sum(permute(Xi_nz(:,k,:),[1 3 2]),2);
Psi(:,k,:) = Xi_nz(:,k,:)./repmat(sXi,[1 1 K]);
if ~isempty(Xi)
t = (sum(T(1:tr-1)) - (tr-1)*(order+1) + 1) : ((sum(T(1:tr)) - tr*(order+1)));
Xi_nz = Xi(t,:,:);
Xi_nz(Xi_nz==0) = realmin;
if any(isinf(log(Xi_nz(:)))), Xi_nz(Xi_nz==0) = eps; end
Psi = zeros(size(Xi_nz)); % P(S_t|S_t-1)
for k = 1:K
sXi = sum(permute(Xi_nz(:,k,:),[1 3 2]),2);
Psi(:,k,:) = Xi_nz(:,k,:)./repmat(sXi,[1 1 K]);
end
Psi(Psi==0) = realmin;
if any(isinf(log(Psi(:)))), Psi(Psi==0) = eps; end
Entr = Entr - sum(Xi_nz(:).*log(Psi(:))); % entropy of hidden states
end
Psi(Psi==0) = realmin;
if any(isinf(log(Psi(:)))), Psi(Psi==0) = eps; end
Entr = Entr - sum(Xi_nz(:).*log(Psi(:))); % entropy of hidden states
end
if isnan(Entr(:))
error(['Error computing entropy of the state time courses - ' ...
Expand Down
87 changes: 51 additions & 36 deletions eval/GammaavLL.m
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
function avLL = GammaavLL(hmm,Gamma,Xi,T)
% average loglikelihood for state time course
% if isfield(hmm.train,'grouping')

% if isfield(hmm.train,'grouping') % DEPRECATED
% Q = length(unique(hmm.train.grouping));
% else
% Q = 1;
% end
Q = 1;
N = length(T);
Q = 1;
N = length(T);
order = (sum(T) - size(Gamma,1))/N;
avLL = 0; K = size(Gamma,2);
jj = zeros(N,1); % reference to first time point of the segments
for in = 1:N
jj(in) = sum(T(1:in-1)) - order*(in-1) + 1;
end
% avLL initial state

% avLL initial state % DEPRECATED
% if Q>1
% for i = 1:Q
% PsiDir_alphasum = psi(sum(hmm.Dir_alpha(:,i)));
Expand All @@ -25,43 +23,60 @@
% end
% else
% PsiDir_alphasum = psi(sum(hmm.Dir_alpha));
% for l = 1:K
% for l = 1:K
% if ~hmm.train.Pistructure(l), continue; end
% avLL = avLL + sum(Gamma(jj,l)) * (psi(hmm.Dir_alpha(l)) - PsiDir_alphasum);
% end
% end
PsiDir_alphasum = psi(sum(hmm.Dir_alpha));
for l = 1:K
if ~hmm.train.Pistructure(l), continue; end
avLL = avLL + sum(Gamma(jj,l)) * (psi(hmm.Dir_alpha(l)) - PsiDir_alphasum);
end
% avLL remaining time points
for i = 1:Q
if Q > 1
ii = find(hmm.train.grouping==i)';
else
ii = 1:length(T);

if ~isempty(Xi) % a proper HMM

jj = zeros(N,1); % reference to first time point of the segments
for in = 1:N
jj(in) = sum(T(1:in-1)) - order*(in-1) + 1;
end
PsiDir2d_alphasum = zeros(K,1);
for l = 1:K, PsiDir2d_alphasum(l) = psi(sum(hmm.Dir2d_alpha(l,:,i))); end
for k = 1:K
for l = 1:K
if ~hmm.train.Pstructure(l,k), continue; end
if Q==1
avLL = avLL + sum(Xi(:,l,k)) * (psi(hmm.Dir2d_alpha(l,k))-PsiDir2d_alphasum(l));
if isnan(avLL)
error(['Error computing log likelihood of the state time courses - ' ...
'Out of precision?'])
end
else
for n = ii
t = (1:T(n)-1-order) + sum(T(1:n-1)) - (order+1)*(n-1) ;
avLL = avLL + sum(Xi(t,l,k)) * ...
(psi(hmm.Dir2d_alpha(l,k,i))-PsiDir2d_alphasum(l));
PsiDir_alphasum = psi(sum(hmm.Dir_alpha));
% first time point
for l = 1:K
if ~hmm.train.Pistructure(l), continue; end
avLL = avLL + sum(Gamma(jj,l)) * (psi(hmm.Dir_alpha(l)) - PsiDir_alphasum);
end
% avLL remaining time points
for i = 1:Q
if Q > 1
ii = find(hmm.train.grouping==i)';
else
ii = 1:length(T);
end
PsiDir2d_alphasum = zeros(K,1);
for l = 1:K, PsiDir2d_alphasum(l) = psi(sum(hmm.Dir2d_alpha(l,:,i))); end
for k = 1:K
for l = 1:K
if ~hmm.train.Pstructure(l,k), continue; end
if Q==1
avLL = avLL + sum(Xi(:,l,k)) * (psi(hmm.Dir2d_alpha(l,k))-PsiDir2d_alphasum(l));
if isnan(avLL)
error(['Error computing log likelihood of the state time courses - ' ...
'Out of precision?'])
end
else
for n = ii
t = (1:T(n)-1-order) + sum(T(1:n-1)) - (order+1)*(n-1) ;
avLL = avLL + sum(Xi(t,l,k)) * ...
(psi(hmm.Dir2d_alpha(l,k,i))-PsiDir2d_alphasum(l));
end
end
end
end
end

else % Simple mixture of distributions

PsiDir_alphasum = psi(sum(hmm.Dir_alpha));
for k = 1:K
avLL = avLL + sum(Gamma(:,k)) * (psi(hmm.Dir_alpha(k)) - PsiDir_alphasum);
end

end

end
10 changes: 6 additions & 4 deletions eval/evalfreeenergy.m
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

if nargin<8, todo = ones(1,5); end

mixture_model = isfield(hmm.train,'id_mixture') && hmm.train.id_mixture;

K = length(hmm.state);
if (nargin<7 || isempty(XX)) && todo(2)==1
setxx; % build XX and get orders
Expand Down Expand Up @@ -86,7 +88,7 @@
hmm.Omega.Gam_shape,hmm.prior.Omega.Gam_shape);
KLdiv = [KLdiv OmegaKL];
end
for k=1:K
for k = 1:K
hs=hmm.state(k);
pr=hmm.state(k).prior;
setstateoptions;
Expand Down Expand Up @@ -240,11 +242,11 @@
C = hmm.Omega.Gam_shape * hmm.Omega.Gam_irate;
avLL = avLL + (-ltpi-ldetWishB+PsiWish_alphasum+0.5*sum(regressed)*log(2));
end
for k=1:K
hs=hmm.state(k);
for k = 1:K
hs = hmm.state(k);
setstateoptions;
if strcmp(train.covtype,'diag')
ldetWishB=0;
ldetWishB = 0;
PsiWish_alphasum=0;
for n=1:ndim
if ~regressed(n), continue; end
Expand Down
15 changes: 12 additions & 3 deletions hmmfe.m
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

% to fix potential compatibility issues with previous versions
hmm = versCompatibilityFix(hmm);
mixture_model = isfield(hmm.train,'id_mixture') && hmm.train.id_mixture;

if nargin<4, Gamma = []; end
if nargin<5, Xi = []; end
if nargin<6 || isempty(preproc), preproc = 1; end
if nargin<7 , grouping = ones(length(T),1); end
if size(grouping,1)==1, grouping = grouping'; end
Expand Down Expand Up @@ -115,8 +118,10 @@
end

% get state time courses
if nargin < 5 || isempty(Gamma) || isempty(Xi)
[Gamma,Xi] = hmmdecode(data,T,hmm,0,residuals,0);
if isempty(Gamma) || isempty(Xi)
if ~(mixture_model && ~isempty(Gamma)) % we have Gamma and Xi is not needed
[Gamma,Xi] = hmmdecode(data,T,hmm,0,residuals,0);
end
end

if stochastic_learn
Expand All @@ -143,7 +148,11 @@
t = (1:(sum(Ti)-length(Ti)*maxorder)) + tacc;
t2 = (1:(sum(Ti)-length(Ti)*(maxorder+1))) + tacc2;
tacc = tacc + length(t); tacc2 = tacc2 + length(t2);
fell = fell + sum(evalfreeenergy(X,Ti,Gamma(t,:),Xi(t2,:,:),hmm,residuals,XX,[0 1 0 0 0])); % state KL
if ~isempty(Xi)
fell = fell + sum(evalfreeenergy(X,Ti,Gamma(t,:),Xi(t2,:,:),hmm,residuals,XX,[0 1 0 0 0])); % state KL
else
fell = fell + sum(evalfreeenergy(X,Ti,Gamma(t,:),[],hmm,residuals,XX,[0 1 0 0 0])); % state KL
end
end
fe = fe + fell;
else
Expand Down
Loading

0 comments on commit f7c8e9f

Please sign in to comment.