Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace fit by build #42

Merged
merged 1 commit into from Jun 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -192,7 +192,7 @@ mutable struct MyNetwork <: MLJFlux.Builder
n2 :: Int
end

function MLJFlux.fit(nn::MyNetwork, n_in, n_out)
function MLJFlux.build(nn::MyNetwork, n_in, n_out)
return Chain(Dense(n_in, nn.n1), Dense(nn.n1, nn.n2), Dense(nn.n2, n_out))
end
```
Expand All @@ -205,8 +205,8 @@ More generally, defining a new builder means defining a new struct
and defining a new `MLJFlux.fit` method with one of these signatures:

```julia
MLJFlux.fit(builder::MyNetwork, n_in, n_out)
MLJFlux.fit(builder::MyNetwork, n_in, n_out, n_channels) # for use with `ImageClassifier`
MLJFlux.build(builder::MyNetwork, n_in, n_out)
MLJFlux.build(builder::MyNetwork, n_in, n_out, n_channels) # for use with `ImageClassifier`
```

This method must return a `Flux.Chain` instance, `chain`, subject to the
Expand Down
4 changes: 2 additions & 2 deletions src/classifier.jl
Expand Up @@ -40,7 +40,7 @@ function MLJModelInterface.fit(model::NeuralNetworkClassifier,
n_input = Tables.schema(X).names |> length
levels = MLJModelInterface.classes(y[1])
n_output = length(levels)
chain = Flux.Chain(fit(model.builder, n_input, n_output),
chain = Flux.Chain(build(model.builder, n_input, n_output),
model.finaliser)

data = collate(model, X, y)
Expand Down Expand Up @@ -86,7 +86,7 @@ function MLJModelInterface.update(model::NeuralNetworkClassifier,
chain = old_chain
epochs = model.epochs - old_model.epochs
else
chain = Flux.Chain(fit(model.builder, n_input, n_output),
chain = Flux.Chain(build(model.builder, n_input, n_output),
model.finaliser)
data = collate(model, X, y)
epochs = model.epochs
Expand Down
4 changes: 2 additions & 2 deletions src/core.jl
Expand Up @@ -135,7 +135,7 @@ mutable struct Linear <: Builder
σ
end
Linear(; σ=Flux.relu) = Linear(σ)
fit(builder::Linear, n::Integer, m::Integer) =
build(builder::Linear, n::Integer, m::Integer) =
Flux.Chain(Flux.Dense(n, m, builder.σ))

# baby example 2:
Expand All @@ -145,7 +145,7 @@ mutable struct Short <: Builder
σ
end
Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid) = Short(n_hidden, dropout, σ)
function fit(builder::Short, n, m)
function build(builder::Short, n, m)
n_hidden =
builder.n_hidden == 0 ? round(Int, sqrt(n*m)) : builder.n_hidden
return Flux.Chain(Flux.Dense(n, n_hidden, builder.σ),
Expand Down
4 changes: 2 additions & 2 deletions src/image.jl
Expand Up @@ -45,7 +45,7 @@ function MLJModelInterface.fit(model::ImageClassifier, verbosity::Int, X_, y_)
n_channels = 3 # 3-D color image
end

chain = Flux.Chain(fit(model.builder, n_input, n_output, n_channels), model.finaliser)
chain = Flux.Chain(build(model.builder, n_input, n_output, n_channels), model.finaliser)

optimiser = deepcopy(model.optimiser)

Expand Down Expand Up @@ -93,7 +93,7 @@ function MLJModelInterface.update(model::ImageClassifier,
else
n_channels = 3 # 3-D color image
end
chain = Flux.Chain(fit(model.builder, n_input, n_output, n_channels),
chain = Flux.Chain(build(model.builder, n_input, n_output, n_channels),
model.finaliser)
data = collate(model, X, y)
epochs = model.epochs
Expand Down
4 changes: 2 additions & 2 deletions src/regressor.jl
Expand Up @@ -74,7 +74,7 @@ function MLJModelInterface.fit(model::Regressor, verbosity::Int, X, y)
end

n_output = length(target_column_names)
chain = fit(model.builder, n_input, n_output)
chain = build(model.builder, n_input, n_output)

optimiser = deepcopy(model.optimiser)

Expand Down Expand Up @@ -110,7 +110,7 @@ function MLJModelInterface.update(model::Regressor,
chain = old_chain
epochs = model.epochs - old_model.epochs
else
chain = fit(model.builder, n_input, n_output)
chain = build(model.builder, n_input, n_output)
data = collate(model, X, y)
epochs = model.epochs
end
Expand Down
2 changes: 1 addition & 1 deletion test/core.jl
Expand Up @@ -111,7 +111,7 @@ end

@testset "dropout" begin
model = MLJFlux.Short()
chain = MLJFlux.fit(model, 5, 1)
chain = MLJFlux.build(model, 5, 1)

input = rand(5,1)
# At the moment, Dropout is active:
Expand Down
4 changes: 2 additions & 2 deletions test/image.jl
Expand Up @@ -10,7 +10,7 @@ mutable struct mnistclassifier <: MLJFlux.Builder
filters2
end

MLJFlux.fit(model::mynn, ip, op, n_channels) =
MLJFlux.build(model::mynn, ip, op, n_channels) =
Flux.Chain(Flux.Conv(model.kernel1, n_channels=>2),
Flux.Conv(model.kernel2, 2=>1),
x->reshape(x, :, size(x)[end]),
Expand Down Expand Up @@ -60,7 +60,7 @@ end
return reshape(x, :, size(x)[end])
end

function MLJFlux.fit(model::mnistclassifier, ip, op, n_channels)
function MLJFlux.build(model::mnistclassifier, ip, op, n_channels)
cnn_output_size = [3,3,32]

return Chain(
Expand Down