Skip to content

Commit

Permalink
Support more complex comprehensions (#302)
Browse files Browse the repository at this point in the history
* better comprehensions
* allow Distributed.@distributed
  • Loading branch information
MarcMush committed Feb 28, 2024
1 parent 5a74c7f commit 7e2bbca
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 276 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ while true
next!(prog)
rand(1:2*10^8) == 1 && break
end
ProgressMeter.finish!(prog)
finish!(prog)
```

By default, `finish!` changes the spinner to a ``, but you can
Expand Down Expand Up @@ -421,7 +421,7 @@ p = Progress(n; enabled = false)
for iter in 1:10
x *= 2
sleep(0.5)
ProgressMeter.next!(p)
next!(p)
end
```

Expand Down
269 changes: 146 additions & 123 deletions src/ProgressMeter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -782,10 +782,6 @@ function showprogressdistributed(args...)
progressargs = args[1:end-1]
expr = Base.remove_linenums!(args[end])

if expr.head != :macrocall || expr.args[1] != Symbol("@distributed")
throw(ArgumentError("malformed @showprogress @distributed expression"))
end

distargs = filter(x -> !(x isa LineNumberNode), expr.args[2:end])
na = length(distargs)
if na == 1
Expand Down Expand Up @@ -846,7 +842,7 @@ function showprogressthreads(args...)
iters = loop.args[1].args[end]

p = gensym()
push!(loop.args[end].args, :(ProgressMeter.next!($p)))
push!(loop.args[end].args, :(next!($p)))

quote
$(esc(p)) = Progress(
Expand Down Expand Up @@ -890,146 +886,173 @@ function showprogress(args...)
end
progressargs = args[1:end-1]
expr = args[end]
if expr.head == :macrocall && expr.args[1] == Symbol("@distributed")
return showprogressdistributed(args...)
end
if expr.head == :macrocall && expr.args[1] == :(Threads.var"@threads")
return showprogressthreads(args...)

if !isa(expr, Expr)
throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, or a map-like function; got $expr"))
end
orig = expr = copy(expr)
if expr.args[1] == :|> # e.g. map(x->x^2) |> sum

if expr.head == :call && expr.args[1] == :|>
# e.g. map(x->x^2) |> sum
expr.args[2] = showprogress(progressargs..., expr.args[2])
return expr

elseif expr.head in (:for, :comprehension, :typed_comprehension)
return showprogress_loop(expr, progressargs)

elseif expr.head == :call
return showprogress_map(expr, progressargs)

elseif expr.head == :do && expr.args[1].head == :call
return showprogress_map(expr, progressargs)

elseif expr.head == :macrocall
macroname = expr.args[1]

if macroname in (Symbol("@distributed"), :(Distributed.@distributed).args[1])
# can be changed to `:(Distributed.var"@distributed")` if support for pre-1.3 is dropped
return showprogressdistributed(args...)

elseif macroname in (Symbol("@threads"), :(Threads.@threads).args[1])
return showprogressthreads(args...)
end
end

throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, or a map-like function; got $expr"))
end

function showprogress_map(expr, progressargs)
metersym = gensym("meter")
kind = :invalid # :invalid, :loop, or :map

if isa(expr, Expr)
if expr.head == :for
outerassignidx = 1
loopbodyidx = lastindex(expr.args)
kind = :loop
elseif expr.head == :comprehension
outerassignidx = lastindex(expr.args)
loopbodyidx = 1
kind = :loop
elseif expr.head == :typed_comprehension
outerassignidx = lastindex(expr.args)
loopbodyidx = 2
kind = :loop
elseif expr.head == :call
kind = :map
elseif expr.head == :do
call = expr.args[1]
if call.head == :call
kind = :map
end
end

# isolate call to map
if expr.head == :do
call = expr.args[1]
else
call = expr
end

if kind == :invalid
throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, or a map-like function; got $expr"))
elseif kind == :loop
# As of julia 0.5, a comprehension's "loop" is actually one level deeper in the syntax tree.
if expr.head !== :for
@assert length(expr.args) == loopbodyidx
expr = expr.args[outerassignidx] = copy(expr.args[outerassignidx])
@assert expr.head === :generator
outerassignidx = lastindex(expr.args)
loopbodyidx = 1
end
# get args to map to determine progress length
mapargs = collect(Any, filter(call.args[2:end]) do a
return isa(a, Symbol) || isa(a, Number) || !(a.head in (:kw, :parameters))
end)
if expr.head == :do
insert!(mapargs, 1, identity) # to make args for ncalls line up
end

# Transform the first loop assignment
loopassign = expr.args[outerassignidx] = copy(expr.args[outerassignidx])
if loopassign.head === :block # this will happen in a for loop with multiple iteration variables
for i in 2:length(loopassign.args)
loopassign.args[i] = esc(loopassign.args[i])
end
loopassign = loopassign.args[1] = copy(loopassign.args[1])
end
@assert loopassign.head === :(=)
@assert length(loopassign.args) == 2
obj = loopassign.args[2]
loopassign.args[1] = esc(loopassign.args[1])
loopassign.args[2] = :(ProgressWrapper(iterable, $(esc(metersym))))

# Transform the loop body break and return statements
if expr.head === :for
expr.args[loopbodyidx] = showprogress_process_expr(expr.args[loopbodyidx], metersym)
end
# change call to progress_map
mapfun = call.args[1]
call.args[1] = :progress_map

# Escape all args except the loop assignment, which was already appropriately escaped.
for i in 1:length(expr.args)
if i != outerassignidx
expr.args[i] = esc(expr.args[i])
end
end
if orig !== expr
# We have additional escaping to do; this will occur for comprehensions with julia 0.5 or later.
for i in 1:length(orig.args)-1
orig.args[i] = esc(orig.args[i])
end
end
# escape args as appropriate
for i in 2:length(call.args)
call.args[i] = esc(call.args[i])
end
if expr.head == :do
expr.args[2] = esc(expr.args[2])
end

setup = quote
iterable = $(esc(obj))
$(esc(metersym)) = Progress(length(iterable), $(showprogress_process_args(progressargs)...))
end
# create appropriate Progress expression
lenex = :(ncalls($(esc(mapfun)), $(esc.(mapargs)...)))
progex = :(Progress($lenex, $(showprogress_process_args(progressargs)...)))

if expr.head === :for
return quote
$setup
$expr
end
else
# We're dealing with a comprehension
return quote
begin
$setup
rv = $orig
next!($(esc(metersym)))
rv
end
end
# insert progress and mapfun kwargs
push!(call.args, Expr(:kw, :progress, progex))
push!(call.args, Expr(:kw, :mapfun, esc(mapfun)))

return expr
end

function showprogress_loop(expr, progressargs)
metersym = gensym("meter")
orig = expr = copy(expr)

if expr.head == :for
outerassignidx = 1
loopbodyidx = lastindex(expr.args)
elseif expr.head == :comprehension
outerassignidx = lastindex(expr.args)
loopbodyidx = 1
elseif expr.head == :typed_comprehension
outerassignidx = lastindex(expr.args)
loopbodyidx = 2
end
# As of julia 0.5, a comprehension's "loop" is actually one level deeper in the syntax tree.
if expr.head !== :for
@assert length(expr.args) == loopbodyidx
expr = expr.args[outerassignidx] = copy(expr.args[outerassignidx])
if expr.head == :flatten
# e.g. [x for x in 1:10 for y in 1:x]
expr = expr.args[1] = copy(expr.args[1])
end
else # kind == :map
@assert expr.head === :generator
outerassignidx = lastindex(expr.args)
loopbodyidx = 1
end

# isolate call to map
if expr.head == :do
call = expr.args[1]
else
call = expr
# Transform the first loop assignment
loopassign = expr.args[outerassignidx] = copy(expr.args[outerassignidx])

if loopassign.head === :filter
# e.g. [x for x=1:10, y=1:10 if x>y]
# y will be wrapped in ProgressWrapper
for i in 1:length(loopassign.args)-1
loopassign.args[i] = esc(loopassign.args[i])
end
loopassign = loopassign.args[end] = copy(loopassign.args[end])
end

# get args to map to determine progress length
mapargs = collect(Any, filter(call.args[2:end]) do a
return isa(a, Symbol) || isa(a, Number) || !(a.head in (:kw, :parameters))
end)
if expr.head == :do
insert!(mapargs, 1, identity) # to make args for ncalls line up
if loopassign.head === :block
# e.g. for x=1:10, y=1:x end
# x will be wrapped in ProgressWrapper
for i in 2:length(loopassign.args)
loopassign.args[i] = esc(loopassign.args[i])
end
loopassign = loopassign.args[1] = copy(loopassign.args[1])
end

@assert loopassign.head === :(=)
@assert length(loopassign.args) == 2
obj = loopassign.args[2]
loopassign.args[1] = esc(loopassign.args[1])
loopassign.args[2] = :(ProgressWrapper(iterable, $(esc(metersym))))

# change call to progress_map
mapfun = call.args[1]
call.args[1] = :progress_map
# Transform the loop body break and return statements
if expr.head === :for
expr.args[loopbodyidx] = showprogress_process_expr(expr.args[loopbodyidx], metersym)
end

# escape args as appropriate
for i in 2:length(call.args)
call.args[i] = esc(call.args[i])
# Escape all args except the loop assignment, which was already appropriately escaped.
for i in 1:length(expr.args)
if i != outerassignidx
expr.args[i] = esc(expr.args[i])
end
if expr.head == :do
expr.args[2] = esc(expr.args[2])
end
if orig !== expr
# We have additional escaping to do; this will occur for comprehensions with julia 0.5 or later.
for i in 1:length(orig.args)-1
orig.args[i] = esc(orig.args[i])
end
end

# create appropriate Progress expression
lenex = :(ncalls($(esc(mapfun)), $(esc.(mapargs)...)))
progex = :(Progress($lenex, $(showprogress_process_args(progressargs)...)))

# insert progress and mapfun kwargs
push!(call.args, Expr(:kw, :progress, progex))
push!(call.args, Expr(:kw, :mapfun, esc(mapfun)))
setup = quote
iterable = $(esc(obj))
$(esc(metersym)) = Progress(length(iterable), $(showprogress_process_args(progressargs)...))
end

return expr
if expr.head === :for
return quote
$setup
$expr
end
else
# We're dealing with a comprehension
return quote
begin
$setup
rv = $orig
finish!($(esc(metersym)))
rv
end
end
end
end

Expand Down
4 changes: 2 additions & 2 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
@test ProgressMeter.durationstring(60*60*24*10 - 0.1) == "9 days, 23:59:59"
@test ProgressMeter.durationstring(60*60*24*10) == "10.00 days"

@test ProgressMeter.Progress(5, desc="Progress:", offset=Int16(5)).offset == 5
@test ProgressMeter.ProgressThresh(0.2, desc="Progress:", offset=Int16(5)).offset == 5
@test Progress(5, desc="Progress:", offset=Int16(5)).offset == 5
@test ProgressThresh(0.2, desc="Progress:", offset=Int16(5)).offset == 5

# test speed string formatting
for ns in [1, 9, 10, 99, 100, 999, 1_000, 9_999, 10_000, 99_000, 100_000, 999_999, 1_000_000, 9_000_000, 10_000_000, 99_999_000, 1_234_567_890, 1_234_567_890 * 10, 1_234_567_890 * 100, 1_234_567_890 * 1_000, 1_234_567_890 * 10_000, 1_234_567_890 * 100_000, 1_234_567_890 * 1_000_000, 1_234_567_890 * 10_000_000]
Expand Down

0 comments on commit 7e2bbca

Please sign in to comment.