Skip to content

Commit

Permalink
Merge pull request #343 from alan-turing-institute/small-stuff
Browse files Browse the repository at this point in the history
Small stuff
  • Loading branch information
ablaom committed Nov 19, 2019
2 parents 158ec65 + 78d9469 commit 48f2335
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/src/common_mlj_workflows.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Loading a built-in supervised dataset:

```@example workflows
X, y = @load_iris;
first(X, 4)
selectrows(X, 1:4) # selectrows works for any Tables.jl table
```

```@example workflows
Expand Down
20 changes: 20 additions & 0 deletions docs/src/machines.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,26 @@ the same fields listed above. See [Composing
Models](composing_models.md) for more on this advanced feature.


### Inspecting machines

There are two methods for inspecting the outcomes of training in
MLJ. To obtain a named-tuple describing the learned parameters, in a
user-friendly way if possible, use `fitted_params(mach)`. All other
training-related outcomes are inspected with `report(mach)`.

```@example machines
X, y = @load_iris
pca = @load PCA
mach = machine(pca, X)
fit!(mach)
```

```@repl machines
fitted_params(mach)
report(mach)
```


### API Reference

```@docs
Expand Down
35 changes: 31 additions & 4 deletions src/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,46 @@ export nrows, nfeatures, color_off, color_on,
Found, Continuous, Finite, Infinite,
OrderedFactor, Unknown,
Count, Multiclass, Binary, Scientific,
scitype, scitype_union, schema, scitypes,
scitype, scitype_union, schema, scitypes, autotype,
target_scitype, input_scitype, output_scitype,
predict, predict_mean, predict_median, predict_mode,
transform, inverse_transform, se, evaluate, fitted_params,
@constant, @more, HANDLE_GIVEN_ID, UnivariateFinite,
classes,
partition, unpack,
mav, mae, rms, rmsl, rmslp1, rmsp, l1, l2,
misclassification_rate, cross_entropy,
default_measure,
default_measure, measures,
@load_boston, @load_ames, @load_iris, @load_reduced_ames,
@load_crabs

# measures to be re-exported from MLJBase:
export mav, mae, rms, rmsl, rmslp1, rmsp, l1, l2
# -- confmat (measures/confusion_matrix)
export confusion_matrix, confmat
# -- finite (measures/finite)
export cross_entropy, BrierScore,
misclassification_rate, mcr, accuracy,
balanced_accuracy, bacc, bac,
matthews_correlation, mcc
# -- -- binary // order independent
export auc, roc_curve, roc
# -- -- binary // order dependent
export TruePositive, TrueNegative, FalsePositive, FalseNegative,
TruePositiveRate, TrueNegativeRate, FalsePositiveRate, FalseNegativeRate,
FalseDiscoveryRate, Precision, NPV, FScore,
# standard synonyms
TPR, TNR, FPR, FNR,
FDR, PPV,
Recall, Specificity, BACC,
# defaults and their synonyms
truepositive, truenegative, falsepositive, falsenegative,
truepositive_rate, truenegative_rate, falsepositive_rate,
falsenegative_rate, negativepredicitive_value,
positivepredictive_value,
tp, tn, fp, fn, tpr, tnr, fpr, fnr,
falsediscovery_rate, fdr, npv, ppv,
recall, sensitivity, hit_rate, miss_rate,
specificity, selectivity, f1score, f1, fallout

# re-export from MLJModels:
export models, localmodels, @load, load, info,
ConstantRegressor, ConstantClassifier, # builtins/Constant.jl
Expand Down
42 changes: 27 additions & 15 deletions src/tuning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ If `measure` supports sample weights (`MLJ.supports_weights(measure)
In the case of two-parameter tuning, a Plots.jl plot of performance
estimates is returned by `plot(mach)` or `heatmap(mach)`.
Once a tuning machine `mach` has bee trained as above, one can access
the learned parameters of the best model, using
`fitted_params(mach).best_fitted_params`. Similarly, the report of
training the best model is accessed via `report(mach).best_report`.
"""
function TunedModel(;model=nothing,
tuning=Grid(),
Expand Down Expand Up @@ -291,26 +296,25 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y
# TODO: maybe avoid using machines here and use model fit/predict?
fitresult = machine(best_model, X, y)
fit!(fitresult, verbosity=verbosity-1)
best_report = fitresult.report
else
fitresult = tuned_model.model
best_report = missing
end

pre_report = (parameter_names= permutedims(parameter_names), # row vector
parameter_scales=permutedims(scales), # row vector
best_measurement=best_measurement,
best_report=best_report)

if tuned_model.full_report
report = (# models=models,
# best_model=best_model,
parameter_names= permutedims(parameter_names), # row vector
parameter_scales=permutedims(scales), # row vector
parameter_values=A,
measurements=measurements,
best_measurement=best_measurement)
report = merge(pre_report,
(parameter_values=A,
measurements=measurements,))
else
report = (# models=[deepcopy(clone),][1:0], # empty vector
# best_model=best_model,
parameter_names= permutedims(parameter_names), # row vector
parameter_scales=permutedims(scales), # row vector
parameter_values=A[1:0,1:0], # empty matrix
measurements=[best_measurement, ][1:0], # empty vector
best_measurement=best_measurement)
report = merge(pre_report,
(parameter_values=missing,
measurements=missing,))
end

cache = nothing
Expand All @@ -319,7 +323,15 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y

end

MLJBase.fitted_params(::EitherTunedModel, fitresult) = (best_model=fitresult.model,)
function MLJBase.fitted_params(tuned_model::EitherTunedModel, fitresult)
if tuned_model.train_best
return (best_model=fitresult.model,
best_fitted_params=fitted_params(fitresult))
else
return (best_model=fitresult.model,
best_fitted_params=missing)
end
end

MLJBase.predict(tuned_model::EitherTunedModel, fitresult, Xnew) = predict(fitresult, Xnew)
MLJBase.best(model::EitherTunedModel, fitresult) = fitresult.model
Expand Down
9 changes: 6 additions & 3 deletions test/tuning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.2*rand(100);
tuned = machine(tuned_model, X, y)

fit!(tuned)
report(tuned)
r = report(tuned)
@test r.best_report isa NamedTuple{(:machines, :reports)}
tuned_model.full_report=true
fit!(tuned)
report(tuned)

b = fitted_params(tuned).best_model
fp = fitted_params(tuned)
@test fp.best_fitted_params isa NamedTuple{(:machines, :fitted_params)}
b = fp.best_model
@test b isa MLJ.SimpleDeterministicCompositeModel

measurements = tuned.report.measurements
# should be all different:
Expand Down

0 comments on commit 48f2335

Please sign in to comment.