Skip to content

Commit

Permalink
fix linRegPred knRegPred
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Mar 11, 2017
1 parent 38c3f9b commit 23dc01d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 21 deletions.
28 changes: 10 additions & 18 deletions chapter02/logGauss.m
Expand Up @@ -7,21 +7,13 @@
% Output:
% y: 1 x n probability density in logrithm scale y=log p(x)
% Written by Mo Chen (sth4nth@gmail.com).
[d,k] = size(mu);
if all(size(sigma)==d) && k==1 % one mu and one dxd sigma
X = bsxfun(@minus,X,mu);
[R,p]= chol(sigma);
if p ~= 0
error('ERROR: sigma is not PD.');
end
Q = R'\X;
q = dot(Q,Q,1); % quadratic term (M distance)
c = d*log(2*pi)+2*sum(log(diag(R))); % normalization constant
y = -0.5*(c+q);
elseif size(sigma,1)==1 && size(sigma,2)==size(mu,2) % k mu and (k or one) scalar sigma
X2 = repmat(dot(X,X,1)',1,k);
D = bsxfun(@plus,X2-2*X'*mu,dot(mu,mu,1));
q = bsxfun(@times,D,1./sigma); % M distance
c = d*(log(2*pi)+2*log(sigma)); % normalization constant
y = -0.5*bsxfun(@plus,q,c);
end
d = size(X,1);
X = X-mu;
[U,p]= chol(sigma);
if p ~= 0
error('ERROR: sigma is not PD.');
end
Q = U'\X;
q = dot(Q,Q,1); % quadratic term (M distance)
c = d*log(2*pi)+2*sum(log(diag(U))); % normalization constant
y = -(c+q)/2;
3 changes: 1 addition & 2 deletions chapter03/linRegPred.m
Expand Up @@ -26,7 +26,6 @@
end

if nargin == 3 && nargout == 3
p = exp(logGauss(t,y,sigma));
% p = exp(-0.5*(((t-y)./sigma).^2+log(2*pi))-log(sigma));
p = exp(-0.5*(((t-y)./sigma).^2+log(2*pi))-log(sigma));
end

2 changes: 1 addition & 1 deletion chapter06/knRegPred.m
Expand Up @@ -25,5 +25,5 @@
end

if nargin == 3 && nargout == 3
p = exp(logGauss(t,y,sigma));
p = exp(-0.5*(((t-y)./sigma).^2+log(2*pi))-log(sigma));
end

0 comments on commit 23dc01d

Please sign in to comment.