Skip to content

Commit

Permalink
Remove reversibility from transforms: 'Map', 'Replace', 'Sample' and …
Browse files Browse the repository at this point in the history
…'Sort' (#258)
  • Loading branch information
eliascarv committed Dec 14, 2023
1 parent 41d588b commit e0e344e
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 177 deletions.
27 changes: 2 additions & 25 deletions src/transforms/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function Map(pairs::MapPair...)
Map(selectors, funs, targets)
end

isrevertible(::Type{Map}) = true
isrevertible(::Type{Map}) = false

_makename(snames, fun) = Symbol(join([snames; nameof(fun)], "_"))

Expand All @@ -73,10 +73,6 @@ function applyfeat(transform::Map, feat, prep)
names = collect(onames)
columns = Any[Tables.getcolumn(cols, nm) for nm in onames]

# replaced names and columns
rnames = empty(names)
rcolumns = empty(columns)

# mapped columns
mapped = map(selectors, funs, targets) do selector, fun, target
snames = selector(names)
Expand All @@ -88,9 +84,7 @@ function applyfeat(transform::Map, feat, prep)

for (name, column) in mapped
if name onames
push!(rnames, name)
i = findfirst(==(name), onames)
push!(rcolumns, columns[i])
columns[i] = column
else
push!(names, name)
Expand All @@ -100,22 +94,5 @@ function applyfeat(transform::Map, feat, prep)

𝒯 = (; zip(names, columns)...)
newfeat = 𝒯 |> Tables.materializer(feat)
newfeat, (onames, rnames, rcolumns)
end

function revertfeat(::Map, newfeat, fcache)
cols = Tables.columns(newfeat)

onames, rnames, rcolumns = fcache
ocolumns = map(onames) do name
if name rnames
i = findfirst(==(name), rnames)
rcolumns[i]
else
Tables.getcolumn(cols, name)
end
end

𝒯 = (; zip(onames, ocolumns)...)
𝒯 |> Tables.materializer(newfeat)
newfeat, nothing
end
30 changes: 5 additions & 25 deletions src/transforms/replace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ function Replace(pairs::Pair...)
Replace(selectors, preds, news)
end

isrevertible(::Type{<:Replace}) = true
isrevertible(::Type{<:Replace}) = false

function applyfeat(transform::Replace, feat, prep)
cols = Tables.columns(feat)
Expand All @@ -81,43 +81,23 @@ function applyfeat(transform::Replace, feat, prep)
name => reps
end

tuples = map(colreps) do (name, reps)
columns = map(colreps) do (name, reps)
x = Tables.getcolumn(cols, name)
if isnothing(reps)
x, nothing
x
else
# reversal dict
rev = Dict{Int,eltype(x)}()
y = map(enumerate(x)) do (i, v)
map(x) do v
for (pred, new) in reps
if pred(v)
rev[i] = v
return new
end
end
v
end
y, rev
end
end

columns = first.(tuples)
fcache = last.(tuples)

𝒯 = (; zip(names, columns)...)
newfeat = 𝒯 |> Tables.materializer(feat)
newfeat, fcache
end

function revertfeat(::Replace, newfeat, fcache)
cols = Tables.columns(newfeat)
names = Tables.columnnames(cols)

columns = map(names, fcache) do name, rev
y = Tables.getcolumn(cols, name)
isnothing(rev) ? y : [get(rev, i, y[i]) for i in 1:length(y)]
end

𝒯 = (; zip(names, columns)...)
𝒯 |> Tables.materializer(newfeat)
newfeat, nothing
end
40 changes: 5 additions & 35 deletions src/transforms/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Sample(size::Int, weights::AbstractWeights; replace=false, ordered=false, rng=Ra

Sample(size::Int, weights; kwargs...) = Sample(size, Weights(collect(weights)); kwargs...)

isrevertible(::Type{<:Sample}) = true
isrevertible(::Type{<:Sample}) = false

function preprocess(transform::Sample, feat)
# retrieve valid indices
Expand All @@ -59,46 +59,16 @@ function preprocess(transform::Sample, feat)
sample(rng, inds, weights, size; replace, ordered)
end

# rejected indices
rinds = setdiff(inds, sinds)

sinds, rinds
sinds
end

function applyfeat(::Sample, feat, prep)
# preprocessed indices
sinds, rinds = prep
sinds = prep

# selected/rejected rows
# selected rows
srows = Tables.subset(feat, sinds, viewhint=true)
rrows = Tables.subset(feat, rinds, viewhint=true)

newfeat = srows |> Tables.materializer(feat)
newfeat, (sinds, rinds, rrows)
end

function revertfeat(::Sample, newfeat, fcache)
cols = Tables.columns(newfeat)
names = Tables.columnnames(cols)

sinds, rinds, rrows = fcache

# columns with selected rows in original order
uinds = indexin(sort(unique(sinds)), sinds)
columns = map(names) do name
y = Tables.getcolumn(cols, name)
[y[i] for i in uinds]
end

# insert rejected rows into columns
rrcols = Tables.columns(rrows)
for (name, x) in zip(names, columns)
r = Tables.getcolumn(rrcols, name)
for (i, v) in zip(rinds, r)
insert!(x, i, v)
end
end

𝒯 = (; zip(names, columns)...)
𝒯 |> Tables.materializer(newfeat)
newfeat, nothing
end
17 changes: 2 additions & 15 deletions src/transforms/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Sort(cols::C...; kwargs...) where {C<:Column} = Sort(selector(cols), values(kwar

Sort(; kwargs...) = throw(ArgumentError("cannot create Sort transform without arguments"))

isrevertible(::Type{<:Sort}) = true
isrevertible(::Type{<:Sort}) = false

function preprocess(transform::Sort, feat)
cols = Tables.columns(feat)
Expand All @@ -59,18 +59,5 @@ function applyfeat(::Sort, feat, prep)

newfeat = srows |> Tables.materializer(feat)

newfeat, sinds
end

function revertfeat(::Sort, newfeat, fcache)
# collect all rows
rows = Tables.rowtable(newfeat)

# reverting indices
sinds = fcache
rinds = sortperm(sinds)

rrows = view(rows, rinds)

rrows |> Tables.materializer(newfeat)
newfeat, nothing
end
20 changes: 2 additions & 18 deletions test/transforms/map.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
@testset "Map" begin
@test !isrevertible(Map(:a => sin))

a = [4, 7, 8, 5, 8, 1]
b = [1, 9, 1, 7, 9, 4]
c = [2, 8, 6, 3, 2, 2]
Expand All @@ -9,65 +11,47 @@
n, c = apply(T, t)
@test Tables.schema(n).names == (:a, :b, :c, :d, :a_sin)
@test n.a_sin == sin.(t.a)
tₒ = revert(T, n, c)
@test t == tₒ

T = Map(:b => cos)
n, c = apply(T, t)
@test Tables.schema(n).names == (:a, :b, :c, :d, :b_cos)
@test n.b_cos == cos.(t.b)
tₒ = revert(T, n, c)
@test t == tₒ

T = Map("c" => tan)
n, c = apply(T, t)
@test Tables.schema(n).names == (:a, :b, :c, :d, :c_tan)
@test n.c_tan == tan.(t.c)
tₒ = revert(T, n, c)
@test t == tₒ

T = Map(:a => sin => :a)
n, c = apply(T, t)
@test Tables.schema(n).names == (:a, :b, :c, :d)
@test n.a == sin.(t.a)
tₒ = revert(T, n, c)
@test t == tₒ

T = Map(:a => sin => "a")
n, c = apply(T, t)
@test Tables.schema(n).names == (:a, :b, :c, :d)
@test n.a == sin.(t.a)
tₒ = revert(T, n, c)
@test t == tₒ

T = Map([2, 3] => ((b, c) -> 2b + c) => :op1)
n, c = apply(T, t)
@test Tables.schema(n).names == (:a, :b, :c, :d, :op1)
@test n.op1 == @. 2 * t.b + t.c
tₒ = revert(T, n, c)
@test t == tₒ

T = Map([:a, :c] => ((a, c) -> 2a * 3c) => :op1)
n, c = apply(T, t)
@test Tables.schema(n).names == (:a, :b, :c, :d, :op1)
@test n.op1 == @. 2 * t.a * 3 * t.c
tₒ = revert(T, n, c)
@test t == tₒ

T = Map(["c", "a"] => ((c, a) -> 3c / a) => :op1, "c" => tan)
n, c = apply(T, t)
@test Tables.schema(n).names == (:a, :b, :c, :d, :op1, :c_tan)
@test n.op1 == @. 3 * t.c / t.a
@test n.c_tan == tan.(t.c)
tₒ = revert(T, n, c)
@test t == tₒ

T = Map(r"[abc]" => ((a, b, c) -> a^2 - 2b + c) => "op1")
n, c = apply(T, t)
@test Tables.schema(n).names == (:a, :b, :c, :d, :op1)
@test n.op1 == @. t.a^2 - 2 * t.b + t.c
tₒ = revert(T, n, c)
@test t == tₒ

# throws
@test_throws ArgumentError Map()
Expand Down
Loading

0 comments on commit e0e344e

Please sign in to comment.