Skip to content

Commit

Permalink
added @slice_head and @slice_tail
Browse files Browse the repository at this point in the history
  • Loading branch information
drizk1 committed Nov 18, 2023
1 parent 42fb9ee commit 6086a5d
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TidierData.jl currently supports the following top-level macros:
- `@mutate()` and `@transmute()`
- `@summarize()` and `@summarise()`
- `@filter()`
- `@slice()`, `@slice_sample()`, `@slice_min()`, and `@slice_max()`
- `@slice()`, `@slice_sample()`, `@slice_min()`, `@slice_max()`, `@slice_head()`, and `@slice_tail()`
- `@group_by()` and `@ungroup()`
- `@arrange()`
- `@pull()`
Expand Down
12 changes: 12 additions & 0 deletions docs/examples/UserGuide/slice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,15 @@ end
@chain df begin
@slice_max(b, prop = .5)
end

# ## Slice the tail

@chain df begin
@slice_tail(prop = .5)
end

# ## Slice the head

@chain df begin
@slice_head(n = 3)
end
2 changes: 1 addition & 1 deletion src/TidierData.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export TidierData_set, across, desc, n, row_number, everything, starts_with, end
as_float, as_integer, as_string, is_float, is_integer, is_string, missing_if, replace_missing, @select, @transmute, @rename, @mutate, @summarize, @summarise, @filter,
@group_by, @ungroup, @slice, @arrange, @distinct, @pull, @left_join, @right_join, @inner_join, @full_join,
@pivot_wider, @pivot_longer, @bind_rows, @bind_cols, @clean_names, @count, @tally, @drop_missing, @glimpse, @separate,
@unite, @summary, @fill_missing, @slice_sample, @slice_min, @slice_max, @rename_with
@unite, @summary, @fill_missing, @slice_sample, @slice_min, @slice_max, @rename_with, @slice_head, @slice_tail

# Package global variables
const code = Ref{Bool}(false) # output DataFrames.jl code?
Expand Down
83 changes: 83 additions & 0 deletions src/docstrings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2568,6 +2568,89 @@ julia> @chain df begin
3 │ 0.2 2.0 0.2
```
"""

const docstring_slice_head =
"""
@slice_head(df; n, prop)
Retrieve rows in the beginning of a DataFrame.
# Arguments
- `df`: The source data frame or grouped data frame from which to slice rows.
- `prop`: The proportion of rows to slice.
- `n`: An optional integer argument to specify the number of rows at the beginning of the dataframe to retrieve. Defaults to 1.
# Examples
```jldoctest
julia> df = DataFrame(
a = [missing, 0.2, missing, missing, 1, missing, 5, 6],
b = [0.3, 2, missing, 0.3, 6, 5, 7, 7],
c = [0.2, 0.2, 0.2, missing, 1, missing, 5, 6]);
julia> @chain df begin
@slice_head(n = 3)
end
3×3 DataFrame
Row │ a b c
│ Float64? Float64? Float64?
─────┼────────────────────────────────
1 │ missing 0.3 0.2
2 │ 0.2 2.0 0.2
3 │ missing missing 0.2
julia> @chain df begin
@slice_head(prop = .25)
end
2×3 DataFrame
Row │ a b c
│ Float64? Float64? Float64?
─────┼───────────────────────────────
1 │ missing 0.3 0.2
2 │ 0.2 2.0 0.2
```
"""

const docstring_slice_tail =
"""
@slice_tail(df; n, prop)
Retrieve rows in the beginning of a DataFrame.
# Arguments
- `df`: The source data frame or grouped data frame from which to slice rows.
- `prop`: The proportion of rows to slice.
- `n`: An optional integer argument to specify the number of rows at the beginning of the dataframe to retrieve. Defaults to 1.
# Examples
```jldoctest
julia> df = DataFrame(
a = [missing, 0.2, missing, missing, 1, missing, 5, 6],
b = [0.3, 2, missing, 0.3, 6, 5, 7, 7],
c = [0.2, 0.2, 0.2, missing, 1, missing, 5, 6]);
julia> @chain df begin
@slice_tail(n = 3)
end
3×3 DataFrame
Row │ a b c
│ Float64? Float64? Float64?
─────┼────────────────────────────────
1 │ missing 5.0 missing
2 │ 5.0 7.0 5.0
3 │ 6.0 7.0 6.0
julia> @chain df begin
@slice_tail(prop = .25)
end
2×3 DataFrame
Row │ a b c
│ Float64? Float64? Float64?
─────┼──────────────────────────────
1 │ 5.0 7.0 5.0
2 │ 6.0 7.0 6.0
```
"""

const docstring_missing_if =
"""
missing_if(x, value)
Expand Down
94 changes: 94 additions & 0 deletions src/slice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,98 @@ macro slice_min(df, exprs...)
temp_df
end
end
end

"""
$docstring_slice_head
"""
macro slice_head(df, exprs...)
expr_dict = :(Dict())

for expr in exprs
if @capture(expr, lhs_ = rhs_)
push!(expr_dict.args, :($(QuoteNode(lhs)) => $(esc(rhs))))
end
end
return quote
expr_dict = $expr_dict
temp_df = $(esc(df))
grouping_cols = Symbol[]

if temp_df isa DataFrames.GroupedDataFrame
grouping_cols = DataFrames.groupcols(temp_df)
end
local n = get(expr_dict, :n, 1)
local prop_val = get(expr_dict, :prop, 1.0)
if prop_val < 0.0 || prop_val > 1.0
throw(ArgumentError("Prop value should be between 0 and 1"))
end
if temp_df isa DataFrames.GroupedDataFrame
result_dfs = []
for sdf in temp_df
local group_n = n
if prop_val != 1.0
group_n = floor(Int, nrow(sdf) * prop_val)
end
push!(result_dfs, first(sdf, group_n))
end
temp_df = vcat(result_dfs...)
else
if prop_val != 1.0
n = floor(Int, nrow(temp_df) * prop_val)
end
temp_df = first(temp_df, n)
end

if !isempty(grouping_cols)
temp_df = DataFrames.groupby(temp_df, grouping_cols)
end
temp_df
end
end

"""
$docstring_slice_tail
"""
macro slice_tail(df, exprs...)
expr_dict = :(Dict())
for expr in exprs
if @capture(expr, lhs_ = rhs_)
push!(expr_dict.args, :($(QuoteNode(lhs)) => $(esc(rhs))))
end
end
return quote
expr_dict = $expr_dict
temp_df = $(esc(df))
grouping_cols = Symbol[]
if temp_df isa DataFrames.GroupedDataFrame
grouping_cols = DataFrames.groupcols(temp_df)
end
local n = get(expr_dict, :n, 1)
local prop_val = get(expr_dict, :prop, 1.0)
if prop_val < 0.0 || prop_val > 1.0
throw(ArgumentError("Prop value should be between 0 and 1"))
end
if temp_df isa DataFrames.GroupedDataFrame
result_dfs = []
for sdf in temp_df
local group_n = n
if prop_val != 1.0
group_n = floor(Int, nrow(sdf) * prop_val)
end
push!(result_dfs, last(sdf, group_n))
end
temp_df = vcat(result_dfs...)
else
if prop_val != 1.0
n = floor(Int, nrow(temp_df) * prop_val)
end
temp_df = last(temp_df, n)
end

if !isempty(grouping_cols)
temp_df = DataFrames.groupby(temp_df, grouping_cols)
end
temp_df
end
end

0 comments on commit 6086a5d

Please sign in to comment.