# ML.Net - Samples - CreditCardFraudDetection -> (Binary Classification)

## Fraud detection in credit cards (binary classification)

| ML.NET version | API type          | Status                        | App Type    | Data type | Scenario            | ML Task                   | Algorithms                  |
|----------------|-------------------|-------------------------------|-------------|-----------|---------------------|---------------------------|-----------------------------|
| v1.5           | Dynamic API | Up-to-date | Jupyter Notebook | .csv file | Fraud Detection | Two-class classification | FastTree Binary Classification |

In this introductory sample, you'll see how to use ML.NET to predict a credit card fraud. In the world of machine learning, this type of prediction is known as binary classification.

In this sample, you'll see how to use [ML.NET](https://www.microsoft.com/net/learn/apps/machine-learning-and-ai/ml-dotnet) to predict whether a text message is spam. In the world of machine learning, this type of prediction is known as **binary classification**.

## Problem

This problem is centered around predicting if credit card transaction (with its related info/variables) is a fraud or no. 
 
The input dataset of the transactions contain only numerical input variables which are the result of previous PCA (Principal Component Analysis) transformations. Unfortunately, due to confidentiality issues, the original features and additional background information are not available, but the way you build the model doesn't change.  

Features V1, V2, ... V28 are the principal components obtained with PCA, the only features which have not been transformed with PCA are 'Time' and 'Amount'. 

The feature 'Time' contains the seconds elapsed between each transaction and the first transaction in the dataset. The feature 'Amount' is the transaction Amount, this feature can be used for example-dependant cost-sensitive learning. Feature 'Class' is the response variable and it takes value 1 in case of fraud and 0 otherwise.

The dataset is highly unbalanced, the positive class (frauds) account for 0.172% of all transactions.

Using those datasets you build a model that when predicting it will analyze a transaction's input variables and predict a fraud value of false or true.

## DataSet

The training and testing data is based on a public [dataset available at Kaggle](https://www.kaggle.com/mlg-ulb/creditcardfraud) originally from Worldline and the Machine Learning Group (http://mlg.ulb.ac.be) of ULB (Université Libre de Bruxelles), collected and analysed during a research collaboration. 

The datasets contains transactions made by credit cards in September 2013 by european cardholders. This dataset presents transactions that occurred in two days, where we have 492 frauds out of 284,807 transactions.

By: Andrea Dal Pozzolo, Olivier Caelen, Reid A. Johnson and Gianluca Bontempi. Calibrating Probability with Undersampling for Unbalanced Classification. In Symposium on Computational Intelligence and Data Mining (CIDM), IEEE, 2015

More details on current and past projects on related topics are available on http://mlg.ulb.ac.be/BruFence and http://mlg.ulb.ac.be/ARTML

## ML Task - [Binary Classification](https://en.wikipedia.org/wiki/Binary_classification)

Binary or binomial classification is the task of classifying the elements of a given set into two groups (predicting which group each one belongs to) on the basis of a classification rule. Contexts requiring a decision as to whether or not an item has some qualitative property, some specified characteristic.

If you would like to learn how to detect fraud using anomaly detection, visit the [Anomaly Detection Credit Card Fraud Detection sample](../AnomalyDetection_CreditCardFraudDetection).

## Solution

To solve this problem, first you need to build a machine learning model. Then you train the model on existing training data, evaluate how good its accuracy is, and lastly you consume the model (deploying the built model in a different app) to predict a fraud for a sample credit card transaction.

![Build -> Train -> Evaluate -> Consume](../shared_content/modelpipeline.png)


In [None]:
// ML.NET Nuget packages 
#r "nuget:Microsoft.ML"     

// ML.NET FastTree Nuget packages 
#r "nuget:Microsoft.ML.FastTree"

## Using C# Class

In [None]:
using System;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Net;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Trainers;
using System.Collections.Generic;
using static Microsoft.ML.TrainCatalogBase;
using static Microsoft.ML.DataOperationsCatalog;
using System.Diagnostics;

## Declare data-classes for input data and predictions

In [None]:
public interface IModelEntity {
    void PrintToConsole();
}

public class TransactionObservation : IModelEntity
{
    // Note we're not loading the 'Time' column, since que don't need it as a feature
    [LoadColumn(0)]
    public float Time;

    [LoadColumn(1)]
    public float V1;

    [LoadColumn(2)]
    public float V2;

    [LoadColumn(3)]
    public float V3;

    [LoadColumn(4)]
    public float V4;

    [LoadColumn(5)]
    public float V5;

    [LoadColumn(6)]
    public float V6;

    [LoadColumn(7)]
    public float V7;

    [LoadColumn(8)]
    public float V8;

    [LoadColumn(9)]
    public float V9;

    [LoadColumn(10)]
    public float V10;

    [LoadColumn(11)]
    public float V11;

    [LoadColumn(12)]
    public float V12;

    [LoadColumn(13)]
    public float V13;

    [LoadColumn(14)]
    public float V14;

    [LoadColumn(15)]
    public float V15;

    [LoadColumn(16)]
    public float V16;

    [LoadColumn(17)]
    public float V17;

    [LoadColumn(18)]
    public float V18;

    [LoadColumn(19)]
    public float V19;

    [LoadColumn(20)]
    public float V20;

    [LoadColumn(21)]
    public float V21;

    [LoadColumn(22)]
    public float V22;

    [LoadColumn(23)]
    public float V23;

    [LoadColumn(24)]
    public float V24;

    [LoadColumn(25)]
    public float V25;

    [LoadColumn(26)]
    public float V26;

    [LoadColumn(27)]
    public float V27;

    [LoadColumn(28)]
    public float V28;

    [LoadColumn(29)]
    public float Amount;

    [LoadColumn(30)]
    public bool Label;

    public void PrintToConsole() {
        Console.WriteLine($"Label: {Label}");
        Console.WriteLine($"Features: [V1] {V1} [V2] {V2} [V3] {V3} ... [V28] {V28} Amount: {Amount}");
    }       
}

public class TransactionFraudPrediction : IModelEntity
{
    public bool Label;
    public bool PredictedLabel;
    public float Score;
    public float Probability;

    public void PrintToConsole()
    {
        Console.WriteLine($"Predicted Label: {PredictedLabel}");
        Console.WriteLine($"Probability: {Probability}  ({Score})");
    }
}

public class Predictor
{
    private readonly string _modelfile;
    private readonly string _dasetFile;

    public Predictor(string modelfile, string dasetFile)
    {
        _modelfile = modelfile ?? throw new ArgumentNullException(nameof(modelfile));
        _dasetFile = dasetFile ?? throw new ArgumentNullException(nameof(dasetFile));
    }

    public void RunMultiplePredictions(int numberOfPredictions)
    {

        var mlContext = new MLContext();

        //Load data as input for predictions
        IDataView inputDataForPredictions = mlContext.Data.LoadFromTextFile<TransactionObservation>(_dasetFile, separatorChar: ',', hasHeader: true);

        Console.WriteLine($"Predictions from saved model:");

        ITransformer model = mlContext.Model.Load(_modelfile, out var inputSchema);

        var predictionEngine = mlContext.Model.CreatePredictionEngine<TransactionObservation, TransactionFraudPredictionWithContribution>(model);
        Console.WriteLine($"\n \n Test {numberOfPredictions} transactions, from the test datasource, that should be predicted as fraud (true):");

        mlContext.Data.CreateEnumerable<TransactionObservation>(inputDataForPredictions, reuseRowObject: false)
                    .Where(x => x.Label == true)
                    .Take(numberOfPredictions)
                    .Select(testData => testData)
                    .ToList()
                    .ForEach(testData =>
                                {
                                    Console.WriteLine($"--- Transaction ---");
                                    testData.PrintToConsole();
                                    predictionEngine.Predict(testData).PrintToConsole();
                                    Console.WriteLine($"-------------------");
                                });


        Console.WriteLine($"\n \n Test {numberOfPredictions} transactions, from the test datasource, that should NOT be predicted as fraud (false):");

        mlContext.Data.CreateEnumerable<TransactionObservation>(inputDataForPredictions, reuseRowObject: false)
                   .Where(x => x.Label == false)
                   .Take(numberOfPredictions)
                   .ToList()
                   .ForEach(testData =>
                               {
                                   Console.WriteLine($"--- Transaction ---");
                                   testData.PrintToConsole();
                                   predictionEngine.Predict(testData).PrintToConsole(model.GetOutputSchema(inputDataForPredictions.Schema));
                                   Console.WriteLine($"-------------------");
                               });
    }

    private class TransactionFraudPredictionWithContribution : TransactionFraudPrediction
    {
        public float[] FeatureContributions { get; set; }

        public void PrintToConsole(DataViewSchema dataview)
        {
            base.PrintToConsole();
            VBuffer<ReadOnlyMemory<char>> slots = default;
            dataview.GetColumnOrNull("Features").Value.GetSlotNames(ref slots);
            var featureNames = slots.DenseValues().ToArray();
            Console.WriteLine($"Feature Contributions: " +
                              $"[{featureNames[0]}] {FeatureContributions[0]} " +
                              $"[{featureNames[1]}] {FeatureContributions[1]} " +
                              $"[{featureNames[2]}] {FeatureContributions[2]} ... " +
                              $"[{featureNames[27]}] {FeatureContributions[27]} " +
                              $"[{featureNames[28]}] {FeatureContributions[28]}");
        }
    }
}

### Constants

In [None]:
//File paths
string assetsPath = @"./datasets/CreditCardFraudDetection";
string zipDataSet = Path.Combine(assetsPath, "input", "creditcardfraud-dataset.zip");
string fullDataSetFilePath = Path.Combine(assetsPath, "input", "creditcard.csv");
string trainDataSetFilePath = Path.Combine(assetsPath, "output", "trainData.csv"); 
string testDataSetFilePath = Path.Combine(assetsPath, "output", "testData.csv");
string modelFilePath = Path.Combine(assetsPath, "output", "fastTree.zip");
string trainOutput = "./datasets/CreditCardFraudDetection/output";

### ConsoleHelper

In [None]:
public static class ConsoleHelper
{
    public static void PrintPrediction(string prediction)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"Predicted : {prediction}");
        Console.WriteLine($"*************************************************");
    }

    public static void PrintRegressionPredictionVersusObserved(string predictionCount, string observedCount)
    {
        Console.WriteLine($"-------------------------------------------------");
        Console.WriteLine($"Predicted : {predictionCount}");
        Console.WriteLine($"Actual:     {observedCount}");
        Console.WriteLine($"-------------------------------------------------");
    }

    public static void PrintRegressionMetrics(string name, RegressionMetrics metrics)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"*       Metrics for {name} regression model      ");
        Console.WriteLine($"*------------------------------------------------");
        Console.WriteLine($"*       LossFn:        {metrics.LossFunction:0.##}");
        Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
        Console.WriteLine($"*       Absolute loss: {metrics.MeanAbsoluteError:#.##}");
        Console.WriteLine($"*       Squared loss:  {metrics.MeanSquaredError:#.##}");
        Console.WriteLine($"*       RMS loss:      {metrics.RootMeanSquaredError:#.##}");
        Console.WriteLine($"*************************************************");
    }

    public static void PrintBinaryClassificationMetrics(string name, CalibratedBinaryClassificationMetrics metrics)
    {
        Console.WriteLine($"************************************************************");
        Console.WriteLine($"*       Metrics for {name} binary classification model      ");
        Console.WriteLine($"*-----------------------------------------------------------");
        Console.WriteLine($"*       Accuracy: {metrics.Accuracy:P2}");
        Console.WriteLine($"*       Area Under Curve:      {metrics.AreaUnderRocCurve:P2}");
        Console.WriteLine($"*       Area under Precision recall Curve:  {metrics.AreaUnderPrecisionRecallCurve:P2}");
        Console.WriteLine($"*       F1Score:  {metrics.F1Score:P2}");
        Console.WriteLine($"*       LogLoss:  {metrics.LogLoss:#.##}");
        Console.WriteLine($"*       LogLossReduction:  {metrics.LogLossReduction:#.##}");
        Console.WriteLine($"*       PositivePrecision:  {metrics.PositivePrecision:#.##}");
        Console.WriteLine($"*       PositiveRecall:  {metrics.PositiveRecall:#.##}");
        Console.WriteLine($"*       NegativePrecision:  {metrics.NegativePrecision:#.##}");
        Console.WriteLine($"*       NegativeRecall:  {metrics.NegativeRecall:P2}");
        Console.WriteLine($"************************************************************");
    }

    public static void PrintMultiClassClassificationMetrics(string name, MulticlassClassificationMetrics metrics)
    {
        Console.WriteLine($"************************************************************");
        Console.WriteLine($"*    Metrics for {name} multi-class classification model   ");
        Console.WriteLine($"*-----------------------------------------------------------");
        Console.WriteLine($"    AccuracyMacro = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
        Console.WriteLine($"    AccuracyMicro = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
        Console.WriteLine($"    LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
        Console.WriteLine($"************************************************************");
    }
    
    public static void PrintRegressionFoldsAverageMetrics(string algorithmName, IReadOnlyList<CrossValidationResult<RegressionMetrics>> crossValidationResults)
    {
        var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
        var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
        var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);
        var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
        var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);

        Console.WriteLine($"*************************************************************************************************************");
        Console.WriteLine($"*       Metrics for {algorithmName} Regression model      ");
        Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
        Console.WriteLine($"*       Average L1 Loss:    {L1.Average():0.###} ");
        Console.WriteLine($"*       Average L2 Loss:    {L2.Average():0.###}  ");
        Console.WriteLine($"*       Average RMS:          {RMS.Average():0.###}  ");
        Console.WriteLine($"*       Average Loss Function: {lossFunction.Average():0.###}  ");
        Console.WriteLine($"*       Average R-squared: {R2.Average():0.###}  ");
        Console.WriteLine($"*************************************************************************************************************");
    }
    
    public static void PrintMulticlassClassificationFoldsAverageMetrics(
                                     string algorithmName,
                                   IReadOnlyList<CrossValidationResult<MulticlassClassificationMetrics>> crossValResults
                                                                       )
    {
        var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);

        var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
        var microAccuracyAverage = microAccuracyValues.Average();
        var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
        var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);

        var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
        var macroAccuracyAverage = macroAccuracyValues.Average();
        var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
        var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);

        var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
        var logLossAverage = logLossValues.Average();
        var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
        var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);

        var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
        var logLossReductionAverage = logLossReductionValues.Average();
        var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
        var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);

        Console.WriteLine($"*************************************************************************************************************");
        Console.WriteLine($"*       Metrics for {algorithmName} Multi-class Classification model      ");
        Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
        Console.WriteLine($"*       Average MicroAccuracy:    {microAccuracyAverage:0.###}  - Standard deviation: ({microAccuraciesStdDeviation:#.###})  - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average MacroAccuracy:    {macroAccuracyAverage:0.###}  - Standard deviation: ({macroAccuraciesStdDeviation:#.###})  - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average LogLoss:          {logLossAverage:#.###}  - Standard deviation: ({logLossStdDeviation:#.###})  - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average LogLossReduction: {logLossReductionAverage:#.###}  - Standard deviation: ({logLossReductionStdDeviation:#.###})  - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
        Console.WriteLine($"*************************************************************************************************************");
    }    

    public static double CalculateStandardDeviation (IEnumerable<double> values)
    {
        double average = values.Average();
        double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
        double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count()-1));
        return standardDeviation;
    }

    public static double CalculateConfidenceInterval95(IEnumerable<double> values)
    {
        double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count()-1));
        return confidenceInterval95;
    }

    public static void PrintClusteringMetrics(string name, ClusteringMetrics metrics)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"*       Metrics for {name} clustering model      ");
        Console.WriteLine($"*------------------------------------------------");
        Console.WriteLine($"*       Average Distance: {metrics.AverageDistance}");
        Console.WriteLine($"*       Davies Bouldin Index is: {metrics.DaviesBouldinIndex}");
        Console.WriteLine($"*************************************************");
    }    
    
    public static void PeekDataViewInConsole(MLContext mlContext, IDataView dataView, IEstimator<ITransformer> pipeline, int numberOfRows = 4)
    {
        string msg = string.Format("Peek data in DataView: Showing {0} rows with the columns", numberOfRows.ToString());
        ConsoleWriteHeader(msg);

        //https://github.com/dotnet/machinelearning/blob/master/docs/code/MlNetCookBook.md#how-do-i-look-at-the-intermediate-data
        var transformer = pipeline.Fit(dataView);
        var transformedData = transformer.Transform(dataView);

        // 'transformedData' is a 'promise' of data, lazy-loading. call Preview  
        //and iterate through the returned collection from preview.

        var preViewTransformedData = transformedData.Preview(maxRows: numberOfRows);

        foreach (var row in preViewTransformedData.RowView)
        {
            var ColumnCollection = row.Values;
            string lineToPrint = "Row--> ";
            foreach (KeyValuePair<string, object> column in ColumnCollection)
            {
                lineToPrint += $"| {column.Key}:{column.Value}";
            }
            Console.WriteLine(lineToPrint + "\n");
        }
    }
    
    public static void PeekVectorColumnDataInConsole(MLContext mlContext, string columnName, IDataView dataView, IEstimator<ITransformer> pipeline, int numberOfRows = 4)
    {
        string msg = string.Format("Peek data in DataView: : Show {0} rows with just the '{1}' column", numberOfRows, columnName );
        ConsoleWriteHeader(msg);

        var transformer = pipeline.Fit(dataView);
        var transformedData = transformer.Transform(dataView);

        // Extract the 'Features' column.
        var someColumnData = transformedData.GetColumn<float[]>(columnName)
                                                    .Take(numberOfRows).ToList();

        // print to console the peeked rows

        int currentRow = 0;
        someColumnData.ForEach(row => {
                                        currentRow++;
                                        String concatColumn = String.Empty;
                                        foreach (float f in row)
                                        {
                                            concatColumn += f.ToString();                                              
                                        }

                                        Console.WriteLine();
                                        string rowMsg = string.Format("**** Row {0} with '{1}' field value ****", currentRow, columnName);
                                        Console.WriteLine(rowMsg);
                                        Console.WriteLine(concatColumn);
                                        Console.WriteLine();
                                      });
    }
    
    public static void ConsoleWriteHeader(params string[] lines)
    {
        var defaultColor = Console.ForegroundColor;
        Console.ForegroundColor = ConsoleColor.Yellow;
        Console.WriteLine(" ");
        foreach (var line in lines)
        {
            Console.WriteLine(line);
        }
        var maxLength = lines.Select(x => x.Length).Max();
        Console.WriteLine(new string('#', maxLength));
        Console.ForegroundColor = defaultColor;
    }

    public static void ConsoleWriterSection(params string[] lines)
    {
        var defaultColor = Console.ForegroundColor;
        Console.ForegroundColor = ConsoleColor.Blue;
        Console.WriteLine(" ");
        foreach (var line in lines)
        {
            Console.WriteLine(line);
        }
        var maxLength = lines.Select(x => x.Length).Max();
        Console.WriteLine(new string('-', maxLength));
        Console.ForegroundColor = defaultColor;
    }
    
}

## Methods 

In [None]:
public static void PrepDatasets(MLContext mlContext, string fullDataSetFilePath, string trainDataSetFilePath, string testDataSetFilePath)
{
    //Only prep-datasets if train and test datasets don't exist yet

    if (!File.Exists(trainDataSetFilePath) &&
        !File.Exists(testDataSetFilePath))
    {
        Console.WriteLine("===== Preparing train/test datasets =====");

        //Load the original single dataset
        IDataView originalFullData = mlContext.Data.LoadFromTextFile<TransactionObservation>(fullDataSetFilePath, separatorChar: ',', hasHeader: true);
                     
        // Split the data 80:20 into train and test sets, train and evaluate.
        TrainTestData trainTestData = mlContext.Data.TrainTestSplit(originalFullData, testFraction: 0.2, seed: 1);
        IDataView trainData = trainTestData.TrainSet;
        IDataView testData = trainTestData.TestSet;

        //Inspect TestDataView to make sure there are true and false observations in test dataset, after spliting 
        InspectData(mlContext, testData, 4);

        // save train split
        using (var fileStream = File.Create(trainDataSetFilePath))
        {
            mlContext.Data.SaveAsText(trainData, fileStream, separatorChar: ',', headerRow: true, schema: true);
        }

        // save test split 
        using (var fileStream = File.Create(testDataSetFilePath))
        {
            mlContext.Data.SaveAsText(testData, fileStream, separatorChar: ',', headerRow: true, schema: true);
        }
    }
}

public static (ITransformer model, string trainerName) TrainModel(MLContext mlContext, IDataView trainDataView)
{
    //Get all the feature column names (All except the Label and the IdPreservationColumn)
    string[] featureColumnNames = trainDataView.Schema.AsQueryable()
        .Select(column => column.Name)                               // Get alll the column names
        .Where(name => name != nameof(TransactionObservation.Label)) // Do not include the Label column
        .Where(name => name != "IdPreservationColumn")               // Do not include the IdPreservationColumn/StratificationColumn
        .Where(name => name != "Time")                               // Do not include the Time column. Not needed as feature column
        .ToArray();

    // Create the data process pipeline
    IEstimator<ITransformer> dataProcessPipeline = mlContext.Transforms.Concatenate("Features", featureColumnNames)
                                    .Append(mlContext.Transforms.DropColumns(new string[] { "Time" }))
                                    .Append(mlContext.Transforms.NormalizeMeanVariance(inputColumnName: "Features",
                                                                         outputColumnName: "FeaturesNormalizedByMeanVar"));

    // (OPTIONAL) Peek data (such as 2 records) in training DataView after applying the ProcessPipeline's transformations into "Features" 
    ConsoleHelper.PeekDataViewInConsole(mlContext, trainDataView, dataProcessPipeline, 2);
    ConsoleHelper.PeekVectorColumnDataInConsole(mlContext, "Features", trainDataView, dataProcessPipeline, 1);
  
    // Set the training algorithm
    var trainer = mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: nameof(TransactionObservation.Label),
                                                                                        featureColumnName: "FeaturesNormalizedByMeanVar",
                                                                                        numberOfLeaves: 20,
                                                                                        numberOfTrees: 100,
                                                                                        minimumExampleCountPerLeaf: 10,
                                                                                        learningRate: 0.2);

    var trainingPipeline = dataProcessPipeline.Append(trainer);

    ConsoleHelper.ConsoleWriteHeader("=============== Training model ===============");

    var model = trainingPipeline.Fit(trainDataView);

     ConsoleHelper.ConsoleWriteHeader("=============== End of training process ===============");

    // Append feature contribution calculator in the pipeline. This will be used
    // at prediction time for explainability. 
    var fccPipeline = model.Append(mlContext.Transforms
        .CalculateFeatureContribution(model.LastTransformer)
        .Fit(dataProcessPipeline.Fit(trainDataView).Transform(trainDataView)));

    return (fccPipeline, fccPipeline.ToString());

}

private static void EvaluateModel(MLContext mlContext, ITransformer model, IDataView testDataView, string trainerName)
{
    // Evaluate the model and show accuracy stats
    Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
    var predictions = model.Transform(testDataView);

    var metrics = mlContext.BinaryClassification.Evaluate(data: predictions, 
                                                          labelColumnName: nameof(TransactionObservation.Label), 
                                                          scoreColumnName: "Score");

    ConsoleHelper.PrintBinaryClassificationMetrics(trainerName, metrics);
}

public static void InspectData(MLContext mlContext, IDataView data, int records)
{
    //We want to make sure we have True and False observations

    Console.WriteLine("Show 4 fraud transactions (true)");
    ShowObservationsFilteredByLabel(mlContext, data, label: true, count: records);

    Console.WriteLine("Show 4 NOT-fraud transactions (false)");
    ShowObservationsFilteredByLabel(mlContext, data, label: false, count: records);
}

public static void ShowObservationsFilteredByLabel(MLContext mlContext, IDataView dataView, bool label = true, int count = 2)
{
    // Convert to an enumerable of user-defined type. 
    var data = mlContext.Data.CreateEnumerable<TransactionObservation>(dataView, reuseRowObject: false)
                                    .Where(x => x.Label == label)
                                    // Take a couple values as an array.
                                    .Take(count)
                                    .ToList();

    // print to console
    data.ForEach(row => { row.PrintToConsole(); });
}

public static void UnZipDataSet(string zipDataSet, string destinationFile)
{
    if (!File.Exists(destinationFile))
    {
        var destinationDirectory = Path.GetDirectoryName(destinationFile);
        ZipFile.ExtractToDirectory(zipDataSet, $"{destinationDirectory}");
    }
}

private static void SaveModel(MLContext mlContext, ITransformer model, string modelFilePath, DataViewSchema trainingDataSchema)
{
    Console.WriteLine(modelFilePath);
    mlContext.Model.Save(model,trainingDataSchema, modelFilePath);

    Console.WriteLine("Saved model to " + modelFilePath);
}

public static void CopyModelAndDatasetFromTrainingProject(string trainOutput, string assetsPath)
{
     if (!File.Exists(Path.Combine(trainOutput, "testData.csv")) ||
         !File.Exists(Path.Combine(trainOutput, "fastTree.zip")))
     {
         Console.WriteLine("***** YOU NEED TO RUN THE TRAINING PROJECT IN THE FIRST PLACE *****");
         Console.WriteLine("=============== Continue ===============");
         Environment.Exit(0);
     }

     // copy files from train output
     foreach (var file in Directory.GetFiles(trainOutput))
     {

         //Console.WriteLine(Path.Combine(Path.Combine(Environment.CurrentDirectory, "datasets\\CreditCardFraudDetection\\input"), Path.GetFileName(file)));
                       
         var fileDestination = Path.Combine(Path.Combine(Environment.CurrentDirectory, "datasets\\CreditCardFraudDetection\\input"), Path.GetFileName(file));
         if (File.Exists(fileDestination))
         {
             //LocalConsoleHelper.DeleteAssets(fileDestination);
             File.Delete(fileDestination);
         }

         File.Copy(file, Path.Combine(Path.Combine(Environment.CurrentDirectory, "datasets\\CreditCardFraudDetection\\input"), Path.GetFileName(file)));
     }
}


### Trainer

In [None]:
 // Unzip the original dataset as it is too large for GitHub repo if not zipped
UnZipDataSet(zipDataSet, fullDataSetFilePath);

// Create a common ML.NET context.
// Seed set to any number so you have a deterministic environment for repeateable results
MLContext mlContext = new MLContext(seed: 1);

// Prepare data and create Train/Test split datasets
PrepDatasets(mlContext, fullDataSetFilePath, trainDataSetFilePath, testDataSetFilePath);

// Load Datasets
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<TransactionObservation>(trainDataSetFilePath, separatorChar: ',', hasHeader: true);
IDataView testDataView = mlContext.Data.LoadFromTextFile<TransactionObservation>(testDataSetFilePath, separatorChar: ',', hasHeader: true);

// Train Model
(ITransformer model, string trainerName) = TrainModel(mlContext, trainingDataView);

// Evaluate quality of Model
EvaluateModel(mlContext, model, testDataView, trainerName);

// Save model
SaveModel(mlContext, model, modelFilePath, trainingDataView.Schema);

Console.WriteLine("=============== Continue ===============");

## Predictor

In [None]:
CopyModelAndDatasetFromTrainingProject(trainOutput, assetsPath);

var inputDatasetForPredictions = Path.Combine(assetsPath,"input", "testData.csv");
var modelFilePath = Path.Combine(assetsPath, "input", "fastTree.zip");

// Create model predictor to perform a few predictions
var modelPredictor = new Predictor(modelFilePath,inputDatasetForPredictions);

modelPredictor.RunMultiplePredictions(numberOfPredictions:5);

Console.WriteLine("=============== The End ===============");