forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BinaryClassificationExperiment.cs
73 lines (66 loc) · 2.8 KB
/
BinaryClassificationExperiment.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
// Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
namespace Microsoft.ML.Auto
{
public sealed class BinaryExperimentSettings : ExperimentSettings
{
public BinaryClassificationMetric OptimizingMetric { get; set; } = BinaryClassificationMetric.Accuracy;
public ICollection<BinaryClassificationTrainer> Trainers { get; } =
Enum.GetValues(typeof(BinaryClassificationTrainer)).OfType<BinaryClassificationTrainer>().ToList();
}
public enum BinaryClassificationMetric
{
Accuracy,
AreaUnderRocCurve,
AreaUnderPrecisionRecallCurve,
F1Score,
PositivePrecision,
PositiveRecall,
NegativePrecision,
NegativeRecall,
}
public enum BinaryClassificationTrainer
{
AveragedPerceptron,
FastForest,
FastTree,
LightGbm,
LinearSupportVectorMachines,
LbfgsLogisticRegression,
SdcaLogisticRegression,
SgdCalibrated,
SymbolicSgdLogisticRegression,
}
public sealed class BinaryClassificationExperiment : ExperimentBase<BinaryClassificationMetrics>
{
internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSettings settings)
: base(context,
new BinaryMetricsAgent(context, settings.OptimizingMetric),
new OptimizingMetricInfo(settings.OptimizingMetric),
settings,
TaskKind.BinaryClassification,
TrainerExtensionUtil.GetTrainerNames(settings.Trainers))
{
}
}
public static class BinaryExperimentResultExtensions
{
public static RunDetail<BinaryClassificationMetrics> Best(this IEnumerable<RunDetail<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
{
var metricsAgent = new BinaryMetricsAgent(null, metric);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
public static CrossValidationRunDetail<BinaryClassificationMetrics> Best(this IEnumerable<CrossValidationRunDetail<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
{
var metricsAgent = new BinaryMetricsAgent(null, metric);
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
}
}
}