Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/DiffinDiffsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,9 @@ export cb,
treat,

StatsStep,
namedargs,
AbstractStatsProcedure,
SharedStatsStep,
PooledStatsProcedure,
pool,
StatsSpec,
proceed,
@specset,
Expand Down
96 changes: 74 additions & 22 deletions src/StatsProcedures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ An instance of `StatsStep` is callable.
# Methods
(step::StatsStep{A,F})(ntargs::NamedTuple; verbose::Bool=false)

Call an instance of function of type `F` with arguments
formed by updating `NamedTuple` returned by `[`namedargs(step)`](@ref)` with `ntargs`.
Call an instance of function of type `F` with arguments extracted from `ntargs`
via [`groupargs`](@ref) and [`combinedargs`](@ref).

A message with the name of the `StatsStep` is printed to `stdout`
if a keyword `verbose` takes the value `true`
Expand All @@ -28,23 +28,75 @@ struct StatsStep{Alias, F<:Function} end
_f(::StatsStep{A,F}) where {A,F} = F.instance

"""
namedargs(s::StatsStep)
required(s::StatsStep)

Return a `NamedTuple` with keys showing the names of arguments
accepted by `s` and values representing the defaults.
Return a tuple of `Symbol`s representing the names of arguments
used to form [`groupargs`](@ref) that do not have defaults.
See also [`default`](@ref) and [`transformed`](@ref).
"""
namedargs(s::StatsStep) = error("method for $(typeof(s)) is not defined")
required(::StatsStep) = ()

_getargs(ntargs::NamedTuple, s::StatsStep) = _update(ntargs, namedargs(s))
_update(a::NamedTuple{N1}, b::NamedTuple{N2}) where {N1,N2} =
NamedTuple{N2}(map(n->getfield(sym_in(n, N1) ? a : b, n), N2))
"""
default(s::StatsStep)

Return a `NamedTuple` of arguments with keys showing the names
and values representing the defaults to be used to form [`groupargs`](@ref).
See also [`required`](@ref) and [`transformed`](@ref).
"""
default(::StatsStep) = NamedTuple()

"""
transformed(s::StatsStep, ntargs::NamedTuple)

Return a tuple of arguments transformed from fields in `ntargs`
to be used to form [`groupargs`](@ref).
See also [`required`](@ref) and [`default`](@ref).
"""
transformed(::StatsStep, ::NamedTuple) = ()

_get(a::NTuple{N,Symbol}, @nospecialize(nt::NamedTuple)) where N =
map(s->getfield(nt, s), a)
_get(a::NamedTuple{S1}, nt::NamedTuple{S2}) where {S1,S2} =
map(s->getfield(ifelse(sym_in(s, S2), nt, a), s), S1)

_combinedargs(::StatsStep, ::Any) = ()
"""
groupargs(s::StatsStep, ntargs::NamedTuple)

Return a tuple of arguments that allow classifying multiple `ntargs`s into groups.
Equality (defined by `isequal`) of the returned tuples across `ntargs`s imply that
it is possible to exectute step `s` for only once
to obtain results for these `ntargs`s.

This function is important for [`proceed`](@ref) to work properly.
However, in most cases, there is no need to define new methods
for concrete [`StatsStep`](@ref)s.
Instead, one should define methods for
[`required`](@ref), [`default`](@ref) or [`transformed`](@ref).
See also [`combinedargs`](@ref).
"""
groupargs(s::StatsStep, @nospecialize(ntargs::NamedTuple)) =
(_get(required(s), ntargs)..., _get(default(s), ntargs)...,
transformed(s, ntargs)...)

"""
combinedargs(s::StatsStep, allntargs::Any)

Return a tuple of arguments obtained by combining a collection of arguments
across multiple specifications.

The element type of `allntargs` can be assumed to be `NamedTuple`.
This function allows combining arguments that differ
across specifications in the same group classified based on [`groupargs`](@ref)
into objects that are accepted by the call of `s`.
See also [`proceed`](@ref).
"""
combinedargs(::StatsStep, ::Any) = ()

function (step::StatsStep{A,F})(ntargs::NamedTuple; verbose::Bool=false) where {A,F}
function (step::StatsStep{A,F})(@nospecialize(ntargs::NamedTuple);
verbose::Bool=false) where {A,F}
haskey(ntargs, :verbose) && (verbose = ntargs.verbose)
verbose && printstyled("Running ", step, "\n", color=:green)
ret = F.instance(_getargs(ntargs, step)..., _combinedargs(step, (ntargs,))...)
ret = F.instance(groupargs(step, ntargs)..., combinedargs(step, (ntargs,))...)
if ret isa Tuple{<:NamedTuple, Bool}
return merge(ntargs, ret[1])
else
Expand Down Expand Up @@ -77,7 +129,7 @@ all subtypes of `AbstractStatsProcedure`.
"""
abstract type AbstractStatsProcedure{Alias, T<:NTuple{N,StatsStep} where N} end

_result(::Type{<:AbstractStatsProcedure}, ntargs::NamedTuple) = ntargs
result(::Type{<:AbstractStatsProcedure}, @nospecialize(ntargs::NamedTuple)) = ntargs

length(::AbstractStatsProcedure{A,T}) where {A,T} = length(T.parameters)
eltype(::Type{<:AbstractStatsProcedure}) = StatsStep
Expand Down Expand Up @@ -131,8 +183,8 @@ end

_sharedby(::SharedStatsStep{T,I}) where {T,I} = I
_f(s::SharedStatsStep) = _f(s.step)
_getargs(ntargs::NamedTuple, s::SharedStatsStep) = _getargs(ntargs, s.step)
_combinedargs(s::SharedStatsStep, v::AbstractArray) = _combinedargs(s.step, v)
groupargs(s::SharedStatsStep, @nospecialize(ntargs::NamedTuple)) = groupargs(s.step, ntargs)
combinedargs(s::SharedStatsStep, v::AbstractArray) = combinedargs(s.step, v)

show(io::IO, s::SharedStatsStep) = print(io, s.step)

Expand Down Expand Up @@ -303,7 +355,7 @@ Otherwise, the last value returned by the last [`StatsStep`](@ref) is returned.
struct StatsSpec{Alias, T<:AbstractStatsProcedure}
args::NamedTuple
StatsSpec(name::Union{Symbol,String},
T::Type{<:AbstractStatsProcedure}, args::NamedTuple) =
T::Type{<:AbstractStatsProcedure}, @nospecialize(args::NamedTuple)) =
new{Symbol(name),T}(args)
end

Expand Down Expand Up @@ -332,8 +384,7 @@ _procedure(::StatsSpec{A,T}) where {A,T} = T
function (sp::StatsSpec{A,T})(;
verbose::Bool=false, keep=nothing, keepall::Bool=false) where {A,T}
args = verbose ? merge(sp.args, (verbose=true,)) : sp.args
ntall = foldl(|>, T(), init=args)
ntall = _result(T, ntall)
ntall = result(T, foldl(|>, T(), init=args))
if keepall
return ntall
else
Expand Down Expand Up @@ -404,9 +455,9 @@ function proceed(sps::AbstractVector{<:StatsSpec};
ntask = 0
verbose && printstyled("Running ", step, "...")
taskids = vcat((gids[steps.procs[i]] for i in _sharedby(step))...)
tasks = groupview(r->_getargs(r, step), view(traces, taskids))
tasks = groupview(r->groupargs(step, r), view(traces, taskids))
for (ins, subtb) in pairs(tasks)
ret = _f(step)(ins..., _combinedargs(step, subtb)...)
ret = _f(step)(ins..., combinedargs(step, subtb)...)
if ret isa Tuple{<:NamedTuple, Bool}
ret, share = ret
else
Expand Down Expand Up @@ -435,7 +486,7 @@ function proceed(sps::AbstractVector{<:StatsSpec};
ntask_total > 1 ? " tasks" : " task", " for ", nprocs,
nprocs > 1 ? " procedures)\n" : " procedure)\n", bold=true, color=:green)
for i in 1:nsps
traces[i] = _result(_procedure(sps[i]), traces[i])
traces[i] = result(_procedure(sps[i]), traces[i])
end
if keepall
return traces
Expand Down Expand Up @@ -519,7 +570,8 @@ For end users, `Macro`s that generate `Expr`s for these function calls should be

Optional default arguments are merged
with the arguments provided for each individual specification
and supersede the default values specified for each procedure through [`namedargs`](@ref).
and supersede any default value associated with each [`StatsStep`](@ref)
via [`default`](@ref).
These default arguments should be specified in the same pattern as
how arguments are specified for each specification inside the code block,
as `@specset` processes these arguments by calling
Expand Down
18 changes: 10 additions & 8 deletions src/procedures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ for some preliminary checks of the input data.
"""
const CheckData = StatsStep{:CheckData, typeof(checkdata)}

namedargs(::CheckData) = (data=nothing, subset=nothing, weightname=nothing)
required(::CheckData) = (:data,)
default(::CheckData) = (subset=nothing, weightname=nothing)

function _overlaptime(tr::DynamicTreatment, tr_rows::BitArray, data)
control_time = Set(view(getcolumn(data, tr.time), .!tr_rows))
Expand Down Expand Up @@ -74,8 +75,8 @@ and find rows with data from treated units.
See also [`CheckVars`](@ref).
"""
function checkvars!(data, tr::AbstractTreatment, pr::AbstractParallel,
yterm::AbstractTerm, treatname::Symbol, treatintterms::Terms,
xterms::Terms, esample::BitArray)
yterm::AbstractTerm, treatname::Symbol, esample::BitArray,
treatintterms::Terms, xterms::Terms)

treatvars = union([treatname], (termvars(t) for t in (tr, pr, treatintterms))...)
for v in treatvars
Expand Down Expand Up @@ -106,22 +107,22 @@ Call [`DiffinDiffsBase.checkvars!`](@ref) to exclude invalid rows for relevant v
"""
const CheckVars = StatsStep{:CheckVars, typeof(checkvars!)}

namedargs(::CheckVars) = (data=nothing, tr=nothing, pr=nothing,
yterm=nothing, treatname=nothing, treatintterms=(), xterms=(), esample=nothing)
required(::CheckVars) = (:data, :tr, :pr, :yterm, :treatname, :esample)
default(::CheckVars) = (treatintterms=(), xterms=())

"""
makeweights(args...)

Construct a generic `Weights` vector.
See also [`MakeWeights`](@ref).
"""
function makeweights(data, weightname::Symbol, esample::BitArray)
function makeweights(data, esample::BitArray, weightname::Symbol)
weights = Weights(convert(Vector{Float64}, view(getcolumn(data, weightname), esample)))
all(isfinite, weights) || error("data column $weightname contain not-a-number values")
(weights=weights,), true
end

function makeweights(data, weightname::Nothing, esample::BitArray)
function makeweights(data, esample::BitArray, weightname::Nothing)
weights = uweights(sum(esample))
(weights=weights,), true
end
Expand All @@ -134,7 +135,8 @@ The returned object named `weights` may be shared across multiple specifications
"""
const MakeWeights = StatsStep{:MakeWeights, typeof(makeweights)}

namedargs(::MakeWeights) = (data=nothing, weightname=nothing, esample=nothing)
required(::MakeWeights) = (:data, :esample)
default(::MakeWeights) = (weightname=nothing,)

_getsubcolumns(data, name::Symbol, idx=Colon()) =
columntable(NamedTuple{(name,)}((disallowmissing(view(getcolumn(data, name), idx)),)))
Expand Down
52 changes: 33 additions & 19 deletions test/StatsProcedures.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,57 @@
using DiffinDiffsBase: _f, _getargs, _update,
using DiffinDiffsBase: _f, _get, groupargs,
_sharedby, _show_args, _args_kwargs, _parse!, proceed
import DiffinDiffsBase: _getargs, _combinedargs
import DiffinDiffsBase: required, default, transformed, combinedargs

testvoidstep(a::String) = NamedTuple(), false
const TestVoidStep = StatsStep{:TestVoidStep, typeof(testvoidstep)}
namedargs(::TestVoidStep) = (a="a",)
required(::TestVoidStep) = (:a,)

testregstep(a::String, b::String) = (c=a*b,), false
const TestRegStep = StatsStep{:TestRegStep, typeof(testregstep)}
namedargs(::TestRegStep) = (a="a", b="b")
default(::TestRegStep) = (a="a", b="b")

testlaststep(a::String, c::String) = (result=a*c,), false
const TestLastStep = StatsStep{:TestLastStep, typeof(testlaststep)}
namedargs(::TestLastStep) = (a="a", c="b")
default(::TestLastStep) = (a="a",)
transformed(::TestLastStep, ntargs::NamedTuple) = (ntargs.c,)

testarraystep(a::String) = (result=[a],), false
const TestArrayStep = StatsStep{:TestArrayStep, typeof(testarraystep)}
namedargs(::TestArrayStep) = (a="a",)
default(::TestArrayStep) = (a="a",)

testcombinestep(a::String, bs::String...) = (c=collect(bs),), true
const TestCombineStep = StatsStep{:TestCombineStep, typeof(testcombinestep)}
namedargs(::TestCombineStep) = (a="a", b=nothing)
_getargs(ntargs::NamedTuple, s::TestCombineStep) = _update((a=ntargs.a,), (a="a",))
_combinedargs(::TestCombineStep, ntargs) = [nt.b for nt in ntargs]
default(::TestCombineStep) = (a="a",)
combinedargs(::TestCombineStep, ntargs) = [nt.b for nt in ntargs]

testinvalidstep(a::String, b::String) = b, false
const TestInvalidStep = StatsStep{:TestInvalidStep, typeof(testinvalidstep)}
namedargs(::TestInvalidStep) = (a="a",b="b")
default(::TestInvalidStep) = (a="a",b="b")

const TestUnnamedStep = StatsStep{:TestUnnamedStep, typeof(testinvalidstep)}

@testset "StatsStep" begin
@testset "_get" begin
@test _get((), NamedTuple()) == ()
@test _get((:a,), (a=1, b=2)) == (1,)
@test_throws ErrorException _get((:a,), (b=2,))

@test _get(NamedTuple(), NamedTuple()) == ()
@test _get((a=1,), (b=2,)) == (1,)
@test _get((a=1,), (a=1, b=2)) == (1,)
@test _get((a=1, b=2), (a=2,)) == (2, 2)
end

@testset "args" begin
@test _getargs(NamedTuple(), TestRegStep()) == (a="a", b="b")
@test _getargs((a="a1",), TestRegStep()) == (a="a1", b="b")
@test _getargs((c="c",), TestRegStep()) == (a="a", b="b")
@test groupargs(TestVoidStep(), (a="a",)) == ("a",)
@test_throws ErrorException groupargs(TestVoidStep(), (b="b",))

@test groupargs(TestRegStep(), NamedTuple()) == ("a", "b")
@test groupargs(TestRegStep(), (a="a1",)) == ("a1", "b")
@test groupargs(TestRegStep(), (c="c",)) == ("a", "b")

@test_throws ErrorException namedargs(TestUnnamedStep())
@test _combinedargs(TestRegStep(), (a="a",)) == ()
@test groupargs(TestUnnamedStep(), (a="a", b="b")) == ()
@test combinedargs(TestRegStep(), (a="a",)) == ()
end

@testset "teststeps" begin
Expand Down Expand Up @@ -118,7 +132,7 @@ end
@test _sharedby(s1) == (1,)
@test _sharedby(s2) == (2,3)
@test _f(s1) == testregstep
@test _getargs(NamedTuple(), s1) == (a="a", b="b")
@test groupargs(s1, NamedTuple()) == ("a", "b")

@test sprint(show, s1) == "TestRegStep"
@test sprint(show, MIME("text/plain"), s1) ==
Expand Down Expand Up @@ -240,7 +254,7 @@ testformatter(nt::NamedTuple) = (haskey(nt, :name) ? nt.name : "", nt.p, (a=nt.a

@testset "proceed" begin
s1 = StatsSpec("s1", RP, (a="a", b="b"))
s2 = StatsSpec("s2", RP, NamedTuple())
s2 = StatsSpec("s2", RP, (a="a",))
s3 = StatsSpec("s3", RP, (a="a", b="b1"))
s4 = StatsSpec("s4", UP, (a="a", b="b"))
s5 = StatsSpec("s5", IP, (a="a", b="b"))
Expand All @@ -262,8 +276,8 @@ testformatter(nt::NamedTuple) = (haskey(nt, :name) ? nt.name : "", nt.p, (a=nt.a
@test proceed([s1], keepall=true) == [(a="a", b="b", c="ab", result="aab")]

@test proceed([s2]) == ["aab"]
@test proceed([s2], keep=:a) == [(result="aab",)]
@test proceed([s2], keepall=true) == [(c="ab", result="aab",)]
@test proceed([s2], keep=:b) == [(result="aab",)]
@test proceed([s2], keepall=true) == [(a="a", c="ab", result="aab",)]

@test proceed([s1,s4], keep=[:a, :result]) ==
[(a="a", result="aab"), (a="a",)]
Expand Down
12 changes: 4 additions & 8 deletions test/procedures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,20 @@
"CheckData (StatsStep that calls DiffinDiffsBase.checkdata)"

@test _f(CheckData()) == checkdata
@test namedargs(CheckData()) == (data=nothing, subset=nothing, weightname=nothing)

hrs = exampledata("hrs")
nt = (data=hrs, subset=nothing, weightname=nothing)
@test CheckData()(nt) == merge(nt, (esample=trues(size(hrs,1)),))
@test CheckData()((data=hrs,)) == (data=hrs, esample=trues(size(hrs,1)))
@test_throws ArgumentError CheckData()()
@test_throws ErrorException CheckData()()
end
end

@testset "CheckVars" begin
@testset "checkvars!" begin
hrs = exampledata("hrs")
nt = (data=hrs, tr=dynamic(:wave, -1), pr=nevertreated(11), yterm=term(:oop_spend),
treatname=:wave_hosp, treatintterms=(), xterms=(), esample=trues(size(hrs,1)))
treatname=:wave_hosp, esample=trues(size(hrs,1)), treatintterms=(), xterms=())
@test checkvars!(nt...) == ((esample=trues(size(hrs,1)),
tr_rows=hrs.wave_hosp.!=11), false)

Expand Down Expand Up @@ -78,8 +77,6 @@ end
"CheckVars (StatsStep that calls DiffinDiffsBase.checkvars!)"

@test _f(CheckVars()) == checkvars!
@test namedargs(CheckVars()) == (data=nothing, tr=nothing, pr=nothing,
yterm=nothing, treatname=nothing, treatintterms=(), xterms=(), esample=nothing)

hrs = exampledata("hrs")
nt = (data=hrs, tr=dynamic(:wave, -1), pr=nevertreated(11), yterm=term(:oop_spend),
Expand All @@ -90,14 +87,14 @@ end
treatname=:wave_hosp, esample=trues(size(hrs,1)))
@test CheckVars()(nt) ==
merge(nt, (esample=trues(size(hrs,1)), tr_rows=hrs.wave_hosp.!=11))
@test_throws MethodError CheckVars()()
@test_throws ErrorException CheckVars()()
end
end

@testset "MakeWeights" begin
@testset "makeweights" begin
hrs = exampledata("hrs")
nt = (data=hrs, weightname=nothing, esample=trues(size(hrs,1)))
nt = (data=hrs, esample=trues(size(hrs,1)), weightname=nothing)
r, s = makeweights(nt...)
@test r.weights isa UnitWeights && sum(r.weights) == size(hrs,1) && s

Expand All @@ -112,7 +109,6 @@ end
"MakeWeights (StatsStep that calls DiffinDiffsBase.makeweights)"

@test _f(MakeWeights()) == makeweights
@test namedargs(MakeWeights()) == (data=nothing, weightname=nothing, esample=nothing)

hrs = exampledata("hrs")
nt = (data=hrs, esample=trues(size(hrs,1)))
Expand Down
Loading