# Introduction

Experimental data containers that enable different training, validation, and test `DataSet`s to be combined in a `DataLoaders` wrapper for iteration.

# DataSet

Simple data set structure to hold training, validation, and test data.

In [1]:
abstract type DataSet end

In [2]:
struct TrainDataSet<:DataSet
    x
    y
end

In [3]:
struct ValidDataSet<:DataSet
    x
    y
end

In [4]:
struct TestDataSet<:DataSet
    x
    y
end

In [5]:
subtypes(DataSet)

3-element Array{Any,1}:
 TestDataSet 
 TrainDataSet
 ValidDataSet

# DataLoaders

In [6]:
mutable struct DataLoaders
    ds #::DataSet does not work?
    # use additional flags, e.g., turn on/off specific datasets?
    DataLoaders(ds...) = new(ds)
end

In [7]:
# From https://github.com/denizyuret/Knet.jl/blob/master/tutorial/25.iterators.ipynb
function Base.iterate(dls::DataLoaders, s...)
    next = iterate(dls.ds, s...)
    next === nothing && return nothing
    (d,state) = next
    return d, state
end

Base.length(dls::DataLoaders) = length(dls.ds) # collect needs this

In [8]:
#(dl::DataLoaders)(x) = (for d in dl.datasets; summary(d); end)

In [9]:
#(dl::DataLoaders) = (d for d in dl.datasets)

# Test

## Creation

In [10]:
fold1 = TrainDataSet([1.,2.],[3.,4.])
fold2 = TrainDataSet([5.,6.],[7.,8.])
fold3 = TrainDataSet([9.,10.],[11.,12.])
valid = ValidDataSet([13.,14.],[15.,16.])
test = TestDataSet([17.,18.],[19.,20.])

TestDataSet([17.0, 18.0], [19.0, 20.0])

In [11]:
dls = DataLoaders(fold1, fold2, fold3, valid, test)

DataLoaders((TrainDataSet([1.0, 2.0], [3.0, 4.0]), TrainDataSet([5.0, 6.0], [7.0, 8.0]), TrainDataSet([9.0, 10.0], [11.0, 12.0]), ValidDataSet([13.0, 14.0], [15.0, 16.0]), TestDataSet([17.0, 18.0], [19.0, 20.0])))

In [12]:
dls

DataLoaders((TrainDataSet([1.0, 2.0], [3.0, 4.0]), TrainDataSet([5.0, 6.0], [7.0, 8.0]), TrainDataSet([9.0, 10.0], [11.0, 12.0]), ValidDataSet([13.0, 14.0], [15.0, 16.0]), TestDataSet([17.0, 18.0], [19.0, 20.0])))

In [13]:
dls.ds

(TrainDataSet([1.0, 2.0], [3.0, 4.0]), TrainDataSet([5.0, 6.0], [7.0, 8.0]), TrainDataSet([9.0, 10.0], [11.0, 12.0]), ValidDataSet([13.0, 14.0], [15.0, 16.0]), TestDataSet([17.0, 18.0], [19.0, 20.0]))

In [14]:
length(dls), length(dls.ds)

(5, 5)

In [15]:
collect(dls)

5-element Array{Any,1}:
 TrainDataSet([1.0, 2.0], [3.0, 4.0])    
 TrainDataSet([5.0, 6.0], [7.0, 8.0])    
 TrainDataSet([9.0, 10.0], [11.0, 12.0]) 
 ValidDataSet([13.0, 14.0], [15.0, 16.0])
 TestDataSet([17.0, 18.0], [19.0, 20.0]) 

## Iteration

In [16]:
@which iterate(dls)

In [17]:
for d in dls
    println(d)
end

TrainDataSet([1.0, 2.0], [3.0, 4.0])
TrainDataSet([5.0, 6.0], [7.0, 8.0])
TrainDataSet([9.0, 10.0], [11.0, 12.0])
ValidDataSet([13.0, 14.0], [15.0, 16.0])
TestDataSet([17.0, 18.0], [19.0, 20.0])


In [18]:
[typeof(d) for d in dls]

5-element Array{DataType,1}:
 TrainDataSet
 TrainDataSet
 TrainDataSet
 ValidDataSet
 TestDataSet 

In [21]:
[typeof(d) for d in dls if isa(d,TrainDataSet)]

3-element Array{DataType,1}:
 TrainDataSet
 TrainDataSet
 TrainDataSet

In [19]:
function fit(d::TrainDataSet)
    println("Train!")
end

function fit(d::ValidDataSet)
    println("Valid!")
end

function fit(d::TestDataSet)
    println("Test!")
end

fit (generic function with 3 methods)

In [20]:
[fit(d) for d in dls];

Train!
Train!
Train!
Valid!
Test!
