Skip to content

Commit

Permalink
Merge ed96311 into 7c3875d
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush-1506 committed Mar 31, 2020
2 parents 7c3875d + ed96311 commit 8d4773f
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Expand Up @@ -13,7 +13,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
CategoricalArrays = "^0.7.3"
Flux = "^0.8.3"
Flux = "^0.10.3"
LossFunctions = "^0.5"
MLJModelInterface = "^0.2"
ProgressMeter = "^1.1"
Expand Down
2 changes: 1 addition & 1 deletion src/classifier.jl
Expand Up @@ -76,7 +76,7 @@ function MLJModelInterface.predict(model::NeuralNetworkClassifier, fitresult, Xn

Xnew_ = MLJModelInterface.matrix(Xnew_)

return [MLJModelInterface.UnivariateFinite(MLJModelInterface.classes(levels), map(x->x.data, Flux.softmax(chain(Xnew_[i, :]))) |> vec) for i in 1:size(Xnew_, 1)]
return [MLJModelInterface.UnivariateFinite(MLJModelInterface.classes(levels), Flux.softmax(chain(Xnew_[i, :])) |> vec) for i in 1:size(Xnew_, 1)]

end

Expand Down
2 changes: 1 addition & 1 deletion src/regressor.jl
Expand Up @@ -114,7 +114,7 @@ function MLJModelInterface.predict(model::Union{NeuralNetworkRegressor, Multivar
Xnew_ = MLJModelInterface.matrix(Xnew_)

if ismulti
ypred = [map(x->x.data, chain(values.(Xnew_[i, :]))) for i in 1:size(Xnew_, 1)]
ypred = [chain(values.(Xnew_[i, :])) for i in 1:size(Xnew_, 1)]
return MLJModelInterface.table(reduce(hcat, y for y in ypred)', names=target_columns)
else
return [chain(values.(Xnew_[i, :]))[1] for i in 1:size(Xnew_, 1)]
Expand Down

0 comments on commit 8d4773f

Please sign in to comment.