# ML.Net - Samples - Deep Learning Image Classification

# Spam Detection for Text Messages

| ML.NET version | API type          | Status                        | App Type    | Data type | Scenario            | ML Task                   | Algorithms                  |
|----------------|-------------------|-------------------------------|-------------|-----------|---------------------|---------------------------|-----------------------------|
| v1.4           | Dynamic API | Might need to update project structure to match template | Jupyter-Notebbok | .tsv files | Spam detection | Two-class classification | Averaged Perceptron (linear learner) |

In this sample, you'll see how to use [ML.NET](https://www.microsoft.com/net/learn/apps/machine-learning-and-ai/ml-dotnet) to predict whether a text message is spam. In the world of machine learning, this type of prediction is known as **binary classification**.

## Problem

Our goal here is to predict whether a text message is spam (an irrelevant/unwanted message). We will use the [SMS Spam Collection Data Set](https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection) from UCI, which contains close to 6000 messages that have been classified as being "spam" or "ham" (not spam). We will use this dataset to train a model that can take in new message and predict whether they are spam or not.

This is an example of binary classification, as we are classifying the text messages into one of two categories.


## Solution

To solve this problem, first we will build an estimator to define the ML pipeline we want to use. Then we will train this estimator on existing data, evaluate how good it is, and lastly we'll consume the model to predict whether a few examples messages are spam.

![Build -> Train -> Evaluate -> Consume](../shared_content/modelpipeline.png)

### 1. Build Model

To build the model we will:

* Define how to read the spam dataset that will be downloaded from https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection. 

* Apply several data transformations:

    * Convert the label ("spam" or "ham") to a boolean ("true" represents spam) so we can use it with a binary classifier. 
    * Featurize the text message into a numeric vector so a machine learning trainer can use it

* Add a trainer (such as `StochasticDualCoordinateAscent`).

The initial code is similar to the following:


In [2]:
// ML.NET Nuget packages installation
#r "nuget:Microsoft.ML" 
#r "nuget:Microsoft.Extensions.ML" 
#r "nuget:Microsoft.ML.Vision" 
#r "nuget:SharpZipLib" 
#r "nuget:SciSharp.TensorFlow.Redist" 
//#r "nuget:SciSharp.TensorFlow.Redist-Windows-GPU" 

Installing package SciSharp.TensorFlow.Redist...

Installed package SharpZipLib version 1.2.0

Installed package Microsoft.Extensions.ML version 1.5.0

Installed package Microsoft.ML version 1.5.0

Installed package Microsoft.ML.Vision version 1.5.0

## Using C# Class

In [3]:
using System;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Net;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms;
using Microsoft.ML.Vision;
using System.Collections.Generic;
using static Microsoft.ML.TrainCatalogBase;
using static Microsoft.ML.DataOperationsCatalog;
using System.Diagnostics;
using System.Reflection;
using ICSharpCode.SharpZipLib.Core;
using ICSharpCode.SharpZipLib.GZip;
using ICSharpCode.SharpZipLib.Tar;
using System.Threading;
using System.Threading.Tasks;
using static Microsoft.ML.Transforms.ValueToKeyMappingEstimator;

## Declare data-classes for input data and predictions

In [4]:
public class ImagePrediction
{
    [ColumnName("Score")]
    public float[] Score;

    [ColumnName("PredictedLabel")]
    public string PredictedLabel;
}

public class InMemoryImageData
{
    public InMemoryImageData(byte[] image, string label, string imageFileName)
    {
        Image = image;
        Label = label;
        ImageFileName = imageFileName;
    }

    public readonly byte[] Image;

    public readonly string Label;

    public readonly string ImageFileName;
}

public class ImageData
{
    public ImageData(string imagePath, string label)
    {
        ImagePath = imagePath;
        Label = label;
    }

    public readonly string ImagePath;

    public readonly string Label;
}

### FileUtils

In [5]:
public class FileUtils
{
    public static IEnumerable<(string imagePath, string label)> LoadImagesFromDirectory(
        string folder,
        bool useFolderNameasLabel)
    {
        var imagesPath = Directory
            .GetFiles(folder, "*", searchOption: SearchOption.AllDirectories)
            .Where(x => Path.GetExtension(x) == ".jpg" || Path.GetExtension(x) == ".png");

        return useFolderNameasLabel
            ? imagesPath.Select(imagePath => (imagePath, Directory.GetParent(imagePath).Name))
            : imagesPath.Select(imagePath =>
            {
                var label = Path.GetFileName(imagePath);
                for (var index = 0; index < label.Length; index++)
                {
                    if (!char.IsLetter(label[index]))
                    {
                        label = label.Substring(0, index);
                        break;
                    }
                }
                return (imagePath, label);
            });
    }

    public static IEnumerable<InMemoryImageData> LoadInMemoryImagesFromDirectory(
        string folder,
        bool useFolderNameAsLabel = true)
        => LoadImagesFromDirectory(folder, useFolderNameAsLabel)
            .Select(x => new InMemoryImageData(
                image: File.ReadAllBytes(x.imagePath),
                label: x.label,
                imageFileName: Path.GetFileName(x.imagePath)));

}

### Compress

In [6]:
public class Compress
{
    public static void ExtractGZip(string gzipFileName, string targetDir)
    {
        // Use a 4K buffer. Any larger is a waste.    
        byte[] dataBuffer = new byte[4096];

        using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read))
        {
            using (GZipInputStream gzipStream = new GZipInputStream(fs))
            {
                // Change this to your needs
                string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName));

                using (FileStream fsOut = File.Create(fnOut))
                {
                    StreamUtils.Copy(gzipStream, fsOut, dataBuffer);
                }
            }
        }
    }

    public static void UnZip(String gzArchiveName, String destFolder)
    {
        var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";
        if (File.Exists(Path.Combine(destFolder, flag))) return;

        Console.WriteLine($"Extracting.");
        var task = Task.Run(() =>
        {
            ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
        });

        while (!task.IsCompleted)
        {
            Thread.Sleep(200);
            Console.Write(".");
        }

        File.Create(Path.Combine(destFolder, flag));
        Console.WriteLine("");
        Console.WriteLine("Extracting is completed.");
    }

    public static void ExtractTGZ(String gzArchiveName, String destFolder)
    {
        var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";
        if (File.Exists(Path.Combine(destFolder, flag))) return;

        Console.WriteLine($"Extracting.");
        var task = Task.Run(() =>
        {
            using (var inStream = File.OpenRead(gzArchiveName))
            {
                using (var gzipStream = new GZipInputStream(inStream))
                {
                    using (TarArchive tarArchive = TarArchive.CreateInputTarArchive(gzipStream))
                        tarArchive.ExtractContents(destFolder);
                }
            }
        });

        while (!task.IsCompleted)
        {
            Thread.Sleep(200);
            Console.Write(".");
        }

        File.Create(Path.Combine(destFolder, flag));
        Console.WriteLine("");
        Console.WriteLine("Extracting is completed.");
    }
}

### Web

In [7]:
public class Web
{
    public static bool Download(string url, string destDir, string destFileName)
    {
        if (destFileName == null)
            destFileName = url.Split(Path.DirectorySeparatorChar).Last();

        Directory.CreateDirectory(destDir);

        string relativeFilePath = Path.Combine(destDir, destFileName);

        if (File.Exists(relativeFilePath))
        {
            Console.WriteLine($"{relativeFilePath} already exists.");
            return false;
        }

        var wc = new WebClient();
        Console.WriteLine($"Downloading {relativeFilePath}");
        var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
        while (!download.IsCompleted)
        {
            Thread.Sleep(1000);
            Console.Write(".");
        }
        Console.WriteLine("");
        Console.WriteLine($"Downloaded {relativeFilePath}");

        return true;
    }
}

### ConsoleHelper

In [8]:
public static class ConsoleHelper
{
    public static void PrintPrediction(string prediction)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"Predicted : {prediction}");
        Console.WriteLine($"*************************************************");
    }

    public static void PrintRegressionPredictionVersusObserved(string predictionCount, string observedCount)
    {
        Console.WriteLine($"-------------------------------------------------");
        Console.WriteLine($"Predicted : {predictionCount}");
        Console.WriteLine($"Actual:     {observedCount}");
        Console.WriteLine($"-------------------------------------------------");
    }

    public static void PrintRegressionMetrics(string name, RegressionMetrics metrics)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"*       Metrics for {name} regression model      ");
        Console.WriteLine($"*------------------------------------------------");
        Console.WriteLine($"*       LossFn:        {metrics.LossFunction:0.##}");
        Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
        Console.WriteLine($"*       Absolute loss: {metrics.MeanAbsoluteError:#.##}");
        Console.WriteLine($"*       Squared loss:  {metrics.MeanSquaredError:#.##}");
        Console.WriteLine($"*       RMS loss:      {metrics.RootMeanSquaredError:#.##}");
        Console.WriteLine($"*************************************************");
    }

    public static void PrintBinaryClassificationMetrics(string name, CalibratedBinaryClassificationMetrics metrics)
    {
        Console.WriteLine($"************************************************************");
        Console.WriteLine($"*       Metrics for {name} binary classification model      ");
        Console.WriteLine($"*-----------------------------------------------------------");
        Console.WriteLine($"*       Accuracy: {metrics.Accuracy:P2}");
        Console.WriteLine($"*       Area Under Curve:      {metrics.AreaUnderRocCurve:P2}");
        Console.WriteLine($"*       Area under Precision recall Curve:  {metrics.AreaUnderPrecisionRecallCurve:P2}");
        Console.WriteLine($"*       F1Score:  {metrics.F1Score:P2}");
        Console.WriteLine($"*       LogLoss:  {metrics.LogLoss:#.##}");
        Console.WriteLine($"*       LogLossReduction:  {metrics.LogLossReduction:#.##}");
        Console.WriteLine($"*       PositivePrecision:  {metrics.PositivePrecision:#.##}");
        Console.WriteLine($"*       PositiveRecall:  {metrics.PositiveRecall:#.##}");
        Console.WriteLine($"*       NegativePrecision:  {metrics.NegativePrecision:#.##}");
        Console.WriteLine($"*       NegativeRecall:  {metrics.NegativeRecall:P2}");
        Console.WriteLine($"************************************************************");
    }

    public static void PrintMultiClassClassificationMetrics(string name, MulticlassClassificationMetrics metrics)
    {
        Console.WriteLine($"************************************************************");
        Console.WriteLine($"*    Metrics for {name} multi-class classification model   ");
        Console.WriteLine($"*-----------------------------------------------------------");
        Console.WriteLine($"    AccuracyMacro = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
        Console.WriteLine($"    AccuracyMicro = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
        Console.WriteLine($"    LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
        Console.WriteLine($"    LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
        Console.WriteLine($"************************************************************");
    }
    
    public static void PrintRegressionFoldsAverageMetrics(string algorithmName, IReadOnlyList<CrossValidationResult<RegressionMetrics>> crossValidationResults)
    {
        var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
        var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
        var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);
        var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
        var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);

        Console.WriteLine($"*************************************************************************************************************");
        Console.WriteLine($"*       Metrics for {algorithmName} Regression model      ");
        Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
        Console.WriteLine($"*       Average L1 Loss:    {L1.Average():0.###} ");
        Console.WriteLine($"*       Average L2 Loss:    {L2.Average():0.###}  ");
        Console.WriteLine($"*       Average RMS:          {RMS.Average():0.###}  ");
        Console.WriteLine($"*       Average Loss Function: {lossFunction.Average():0.###}  ");
        Console.WriteLine($"*       Average R-squared: {R2.Average():0.###}  ");
        Console.WriteLine($"*************************************************************************************************************");
    }
    
    public static void PrintMulticlassClassificationFoldsAverageMetrics(
                                     string algorithmName,
                                   IReadOnlyList<CrossValidationResult<MulticlassClassificationMetrics>> crossValResults
                                                                       )
    {
        var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);

        var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
        var microAccuracyAverage = microAccuracyValues.Average();
        var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
        var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);

        var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
        var macroAccuracyAverage = macroAccuracyValues.Average();
        var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
        var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);

        var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
        var logLossAverage = logLossValues.Average();
        var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
        var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);

        var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
        var logLossReductionAverage = logLossReductionValues.Average();
        var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
        var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);

        Console.WriteLine($"*************************************************************************************************************");
        Console.WriteLine($"*       Metrics for {algorithmName} Multi-class Classification model      ");
        Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
        Console.WriteLine($"*       Average MicroAccuracy:    {microAccuracyAverage:0.###}  - Standard deviation: ({microAccuraciesStdDeviation:#.###})  - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average MacroAccuracy:    {macroAccuracyAverage:0.###}  - Standard deviation: ({macroAccuraciesStdDeviation:#.###})  - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average LogLoss:          {logLossAverage:#.###}  - Standard deviation: ({logLossStdDeviation:#.###})  - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
        Console.WriteLine($"*       Average LogLossReduction: {logLossReductionAverage:#.###}  - Standard deviation: ({logLossReductionStdDeviation:#.###})  - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
        Console.WriteLine($"*************************************************************************************************************");
    }    

    public static double CalculateStandardDeviation (IEnumerable<double> values)
    {
        double average = values.Average();
        double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
        double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count()-1));
        return standardDeviation;
    }

    public static double CalculateConfidenceInterval95(IEnumerable<double> values)
    {
        double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count()-1));
        return confidenceInterval95;
    }

    public static void PrintClusteringMetrics(string name, ClusteringMetrics metrics)
    {
        Console.WriteLine($"*************************************************");
        Console.WriteLine($"*       Metrics for {name} clustering model      ");
        Console.WriteLine($"*------------------------------------------------");
        Console.WriteLine($"*       Average Distance: {metrics.AverageDistance}");
        Console.WriteLine($"*       Davies Bouldin Index is: {metrics.DaviesBouldinIndex}");
        Console.WriteLine($"*************************************************");
    }    
}

### Methods

In [9]:
private static void EvaluateModel(MLContext mlContext, IDataView testDataset, ITransformer trainedModel)
{
    Console.WriteLine("Making predictions in bulk for evaluating model's quality...");

    // Measuring time
    var watch = Stopwatch.StartNew();

    var predictionsDataView = trainedModel.Transform(testDataset);

    var metrics = mlContext.MulticlassClassification.Evaluate(predictionsDataView, labelColumnName:"LabelAsKey",
    predictedLabelColumnName: "PredictedLabel");
    ConsoleHelper.PrintMultiClassClassificationMetrics("TensorFlow DNN Transfer Learning", metrics);

    watch.Stop();
    var elapsed2Ms = watch.ElapsedMilliseconds;

    Console.WriteLine($"Predicting and Evaluation took: {elapsed2Ms / 1000} seconds");
}

private static void TrySinglePrediction(string imagesFolderPathForPredictions, MLContext mlContext, ITransformer 
trainedModel)
{
    // Create prediction function to try one prediction
    var predictionEngine = mlContext.Model
        .CreatePredictionEngine<InMemoryImageData, ImagePrediction>(trainedModel);

    var testImages = FileUtils.LoadInMemoryImagesFromDirectory(
        imagesFolderPathForPredictions, false);

    var imageToPredict = testImages.First();

    var prediction = predictionEngine.Predict(imageToPredict);

    Console.WriteLine(
        $"Image Filename : [{imageToPredict.ImageFileName}], " +
        $"Scores : [{string.Join(",", prediction.Score)}], " +
        $"Predicted Label : {prediction.PredictedLabel}");
}


public static IEnumerable<ImageData> LoadImagesFromDirectory(
    string folder,
    bool useFolderNameAsLabel = true)
    => FileUtils.LoadImagesFromDirectory(folder, useFolderNameAsLabel)
        .Select(x => new ImageData(x.imagePath, x.label));

public static string DownloadImageSet(string imagesDownloadFolder)
{
    // get a set of images to teach the network about the new classes

    //SINGLE SMALL FLOWERS IMAGESET (200 files)
    const string fileName = "flower_photos_small_set.zip";
    var url = $"https://mlnetfilestorage.file.core.windows.net/imagesets/flower_images/flower_photos_small_set.zip?st=2019-08-07T21%3A27%3A44Z&se=2030-08-08T21%3A27%3A00Z&sp=rl&sv=2018-03-28&sr=f&sig=SZ0UBX47pXD0F1rmrOM%2BfcwbPVob8hlgFtIlN89micM%3D";
    Web.Download(url, imagesDownloadFolder, fileName);
    Compress.UnZip(Path.Join(imagesDownloadFolder, fileName), imagesDownloadFolder);

    //SINGLE FULL FLOWERS IMAGESET (3,600 files)
    //string fileName = "flower_photos.tgz";
    //string url = $"http://download.tensorflow.org/example_images/{fileName}";
    //Web.Download(url, imagesDownloadFolder, fileName);
    //Compress.ExtractTGZ(Path.Join(imagesDownloadFolder, fileName), imagesDownloadFolder);
    
    return Path.GetFileNameWithoutExtension(fileName);
}

public static void ConsoleWriteImagePrediction(string ImagePath, string Label, string PredictedLabel, float 
Probability)
{
    var defaultForeground = Console.ForegroundColor;
    var labelColor = ConsoleColor.Magenta;
    var probColor = ConsoleColor.Blue;

    Console.Write("Image File: ");
    Console.ForegroundColor = labelColor;
    Console.Write($"{Path.GetFileName(ImagePath)}");
    Console.ForegroundColor = defaultForeground;
    Console.Write(" original labeled as ");
    Console.ForegroundColor = labelColor;
    Console.Write(Label);
    Console.ForegroundColor = defaultForeground;
    Console.Write(" predicted as ");
    Console.ForegroundColor = labelColor;
    Console.Write(PredictedLabel);
    Console.ForegroundColor = defaultForeground;
    Console.Write(" with score ");
    Console.ForegroundColor = probColor;
    Console.Write(Probability);
    Console.ForegroundColor = defaultForeground;
    Console.WriteLine("");
}

private static void FilterMLContextLog(object sender, LoggingEventArgs e)
{
    if (e.Message.StartsWith("[Source=ImageClassificationTrainer;"))
    {
        Console.WriteLine(e.Message);
    }
}

### Constants

In [10]:
string outputMlNetModelFilePath = Path.Combine(@".\Datasets\Image_Classification", "outputs", "imageClassifier.zip");
//Console.WriteLine(outputMlNetModelFilePath);
string imagesFolderPathForPredictions = Path.Combine(@".\Datasets\Image_Classification", "inputs", "images-for-predictions", "FlowersForPredictions");
//Console.WriteLine(imagesFolderPathForPredictions);
string imagesDownloadFolderPath = Path.Combine(@".\Datasets\Image_Classification", "inputs", "images");
//Console.WriteLine(imagesDownloadFolderPath);

## Train

In [11]:
// 1. Download the image set and unzip
string finalImagesFolderName = DownloadImageSet(imagesDownloadFolderPath);
string fullImagesetFolderPath = Path.Combine(imagesDownloadFolderPath, finalImagesFolderName);

var mlContext = new MLContext(seed: 1);

// Specify MLContext Filter to only show feedback log/traces about ImageClassification
// This is not needed for feedback output if using the explicit MetricsCallback parameter
mlContext.Log += FilterMLContextLog;           

// 2. Load the initial full image-set into an IDataView and shuffle so it'll be better balanced
IEnumerable<ImageData> images = LoadImagesFromDirectory(folder: fullImagesetFolderPath, useFolderNameAsLabel: true);
IDataView fullImagesDataset = mlContext.Data.LoadFromEnumerable(images);
IDataView shuffledFullImageFilePathsDataset = mlContext.Data.ShuffleRows(fullImagesDataset);

// 3. Load Images with in-memory type within the IDataView and Transform Labels to Keys (Categorical)
IDataView shuffledFullImagesDataset = mlContext.Transforms.Conversion.
        MapValueToKey(outputColumnName: "LabelAsKey", inputColumnName: "Label", keyOrdinality: KeyOrdinality.ByValue)
    .Append(mlContext.Transforms.LoadRawImageBytes(
                                    outputColumnName: "Image",
                                    imageFolder: fullImagesetFolderPath,
                                    inputColumnName: "ImagePath"))
    .Fit(shuffledFullImageFilePathsDataset)
    .Transform(shuffledFullImageFilePathsDataset);

// 4. Split the data 80:20 into train and test sets, train and evaluate.
var trainTestData = mlContext.Data.TrainTestSplit(shuffledFullImagesDataset, testFraction: 0.2);
IDataView trainDataView = trainTestData.TrainSet;
IDataView testDataView = trainTestData.TestSet;

// 5. Define the model's training pipeline using DNN default values
//
var pipeline = mlContext.MulticlassClassification.Trainers
        .ImageClassification(featureColumnName: "Image",
                             labelColumnName: "LabelAsKey",
                             validationSet: testDataView)
    .Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel",
                                                          inputColumnName: "PredictedLabel"));

// 5.1 (OPTIONAL) Define the model's training pipeline by using explicit hyper-parameters
//
//var options = new ImageClassificationTrainer.Options()
//{
//    FeatureColumnName = "Image",
//    LabelColumnName = "LabelAsKey",
//    // Just by changing/selecting InceptionV3/MobilenetV2/ResnetV250  
//    // you can try a different DNN architecture (TensorFlow pre-trained model). 
//    Arch = ImageClassificationTrainer.Architecture.MobilenetV2,
//    Epoch = 50,       //100
//    BatchSize = 10,
//    LearningRate = 0.01f,
//    MetricsCallback = (metrics) => Console.WriteLine(metrics),
//    ValidationSet = testDataView
//};

//var pipeline = mlContext.MulticlassClassification.Trainers.ImageClassification(options)
//        .Append(mlContext.Transforms.Conversion.MapKeyToValue(
//            outputColumnName: "PredictedLabel",
//            inputColumnName: "PredictedLabel"));

// 6. Train/create the ML model
Console.WriteLine("*** Training the image classification model with DNN Transfer Learning on top of the selected pre-trained model/architecture ***");

// Measuring training time
var watch = Stopwatch.StartNew();

//Train
ITransformer trainedModel = pipeline.Fit(trainDataView);

watch.Stop();
var elapsedMs = watch.ElapsedMilliseconds;

Console.WriteLine($"Training with transfer learning took: {elapsedMs / 1000} seconds");

// 7. Get the quality metrics (accuracy, etc.)
EvaluateModel(mlContext, testDataView, trainedModel);

// 8. Save the model to assets/outputs (You get ML.NET .zip model file and TensorFlow .pb model file)
mlContext.Model.Save(trainedModel, trainDataView.Schema, outputMlNetModelFilePath);
Console.WriteLine($"Model saved to: {outputMlNetModelFilePath}");

// 9. Try a single prediction simulating an end-user app
TrySinglePrediction(imagesFolderPathForPredictions, mlContext, trainedModel);

Console.WriteLine("Press any key to finish");

Unhandled exception: (19,34): error CS1061: ‘TransformsCatalog’ não contém uma definição para "LoadRawImageBytes" e não foi possível encontrar nenhum método de extensão "LoadRawImageBytes" que aceite um primeiro argumento do tipo ‘TransformsCatalog’ (você está se esquecendo de usar uma diretiva ou uma referência de assembly?)