-
Notifications
You must be signed in to change notification settings - Fork 53
/
ImageClassificationPredictor.cs
84 lines (67 loc) · 3.71 KB
/
ImageClassificationPredictor.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
using System;
using System.IO;
using chapter12.wpf.ML.Base;
using chapter12.wpf.ML.Objects;
using Microsoft.ML;
namespace chapter12.wpf.ML
{
public class ImageClassificationPredictor : BaseML
{
// Training Variables
private static readonly string _assetsPath = Path.Combine(Environment.CurrentDirectory, "assets");
private static readonly string _imagesFolder = Path.Combine(_assetsPath, "images");
private readonly string _trainTagsTsv = Path.Combine(_imagesFolder, "tags.tsv");
private readonly string _inceptionTensorFlowModel = Path.Combine(_assetsPath, "inception", "tensorflow_inception_graph.pb");
private const string TF_SOFTMAX = "softmax2_pre_activation";
private const string INPUT = "input";
private static readonly string ML_NET_MODEL = Path.Combine(Environment.CurrentDirectory, "chapter12.mdl");
private ITransformer _model;
private struct InceptionSettings
{
public const int ImageHeight = 224;
public const int ImageWidth = 224;
public const float Mean = 117;
public const float Scale = 1;
public const bool ChannelsLast = true;
}
public ImageDataPredictionItem Predict(string filePath) =>
Predict(new ImageDataInputItem
{
ImagePath = filePath
}
);
public (bool Success, string Exception) Initialize()
{
try
{
if (File.Exists(ML_NET_MODEL))
{
_model = MlContext.Model.Load(ML_NET_MODEL, out DataViewSchema modelSchema);
return (true, string.Empty);
}
IEstimator<ITransformer> pipeline = MlContext.Transforms.LoadImages(outputColumnName: INPUT, imageFolder: _imagesFolder, inputColumnName: nameof(ImageDataInputItem.ImagePath))
.Append(MlContext.Transforms.ResizeImages(outputColumnName: INPUT, imageWidth: InceptionSettings.ImageWidth, imageHeight: InceptionSettings.ImageHeight, inputColumnName: INPUT))
.Append(MlContext.Transforms.ExtractPixels(outputColumnName: INPUT, interleavePixelColors: InceptionSettings.ChannelsLast, offsetImage: InceptionSettings.Mean))
.Append(MlContext.Model.LoadTensorFlowModel(_inceptionTensorFlowModel)
.ScoreTensorFlowModel(outputColumnNames: new[] { TF_SOFTMAX }, inputColumnNames: new[] { INPUT }, addBatchDimensionInput: true))
.Append(MlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "LabelKey", inputColumnName: nameof(ImageDataPredictionItem.Label)))
.Append(MlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "LabelKey", featureColumnName: TF_SOFTMAX))
.Append(MlContext.Transforms.Conversion.MapKeyToValue(nameof(ImageDataPredictionItem.PredictedLabelValue), "PredictedLabel"))
.AppendCacheCheckpoint(MlContext);
IDataView trainingData = MlContext.Data.LoadFromTextFile<ImageDataInputItem>(path: _trainTagsTsv, hasHeader: false);
_model = pipeline.Fit(trainingData);
MlContext.Model.Save(_model, trainingData.Schema, ML_NET_MODEL);
return (true, string.Empty);
}
catch (Exception ex)
{
return (false, ex.ToString());
}
}
public ImageDataPredictionItem Predict(ImageDataInputItem image)
{
var predictor = MlContext.Model.CreatePredictionEngine<ImageDataInputItem, ImageDataPredictionItem>(_model);
return predictor.Predict(image);
}
}
}