# ML.Net - Samples - Customer Segmentation

## Customer Segmentation - Clustering sample

| ML.NET version | API type          | Status                        | App Type    | Data type | Scenario            | ML Task                   | Algorithms                  |
|----------------|-------------------|-------------------------------|-------------|-----------|---------------------|---------------------------|-----------------------------|
| v1.5          | Dynamic API | Up-to-date | Jupyter Notebook | .csv files | Customer segmentation | Clustering | K-means++ |

You want to **identify groups of customers with similar profile** so you could target them afterwards (like different marketing campaigns per identified customer group with similar characteristics, etc.)

The problem to solve is how you can identify different groups of customers with similar profile and interest without having any pre-existing category list. You are *not* classifying customers across a category list because your customers are not *labeled* so you cannot do that. You just need to make groups/clusters of customers that the company will use afterwards for other business purposes.

## DataSet

In this hypothetical case, the data to process is coming from 'The Wine Company'. That data is basically a historic of offers/deals (part of marketing campaigns) provided by the company in the past plus the historic of purchases made by customers.

The training dataset is located in the `assets/inputs` folder, and split between two files. The offers file contains information about past marketing campaigns with specific offers/deals:

|Offer #|Campaign|Varietal|Minimum Qty (kg)|Discount (%)|Origin|Past Peak|
|-------|--------|--------|----------------|------------|------|---------|
|1|January|Malbec|72|56|France|FALSE|
|2|January|Pinot Noir|72|17|France|FALSE|
|3|February|Espumante|144|32|Oregon|TRUE|
|4|February|Champagne|72|48|France|TRUE|
|5|February|Cabernet Sauvignon|144|44|New Zealand|TRUE|

The transactions file contains information about customer purchases (related to the mentioned offers):

|Customer Last Name|Offer #|
|------------------|-------|
|Smith|2|
|Smith|24|
|Johnson|17|
|Johnson|24|
|Johnson|26|
|Williams|18|

This dataset comes from John Foreman's book titled [Data Smart](http://www.john-foreman.com/data-smart-book.html). 

## ML Task - [Clustering](https://en.wikipedia.org/wiki/Cluster_analysis)

The ML task to solve this kind of problem is called **Clustering**.

By applying ML clustering techniques, you will be able to identify similar customers and group them in clusters without having pre-existing categories and historic labeled/categorized data. Clustering is a good way to identify groups of 'related or similar things' without having any pre-existing category list. That is precisely the main difference between *clustering* and *classification*.

The algorithm used for this task in this particular sample is *K-Means*. In short, this algorithm assign samples from the dataset to **k** clusters:
* *K-Means* does not figure out the optimal number of clusters, so this is an algorithm parameter
* *K-Means* minimizes the distance between each point and the centroid (midpoint) of the cluster
* All points belonging to the cluster have similar properties (but these properties does not necessarily directly map to the features used for training, and are often objective of further data analysis)

Plotting a chart with the clusters helps you to visually identify what number of clusters works better for your data depending on how well segregated you can identify each cluster. Once you decide on the number of clusters, you can name each cluster with your preferred names and use each customer group/cluster for any business purpose. 

The following picture shows a sample clustered data distribution, and then, how k-Means is able to re-build data clusters.

![](../shared_content/k-means.png)

From the former figure, one question arises: how can we plot a sample formed by different features in a 2 dimensional space? This is a problem called "dimensionality reduction": each sample belongs to a dimensional space formed by each of his features (offer, campaign, etc), so we need a function that "translates" observation from the former space to another space (usually, with much less features, in our case, only two: X and Y). In this case, we will use a common technique called PCA, but there exists similar techniques, like SVD which can be used for the same purpose.


To solve this problem, first we will build an ML model. Then we will train the model on existing data, evaluate how good it is, and finally we'll consume the model to classify customers into clusters.

In [None]:
// ML.NET Nuget packages installation
#r "nuget:Microsoft.ML"
#r "nuget:OxyPlot.Core"  

In [None]:
using System;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Net;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using System.Collections.Generic;
using static Microsoft.ML.TrainCatalogBase;
using static Microsoft.ML.DataOperationsCatalog;
using System.Diagnostics;
using System.Globalization;
using OxyPlot;
using OxyPlot.Series;
using Microsoft.ML.Transforms;

### ConsoleHelper

In [None]:
public static class ConsoleHelper
{
    public static void PrintPrediction(string prediction)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"Predicted : {prediction}");
        Console.WriteLine($"*************************************************");
    }

    public static void PrintRegressionPredictionVersusObserved(string predictionCount, string observedCount)
    {
        Console.WriteLine($"-------------------------------------------------");
        Console.WriteLine($"Predicted : {predictionCount}");
        Console.WriteLine($"Actual:     {observedCount}");
        Console.WriteLine($"-------------------------------------------------");
    }

    public static void PrintRegressionMetrics(string name, RegressionMetrics metrics)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"*       Metrics for {name} regression model      ");
        Console.WriteLine($"*------------------------------------------------");
        Console.WriteLine($"*       LossFn:        {metrics.LossFunction:0.##}");
        Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
        Console.WriteLine($"*       Absolute loss: {metrics.MeanAbsoluteError:#.##}");
        Console.WriteLine($"*       Squared loss:  {metrics.MeanSquaredError:#.##}");
        Console.WriteLine($"*       RMS loss:      {metrics.RootMeanSquaredError:#.##}");
        Console.WriteLine($"*************************************************");
    }

    public static void PrintBinaryClassificationMetrics(string name, CalibratedBinaryClassificationMetrics metrics)
    {
        Console.WriteLine($"************************************************************");
        Console.WriteLine($"*       Metrics for {name} binary classification model      ");
        Console.WriteLine($"*-----------------------------------------------------------");
        Console.WriteLine($"*       Accuracy: {metrics.Accuracy:P2}");
        Console.WriteLine($"*       Area Under Curve:      {metrics.AreaUnderRocCurve:P2}");
        Console.WriteLine($"*       Area under Precision recall Curve:  {metrics.AreaUnderPrecisionRecallCurve:P2}");
        Console.WriteLine($"*       F1Score:  {metrics.F1Score:P2}");
        Console.WriteLine($"*       LogLoss:  {metrics.LogLoss:#.##}");
        Console.WriteLine($"*       LogLossReduction:  {metrics.LogLossReduction:#.##}");
        Console.WriteLine($"*       PositivePrecision:  {metrics.PositivePrecision:#.##}");
        Console.WriteLine($"*       PositiveRecall:  {metrics.PositiveRecall:#.##}");
        Console.WriteLine($"*       NegativePrecision:  {metrics.NegativePrecision:#.##}");
        Console.WriteLine($"*       NegativeRecall:  {metrics.NegativeRecall:P2}");
        Console.WriteLine($"************************************************************");
    }

    public static void PrintMultiClassClassificationMetrics(string name, MulticlassClassificationMetrics metrics)
    {
        Console.WriteLine($"************************************************************");
        Console.WriteLine($"*    Metrics for {name} multi-class classification model   ");
        Console.WriteLine($"*-----------------------------------------------------------");
        Console.WriteLine($"    AccuracyMacro = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
        Console.WriteLine($"    AccuracyMicro = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
        Console.WriteLine($"    LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
        Console.WriteLine($"************************************************************");
    }
    
    public static void PrintRegressionFoldsAverageMetrics(string algorithmName, IReadOnlyList<CrossValidationResult<RegressionMetrics>> crossValidationResults)
    {
        var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
        var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
        var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);
        var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
        var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);

        Console.WriteLine($"*************************************************************************************************************");
        Console.WriteLine($"*       Metrics for {algorithmName} Regression model      ");
        Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
        Console.WriteLine($"*       Average L1 Loss:    {L1.Average():0.###} ");
        Console.WriteLine($"*       Average L2 Loss:    {L2.Average():0.###}  ");
        Console.WriteLine($"*       Average RMS:          {RMS.Average():0.###}  ");
        Console.WriteLine($"*       Average Loss Function: {lossFunction.Average():0.###}  ");
        Console.WriteLine($"*       Average R-squared: {R2.Average():0.###}  ");
        Console.WriteLine($"*************************************************************************************************************");
    }
    
    public static void PrintMulticlassClassificationFoldsAverageMetrics(
                                     string algorithmName,
                                   IReadOnlyList<CrossValidationResult<MulticlassClassificationMetrics>> crossValResults
                                                                       )
    {
        var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);

        var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
        var microAccuracyAverage = microAccuracyValues.Average();
        var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
        var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);

        var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
        var macroAccuracyAverage = macroAccuracyValues.Average();
        var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
        var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);

        var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
        var logLossAverage = logLossValues.Average();
        var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
        var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);

        var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
        var logLossReductionAverage = logLossReductionValues.Average();
        var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
        var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);

        Console.WriteLine($"*************************************************************************************************************");
        Console.WriteLine($"*       Metrics for {algorithmName} Multi-class Classification model      ");
        Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
        Console.WriteLine($"*       Average MicroAccuracy:    {microAccuracyAverage:0.###}  - Standard deviation: ({microAccuraciesStdDeviation:#.###})  - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average MacroAccuracy:    {macroAccuracyAverage:0.###}  - Standard deviation: ({macroAccuraciesStdDeviation:#.###})  - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average LogLoss:          {logLossAverage:#.###}  - Standard deviation: ({logLossStdDeviation:#.###})  - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average LogLossReduction: {logLossReductionAverage:#.###}  - Standard deviation: ({logLossReductionStdDeviation:#.###})  - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
        Console.WriteLine($"*************************************************************************************************************");
    }    

    public static double CalculateStandardDeviation (IEnumerable<double> values)
    {
        double average = values.Average();
        double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
        double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count()-1));
        return standardDeviation;
    }

    public static double CalculateConfidenceInterval95(IEnumerable<double> values)
    {
        double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count()-1));
        return confidenceInterval95;
    }

    public static void PrintClusteringMetrics(string name, ClusteringMetrics metrics)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"*       Metrics for {name} clustering model      ");
        Console.WriteLine($"*------------------------------------------------");
        Console.WriteLine($"*       Average Distance: {metrics.AverageDistance}");
        Console.WriteLine($"*       Davies Bouldin Index is: {metrics.DaviesBouldinIndex}");
        Console.WriteLine($"*************************************************");
    }    
    
    public static void PeekDataViewInConsole(MLContext mlContext, IDataView dataView, IEstimator<ITransformer> pipeline, int numberOfRows = 4)
    {
        string msg = string.Format("Peek data in DataView: Showing {0} rows with the columns", numberOfRows.ToString());
        ConsoleWriteHeader(msg);

        //https://github.com/dotnet/machinelearning/blob/master/docs/code/MlNetCookBook.md#how-do-i-look-at-the-intermediate-data
        var transformer = pipeline.Fit(dataView);
        var transformedData = transformer.Transform(dataView);

        // 'transformedData' is a 'promise' of data, lazy-loading. call Preview  
        //and iterate through the returned collection from preview.

        var preViewTransformedData = transformedData.Preview(maxRows: numberOfRows);

        foreach (var row in preViewTransformedData.RowView)
        {
            var ColumnCollection = row.Values;
            string lineToPrint = "Row--> ";
            foreach (KeyValuePair<string, object> column in ColumnCollection)
            {
                lineToPrint += $"| {column.Key}:{column.Value}";
            }
            Console.WriteLine(lineToPrint + "\n");
        }
    }
    
    public static void PeekVectorColumnDataInConsole(MLContext mlContext, string columnName, IDataView dataView, IEstimator<ITransformer> pipeline, int numberOfRows = 4)
    {
        string msg = string.Format("Peek data in DataView: : Show {0} rows with just the '{1}' column", numberOfRows, columnName );
        ConsoleWriteHeader(msg);

        var transformer = pipeline.Fit(dataView);
        var transformedData = transformer.Transform(dataView);

        // Extract the 'Features' column.
        var someColumnData = transformedData.GetColumn<float[]>(columnName)
                                                    .Take(numberOfRows).ToList();

        // print to console the peeked rows

        int currentRow = 0;
        someColumnData.ForEach(row => {
                                        currentRow++;
                                        String concatColumn = String.Empty;
                                        foreach (float f in row)
                                        {
                                            concatColumn += f.ToString();                                              
                                        }

                                        Console.WriteLine();
                                        string rowMsg = string.Format("**** Row {0} with '{1}' field value ****", currentRow, columnName);
                                        Console.WriteLine(rowMsg);
                                        Console.WriteLine(concatColumn);
                                        Console.WriteLine();
                                      });
    }
    
    public static void ConsoleWriteHeader(params string[] lines)
    {
        var defaultColor = Console.ForegroundColor;
        Console.ForegroundColor = ConsoleColor.Yellow;
        Console.WriteLine(" ");
        foreach (var line in lines)
        {
            Console.WriteLine(line);
        }
        var maxLength = lines.Select(x => x.Length).Max();
        Console.WriteLine(new string('#', maxLength));
        Console.ForegroundColor = defaultColor;
    }

    public static void ConsoleWriterSection(params string[] lines)
    {
        var defaultColor = Console.ForegroundColor;
        Console.ForegroundColor = ConsoleColor.Blue;
        Console.WriteLine(" ");
        foreach (var line in lines)
        {
            Console.WriteLine(line);
        }
        var maxLength = lines.Select(x => x.Length).Max();
        Console.WriteLine(new string('-', maxLength));
        Console.ForegroundColor = defaultColor;
    }
    
    public static void ConsoleWriteException(params string[] lines)
        {
            var defaultColor = Console.ForegroundColor;
            Console.ForegroundColor = ConsoleColor.Red;
            const string exceptionTitle = "EXCEPTION";
            Console.WriteLine(" ");
            Console.WriteLine(exceptionTitle);
            Console.WriteLine(new string('#', exceptionTitle.Length));
            Console.ForegroundColor = defaultColor;
            foreach (var line in lines)
            {
                Console.WriteLine(line);
            }
        }
    
}

In [None]:
public class PivotData
{
    public float C1 { get; set; }
    public float C2 { get; set; }
    public float C3 { get; set; }
    public float C4 { get; set; }
    public float C5 { get; set; }
    public float C6 { get; set; }
    public float C7 { get; set; }
    public float C8 { get; set; }
    public float C9 { get; set; }
    public float C10 { get; set; }
    public float C11 { get; set; }
    public float C12 { get; set; }
    public float C13 { get; set; }
    public float C14 { get; set; }
    public float C15 { get; set; }
    public float C16 { get; set; }
    public float C17 { get; set; }
    public float C18 { get; set; }
    public float C19 { get; set; }
    public float C20 { get; set; }
    public float C21 { get; set; }
    public float C22 { get; set; }
    public float C23 { get; set; }
    public float C24 { get; set; }
    public float C25 { get; set; }
    public float C26 { get; set; }
    public float C27 { get; set; }
    public float C28 { get; set; }
    public float C29 { get; set; }
    public float C30 { get; set; }
    public float C31 { get; set; }
    public float C32 { get; set; }
    public string LastName { get; set; }

    public override string ToString()
    {
        return        $"{C1},{C2},{C3},{C4},{C5},{C6},{C7},{C8},{C9}," +
               $"{C10},{C11},{C12},{C13},{C14},{C15},{C16},{C17},{C18},{C19}," +
               $"{C20},{C21},{C22},{C23},{C24},{C25},{C26},{C27},{C28},{C29}," +
               $"{C30},{C31},{C32},{LastName}";
    }

    public static void SaveToCsv(IEnumerable<PivotData> salesData, string file)
    {
        var columns = "C1,C2,C3,C4,C5,C6,C7,C8,C9," +
                      "C10,C11,C12,C13,C14,C15,C16,C17,C18,C19," +
                      "C20,C21,C22,C23,C24,C25,C26,C27,C28,C29," +
                      $"C30,C31,C32,{nameof(LastName)}";

        File.WriteAllLines(file, salesData
            .Select(s => s.ToString())
            .Prepend(columns));
    }
}

public class PivotObservation
{
    public float[] Features;
    public string LastName;
    //public float[] PCAFeatures;
    //public float[] Score;
}

public class ClusteringPrediction
{
    [ColumnName("PredictedLabel")]
    public uint SelectedClusterId;
    [ColumnName("Score")]
    public float[] Distance;
    [ColumnName("PCAFeatures")]
    public float[] Location;
    [ColumnName("LastName")]
    public string LastName;
}

public class Offer
{
    //Offer #,Campaign,Varietal,Minimum Qty (kg),Discount (%),Origin,Past Peak
    public string OfferId { get; set; }
    public string Campaign { get; set; }
    public string Varietal { get; set; }
    public float Minimum { get; set; }
    public float Discount { get; set; }
    public string Origin { get; set; }
    public string LastPeak { get; set; }

    public static IEnumerable<Offer> ReadFromCsv(string file)
    {
        return File.ReadAllLines(file)
         .Skip(1) // skip header
         .Select(x => x.Split(','))
         .Select(x => new Offer()
         {
             OfferId = x[0],
             Campaign = x[1],
             Varietal = x[2],
             Minimum = float.Parse(x[3], CultureInfo.InvariantCulture),
             Discount = float.Parse(x[4], CultureInfo.InvariantCulture),
             Origin = x[5],
             LastPeak = x[6]
         });
    }
}

public class Transaction
{
    //Customer Last Name,Offer #
    //Smith,2
    public string LastName { get; set; }
    public string OfferId { get; set; }

    public static IEnumerable<Transaction> ReadFromCsv(string file)
    {
        return File.ReadAllLines(file)
         .Skip(1) // skip header
         .Select(x => x.Split(','))
         .Select(x => new Transaction()
         {
             LastName = x[0],
             OfferId = x[1],
         });
    }
}

public class DataHelpers
{
    public static IEnumerable<PivotData> PreProcessAndSave(string offersDataLocation, string transactionsDataLocation, string pivotDataLocation)
    {
        var preProcessData = PreProcess(offersDataLocation, transactionsDataLocation);
        PivotData.SaveToCsv(preProcessData, pivotDataLocation);
        return preProcessData;
    }

    public static IEnumerable<PivotData> PreProcess(string offersDataLocation, string transactionsDataLocation)
    {
        ConsoleHelper.ConsoleWriteHeader("Preprocess input files");
        Console.WriteLine($"Offers file: {offersDataLocation}");
        Console.WriteLine($"Transactions file: {transactionsDataLocation}");

        var offers = Offer.ReadFromCsv(offersDataLocation);
        var transactions = Transaction.ReadFromCsv(transactionsDataLocation);

        // inner join datasets
        var clusterData = (from of in offers
                           join tr in transactions on of.OfferId equals tr.OfferId
                           select new
                           {
                               of.OfferId,
                               of.Campaign,
                               of.Discount,
                               tr.LastName,
                               of.LastPeak,
                               of.Minimum,
                               of.Origin,
                               of.Varietal,
                               Count = 1,
                           }).ToArray();

        // pivot table (naive way)
        // based on code from https://stackoverflow.com/a/43091570
        var pivotDataArray =
            (from c in clusterData
             group c by c.LastName into gcs
             let lookup = gcs.ToLookup(y => y.OfferId, y => y.Count)
             select new PivotData()
             {
                 LastName = gcs.Key,
                 C1 = (float)lookup["1"].Sum(),
                 C2 = (float)lookup["2"].Sum(),
                 C3 = (float)lookup["3"].Sum(),
                 C4 = (float)lookup["4"].Sum(),
                 C5 = (float)lookup["5"].Sum(),
                 C6 = (float)lookup["6"].Sum(),
                 C7 = (float)lookup["7"].Sum(),
                 C8 = (float)lookup["8"].Sum(),
                 C9 = (float)lookup["9"].Sum(),
                 C10 = (float)lookup["10"].Sum(),
                 C11 = (float)lookup["11"].Sum(),
                 C12 = (float)lookup["12"].Sum(),
                 C13 = (float)lookup["13"].Sum(),
                 C14 = (float)lookup["14"].Sum(),
                 C15 = (float)lookup["15"].Sum(),
                 C16 = (float)lookup["16"].Sum(),
                 C17 = (float)lookup["17"].Sum(),
                 C18 = (float)lookup["18"].Sum(),
                 C19 = (float)lookup["19"].Sum(),
                 C20 = (float)lookup["20"].Sum(),
                 C21 = (float)lookup["21"].Sum(),
                 C22 = (float)lookup["22"].Sum(),
                 C23 = (float)lookup["23"].Sum(),
                 C24 = (float)lookup["24"].Sum(),
                 C25 = (float)lookup["25"].Sum(),
                 C26 = (float)lookup["26"].Sum(),
                 C27 = (float)lookup["27"].Sum(),
                 C28 = (float)lookup["28"].Sum(),
                 C29 = (float)lookup["29"].Sum(),
                 C30 = (float)lookup["30"].Sum(),
                 C31 = (float)lookup["31"].Sum(),
                 C32 = (float)lookup["32"].Sum()
             }).ToArray();

        Console.WriteLine($"Total rows: {pivotDataArray.Length}");

        return pivotDataArray;
    }
}

public class ClusteringModelScorer
{
    private readonly string _pivotDataLocation;

    private readonly string _plotLocation;
    private readonly string _csvlocation;
    private readonly MLContext _mlContext;
    private ITransformer _trainedModel;

    public ClusteringModelScorer(MLContext mlContext, string pivotDataLocation, string plotLocation, string csvlocation)
    {
        _pivotDataLocation = pivotDataLocation;
        _plotLocation = plotLocation;
        _csvlocation = csvlocation;
        _mlContext = mlContext;
    }

    public ITransformer LoadModel(string modelPath)
    {
        _trainedModel = _mlContext.Model.Load(modelPath, out var modelInputSchema);
        return _trainedModel;
    }

    public void CreateCustomerClusters()
    {
        var data = _mlContext.Data.LoadFromTextFile(path:_pivotDataLocation,
                        columns: new[]
                                    {
                                      new TextLoader.Column("Features", DataKind.Single, new[] {new TextLoader.Range(0, 31) }),
                                      new TextLoader.Column(nameof(PivotData.LastName), DataKind.String, 32)
                                    },
                        hasHeader: true,
                        separatorChar: ',');
      
        //Apply data transformation to create predictions/clustering
        var tranfomedDataView = _trainedModel.Transform(data);
 
        var predictions = _mlContext.Data.CreateEnumerable <ClusteringPrediction>(tranfomedDataView, false)
                        .ToArray();
        

        SaveCustomerSegmentationCSV(predictions, _csvlocation);

        //Plot/paint the clusters in a chart and open it with the by-default image-tool in Windows
        SaveCustomerSegmentationPlotChart(predictions, _plotLocation);
        // Ver outra forma no navegador
      //  OpenChartInDefaultWindow(_plotLocation);

    }

    private static void SaveCustomerSegmentationCSV(IEnumerable<ClusteringPrediction> predictions, string csvlocation)
    {
        ConsoleHelper.ConsoleWriteHeader("CSV Customer Segmentation");
        using (var w = new System.IO.StreamWriter(csvlocation))
        {
            w.WriteLine($"LastName,SelectedClusterId");
            w.Flush();
            predictions.ToList().ForEach(prediction => {
                w.WriteLine($"{prediction.LastName},{prediction.SelectedClusterId}");
                w.Flush();
            });
        }

        Console.WriteLine($"CSV location: {csvlocation}");
    }

    private static void SaveCustomerSegmentationPlotChart(IEnumerable<ClusteringPrediction> predictions, string plotLocation)
    {
        ConsoleHelper.ConsoleWriteHeader("Plot Customer Segmentation");

        var plot = new PlotModel { Title = "Customer Segmentation", IsLegendVisible = true };

        var clusters = predictions.Select(p => p.SelectedClusterId).Distinct().OrderBy(x => x);

        foreach (var cluster in clusters)
        {
            var scatter = new ScatterSeries { MarkerType = MarkerType.Circle, MarkerStrokeThickness = 2, Title = $"Cluster: {cluster}"};
            var series = predictions
                .Where(p => p.SelectedClusterId == cluster)
                .Select(p => new ScatterPoint(p.Location[0], p.Location[1])).ToArray();
            scatter.Points.AddRange(series);
            plot.Series.Add(scatter);
        }

        plot.DefaultColors = OxyPalettes.HueDistinct(plot.Series.Count).Colors;

        var exporter = new SvgExporter { Width = 600, Height = 400 };
        using (var fs = new System.IO.FileStream(plotLocation, System.IO.FileMode.Create))
        {
            exporter.Export(plot, fs);
        }

        Console.WriteLine($"Plot location: {plotLocation}");
    }

    private static void OpenChartInDefaultWindow(string plotLocation)
    {
        Console.WriteLine("Showing chart...");
        var p = new Process();
        p.StartInfo = new ProcessStartInfo(plotLocation)
        {
            UseShellExecute = true
        };
        p.Start();
    }
}

In [None]:
string transactionsCsv = @"./datasets/CustomerSegmentation/inputs/transactions.csv";
string offersCsv = @"./datasets/CustomerSegmentation/inputs/offers.csv";
string pivotCsv = @"./datasets/CustomerSegmentation/inputs/pivot.csv";
string modelPath = @"./datasets/CustomerSegmentation/inputs/retailClustering.zip";
var plotSvg = @"./datasets/CustomerSegmentation/outputs/customerSegmentation.svg";
var plotCsv = @"./datasets/CustomerSegmentation/outputs/customerSegmentation.csv";

### Trainer

In [None]:
try
{
    //STEP 0: Special data pre-process in this sample creating the PivotTable csv file
    DataHelpers.PreProcessAndSave(offersCsv, transactionsCsv, pivotCsv);

    //Create the MLContext to share across components for deterministic results
    MLContext mlContext = new MLContext(seed: 1);  //Seed set to any number so you have a deterministic environment

    // STEP 1: Common data loading configuration
    var pivotDataView = mlContext.Data.LoadFromTextFile(path: pivotCsv,
                                columns: new[]
                                            {
                                            new TextLoader.Column("Features", DataKind.Single, new[] {new TextLoader.Range(0, 31) }),
                                            new TextLoader.Column(nameof(PivotData.LastName), DataKind.String, 32)
                                            },
                                hasHeader: true,
                                separatorChar: ',');

    //STEP 2: Configure data transformations in pipeline
    var dataProcessPipeline = mlContext.Transforms.ProjectToPrincipalComponents(outputColumnName: "PCAFeatures", inputColumnName: "Features", rank: 2)
     .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "LastNameKey", inputColumnName: nameof(PivotData.LastName), OneHotEncodingEstimator.OutputKind.Indicator));


    // (Optional) Peek data in training DataView after applying the ProcessPipeline's transformations
    ConsoleHelper.PeekDataViewInConsole(mlContext, pivotDataView, dataProcessPipeline, 10);
    ConsoleHelper.PeekVectorColumnDataInConsole(mlContext, "Features", pivotDataView, dataProcessPipeline, 10);

    //STEP 3: Create the training pipeline
    var trainer = mlContext.Clustering.Trainers.KMeans(featureColumnName: "Features", numberOfClusters: 3);
    var trainingPipeline = dataProcessPipeline.Append(trainer);

    //STEP 4: Train the model fitting to the pivotDataView
    Console.WriteLine("=============== Training the model ===============");
    ITransformer trainedModel = trainingPipeline.Fit(pivotDataView);

    //STEP 5: Evaluate the model and show accuracy stats
    Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
    var predictions = trainedModel.Transform(pivotDataView);
    var metrics = mlContext.Clustering.Evaluate(predictions, scoreColumnName: "Score", featureColumnName: "Features");

    ConsoleHelper.PrintClusteringMetrics(trainer.ToString(), metrics);

    //STEP 6: Save/persist the trained model to a .ZIP file
    mlContext.Model.Save(trainedModel, pivotDataView.Schema, modelPath);

    Console.WriteLine("The model is saved to {0}", modelPath);
}
catch (Exception ex)
{
    ConsoleHelper.ConsoleWriteException(ex.ToString());
}

Console.WriteLine("=============== The End ===============");

In [None]:
try
{
    MLContext mlContext = new MLContext();  //Seed set to any number so you have a deterministic results

    //Create the clusters: Create data files and plot a chart
    var clusteringModelScorer = new ClusteringModelScorer(mlContext, pivotCsv, plotSvg, plotCsv);
    clusteringModelScorer.LoadModel(modelPath);

    clusteringModelScorer.CreateCustomerClusters();
} catch (Exception ex)
{
    ConsoleHelper.ConsoleWriteException(ex.ToString());
}

