From 5fda57dd7a90fe4961619c16b8b3a4a4d5965d98 Mon Sep 17 00:00:00 2001 From: Norman Date: Wed, 3 Feb 2021 17:39:51 -0800 Subject: [PATCH] Replace _getargs with groupargs --- src/DiffinDiffsBase.jl | 2 - src/StatsProcedures.jl | 96 +++++++++++++++++++++++++++++++---------- src/procedures.jl | 18 ++++---- test/StatsProcedures.jl | 52 ++++++++++++++-------- test/procedures.jl | 12 ++---- test/runtests.jl | 4 +- test/testutils.jl | 4 +- 7 files changed, 125 insertions(+), 63 deletions(-) diff --git a/src/DiffinDiffsBase.jl b/src/DiffinDiffsBase.jl index fcaeea5..e1e6560 100644 --- a/src/DiffinDiffsBase.jl +++ b/src/DiffinDiffsBase.jl @@ -47,11 +47,9 @@ export cb, treat, StatsStep, - namedargs, AbstractStatsProcedure, SharedStatsStep, PooledStatsProcedure, - pool, StatsSpec, proceed, @specset, diff --git a/src/StatsProcedures.jl b/src/StatsProcedures.jl index 1dcfc08..7ce18af 100644 --- a/src/StatsProcedures.jl +++ b/src/StatsProcedures.jl @@ -11,8 +11,8 @@ An instance of `StatsStep` is callable. # Methods (step::StatsStep{A,F})(ntargs::NamedTuple; verbose::Bool=false) -Call an instance of function of type `F` with arguments -formed by updating `NamedTuple` returned by `[`namedargs(step)`](@ref)` with `ntargs`. +Call an instance of function of type `F` with arguments extracted from `ntargs` +via [`groupargs`](@ref) and [`combinedargs`](@ref). A message with the name of the `StatsStep` is printed to `stdout` if a keyword `verbose` takes the value `true` @@ -28,23 +28,75 @@ struct StatsStep{Alias, F<:Function} end _f(::StatsStep{A,F}) where {A,F} = F.instance """ - namedargs(s::StatsStep) + required(s::StatsStep) -Return a `NamedTuple` with keys showing the names of arguments -accepted by `s` and values representing the defaults. +Return a tuple of `Symbol`s representing the names of arguments +used to form [`groupargs`](@ref) that do not have defaults. +See also [`default`](@ref) and [`transformed`](@ref). """ -namedargs(s::StatsStep) = error("method for $(typeof(s)) is not defined") +required(::StatsStep) = () -_getargs(ntargs::NamedTuple, s::StatsStep) = _update(ntargs, namedargs(s)) -_update(a::NamedTuple{N1}, b::NamedTuple{N2}) where {N1,N2} = - NamedTuple{N2}(map(n->getfield(sym_in(n, N1) ? a : b, n), N2)) +""" + default(s::StatsStep) + +Return a `NamedTuple` of arguments with keys showing the names +and values representing the defaults to be used to form [`groupargs`](@ref). +See also [`required`](@ref) and [`transformed`](@ref). +""" +default(::StatsStep) = NamedTuple() + +""" + transformed(s::StatsStep, ntargs::NamedTuple) + +Return a tuple of arguments transformed from fields in `ntargs` +to be used to form [`groupargs`](@ref). +See also [`required`](@ref) and [`default`](@ref). +""" +transformed(::StatsStep, ::NamedTuple) = () + +_get(a::NTuple{N,Symbol}, @nospecialize(nt::NamedTuple)) where N = + map(s->getfield(nt, s), a) +_get(a::NamedTuple{S1}, nt::NamedTuple{S2}) where {S1,S2} = + map(s->getfield(ifelse(sym_in(s, S2), nt, a), s), S1) -_combinedargs(::StatsStep, ::Any) = () +""" + groupargs(s::StatsStep, ntargs::NamedTuple) + +Return a tuple of arguments that allow classifying multiple `ntargs`s into groups. +Equality (defined by `isequal`) of the returned tuples across `ntargs`s imply that +it is possible to exectute step `s` for only once +to obtain results for these `ntargs`s. + +This function is important for [`proceed`](@ref) to work properly. +However, in most cases, there is no need to define new methods +for concrete [`StatsStep`](@ref)s. +Instead, one should define methods for +[`required`](@ref), [`default`](@ref) or [`transformed`](@ref). +See also [`combinedargs`](@ref). +""" +groupargs(s::StatsStep, @nospecialize(ntargs::NamedTuple)) = + (_get(required(s), ntargs)..., _get(default(s), ntargs)..., + transformed(s, ntargs)...) + +""" + combinedargs(s::StatsStep, allntargs::Any) + +Return a tuple of arguments obtained by combining a collection of arguments +across multiple specifications. + +The element type of `allntargs` can be assumed to be `NamedTuple`. +This function allows combining arguments that differ +across specifications in the same group classified based on [`groupargs`](@ref) +into objects that are accepted by the call of `s`. +See also [`proceed`](@ref). +""" +combinedargs(::StatsStep, ::Any) = () -function (step::StatsStep{A,F})(ntargs::NamedTuple; verbose::Bool=false) where {A,F} +function (step::StatsStep{A,F})(@nospecialize(ntargs::NamedTuple); + verbose::Bool=false) where {A,F} haskey(ntargs, :verbose) && (verbose = ntargs.verbose) verbose && printstyled("Running ", step, "\n", color=:green) - ret = F.instance(_getargs(ntargs, step)..., _combinedargs(step, (ntargs,))...) + ret = F.instance(groupargs(step, ntargs)..., combinedargs(step, (ntargs,))...) if ret isa Tuple{<:NamedTuple, Bool} return merge(ntargs, ret[1]) else @@ -77,7 +129,7 @@ all subtypes of `AbstractStatsProcedure`. """ abstract type AbstractStatsProcedure{Alias, T<:NTuple{N,StatsStep} where N} end -_result(::Type{<:AbstractStatsProcedure}, ntargs::NamedTuple) = ntargs +result(::Type{<:AbstractStatsProcedure}, @nospecialize(ntargs::NamedTuple)) = ntargs length(::AbstractStatsProcedure{A,T}) where {A,T} = length(T.parameters) eltype(::Type{<:AbstractStatsProcedure}) = StatsStep @@ -131,8 +183,8 @@ end _sharedby(::SharedStatsStep{T,I}) where {T,I} = I _f(s::SharedStatsStep) = _f(s.step) -_getargs(ntargs::NamedTuple, s::SharedStatsStep) = _getargs(ntargs, s.step) -_combinedargs(s::SharedStatsStep, v::AbstractArray) = _combinedargs(s.step, v) +groupargs(s::SharedStatsStep, @nospecialize(ntargs::NamedTuple)) = groupargs(s.step, ntargs) +combinedargs(s::SharedStatsStep, v::AbstractArray) = combinedargs(s.step, v) show(io::IO, s::SharedStatsStep) = print(io, s.step) @@ -303,7 +355,7 @@ Otherwise, the last value returned by the last [`StatsStep`](@ref) is returned. struct StatsSpec{Alias, T<:AbstractStatsProcedure} args::NamedTuple StatsSpec(name::Union{Symbol,String}, - T::Type{<:AbstractStatsProcedure}, args::NamedTuple) = + T::Type{<:AbstractStatsProcedure}, @nospecialize(args::NamedTuple)) = new{Symbol(name),T}(args) end @@ -332,8 +384,7 @@ _procedure(::StatsSpec{A,T}) where {A,T} = T function (sp::StatsSpec{A,T})(; verbose::Bool=false, keep=nothing, keepall::Bool=false) where {A,T} args = verbose ? merge(sp.args, (verbose=true,)) : sp.args - ntall = foldl(|>, T(), init=args) - ntall = _result(T, ntall) + ntall = result(T, foldl(|>, T(), init=args)) if keepall return ntall else @@ -404,9 +455,9 @@ function proceed(sps::AbstractVector{<:StatsSpec}; ntask = 0 verbose && printstyled("Running ", step, "...") taskids = vcat((gids[steps.procs[i]] for i in _sharedby(step))...) - tasks = groupview(r->_getargs(r, step), view(traces, taskids)) + tasks = groupview(r->groupargs(step, r), view(traces, taskids)) for (ins, subtb) in pairs(tasks) - ret = _f(step)(ins..., _combinedargs(step, subtb)...) + ret = _f(step)(ins..., combinedargs(step, subtb)...) if ret isa Tuple{<:NamedTuple, Bool} ret, share = ret else @@ -435,7 +486,7 @@ function proceed(sps::AbstractVector{<:StatsSpec}; ntask_total > 1 ? " tasks" : " task", " for ", nprocs, nprocs > 1 ? " procedures)\n" : " procedure)\n", bold=true, color=:green) for i in 1:nsps - traces[i] = _result(_procedure(sps[i]), traces[i]) + traces[i] = result(_procedure(sps[i]), traces[i]) end if keepall return traces @@ -519,7 +570,8 @@ For end users, `Macro`s that generate `Expr`s for these function calls should be Optional default arguments are merged with the arguments provided for each individual specification -and supersede the default values specified for each procedure through [`namedargs`](@ref). +and supersede any default value associated with each [`StatsStep`](@ref) +via [`default`](@ref). These default arguments should be specified in the same pattern as how arguments are specified for each specification inside the code block, as `@specset` processes these arguments by calling diff --git a/src/procedures.jl b/src/procedures.jl index 56127c3..5817bf9 100644 --- a/src/procedures.jl +++ b/src/procedures.jl @@ -36,7 +36,8 @@ for some preliminary checks of the input data. """ const CheckData = StatsStep{:CheckData, typeof(checkdata)} -namedargs(::CheckData) = (data=nothing, subset=nothing, weightname=nothing) +required(::CheckData) = (:data,) +default(::CheckData) = (subset=nothing, weightname=nothing) function _overlaptime(tr::DynamicTreatment, tr_rows::BitArray, data) control_time = Set(view(getcolumn(data, tr.time), .!tr_rows)) @@ -74,8 +75,8 @@ and find rows with data from treated units. See also [`CheckVars`](@ref). """ function checkvars!(data, tr::AbstractTreatment, pr::AbstractParallel, - yterm::AbstractTerm, treatname::Symbol, treatintterms::Terms, - xterms::Terms, esample::BitArray) + yterm::AbstractTerm, treatname::Symbol, esample::BitArray, + treatintterms::Terms, xterms::Terms) treatvars = union([treatname], (termvars(t) for t in (tr, pr, treatintterms))...) for v in treatvars @@ -106,8 +107,8 @@ Call [`DiffinDiffsBase.checkvars!`](@ref) to exclude invalid rows for relevant v """ const CheckVars = StatsStep{:CheckVars, typeof(checkvars!)} -namedargs(::CheckVars) = (data=nothing, tr=nothing, pr=nothing, - yterm=nothing, treatname=nothing, treatintterms=(), xterms=(), esample=nothing) +required(::CheckVars) = (:data, :tr, :pr, :yterm, :treatname, :esample) +default(::CheckVars) = (treatintterms=(), xterms=()) """ makeweights(args...) @@ -115,13 +116,13 @@ namedargs(::CheckVars) = (data=nothing, tr=nothing, pr=nothing, Construct a generic `Weights` vector. See also [`MakeWeights`](@ref). """ -function makeweights(data, weightname::Symbol, esample::BitArray) +function makeweights(data, esample::BitArray, weightname::Symbol) weights = Weights(convert(Vector{Float64}, view(getcolumn(data, weightname), esample))) all(isfinite, weights) || error("data column $weightname contain not-a-number values") (weights=weights,), true end -function makeweights(data, weightname::Nothing, esample::BitArray) +function makeweights(data, esample::BitArray, weightname::Nothing) weights = uweights(sum(esample)) (weights=weights,), true end @@ -134,7 +135,8 @@ The returned object named `weights` may be shared across multiple specifications """ const MakeWeights = StatsStep{:MakeWeights, typeof(makeweights)} -namedargs(::MakeWeights) = (data=nothing, weightname=nothing, esample=nothing) +required(::MakeWeights) = (:data, :esample) +default(::MakeWeights) = (weightname=nothing,) _getsubcolumns(data, name::Symbol, idx=Colon()) = columntable(NamedTuple{(name,)}((disallowmissing(view(getcolumn(data, name), idx)),))) diff --git a/test/StatsProcedures.jl b/test/StatsProcedures.jl index c7cc09a..d43a32a 100644 --- a/test/StatsProcedures.jl +++ b/test/StatsProcedures.jl @@ -1,43 +1,57 @@ -using DiffinDiffsBase: _f, _getargs, _update, +using DiffinDiffsBase: _f, _get, groupargs, _sharedby, _show_args, _args_kwargs, _parse!, proceed -import DiffinDiffsBase: _getargs, _combinedargs +import DiffinDiffsBase: required, default, transformed, combinedargs testvoidstep(a::String) = NamedTuple(), false const TestVoidStep = StatsStep{:TestVoidStep, typeof(testvoidstep)} -namedargs(::TestVoidStep) = (a="a",) +required(::TestVoidStep) = (:a,) testregstep(a::String, b::String) = (c=a*b,), false const TestRegStep = StatsStep{:TestRegStep, typeof(testregstep)} -namedargs(::TestRegStep) = (a="a", b="b") +default(::TestRegStep) = (a="a", b="b") testlaststep(a::String, c::String) = (result=a*c,), false const TestLastStep = StatsStep{:TestLastStep, typeof(testlaststep)} -namedargs(::TestLastStep) = (a="a", c="b") +default(::TestLastStep) = (a="a",) +transformed(::TestLastStep, ntargs::NamedTuple) = (ntargs.c,) testarraystep(a::String) = (result=[a],), false const TestArrayStep = StatsStep{:TestArrayStep, typeof(testarraystep)} -namedargs(::TestArrayStep) = (a="a",) +default(::TestArrayStep) = (a="a",) testcombinestep(a::String, bs::String...) = (c=collect(bs),), true const TestCombineStep = StatsStep{:TestCombineStep, typeof(testcombinestep)} -namedargs(::TestCombineStep) = (a="a", b=nothing) -_getargs(ntargs::NamedTuple, s::TestCombineStep) = _update((a=ntargs.a,), (a="a",)) -_combinedargs(::TestCombineStep, ntargs) = [nt.b for nt in ntargs] +default(::TestCombineStep) = (a="a",) +combinedargs(::TestCombineStep, ntargs) = [nt.b for nt in ntargs] testinvalidstep(a::String, b::String) = b, false const TestInvalidStep = StatsStep{:TestInvalidStep, typeof(testinvalidstep)} -namedargs(::TestInvalidStep) = (a="a",b="b") +default(::TestInvalidStep) = (a="a",b="b") const TestUnnamedStep = StatsStep{:TestUnnamedStep, typeof(testinvalidstep)} @testset "StatsStep" begin + @testset "_get" begin + @test _get((), NamedTuple()) == () + @test _get((:a,), (a=1, b=2)) == (1,) + @test_throws ErrorException _get((:a,), (b=2,)) + + @test _get(NamedTuple(), NamedTuple()) == () + @test _get((a=1,), (b=2,)) == (1,) + @test _get((a=1,), (a=1, b=2)) == (1,) + @test _get((a=1, b=2), (a=2,)) == (2, 2) + end + @testset "args" begin - @test _getargs(NamedTuple(), TestRegStep()) == (a="a", b="b") - @test _getargs((a="a1",), TestRegStep()) == (a="a1", b="b") - @test _getargs((c="c",), TestRegStep()) == (a="a", b="b") + @test groupargs(TestVoidStep(), (a="a",)) == ("a",) + @test_throws ErrorException groupargs(TestVoidStep(), (b="b",)) + + @test groupargs(TestRegStep(), NamedTuple()) == ("a", "b") + @test groupargs(TestRegStep(), (a="a1",)) == ("a1", "b") + @test groupargs(TestRegStep(), (c="c",)) == ("a", "b") - @test_throws ErrorException namedargs(TestUnnamedStep()) - @test _combinedargs(TestRegStep(), (a="a",)) == () + @test groupargs(TestUnnamedStep(), (a="a", b="b")) == () + @test combinedargs(TestRegStep(), (a="a",)) == () end @testset "teststeps" begin @@ -118,7 +132,7 @@ end @test _sharedby(s1) == (1,) @test _sharedby(s2) == (2,3) @test _f(s1) == testregstep - @test _getargs(NamedTuple(), s1) == (a="a", b="b") + @test groupargs(s1, NamedTuple()) == ("a", "b") @test sprint(show, s1) == "TestRegStep" @test sprint(show, MIME("text/plain"), s1) == @@ -240,7 +254,7 @@ testformatter(nt::NamedTuple) = (haskey(nt, :name) ? nt.name : "", nt.p, (a=nt.a @testset "proceed" begin s1 = StatsSpec("s1", RP, (a="a", b="b")) - s2 = StatsSpec("s2", RP, NamedTuple()) + s2 = StatsSpec("s2", RP, (a="a",)) s3 = StatsSpec("s3", RP, (a="a", b="b1")) s4 = StatsSpec("s4", UP, (a="a", b="b")) s5 = StatsSpec("s5", IP, (a="a", b="b")) @@ -262,8 +276,8 @@ testformatter(nt::NamedTuple) = (haskey(nt, :name) ? nt.name : "", nt.p, (a=nt.a @test proceed([s1], keepall=true) == [(a="a", b="b", c="ab", result="aab")] @test proceed([s2]) == ["aab"] - @test proceed([s2], keep=:a) == [(result="aab",)] - @test proceed([s2], keepall=true) == [(c="ab", result="aab",)] + @test proceed([s2], keep=:b) == [(result="aab",)] + @test proceed([s2], keepall=true) == [(a="a", c="ab", result="aab",)] @test proceed([s1,s4], keep=[:a, :result]) == [(a="a", result="aab"), (a="a",)] diff --git a/test/procedures.jl b/test/procedures.jl index 7888af6..a43b23c 100644 --- a/test/procedures.jl +++ b/test/procedures.jl @@ -23,13 +23,12 @@ "CheckData (StatsStep that calls DiffinDiffsBase.checkdata)" @test _f(CheckData()) == checkdata - @test namedargs(CheckData()) == (data=nothing, subset=nothing, weightname=nothing) hrs = exampledata("hrs") nt = (data=hrs, subset=nothing, weightname=nothing) @test CheckData()(nt) == merge(nt, (esample=trues(size(hrs,1)),)) @test CheckData()((data=hrs,)) == (data=hrs, esample=trues(size(hrs,1))) - @test_throws ArgumentError CheckData()() + @test_throws ErrorException CheckData()() end end @@ -37,7 +36,7 @@ end @testset "checkvars!" begin hrs = exampledata("hrs") nt = (data=hrs, tr=dynamic(:wave, -1), pr=nevertreated(11), yterm=term(:oop_spend), - treatname=:wave_hosp, treatintterms=(), xterms=(), esample=trues(size(hrs,1))) + treatname=:wave_hosp, esample=trues(size(hrs,1)), treatintterms=(), xterms=()) @test checkvars!(nt...) == ((esample=trues(size(hrs,1)), tr_rows=hrs.wave_hosp.!=11), false) @@ -78,8 +77,6 @@ end "CheckVars (StatsStep that calls DiffinDiffsBase.checkvars!)" @test _f(CheckVars()) == checkvars! - @test namedargs(CheckVars()) == (data=nothing, tr=nothing, pr=nothing, - yterm=nothing, treatname=nothing, treatintterms=(), xterms=(), esample=nothing) hrs = exampledata("hrs") nt = (data=hrs, tr=dynamic(:wave, -1), pr=nevertreated(11), yterm=term(:oop_spend), @@ -90,14 +87,14 @@ end treatname=:wave_hosp, esample=trues(size(hrs,1))) @test CheckVars()(nt) == merge(nt, (esample=trues(size(hrs,1)), tr_rows=hrs.wave_hosp.!=11)) - @test_throws MethodError CheckVars()() + @test_throws ErrorException CheckVars()() end end @testset "MakeWeights" begin @testset "makeweights" begin hrs = exampledata("hrs") - nt = (data=hrs, weightname=nothing, esample=trues(size(hrs,1))) + nt = (data=hrs, esample=trues(size(hrs,1)), weightname=nothing) r, s = makeweights(nt...) @test r.weights isa UnitWeights && sum(r.weights) == size(hrs,1) && s @@ -112,7 +109,6 @@ end "MakeWeights (StatsStep that calls DiffinDiffsBase.makeweights)" @test _f(MakeWeights()) == makeweights - @test namedargs(MakeWeights()) == (data=nothing, weightname=nothing, esample=nothing) hrs = exampledata("hrs") nt = (data=hrs, esample=trues(size(hrs,1))) diff --git a/test/runtests.jl b/test/runtests.jl index 88f7b31..88aa273 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,11 +3,11 @@ using DiffinDiffsBase using DataFrames using DiffinDiffsBase: unpack, @unpack, hastreat, parse_treat, - _f, checkdata, checkvars!, makeweights, _getsubcolumns, parse_didargs + _f, groupargs, pool, checkdata, checkvars!, makeweights, _getsubcolumns, parse_didargs using StatsBase: Weights, UnitWeights using StatsModels: termvars -import DiffinDiffsBase: valid_didargs, namedargs +import DiffinDiffsBase: required, valid_didargs include("testutils.jl") diff --git a/test/testutils.jl b/test/testutils.jl index fe2b486..c90208b 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -25,11 +25,11 @@ tpara(c::ConstantTerm) = TestParallel{ParallelCondition,ParallelStrength}(c.n) teststep(tr::AbstractTreatment, pr::AbstractParallel) = ((str=sprint(show, tr), spr=sprint(show, pr)), false) const TestStep = StatsStep{:TestStep, typeof(teststep)} -namedargs(::TestStep) = (tr=nothing, pr=nothing) +required(::TestStep) = (:tr, :pr) testresult(::AbstractTreatment, ::String) = ((result="testresult",), false) const TestResult = StatsStep{:TestResult, typeof(testresult)} -namedargs(::TestResult) = (tr=nothing, str=nothing) +required(::TestResult) = (:tr, :str) const TestDID = DiffinDiffsEstimator{:TestDID, Tuple{TestStep,TestResult}} const NotImplemented = DiffinDiffsEstimator{:NotImplemented, Tuple{}}