# ML.Net - Samples - Large Datasets

# Spam Detection for Text Messages

| ML.NET version | API type          | Status                        | App Type    | Data type | Scenario            | ML Task                   | Algorithms                  |
|----------------|-------------------|-------------------------------|-------------|-----------|---------------------|---------------------------|-----------------------------|
| v1.4           | Dynamic API | Might need to update project structure to match template | Jupyter-Notebbok | .tsv files | Spam detection | Two-class classification | Averaged Perceptron (linear learner) |

In this sample, you'll see how to use [ML.NET](https://www.microsoft.com/net/learn/apps/machine-learning-and-ai/ml-dotnet) to predict whether a text message is spam. In the world of machine learning, this type of prediction is known as **binary classification**.

## Problem

Our goal here is to predict whether a text message is spam (an irrelevant/unwanted message). We will use the [SMS Spam Collection Data Set](https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection) from UCI, which contains close to 6000 messages that have been classified as being "spam" or "ham" (not spam). We will use this dataset to train a model that can take in new message and predict whether they are spam or not.

This is an example of binary classification, as we are classifying the text messages into one of two categories.


## Solution

To solve this problem, first we will build an estimator to define the ML pipeline we want to use. Then we will train this estimator on existing data, evaluate how good it is, and lastly we'll consume the model to predict whether a few examples messages are spam.

![Build -> Train -> Evaluate -> Consume](../shared_content/modelpipeline.png)

### 1. Build Model

To build the model we will:

* Define how to read the spam dataset that will be downloaded from https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection. 

* Apply several data transformations:

    * Convert the label ("spam" or "ham") to a boolean ("true" represents spam) so we can use it with a binary classifier. 
    * Featurize the text message into a numeric vector so a machine learning trainer can use it

* Add a trainer (such as `StochasticDualCoordinateAscent`).

The initial code is similar to the following:


In [3]:
// ML.NET Nuget packages installation
#r "nuget:Microsoft.ML" 
#r "nuget:SharpZipLib" 

Installed package SharpZipLib version 1.2.0

## Using C# Class

In [4]:
using System;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Net;
using ICSharpCode.SharpZipLib.GZip;
using ICSharpCode.SharpZipLib.Tar;
using Microsoft.ML;
using Microsoft.ML.Data;
using System.Collections.Generic;
using static Microsoft.ML.TrainCatalogBase;
using static Microsoft.ML.DataOperationsCatalog;
using System.Diagnostics;

## Declare data-classes for input data and predictions

In [18]:
public class UrlData
{
    [LoadColumn(0)]
    public string LabelColumn;
    
    [LoadColumn(1, 3231961)]
    [VectorType(3231961)]
    public float[] FeatureVector;
}

public class UrlPrediction
{
    // ColumnName attribute is used to change the column name from
    // its default value, which is the name of the field.
    [ColumnName("PredictedLabel")]
    public bool Prediction;
    
    public float Score;
}

### Constants

In [20]:
static string originalDataDirectoryPath = @"./datasets/LargeDatasets/OriginalUrlData";
static string originalDataPath = @"./datasets/LargeDatasets/OriginalUrlData/url_svmlight";
static string preparedDataPath = @"./datasets/LargeDatasets/PreparedUrlData/url_svmlight";

### ConsoleHelper

In [21]:
public static class ConsoleHelper
{
    public static void PrintPrediction(string prediction)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"Predicted : {prediction}");
        Console.WriteLine($"*************************************************");
    }

    public static void PrintRegressionPredictionVersusObserved(string predictionCount, string observedCount)
    {
        Console.WriteLine($"-------------------------------------------------");
        Console.WriteLine($"Predicted : {predictionCount}");
        Console.WriteLine($"Actual:     {observedCount}");
        Console.WriteLine($"-------------------------------------------------");
    }

    public static void PrintRegressionMetrics(string name, RegressionMetrics metrics)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"*       Metrics for {name} regression model      ");
        Console.WriteLine($"*------------------------------------------------");
        Console.WriteLine($"*       LossFn:        {metrics.LossFunction:0.##}");
        Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
        Console.WriteLine($"*       Absolute loss: {metrics.MeanAbsoluteError:#.##}");
        Console.WriteLine($"*       Squared loss:  {metrics.MeanSquaredError:#.##}");
        Console.WriteLine($"*       RMS loss:      {metrics.RootMeanSquaredError:#.##}");
        Console.WriteLine($"*************************************************");
    }

    public static void PrintBinaryClassificationMetrics(string name, CalibratedBinaryClassificationMetrics metrics)
    {
        Console.WriteLine($"************************************************************");
        Console.WriteLine($"*       Metrics for {name} binary classification model      ");
        Console.WriteLine($"*-----------------------------------------------------------");
        Console.WriteLine($"*       Accuracy: {metrics.Accuracy:P2}");
        Console.WriteLine($"*       Area Under Curve:      {metrics.AreaUnderRocCurve:P2}");
        Console.WriteLine($"*       Area under Precision recall Curve:  {metrics.AreaUnderPrecisionRecallCurve:P2}");
        Console.WriteLine($"*       F1Score:  {metrics.F1Score:P2}");
        Console.WriteLine($"*       LogLoss:  {metrics.LogLoss:#.##}");
        Console.WriteLine($"*       LogLossReduction:  {metrics.LogLossReduction:#.##}");
        Console.WriteLine($"*       PositivePrecision:  {metrics.PositivePrecision:#.##}");
        Console.WriteLine($"*       PositiveRecall:  {metrics.PositiveRecall:#.##}");
        Console.WriteLine($"*       NegativePrecision:  {metrics.NegativePrecision:#.##}");
        Console.WriteLine($"*       NegativeRecall:  {metrics.NegativeRecall:P2}");
        Console.WriteLine($"************************************************************");
    }

    public static void PrintMultiClassClassificationMetrics(string name, MulticlassClassificationMetrics metrics)
    {
        Console.WriteLine($"************************************************************");
        Console.WriteLine($"*    Metrics for {name} multi-class classification model   ");
        Console.WriteLine($"*-----------------------------------------------------------");
        Console.WriteLine($"    AccuracyMacro = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
        Console.WriteLine($"    AccuracyMicro = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
        Console.WriteLine($"    LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
        Console.WriteLine($"************************************************************");
    }
    
    public static void PrintRegressionFoldsAverageMetrics(string algorithmName, IReadOnlyList<CrossValidationResult<RegressionMetrics>> crossValidationResults)
    {
        var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
        var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
        var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);
        var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
        var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);

        Console.WriteLine($"*************************************************************************************************************");
        Console.WriteLine($"*       Metrics for {algorithmName} Regression model      ");
        Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
        Console.WriteLine($"*       Average L1 Loss:    {L1.Average():0.###} ");
        Console.WriteLine($"*       Average L2 Loss:    {L2.Average():0.###}  ");
        Console.WriteLine($"*       Average RMS:          {RMS.Average():0.###}  ");
        Console.WriteLine($"*       Average Loss Function: {lossFunction.Average():0.###}  ");
        Console.WriteLine($"*       Average R-squared: {R2.Average():0.###}  ");
        Console.WriteLine($"*************************************************************************************************************");
    }
    
    public static void PrintMulticlassClassificationFoldsAverageMetrics(
                                     string algorithmName,
                                   IReadOnlyList<CrossValidationResult<MulticlassClassificationMetrics>> crossValResults
                                                                       )
    {
        var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);

        var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
        var microAccuracyAverage = microAccuracyValues.Average();
        var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
        var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);

        var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
        var macroAccuracyAverage = macroAccuracyValues.Average();
        var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
        var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);

        var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
        var logLossAverage = logLossValues.Average();
        var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
        var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);

        var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
        var logLossReductionAverage = logLossReductionValues.Average();
        var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
        var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);

        Console.WriteLine($"*************************************************************************************************************");
        Console.WriteLine($"*       Metrics for {algorithmName} Multi-class Classification model      ");
        Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
        Console.WriteLine($"*       Average MicroAccuracy:    {microAccuracyAverage:0.###}  - Standard deviation: ({microAccuraciesStdDeviation:#.###})  - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average MacroAccuracy:    {macroAccuracyAverage:0.###}  - Standard deviation: ({macroAccuraciesStdDeviation:#.###})  - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average LogLoss:          {logLossAverage:#.###}  - Standard deviation: ({logLossStdDeviation:#.###})  - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average LogLossReduction: {logLossReductionAverage:#.###}  - Standard deviation: ({logLossReductionStdDeviation:#.###})  - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
        Console.WriteLine($"*************************************************************************************************************");
    }    

    public static double CalculateStandardDeviation (IEnumerable<double> values)
    {
        double average = values.Average();
        double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
        double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count()-1));
        return standardDeviation;
    }

    public static double CalculateConfidenceInterval95(IEnumerable<double> values)
    {
        double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count()-1));
        return confidenceInterval95;
    }

    public static void PrintClusteringMetrics(string name, ClusteringMetrics metrics)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"*       Metrics for {name} clustering model      ");
        Console.WriteLine($"*------------------------------------------------");
        Console.WriteLine($"*       Average Distance: {metrics.AverageDistance}");
        Console.WriteLine($"*       Davies Bouldin Index is: {metrics.DaviesBouldinIndex}");
        Console.WriteLine($"*************************************************");
    }    
    
    public static void PeekDataViewInConsole(MLContext mlContext, IDataView dataView, IEstimator<ITransformer> pipeline, int numberOfRows = 4)
    {
        string msg = string.Format("Peek data in DataView: Showing {0} rows with the columns", numberOfRows.ToString());
        ConsoleWriteHeader(msg);

        //https://github.com/dotnet/machinelearning/blob/master/docs/code/MlNetCookBook.md#how-do-i-look-at-the-intermediate-data
        var transformer = pipeline.Fit(dataView);
        var transformedData = transformer.Transform(dataView);

        // 'transformedData' is a 'promise' of data, lazy-loading. call Preview  
        //and iterate through the returned collection from preview.

        var preViewTransformedData = transformedData.Preview(maxRows: numberOfRows);

        foreach (var row in preViewTransformedData.RowView)
        {
            var ColumnCollection = row.Values;
            string lineToPrint = "Row--> ";
            foreach (KeyValuePair<string, object> column in ColumnCollection)
            {
                lineToPrint += $"| {column.Key}:{column.Value}";
            }
            Console.WriteLine(lineToPrint + "\n");
        }
    }
    
    public static void PeekVectorColumnDataInConsole(MLContext mlContext, string columnName, IDataView dataView, IEstimator<ITransformer> pipeline, int numberOfRows = 4)
    {
        string msg = string.Format("Peek data in DataView: : Show {0} rows with just the '{1}' column", numberOfRows, columnName );
        ConsoleWriteHeader(msg);

        var transformer = pipeline.Fit(dataView);
        var transformedData = transformer.Transform(dataView);

        // Extract the 'Features' column.
        var someColumnData = transformedData.GetColumn<float[]>(columnName)
                                                    .Take(numberOfRows).ToList();

        // print to console the peeked rows

        int currentRow = 0;
        someColumnData.ForEach(row => {
                                        currentRow++;
                                        String concatColumn = String.Empty;
                                        foreach (float f in row)
                                        {
                                            concatColumn += f.ToString();                                              
                                        }

                                        Console.WriteLine();
                                        string rowMsg = string.Format("**** Row {0} with '{1}' field value ****", currentRow, columnName);
                                        Console.WriteLine(rowMsg);
                                        Console.WriteLine(concatColumn);
                                        Console.WriteLine();
                                      });
    }
    
    public static void ConsoleWriteHeader(params string[] lines)
    {
        var defaultColor = Console.ForegroundColor;
        Console.ForegroundColor = ConsoleColor.Yellow;
        Console.WriteLine(" ");
        foreach (var line in lines)
        {
            Console.WriteLine(line);
        }
        var maxLength = lines.Select(x => x.Length).Max();
        Console.WriteLine(new string('#', maxLength));
        Console.ForegroundColor = defaultColor;
    }

    public static void ConsoleWriterSection(params string[] lines)
    {
        var defaultColor = Console.ForegroundColor;
        Console.ForegroundColor = ConsoleColor.Blue;
        Console.WriteLine(" ");
        foreach (var line in lines)
        {
            Console.WriteLine(line);
        }
        var maxLength = lines.Select(x => x.Length).Max();
        Console.WriteLine(new string('-', maxLength));
        Console.ForegroundColor = defaultColor;
    }
    
}

### Methods

In [22]:
public static void DownloadDataset(string originalDataDirectoryPath)
{
    if (!Directory.Exists(originalDataDirectoryPath))
    {
        Console.WriteLine("====Downloading and extracting data====");
        using (var client = new WebClient())
        {
            //The code below will download a dataset from a third-party, UCI (link), and may be governed by separate third-party terms. 
            //By proceeding, you agree to those separate terms.
            client.DownloadFile("https://archive.ics.uci.edu/ml/machine-learning-databases/url/url_svmlight.tar.gz", "url_svmlight.zip");
        }

        Stream inputStream = File.OpenRead("url_svmlight.zip");
        Stream gzipStream = new GZipInputStream(inputStream);
        TarArchive tarArchive = TarArchive.CreateInputTarArchive(gzipStream);
        tarArchive.ExtractContents(originalDataDirectoryPath);

        tarArchive.Close();
        gzipStream.Close();
        inputStream.Close();
        Console.WriteLine("====Downloading and extracting is completed====");
    }
}

private static void PrepareDataset(string originalDataPath,string preparedDataPath)
{
    //Create folder for prepared Data path if it does not exist.
    if (!Directory.Exists(preparedDataPath))
    {
        Directory.CreateDirectory(preparedDataPath);
    }
        Console.WriteLine("====Preparing Data====");
        Console.WriteLine("");
        //ML.Net API checks for number of features column before the sparse matrix format
        //So add total number of features i.e 3231961 as second column by taking all the files from originalDataPath
        //and save those files in preparedDataPath.
        if (Directory.GetFiles(preparedDataPath).Length == 0)
        {
            var ext = new List<string> { ".svm" };
            var filesInDirectory = Directory.GetFiles(originalDataPath, "*.*", SearchOption.AllDirectories)
                                        .Where(s => ext.Contains(Path.GetExtension(s)));
            foreach (var file in filesInDirectory)
            {
                AddFeaturesColumn(Path.GetFullPath(file), preparedDataPath);
            }
        }
        Console.WriteLine("====Data Preparation is done====");
        Console.WriteLine("");
        Console.WriteLine("original data path= {0}", originalDataPath);
        Console.WriteLine("");
        Console.WriteLine("prepared data path= {0}", preparedDataPath);
        Console.WriteLine("");
}
        
private static void AddFeaturesColumn(string sourceFilePath,string preparedDataPath)
{
    string sourceFileName = Path.GetFileName(sourceFilePath);
    string preparedFilePath = Path.Combine(preparedDataPath, sourceFileName);

    //if the file does not exist in preparedFilePath then copy from sourceFilePath and then add new column
    if (!File.Exists(preparedFilePath))
    {
        File.Copy(sourceFilePath, preparedFilePath, true);
    }
    string newColumnData =  "3231961";            
    string[] CSVDump = File.ReadAllLines(preparedFilePath);            
    List<List<string>> CSV = CSVDump.Select(x => x.Split(' ').ToList()).ToList();
    for (int i = 0; i < CSV.Count; i++)
    {
        CSV[i].Insert(1, newColumnData);
    }
   
    File.WriteAllLines(preparedFilePath, CSV.Select(x => string.Join('\t', x)));
}

private static List<UrlData> CreateSingleDataSample(MLContext mlContext, IDataView dataView)
{
    // Here (ModelInput object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file.
    List<UrlData> sampleForPredictions = mlContext.Data.CreateEnumerable<UrlData>(dataView, false).Take(4).ToList();                                                                        ;
    return sampleForPredictions;
}  

## Evaluate

In [24]:
//STEP 1: Download dataset
DownloadDataset(originalDataDirectoryPath);

//Step 2: Prepare data by adding second column with value total number of features.
PrepareDataset(originalDataPath, preparedDataPath);

MLContext mlContext = new MLContext();

//STEP 3: Common data loading configuration  
var fullDataView = mlContext.Data.LoadFromTextFile<UrlData>(path: Path.Combine(preparedDataPath, "*"),
                                          hasHeader: false,
                                          allowSparse: true);

//Step 4: Divide the whole dataset into 80% training and 20% testing data.
TrainTestData trainTestData = mlContext.Data.TrainTestSplit(fullDataView, testFraction: 0.2, seed: 1);
IDataView trainDataView = trainTestData.TrainSet;
IDataView testDataView = trainTestData.TestSet;

//Step 5: Map label value from string to bool
var UrlLabelMap = new Dictionary<string, bool>();
UrlLabelMap["+1"] = true; //Malicious url
UrlLabelMap["-1"] = false; //Benign 
var dataProcessingPipeLine = mlContext.Transforms.Conversion.MapValue("LabelKey", UrlLabelMap, "LabelColumn");
ConsoleHelper.PeekDataViewInConsole(mlContext, trainDataView, dataProcessingPipeLine, 2);   

//Step 6: Append trainer to pipeline
var trainingPipeLine = dataProcessingPipeLine.Append(
    mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine(labelColumnName: "LabelKey", featureColumnName: "FeatureVector"));                     

//Step 7: Train the model
Console.WriteLine("====Training the model=====");            
var mlModel = trainingPipeLine.Fit(trainDataView);
Console.WriteLine("====Completed Training the model=====");
Console.WriteLine("");


====Preparing Data====

====Data Preparation is done====

original data path= ./datasets/LargeDatasets/OriginalUrlData/url_svmlight

prepared data path= ./datasets/LargeDatasets/PreparedUrlData/url_svmlight



Unhandled exception: System.ArgumentOutOfRangeException: File does not exist at path: ./datasets/LargeDatasets/PreparedUrlData/url_svmlight\* (Parameter 'path')
   at Microsoft.ML.TextLoaderSaverCatalog.LoadFromTextFile[TInput](DataOperationsCatalog catalog, String path, Char separatorChar, Boolean hasHeader, Boolean allowQuoting, Boolean trimWhitespace, Boolean allowSparse)
   at Submission#30.<<Initialize>>d__0.MoveNext()
--- End of stack trace from previous location where exception was thrown ---
   at Microsoft.CodeAnalysis.Scripting.ScriptExecutionState.RunSubmissionsAsync[TResult](ImmutableArray`1 precedingExecutors, Func`2 currentExecutor, StrongBox`1 exceptionHolderOpt, Func`2 catchExceptionOpt, CancellationToken cancellationToken)