Skip to content

Commit

Permalink
Extra argument for imageclassifier fit
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush-1506 committed Jun 8, 2020
1 parent 6b19067 commit 1c32af1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
15 changes: 13 additions & 2 deletions src/image.jl
Expand Up @@ -39,7 +39,13 @@ function MLJModelInterface.fit(model::ImageClassifier, verbosity::Int, X_, y_)
n_output = length(levels)
n_input = size(X_[1])

chain = Flux.Chain(fit(model.builder, n_input, n_output), model.finaliser)
if scitype(first(X_)) <: GrayImage{A, B} where A where B
n_channels = 1 # 1-D image
else
n_channels = 3 # 3-D color image
end

chain = Flux.Chain(fit(model.builder, n_input, n_output, n_channels), model.finaliser)

optimiser = deepcopy(model.optimiser)

Expand Down Expand Up @@ -82,7 +88,12 @@ function MLJModelInterface.update(model::ImageClassifier,
chain = old_chain
epochs = model.epochs - old_model.epochs
else
chain = Flux.Chain(fit(model.builder, n_input, n_output),
if scitype(first(X)) <: GrayImage{A, B} where A where B
n_channels = 1 # 1-D image
else
n_channels = 3 # 3-D color image
end
chain = Flux.Chain(fit(model.builder, n_input, n_output, n_channels),
model.finaliser)
data = collate(model, X, y)
epochs = model.epochs
Expand Down
45 changes: 43 additions & 2 deletions test/image.jl
Expand Up @@ -10,9 +10,14 @@ mutable struct mnistclassifier <: MLJFlux.Builder
filters2
end

mutable struct colorimageclassifier <: MLJFlux.Builder
kernel1
kernel2
end

@testset "ImageClassifier" begin

MLJFlux.fit(model::mynn, ip, op) =
MLJFlux.fit(model::mynn, ip, op, n_channels) =
Flux.Chain(Flux.Conv(model.kernel1, 1=>2),
Flux.Conv(model.kernel2, 2=>1),
x->reshape(x, :, size(x)[end]),
Expand Down Expand Up @@ -60,7 +65,7 @@ end
return reshape(x, :, size(x)[end])
end

function MLJFlux.fit(model::mnistclassifier, ip, op)
function MLJFlux.fit(model::mnistclassifier, ip, op, n_channels)
cnn_output_size = [3,3,32]

return Chain(
Expand All @@ -81,4 +86,40 @@ end

end

@testset "ColorImages" begin
MLJFlux.fit(model::colorimageclassifier, ip, op, n_channels) =
Flux.Chain(Flux.Conv(model.kernel1, n_channels=>2),
Flux.Conv(model.kernel2, 2=>1),
x->reshape(x, :, size(x)[end]),
Flux.Dense(16, op))

builder = colorimageclassifier((2,2), (2,2))
model = MLJFlux.ImageClassifier(builder=builder, epochs=10)

# collection of color images as a 4D array in WHCN format:
raw_images = rand(6, 6, 3, 50);

images = coerce(raw_images, ColorImage);
@test scitype(images) == AbstractVector{ColorImage{6,6}}

labels = categorical(rand(1:5, 50));

fitresult, cache, report = MLJBase.fit(model, 3, images, labels)

pred = MLJBase.predict(model, fitresult, images[1:6])

model.epochs = 15
MLJBase.update(model, 3, fitresult, cache, images, labels)

pred = MLJBase.predict(model, fitresult, images[1:6])

# try with batch_size > 1:
model = MLJFlux.ImageClassifier(builder=builder, epochs=10, batch_size=2)
fitresult, cache, report = MLJBase.fit(model, 3, images, labels);

# tests update logic, etc (see test_utililites.jl):
@test basictest(MLJFlux.ImageClassifier, images, labels,
model.builder, model.optimiser, 0.95)
end

true

0 comments on commit 1c32af1

Please sign in to comment.