Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions +nla/+edge/+test/SandwichEstimator.m
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,14 @@
stdErrInput.scanMetadata = scanMetadata;
stdErrInput.residual = residual;
stdErrInput.pinvDesignMtx = pinv(designMtx);

%sweRes.stdError = stdErrCalcObj.calculate(stdErrInput);
stdError = stdErrCalcObj.calculate(stdErrInput);
stdErrInput.contrasts = input.contrasts;

%change stdError to compute contrast SE
contrastSE = stdErrCalcObj.calculate(stdErrInput);


contrastCalc = input.contrasts * regressCoeffs;
contrastSE = sqrt((input.contrasts.^2) * (stdError.^2));


dof = obj.calcDegreesOfFreedom(designMtx);

Expand Down
5 changes: 5 additions & 0 deletions +nla/+gfx/PlotValue.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
classdef PlotValue
enumeration
PVALUE, STATISTIC
end
end
2 changes: 1 addition & 1 deletion +nla/+gfx/ProbPlotMethod.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
classdef ProbPlotMethod
enumeration
DEFAULT, LOG, NEGATIVE_LOG_10, STATISTIC
DEFAULT, LOG, NEGATIVE_LOG_10
end
end
6 changes: 4 additions & 2 deletions +nla/+helpers/+stdError/Guillaume.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

methods

function stdError = calculate(obj, sweStdErrInput)
function contrastStdError = calculate(obj, sweStdErrInput)



Expand Down Expand Up @@ -61,7 +61,9 @@

end

stdError = sqrt(betaCovar.v(betaCovar.getDiagElemIdxs,:));
stdErr = sqrt(betaCovar.v(betaCovar.getDiagElemIdxs,:));

contrastStdError = sqrt((sweStdErrInput.contrasts.^2) * (stdErr.^2));


end
Expand Down
13 changes: 12 additions & 1 deletion +nla/+helpers/+stdError/Heteroskedastic.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@

methods

function stdErr = calculate(obj, sweStdErrInput)
function contrastStdErr = calculate(obj, sweStdErrInput)

FORCE_USE_FAST_ALGO = true;
if FORCE_USE_FAST_ALGO
%There is a faster, but possibly larger memory
%implementation of this algorithm. Don't
fastAlgoObj = nla.helpers.stdError.Heteroskedastic_FAST();
contrastStdErr = fastAlgoObj.calculate(sweStdErrInput);
return;
end


%Calculation of standard error assuming heteroskedascticity
Expand All @@ -31,6 +40,8 @@
stdErr(:,fcEdgeIdx) = sqrt(correctionFactor * diag(betaCovariance));

end

contrastStdErr = sqrt((sweStdErrInput.contrasts.^2) * (stdErr.^2));

end

Expand Down
4 changes: 3 additions & 1 deletion +nla/+helpers/+stdError/Heteroskedastic_FAST.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

methods

function stdErr = calculate(obj, sweStdErrInput)
function contrastStdErr = calculate(obj, sweStdErrInput)
%Computes Standard Error, but accelerated using assumption of
%heteroskeadisticity for quicker computation
%
Expand Down Expand Up @@ -51,6 +51,8 @@
diagElemIdxsInFlatArr = 1:(numCovariates+1):numCovariates^2;
stdErr = sqrt(betaCovarianceFlat(diagElemIdxsInFlatArr,:));

contrastStdErr = sqrt((sweStdErrInput.contrasts.^2) * (stdErr.^2));

end

end
Expand Down
6 changes: 4 additions & 2 deletions +nla/+helpers/+stdError/Homoskedastic.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

methods

function stdErr = calculate(obj, sweStdErrInput)
function contrastStdErr = calculate(obj, sweStdErrInput)

%Calculation of standard error assuming homoskedasticity
%(errors are independent and identically distributed iid)
Expand All @@ -20,7 +20,9 @@
meanSqErr = sum(residual.^2) ./ degOfFree; %In regression, divide by dof instead of number of data points (per wikipedia)


stdErr = sqrt(diag(pinvDesignMtx * pinvDesignMtx')*meanSqErr);
stdErr = sqrt(diag(pinvDesignMtx * pinvDesignMtx')*meanSqErr);

contrastStdErr = sqrt((sweStdErrInput.contrasts.^2) * (stdErr.^2));


end
Expand Down
134 changes: 66 additions & 68 deletions +nla/+helpers/+stdError/UnconstrainedBlocks.m
Original file line number Diff line number Diff line change
@@ -1,69 +1,88 @@
classdef UnconstrainedBlocks < nla.helpers.stdError.AbstractSwEStdErrStrategy

properties

SPARSITY_THRESHOLD = 0.2;

end
properties (SetAccess = protected)
REQUIRES_GROUP = true;
end

methods

function stdErr = calculate(obj, sweStdErrInput)
%Computes Standard Error assuming unconstrained blocks
%Uses standard matrix multiplication to compute standard error,
%since it does not make assumption that V is sparse.

function contrastStdErr = calculate(obj, sweStdErrInput)


%rename variables for readability
pinvDesignMtx = sweStdErrInput.pinvDesignMtx;
residual = sweStdErrInput.residual;
groupIds = sweStdErrInput.scanMetadata.groupId;
unqGrps = unique(groupIds);

obj.throwErrorIfVEntirelyFull(unqGrps);

vSparsity = obj.computeVSparsity(groupIds);
[numCovariates, ~] = size(pinvDesignMtx);
[numObs, numFcEdges] = size(residual);

%Ben Kay 'Half Sandwich' algorithm seems to be at least as good
%or better than any of the other clever approaches so far.
%Might be able to beat it by using the clever approach and only
%computing the diagonal of covBat
FORCE_HALF_SW_ALGO = true;
stdErr = zeros(numCovariates, numFcEdges);

if FORCE_HALF_SW_ALGO
stdErrStrategy = nla.helpers.stdError.UnconstrainedBlocks_BenKay();
elseif vSparsity <= obj.SPARSITY_THRESHOLD
stdErrStrategy = nla.helpers.stdError.UnconstrainedBlocks_Sparse();
numNonzeroValuesInContrast = sum(sweStdErrInput.contrasts~=0);
if numNonzeroValuesInContrast > 1
WALD_TEST = true;
else
stdErrStrategy = nla.helpers.stdError.UnconstrainedBlocks_Dense();
WALD_TEST = false;
end

stdErr = stdErrStrategy.calculate(sweStdErrInput);

if ~WALD_TEST
covB = zeros(numCovariates,numFcEdges);
else
covB = zeros(numCovariates,numCovariates,numFcEdges);
end

%NOTE, optimized from swe_block.m from Benjamin Kay
%if NOT WALD_TEST, can just compute diagonals of cov(B), which
%makes the original halfsandwich of
%pinvDesignMtx(:,subjThisGrp) * residual(subjThisGrp, fcIdx)
%into a [numCovars x 1] matrix, and then squaring for the
%elements and multiplying by the one non-zero value of the
%contrast

if ~WALD_TEST
for grpIdx = 1:length(unqGrps)
thisGrpId = unqGrps(grpIdx);
subjThisGrp = groupIds == thisGrpId;
halfSandwich = pinvDesignMtx(:, subjThisGrp) * residual(subjThisGrp,:);

covB = covB + halfSandwich .* halfSandwich;

end

stdErr(:) = sqrt(covB);
contrastStdErr = sqrt((sweStdErrInput.contrasts.^2) * (stdErr.^2));
else
for grpIdx = 1:length(unqGrps)

thisGrpId = unqGrps(grpIdx);
subjThisGrp = groupIds == thisGrpId;
halfSandwich = pinvDesignMtx(:, subjThisGrp) * residual(subjThisGrp,:);

for fcEdgeIdx = 1:numFcEdges
covB(:,:,fcEdgeIdx) = covB(:,:,fcEdgeIdx) + ...
(halfSandwich(:,fcEdgeIdx) * halfSandwich(:,fcEdgeIdx)');
end
end

%Computation of contrast StdErr here
contrasts = sweStdErrInput.contrasts;
contrastStdErr = zeros(1,numFcEdges);
for fcEdgeIdx = 1:numFcEdges
contrastStdErr(fcEdgeIdx) = contrasts * covB(:,:,fcEdgeIdx) * contrasts';
end

end

end



end

methods (Access = protected)

function fractionNonZero = computeVSparsity(obj, groupIds)
%Do quick check of how many elements of V we expect to be
%nonzero given the group Ids of our observations.
%This calculation will only be fast if V is sparse, so we
%should determine how full V will be and warn user if this
%method will be slow.
unqGrps = unique(groupIds);
countInGrps = histcounts(groupIds,[unqGrps;Inf]);

numNonzeroElems = sum(countInGrps.^2);
totalElems = length(groupIds)*length(groupIds);

fractionNonZero = numNonzeroElems / totalElems;

end


methods (Access = protected)

function throwErrorIfVEntirelyFull(obj, uniqueGroupIds)
%If V is entirely full, throw an error
if length(uniqueGroupIds) == 1
Expand All @@ -76,30 +95,9 @@ function throwErrorIfVEntirelyFull(obj, uniqueGroupIds)
end

end

function [groupedPinv, groupedResidual, groupIds] = reorderDataByGroup(obj, origPinvDesignMtx, origResidual, origGrps)

[numCovariates, ~] = size(origPinvDesignMtx);

%Group all data in one matrix and sort by group
%NOTE: need to use transpose of pinvDesignMatrix!!!
allData = [origGrps, origPinvDesignMtx', origResidual];
sortedData= sortrows(allData,1);

pinvColRange = 2:(numCovariates+1);
residualColRange = (numCovariates+2):(size(allData,2));

groupedPinvTpose = sortedData(:,pinvColRange);
groupedPinv = groupedPinvTpose';
groupedResidual = sortedData(:,residualColRange);
groupIds = sortedData(:,1);

end




end



end
71 changes: 0 additions & 71 deletions +nla/+helpers/+stdError/UnconstrainedBlocks_BenKay.m

This file was deleted.

Loading
Loading