# Iterators
(c) Deniz Yuret, 2019

* Objective: Learning how to construct and use Julia iterators.
* Reading: [Interfaces](https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-iteration-1),  [Collections](https://docs.julialang.org/en/v1/base/collections/#lib-collections-iteration-1), [Iteration Utilities](https://docs.julialang.org/en/v1/base/iterators) and [Generator expressions](https://docs.julialang.org/en/v1/manual/arrays/#Generator-Expressions-1) in the Julia manual.
* Prerequisites: [minibatch, Data](https://github.com/denizyuret/Knet.jl/blob/master/src/data.jl) from the [MNIST notebook](20.mnist.ipynb)
* New functions: 
[first](https://docs.julialang.org/en/v1/base/collections/#Base.first), 
[collect](https://docs.julialang.org/en/v1/base/collections/#Base.collect-Tuple{Any}), 
[take](https://docs.julialang.org/en/v1/base/iterators/#Base.Iterators.take), 
[drop](https://docs.julialang.org/en/v1/base/iterators/#Base.Iterators.drop), 
[cycle](https://docs.julialang.org/en/v1/base/iterators/#Base.Iterators.cycle),
[ncycle](https://juliacollections.github.io/IterTools.jl/stable/#ncycle(xs,-n)-1),
[takenth](https://juliacollections.github.io/IterTools.jl/stable/#takenth(xs,-n)-1),
[takewhile](https://juliacollections.github.io/IterTools.jl/stable/#takewhile(cond,-xs)-1),
[Stateful](https://docs.julialang.org/en/v1/base/iterators/#Base.Iterators.Stateful), 
[iterate](https://docs.julialang.org/en/v1/base/collections/#lib-collections-iteration-1)

The `minibatch` function returns a `Knet.Data` object implemented as a Julia iterator that generates (x,y) minibatches. Iterators are lazy objects that only generate their next element when asked. This has the advantage of not wasting time and memory trying to create and store all the elements at once. We can even have infinite iterators! The training algorithms in Knet are also implemented as iterators so that:
1. We can monitor and report the training loss
2. We can take snapshots of the model during training
3. We can pause/terminate training when necessary

Here are some things Julia can do with iterators:

In [1]:
# Set display width, load packages, import symbols
ENV["COLUMNS"]=72
using Base.Iterators: take, drop, cycle, Stateful
using IterTools: ncycle, takenth, takewhile
using MLDatasets: MNIST
using Knet

In [2]:
# Load MNIST data as an iterator of (x,y) minibatches
xtst,ytst = MNIST.testdata(Float32)
dtst = minibatch(xtst, ytst, 100)

100-element Knet.Train20.Data{Tuple{Array{Float32,3},Array{Int64,1}}}

In [3]:
# We can peek at the first element using first()
summary.(first(dtst))

("28×28×100 Array{Float32,3}", "100-element Array{Int64,1}")

In [4]:
# Iterators can be used in for loops
# Let's count the elements in dtst:
n = 0
for (x,y) in dtst; global n += 1; end
@show n;

n = 100


In [5]:
# Iterators can be converted to arrays using `collect` 
# (don't do this unless necessary, it just wastes memory. Use a for loop instead)
collect(dtst) |> summary

"100-element Array{Tuple{Array{Float32,3},Array{Int64,1}},1}"

In [6]:
# We can generate an iterator for multiple epochs using `ncycle`
# (an epoch is a single pass over the dataset)
n = 0
for (x,y) in ncycle(dtst,5); global n += 1; end
@show n;

n = 500


In [7]:
# We can generate partial epochs using `take` which takes the first n elements
n = 0
for (x,y) in take(dtst,20); global n += 1; end
@show n;

n = 20


In [8]:
# We can also generate partial epochs using `drop` which drops the first n elements
n = 0
for (x,y) in drop(dtst,20); global n += 1; end
@show n;

n = 80


In [9]:
# We can repeat forever using `cycle`
# You do not want to collect a cycle or run a for loop without break! 
n = 0
for (x,y) in cycle(dtst); (global n += 1) > 1234 && break; end
@show n;

n = 1235


In [10]:
# We can repeat until a condition is met using `takewhile`
# This is useful to train until convergence
n = 0
for (x,y) in takewhile(x->(n<56), dtst); global n += 1; end
@show n;

n = 56


In [11]:
# We can take every nth element using `takenth`
# This is useful to report progress every nth iteration
n = 0
for (x,y) in takenth(dtst,6); global n += 1; end
@show n;

n = 16


In [12]:
# We can construct new iterators using [Generator expressions](https://docs.julialang.org/en/v1/manual/arrays/#Generator-Expressions-1)
# The following example constructs an iterator over the x norms in a dataset:
xnorm(data) = (sum(abs2,x) for (x,y) in data)
collect(xnorm(dtst))'

1×100 LinearAlgebra.Adjoint{Float32,Array{Float32,1}}:
 7990.35  7842.33  8162.68  7692.77  …  8494.0  7361.33  8643.01

In [13]:
# Every iterator implements the `iterate` function which returns
# the next element and state (or nothing if no elements left).
# Here is how the for loop for dtst is implemented:
n = 0; next = iterate(dtst)
while next != nothing
    ((_x,_y), state) = next
    global n += 1
    global next = iterate(dtst,state)
end
@show n;

n = 100


In [14]:
# You can define your own iterator by declaring a new type and overriding the `iterate` method.
# Here is another way to define an iterator over the x norms in a dataset:
struct Xnorm; itr; end

function Base.iterate(f::Xnorm, s...)
    next = iterate(f.itr, s...)
    next === nothing && return nothing
    ((x,y),state) = next
    return sum(abs2,x), state
end

Base.length(f::Xnorm) = length(f.itr) # collect needs this

collect(Xnorm(dtst))'

1×100 LinearAlgebra.Adjoint{Any,Array{Any,1}}:
 7990.35  7842.33  8162.68  7692.77  …  8494.0  7361.33  8643.01

In [15]:
# We can make an iterator `Stateful` so it remembers where it left off.
# (by default iterators start from the beginning)
dtst1 = dtst            # dtst1 will start from beginning every time
dtst2 = Stateful(dtst)  # dtst2 will remember where we left off
for (x,y) in dtst1; println(Int.(y[1:5])); break; end
for (x,y) in dtst1; println(Int.(y[1:5])); break; end
for (x,y) in dtst2; println(Int.(y[1:5])); break; end
for (x,y) in dtst2; println(Int.(y[1:5])); break; end

[7, 2, 1, 0, 4]
[7, 2, 1, 0, 4]
[7, 2, 1, 0, 4]
[6, 0, 5, 4, 9]


In [16]:
# We can shuffle instances at every epoch using the keyword argument `shuffle=true`
# (by default elements are generated in the same order)
dtst1 = minibatch(xtst,ytst,100)              # dtst1 iterates in the same order
dtst2 = minibatch(xtst,ytst,100,shuffle=true) # dtst2 shuffles each time
for (x,y) in dtst1; println(Int.(y[1:5])); break; end
for (x,y) in dtst1; println(Int.(y[1:5])); break; end
for (x,y) in dtst2; println(Int.(y[1:5])); break; end
for (x,y) in dtst2; println(Int.(y[1:5])); break; end

[7, 2, 1, 0, 4]
[7, 2, 1, 0, 4]
[3, 3, 8, 0, 4]
[0, 1, 4, 2, 3]
