Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ NLSolversBase = "7"
NLopt = "0.6, 1"
Optim = "1"
PrettyTables = "2"
ProximalAlgorithms = "0.5"
StatsBase = "0.33, 0.34"
Symbolics = "4, 5, 6"
SymbolicUtils = "1.4 - 1.5, 1.7, 2, 3"
Expand All @@ -48,7 +49,7 @@ test = ["Test"]
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
ProximalOperators = "f3b72e0c-5f3e-4b3e-8f3e-3f4f3e3e3e3e"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"

[extensions]
SEMNLOptExt = "NLopt"
Expand Down
8 changes: 4 additions & 4 deletions docs/src/developer/observed.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ end
To compute some fit indices, you need to provide methods for

```julia
# Number of observed datapoints
n_obs(observed::MyObserved) = ...
# Number of manifest variables
n_man(observed::MyObserved) = ...
# Number of samples (observations) in the dataset
nsamples(observed::MyObserved) = ...
# Number of observed variables
nobserved_vars(observed::MyObserved) = ...
```

As always, you can add additional methods for properties that imply types and loss function want to access, for example (from the `SemObservedCommon` implementation):
Expand Down
8 changes: 4 additions & 4 deletions docs/src/tutorials/inspection/inspection.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Model inspection

```@setup colored
using StructuralEquationModels
using StructuralEquationModels

observed_vars = [:x1, :x2, :x3, :y1, :y2, :y3, :y4, :y5, :y6, :y7, :y8]
latent_vars = [:ind60, :dem60, :dem65]
Expand Down Expand Up @@ -32,7 +32,7 @@ end

partable = ParameterTable(
graph,
latent_vars = latent_vars,
latent_vars = latent_vars,
observed_vars = observed_vars)

data = example_data("political_democracy")
Expand Down Expand Up @@ -128,8 +128,8 @@ BIC
χ²
df
minus2ll
n_man
n_obs
nobserved_vars
nsamples
nparams
p_value
RMSEA
Expand Down
3 changes: 2 additions & 1 deletion src/StructuralEquationModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ include("frontend/fit/summary.jl")
include("frontend/pretty_printing.jl")
# observed
include("observed/abstract.jl")
include("observed/covariance.jl")
include("observed/data.jl")
include("observed/covariance.jl")
include("observed/missing_pattern.jl")
include("observed/missing.jl")
include("observed/EM.jl")
# constructor
Expand Down
5 changes: 0 additions & 5 deletions src/additional_functions/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,6 @@ function sparse_outer_mul!(C, A, B::Vector, ind) #computes A*S*B -> C, where ind
end
end

function cov_and_mean(rows; corrected = false)
obs_mean, obs_cov = StatsBase.mean_and_cov(reduce(hcat, rows), 2, corrected = corrected)
return obs_cov, vec(obs_mean)
end

# n²×(n(n+1)/2) matrix to transform a vector of lower
# triangular entries into a vectorized form of a n×n symmetric matrix,
# opposite of elimination_matrix()
Expand Down
71 changes: 15 additions & 56 deletions src/frontend/fit/fitmeasures/minus2ll.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,74 +31,33 @@ minus2ll(minimum::Number, obs, imp::Union{RAM, RAMSymbolic}, loss_ml::SemWLS) =
# compute likelihood for missing data - H0 -------------------------------------------------
# -2ll = (∑ log(2π)*(nᵢ + mᵢ)) + F*n
function minus2ll(minimum::Number, observed, imp::Union{RAM, RAMSymbolic}, loss_ml::SemFIML)
F = minimum
F *= nsamples(observed)
F += sum(log(2π) * observed.pattern_nsamples .* observed.pattern_nobs_vars)
F = minimum * nsamples(observed)
F += log(2π) * sum(pat -> nsamples(pat) * nmeasured_vars(pat), observed.patterns)
return F
end

# compute likelihood for missing data - H1 -------------------------------------------------
# -2ll = ∑ log(2π)*(nᵢ + mᵢ) + ln(Σᵢ) + (mᵢ - μᵢ)ᵀ Σᵢ⁻¹ (mᵢ - μᵢ)) + tr(SᵢΣᵢ)
function minus2ll(observed::SemObservedMissing)
if observed.em_model.fitted
minus2ll(
observed.em_model.μ,
observed.em_model.Σ,
nsamples(observed),
pattern_rows(observed),
observed.patterns,
observed.obs_mean,
observed.obs_cov,
observed.pattern_nsamples,
observed.pattern_nobs_vars,
)
else
em_mvn(observed)
minus2ll(
observed.em_model.μ,
observed.em_model.Σ,
nsamples(observed),
pattern_rows(observed),
observed.patterns,
observed.obs_mean,
observed.obs_cov,
observed.pattern_nsamples,
observed.pattern_nobs_vars,
)
end
end

function minus2ll(
μ,
Σ,
N,
rows,
patterns,
obs_mean,
obs_cov,
pattern_nsamples,
pattern_nobs_vars,
)
F = 0.0
# fit EM-based mean and cov if not yet fitted
# FIXME EM could be very computationally expensive
observed.em_model.fitted || em_mvn(observed)

for i in 1:length(rows)
nᵢ = pattern_nsamples[i]
# missing pattern
pattern = patterns[i]
# observed data
Sᵢ = obs_cov[i]
Σ = observed.em_model.Σ
μ = observed.em_model.μ

F = sum(observed.patterns) do pat
# implied covariance/mean
Σᵢ = Σ[pattern, pattern]
ld = logdet(Σᵢ)
Σᵢ⁻¹ = inv(cholesky(Σᵢ))
meandiffᵢ = obs_mean[i] - μ[pattern]
Σᵢ = Σ[pat.measured_mask, pat.measured_mask]
Σᵢ_chol = cholesky!(Σᵢ)
ld = logdet(Σᵢ_chol)
Σᵢ⁻¹ = LinearAlgebra.inv!(Σᵢ_chol)
meandiffᵢ = pat.measured_mean - μ[pat.measured_mask]

F += F_one_pattern(meandiffᵢ, Σᵢ⁻¹, Sᵢ, ld, nᵢ)
F_one_pattern(meandiffᵢ, Σᵢ⁻¹, pat.measured_cov, ld, nsamples(pat))
end

F += sum(log(2π) * pattern_nsamples .* pattern_nobs_vars)
#F *= N
F += log(2π) * sum(pat -> nsamples(pat) * nmeasured_vars(pat), observed.patterns)

return F
end
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/specification/ParameterTable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ empty_partable_columns(nrows::Integer = 0) = Dict{Symbol, Vector}(
:param => fill(Symbol(), nrows),
)

# construct using the provided columns data or create and empty table
# construct using the provided columns data or create an empty table
function ParameterTable(
columns::Dict{Symbol, Vector} = empty_partable_columns();
columns::Dict{Symbol, Vector};
observed_vars::Union{AbstractVector{Symbol}, Nothing} = nothing,
latent_vars::Union{AbstractVector{Symbol}, Nothing} = nothing,
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
Expand Down
18 changes: 12 additions & 6 deletions src/frontend/specification/Sem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
############################################################################################

function Sem(;
specification = ParameterTable,
observed::O = SemObservedData,
imply::I = RAM,
loss::L = SemML,
Expand All @@ -12,7 +13,7 @@ function Sem(;

set_field_type_kwargs!(kwdict, observed, imply, loss, O, I)

observed, imply, loss = get_fields!(kwdict, observed, imply, loss)
observed, imply, loss = get_fields!(kwdict, specification, observed, imply, loss)

sem = Sem(observed, imply, loss)

Expand Down Expand Up @@ -59,6 +60,7 @@ Returns the loss part of a model.
loss(model::AbstractSemSingle) = model.loss

function SemFiniteDiff(;
specification = ParameterTable,
observed::O = SemObservedData,
imply::I = RAM,
loss::L = SemML,
Expand All @@ -68,7 +70,7 @@ function SemFiniteDiff(;

set_field_type_kwargs!(kwdict, observed, imply, loss, O, I)

observed, imply, loss = get_fields!(kwdict, observed, imply, loss)
observed, imply, loss = get_fields!(kwdict, specification, observed, imply, loss)

sem = SemFiniteDiff(observed, imply, loss)

Expand Down Expand Up @@ -96,23 +98,27 @@ function set_field_type_kwargs!(kwargs, observed, imply, loss, O, I)
end

# construct Sem fields
function get_fields!(kwargs, observed, imply, loss)
function get_fields!(kwargs, specification, observed, imply, loss)
if !isa(specification, SemSpecification)
specification = specification(; kwargs...)
end

# observed
if !isa(observed, SemObserved)
observed = observed(; kwargs...)
observed = observed(; specification, kwargs...)
end
kwargs[:observed] = observed

# imply
if !isa(imply, SemImply)
imply = imply(; kwargs...)
imply = imply(; specification, kwargs...)
end

kwargs[:imply] = imply
kwargs[:nparams] = nparams(imply)

# loss
loss = get_SemLoss(loss; kwargs...)
loss = get_SemLoss(loss; specification, kwargs...)
kwargs[:loss] = loss

return observed, imply, loss
Expand Down
11 changes: 11 additions & 0 deletions src/frontend/specification/StenoGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ function ParameterTable(
return ParameterTable(columns; latent_vars, observed_vars, params)
end

############################################################################################
### keyword only constructor (for call in `Sem` constructor)
############################################################################################

# FIXME: this kw-only ctor conflicts with the empty ParTable constructor;
# it is left here for compatibility with the current Sem construction API,
# the proper fix would be to move away from kw-only ctors in general
ParameterTable(; graph::Union{AbstractStenoGraph, Nothing} = nothing, kwargs...) =
!isnothing(graph) ? ParameterTable(graph; kwargs...) :
ParameterTable(empty_partable_columns(); kwargs...)

############################################################################################
### constructor for EnsembleParameterTable from graph
############################################################################################
Expand Down
Loading
Loading