Skip to content
Permalink
Fetching contributors…
Cannot retrieve contributors at this time
155 lines (126 sloc) 8.14 KB
//*****************************************************************************************
//* *
//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. *
//* *
//*****************************************************************************************
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using SentimentModel.Model.DataModels;
using Microsoft.ML.Trainers;
namespace SentimentModel.ConsoleApp
{
public static class ModelBuilder
{
// SMALL DATASET in GitHub repo:
private const string TRAIN_DATA_FILEPATH = @"../../../../../Data/yelp_labelled.tsv";
// SMALL DATASET in AZURE FILES: private static string TRAIN_DATA_FILEPATH = @"X:\\yelp\\yelp_labelled.tsv";
// LARGE DATASET in AZURE FILES: private static string TRAIN_DATA_FILEPATH = @"X:\\twitter\\Twittersentiment-1Million.tsv";
private static string MODEL_FILEPATH = @"../../../../SentimentModel.Model/MLModel.zip";
// Create MLContext to be shared across the model creation workflow objects
// Set a random seed for repeatable/deterministic results across multiple trainings.
private static MLContext mlContext = new MLContext(seed: 1);
public static void CreateModel()
{
// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: GetAbsolutePath(TRAIN_DATA_FILEPATH),
hasHeader: true,
separatorChar: '\t',
allowQuoting: true,
allowSparse: false);
// Build training pipeline
IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext);
// Evaluate quality of Model
Evaluate(mlContext, trainingDataView, trainingPipeline);
// Train Model
ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);
// Save model
SaveModel(mlContext, mlModel, GetAbsolutePath(MODEL_FILEPATH), trainingDataView.Schema);
}
public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
{
// Data process configuration with pipeline data transformations
var dataProcessPipeline = mlContext.Transforms.Text.FeaturizeText("text_tf", "text")
.Append(mlContext.Transforms.CopyColumns("Features", "text_tf"))
.Append(mlContext.Transforms.NormalizeMinMax("Features", "Features"))
.AppendCacheCheckpoint(mlContext);
// Set the training algorithm
var trainer = mlContext.BinaryClassification.Trainers.SgdCalibrated(new SgdCalibratedTrainer.Options()
{ L2Regularization = 1E-06f,
ConvergenceTolerance = 0.0001f,
NumberOfIterations = 10,
Shuffle = true,
LabelColumnName = "sentiment",
FeatureColumnName = "Features" });
var trainingPipeline = dataProcessPipeline.Append(trainer);
return trainingPipeline;
}
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
Console.WriteLine("=============== Training model ===============");
ITransformer model = trainingPipeline.Fit(trainingDataView);
Console.WriteLine("=============== End of training process ===============");
return model;
}
private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
// Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
// in order to evaluate and get the model's accuracy metrics
Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
var crossValidationResults = mlContext.BinaryClassification.CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "sentiment");
PrintBinaryClassificationFoldsAverageMetrics(crossValidationResults);
}
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
{
// Save/persist the trained model to a .ZIP file
Console.WriteLine($"=============== Saving the model ===============");
mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
}
public static string GetAbsolutePath(string relativePath)
{
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
string assemblyFolderPath = _dataRoot.Directory.FullName;
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
return fullPath;
}
public static void PrintBinaryClassificationMetrics(BinaryClassificationMetrics metrics)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for binary classification model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($"* Accuracy: {metrics.Accuracy:P2}");
Console.WriteLine($"* Auc: {metrics.AreaUnderRocCurve:P2}");
Console.WriteLine($"************************************************************");
}
public static void PrintBinaryClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>> crossValResults)
{
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
var AccuracyValues = metricsInMultipleFolds.Select(m => m.Accuracy);
var AccuracyAverage = AccuracyValues.Average();
var AccuraciesStdDeviation = CalculateStandardDeviation(AccuracyValues);
var AccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(AccuracyValues);
Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for Binary Classification model ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"* Average Accuracy: {AccuracyAverage:0.###} - Standard deviation: ({AccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({AccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"*************************************************************************************************************");
}
public static double CalculateStandardDeviation(IEnumerable<double> values)
{
double average = values.Average();
double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));
return standardDeviation;
}
public static double CalculateConfidenceInterval95(IEnumerable<double> values)
{
double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
return confidenceInterval95;
}
}
}
You can’t perform that action at this time.