Skip to content

Commit

Permalink
Merge c361812 into 3ffd02b
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush-1506 committed May 27, 2020
2 parents 3ffd02b + c361812 commit 27f214c
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 70 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Expand All @@ -14,7 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
CategoricalArrays = "^0.7.3"
CategoricalArrays = "^0.8.1"
Flux = "^0.8.3"
LossFunctions = "^0.5"
MLJModelInterface = "^0.2.1"
Expand All @@ -25,10 +26,11 @@ julia = "1"
[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJScientificTypes = "2e2323e0-db8b-457b-ae0d-bdfb3bc63afd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "MLJBase", "Random", "Statistics", "StatsBase", "Test"]
test = ["LinearAlgebra", "MLJBase", "MLJScientificTypes", "Random", "Statistics", "StatsBase", "Test"]
1 change: 1 addition & 0 deletions src/MLJFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using ProgressMeter
using CategoricalArrays
using Tables
using Statistics
using ColorTypes

include("core.jl")
include("regressor.jl")
Expand Down
62 changes: 58 additions & 4 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,29 +173,83 @@ end
nrows(y::AbstractVector) = length(y)

reformat(X) = reformat(X, scitype(X))

# ---------------------------------
# Reformatting tables

reformat(X, ::Type{<:Table}) = MLJModelInterface.matrix(X)'

# ---------------------------------
# Reformatting images

reformat(X, ::Type{<:GrayImage}) =
reshape(Float32.(X), size(X)..., 1)

function reformat(X, ::Type{<:AbstractVector{<:GrayImage}})
ret = zeros(Float32, size(first(X))..., 1, length(X))
for idx=1:size(ret, 4)
ret[:, :, :, idx] .= reformat(X[idx])
end
return ret
end

function reformat(X, ::Type{<:ColorImage})
ret = zeros(Float32, size(X)... , 3)
for w = 1:size(X)[1]
for h = 1:size(X)[2]
ret[w, h, :] .= Float32.([X[w, h].r, X[w, h].g, X[w, h].b])
end
end
return ret
end

function reformat(X, ::Type{<:AbstractVector{<:ColorImage}})
ret = zeros(Float32, size(first(X))..., 3, length(X))
for idx=1:size(ret, 4)
ret[:, :, :, idx] .= reformat(X[idx])
end
return ret
end

# ------------------------------------------------------------
# Reformatting vectors of "scalar" types

reformat(y, ::Type{<:AbstractVector{<:Continuous}}) = y
function reformat(y, ::Type{<:AbstractVector{<:Finite}})
levels = y |> first |> MLJModelInterface.classes
return hcat([Flux.onehot(ele, levels) for ele in y]...,)
end

get(Xmatrix::AbstractMatrix, b) = Xmatrix[:, b]
get(y::AbstractVector, b) = y[b]
function reformat(y, ::Type{<:AbstractVector{<:Count}})
levels = y |> first |> MLJModelInterface.classes
return hcat([Flux.onehot(ele, levels) for ele in y]...,)
end

function reformat(y, ::Type{<:AbstractVector{<:Multiclass}})
levels = y |> first |> MLJModelInterface.classes
return hcat([Flux.onehot(ele, levels) for ele in y]...,)
end

_get(Xmatrix::AbstractMatrix, b) = Xmatrix[:, b]
_get(y::AbstractVector, b) = y[b]

# each element in X is a single image of size (w, h, c)
_get(X::AbstractArray{<:Any, 4}, b) = X[:, :, :, b]


"""
collate(model, X, y)
Return the Flux-friendly data object required by `MLJFlux.fit!`, given
input `X` and target `y` in the form required by
`MLJModelInterface.input_scitype(X)` and
`MLJModelInterface.target_sictype(y)`. (The batch size used is given
`MLJModelInterface.target_scitype(y)`. (The batch size used is given
by `model.batch_size`.)
"""
function collate(model, X, y)
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
Xmatrix = reformat(X)
ymatrix = reformat(y)
return [(get(Xmatrix, b), get(ymatrix, b)) for b in row_batches]
return [(_get(Xmatrix, b), _get(ymatrix, b)) for b in row_batches]
end
115 changes: 52 additions & 63 deletions src/image.jl
Original file line number Diff line number Diff line change
@@ -1,120 +1,109 @@
mutable struct ImageClassifier{B<:MLJFlux.Builder,O,L} <: MLJModelInterface.Probabilistic
mutable struct ImageClassifier{B<:Builder,F,O,L} <: MLJModelInterface.Probabilistic
builder::B
finaliser::F
optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl
loss::L # can be called as in `loss(yhat, y)`
n::Int # number of epochs
epochs::Int # number of epochs
batch_size::Int # size of a batch
lambda::Float64 # regularization strength
alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1)
optimiser_changes_trigger_retraining::Bool
end

ImageClassifier(; builder::B = Linear()
ImageClassifier(; builder::B = Short()
, finaliser::F = Flux.softmax
, optimiser::O = Flux.Optimise.ADAM()
, loss::L = Flux.crossentropy
, n = 10
, epochs = 10
, batch_size = 1
, lambda = 0
, alpha = 0
, optimiser_changes_trigger_retraining = false) where {B,O,L} =
ImageClassifier{B,O,L}(builder
, optimiser_changes_trigger_retraining = false
) where {B,F,O,L} =
ImageClassifier{B,F,O,L}(builder
, finaliser
, optimiser
, loss
, n
, epochs
, batch_size
, lambda
, alpha
, optimiser_changes_trigger_retraining)


function make_minibatch(X, Y, idxs)
X_batch = Array{Float32}(undef, size(X[1])..., 1, length(idxs))
for i in 1:length(idxs)
X_batch[:, :, :, i] = Float32.(X[idxs[i]])
end
Y_batch = Flux.onehotbatch(Y[idxs], 0:9)
return (X_batch, Y_batch)
end

# This will not only group into batches, but also convert to Flux
# compatible tensors
function collate(model::ImageClassifier, X, Y)
row_batches = Base.iterators.partition(1:length(X), model.batch_size)
return [make_minibatch(X, Y, i) for i in row_batches]
end
, optimiser_changes_trigger_retraining
)

function MLJModelInterface.fit(model::ImageClassifier, verbosity::Int, X_, y_)

data = collate(model, X_, y_, model.batch_size)

target_is_multivariate = y_ isa AbstractVector{<:Tuple}
data = collate(model, X_, y_)



a_target_element = first(y_)
levels = MLJModelInterface.classes(a_target_element)
levels = y_ |> first |> MLJModelInterface.classes
n_output = length(levels)
n_input = size(X_[1])
chain = fit(model.builder,n_input, n_output)

chain = Flux.Chain(fit(model.builder, n_input, n_output), model.finaliser)

optimiser = deepcopy(model.optimiser)

chain, history = fit!(chain, optimiser, model.loss,
model.n, model.lambda, model.alpha,
model.epochs, model.lambda, model.alpha,
verbosity, data)

cache = deepcopy(model), data, history
fitresult = (chain, target_is_multivariate, levels)
cache = deepcopy(model), data, history, n_input, n_output
fitresult = (chain, levels)

report = (training_losses=[loss.data for loss in history])

return fitresult, cache, report
end

# Xnew is an array of 3D values
function MLJModelInterface.predict(model::ImageClassifier, fitresult, Xnew)
chain = fitresult[1]
ismulti = fitresult[2]
levels = fitresult[3]
return [MLJModelInterface.UnivariateFinite(levels, Flux.softmax(chain(Float64.(Xnew[i])).data)) for i in 1:length(Xnew)]

chain, levels = fitresult
X = reformat(Xnew)
[MLJModelInterface.UnivariateFinite(
levels,
vec(map(x -> x.data,
chain(X[:,:,:,idx:idx]))))
for idx=1:length(Xnew)]
end

function MLJModelInterface.update(model::ImageClassifier, verbosity::Int, old_fitresult, old_cache, X, y)
function MLJModelInterface.update(model::ImageClassifier,
verbosity::Int,
old_fitresult,
old_cache,
X,
y)

old_model, data, old_history = old_cache
old_chain, target_is_multivariate = old_fitresult
levels = old_fitresult[3]
old_model, data, old_history, n_input, n_output = old_cache
old_chain, levels = old_fitresult

keep_chain = model.n >= old_model.n &&
model.loss == old_model.loss &&
model.batch_size == old_model.batch_size &&
model.lambda == old_model.lambda &&
model.alpha == old_model.alpha &&
model.builder == old_model.builder &&
(!model.optimiser_changes_trigger_retraining ||
model.optimiser == old_model.optimiser)
optimiser_flag = model.optimiser_changes_trigger_retraining &&
model.optimiser != old_model.optimiser

keep_chain = !optimiser_flag && model.epochs >= old_model.epochs &&
MLJModelInterface.is_same_except(model, old_model, :optimiser, :epochs)

if keep_chain
chain = old_chain
epochs = model.n - old_model.n
epochs = model.epochs - old_model.epochs
else
n_input = Tables.schema(X).names |> length
n_output = length(levels)
chain = fit(model.builder, n_input, n_output)
epochs = model.n
chain = Flux.Chain(fit(model.builder, n_input, n_output),
model.finaliser)
data = collate(model, X, y)
epochs = model.epochs
end

optimiser = deepcopy(model.optimiser)

chain, history = fit!(chain, optimiser, model.loss, epochs,
model.batch_size, model.lambda, model.alpha,
model.lambda, model.alpha,
verbosity, data)
if keep_chain
history = vcat(old_history, history)
end
fitresult = (chain, target_is_multivariate, levels)
cache = (deepcopy(model), data, history)
report = (training_losses=[loss.data for loss in history])

fitresult = (chain, levels)
cache = (deepcopy(model), data, history, n_input, n_output)
report = (training_losses=[loss.data for loss in history], )

return fitresult, cache, report

Expand All @@ -124,5 +113,5 @@ MLJModelInterface.metadata_model(ImageClassifier,
input=AbstractVector{<:MLJModelInterface.GrayImage},
target=AbstractVector{<:Multiclass},
path="MLJFlux.ImageClassifier",
descr = descr="A neural network model for making probabilistic predictions of a `GreyImage` target,
descr="A neural network model for making probabilistic predictions of a `GrayImage` target,
given a table of `Continuous` features. ")
20 changes: 20 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ end
(Xmatrix'[:,4:6], ymatrix'[:,4:6]),
(Xmatrix'[:,7:9], ymatrix'[:,7:9]),
(Xmatrix'[:,10:10], ymatrix'[:,10:10])]

# ImageClassifier
Xmatrix = coerce(rand(6, 6, 1, 10), GrayImage)
y = categorical([:a, :b, :a, :a, :b, :a, :a, :a, :b, :a])
model = MLJFlux.ImageClassifier(batch_size=2)

data = MLJFlux.collate(model, Xmatrix, y)
@test first.(data) ==
[Float32.(cat(Xmatrix[1], Xmatrix[2], dims=4)),
Float32.(cat(Xmatrix[3], Xmatrix[4], dims=4)),
Float32.(cat(Xmatrix[5], Xmatrix[6], dims=4)),
Float32.(cat(Xmatrix[7], Xmatrix[8], dims=4)),
Float32.(cat(Xmatrix[9], Xmatrix[10], dims=4)),
]

expected_y = [[1 0;0 1], [1 1;0 0], [0 1; 1 0], [1 1;0 0], [0 1; 1 0]]
for i=1:5
@test Int.(last.(data)[i]) == expected_y[i]
end

end

@testset "fit!" begin
Expand Down

0 comments on commit 27f214c

Please sign in to comment.