Skip to content

Commit

Permalink
Use StatsAPI interface for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kshedden committed Apr 2, 2024
1 parent c650749 commit c380410
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
6 changes: 4 additions & 2 deletions src/MultivariateStats.jl
Expand Up @@ -4,13 +4,13 @@ module MultivariateStats
using SparseArrays
using Statistics: middle
using Distributions: cdf, FDist
using StatsAPI: RegressionModel
using StatsAPI: RegressionModel, HypothesisTest
using StatsBase: SimpleCovariance, CovarianceEstimator, AbstractDataTransform,
ConvergenceException, pairwise, pairwise!, CoefTable

import Statistics: mean, var, cov, covm, cor
import Base: length, size, show
import StatsAPI: fit, predict, coef, weights, dof, r2
import StatsAPI: fit, predict, coef, weights, dof, r2, pvalue
import LinearAlgebra: eigvals, eigvecs

export
Expand All @@ -30,6 +30,7 @@ module MultivariateStats
loadings, # model loadings
var, # model variance
tests, # hypothesis tests
pvalue, # p-values for hypothesis tests

# lreg
llsq, # Linear Least Square regression
Expand Down Expand Up @@ -68,6 +69,7 @@ module MultivariateStats

## cca
CCA, # Type: Correlation Component Analysis model
CCATest, # Type: hypothesis tests for CCA

ccacov, # CCA based on covariances
ccasvd, # CCA based on singular value decomposition of input data
Expand Down
46 changes: 37 additions & 9 deletions src/cca.jl
Expand Up @@ -341,14 +341,46 @@ function fit(::Type{CCA}, X::AbstractMatrix{T}, Y::AbstractMatrix{T};
return M::CCA
end

struct CCATest <: HypothesisTest
# All below are vectors of length 3, containing
# results for Wilks, Pillai, and Lawley-Hotelling
# respectively.
stat::Vector{Float64}
fstat::Vector{Float64}
df1::Vector{Float64}
df2::Vector{Float64}
end

function htype(type::AbstractString)
t = lowercase(type)
if t == "wilks"
return 1
elseif t == "pillai"
return 2
elseif t == "lawley"
return 3
else
throw(error("Unkown type '$(type)'"))
end
end

function pvalue(ct::CCATest; type="Wilks")
i = htype(type)
return 1 - cdf(FDist(ct.df1[i], ct.df2[i]), ct.fstat[i])
end

function dof(ct::CCATest; type="Wilks")
return (ct.df1[htype(type)], ct.df2[htype(type)])
end

"""
Test hypotheses based on a fitted CCA.
Three test statistics (Wilks Lambda, Pillai's trace, and the
Lawley-Hotelling statistic) are used to test the null hypothesis
that canonical correlations k, k+1, ... are identically zero. By
default the null hypothesis is that all canonical correlations are
zero.
default the null hypothesis is that k==1 -- all canonical correlations
are zero.
**Keyword arguments:**
- `n`: The sample size, required if the CCA was fit using the :cov method.
Expand All @@ -366,7 +398,6 @@ function tests(cca::CCA; n=nothing, k=1)
@warn("Provided n is different from actual n")
end
n = isnothing(n) ? cca.nobs : n
name = ["Wilks", "Pillai", "Lawley-Hotelling"]

# Below are from Rencher and Christensen (2012)

Expand All @@ -377,11 +408,10 @@ function tests(cca::CCA; n=nothing, k=1)
# Wilks lambda
wilks = prod(1 .- r.^2)
w = n - (p + q + 3) / 2
t = sqrt((p^2*q^2 - 4) / (p^2 + q^2 - 5))
t = p*q == 2 ? 1.0 : sqrt((p^2*q^2 - 4) / (p^2 + q^2 - 5))
wilks_df1 = p*q
wilks_df2 = w*t - p*q/2 + 1
wilks_f = ((1 - wilks^(1/t)) / wilks^(1/t)) * (wilks_df2 / wilks_df1)
wilks_pval = 1 - cdf(FDist(wilks_df1, wilks_df2), wilks_f)

# Pillai's trace
pillai = sum(abs2, r)
Expand All @@ -391,19 +421,17 @@ function tests(cca::CCA; n=nothing, k=1)
pillai_f = (2*N + s + 1)*pillai / ((2*m + s + 1) * (s - pillai))
pillai_df1 = s*(2*m + s + 1)
pillai_df2 = s*(2*N + s + 1)
pillai_pval = 1 - cdf(FDist(pillai_df1, pillai_df2), pillai_f)

# Lawley-Hotelling
lawley = sum(r.^2 ./ (1 .- r.^2))
lawley_f = 2*(s*N + 1) * lawley / (s^2 * (2*m + s + 1))
lawley_df1 = s*(2*m + s + 1)
lawley_df2 = 2*(s*N + 1)
lawley_pval = 1 - cdf(FDist(lawley_df1, lawley_df2), lawley_f)

stat = [wilks, pillai, lawley]
fstat = [wilks_f, pillai_f, lawley_f]
df1 = [wilks_df1, pillai_df1, lawley_df1]
df2 = [wilks_df2, pillai_df2, lawley_df2]
pval = [wilks_pval, pillai_pval, lawley_pval]
return (Name=name, stat=stat, df1=df1, df2=df2, fstat=fstat, pval=pval)

return CCATest(stat, fstat, df1, df2)
end

0 comments on commit c380410

Please sign in to comment.