-
Notifications
You must be signed in to change notification settings - Fork 1
/
SpamDetector.cs
73 lines (59 loc) · 3.25 KB
/
SpamDetector.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
using System;
using System.Diagnostics;
using System.IO;
using System.Threading.Tasks;
using Microsoft.ML;
using SentimentAnalysisDemo.ML.Model;
using Volo.Abp;
using Volo.Abp.DependencyInjection;
namespace SentimentAnalysisDemo.ML;
public class SpamDetector : ISpamDetector, ITransientDependency
{
private static readonly string DataPath = Path.Combine(Environment.CurrentDirectory, "ML", "Data", "spam_data.csv");
private static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "ML", "Data", "spam_data_model.zip");
public async Task CheckAsync(string text)
{
var mlContext = new MLContext();
//Step 1: Load Data 👇
IDataView dataView = mlContext.Data.LoadFromTextFile<SentimentAnalyzeInput>(DataPath, hasHeader: true, separatorChar: ',');
//Step 2: Split data to train-test data 👇
DataOperationsCatalog.TrainTestData trainTestSplit = mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2);
IDataView trainingData = trainTestSplit.TrainSet; //80% of the data.
IDataView testData = trainTestSplit.TestSet; //20% of the data.
//Step 3: Common data process configuration with pipeline data transformations + choose and set the training algorithm 👇
var estimator = mlContext.Transforms.Text.FeaturizeText(outputColumnName: "Features", inputColumnName: nameof(SentimentAnalyzeInput.Message))
.Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(labelColumnName: "Label", featureColumnName: "Features"));
//Step 4: Train the model 👇
ITransformer model = estimator.Fit(trainingData);
#region Advanced: Evaulating the model to see its accuracy and save/persist the trained model to a .ZIP file and use it (like a cache).
//* Evaluate the model and show accuracy stats 👇
var predictions = model.Transform(testData);
var metrics = mlContext.BinaryClassification.Evaluate(data: predictions, labelColumnName: "Label", scoreColumnName: "Score");
var accuracy = metrics.Accuracy; // 0.97 for our test model.
var f1Score = metrics.F1Score; //0.91 for our test model.
//* Save/persist the trained model to a .ZIP file. 👇
mlContext.Model.Save(model, trainingData.Schema, ModelPath);
//* Load the model from the .ZIP file on production. 👇
if (!DebugHelper.IsDebug && File.Exists(ModelPath))
{
model = mlContext.Model.Load(ModelPath, out DataViewSchema inputSchema);
}
#endregion
//Step 5: Predict 👇
var sentimentAnalyzeInput = new SentimentAnalyzeInput
{
Message = text
};
var predictionEngine = mlContext.Model.CreatePredictionEngine<SentimentAnalyzeInput, SentimentAnalyzeResult>(model);
var result = predictionEngine.Predict(sentimentAnalyzeInput);
if (IsSpam(result))
{
throw new UserFriendlyException("Spam detected! Please update the message!");
}
}
private static bool IsSpam(SentimentAnalyzeResult result)
{
//1 -> spam / 0 -> ham (for 'Prediction' column)
return result is { Prediction: true, Probability: >= 0.5f };
}
}