# Defining, Training, and Testing Models - Classification

![ml workflow](https://docs.google.com/drawings/d/e/2PACX-1vQ1XLwesZbm_TuDBPFRvbHa4XcjucvtExy3LXE05WnaAw-s6BDVQnnd4lAEUW1Qy6bs6FythuJdFVqP/pub?w=1165&h=662)

Moving on from Regression, let's now try out some **Classification**. As opposed to predicting a continuous value (like weight, temperature, or stock price), classification attempts to predict whether a certain sample belongs to one class or another (e.g., fraud or not fraud). 

For our example, we will be trying to predict the species of Iris flowers (our labels or classes) based on physical measurements of those Iris flowers (our features). We can do this based on a very famous data set which you can learn more about [here](https://en.wikipedia.org/wiki/Iris_flower_data_set). We will try two different kinds of models:

- [k Nearest Neighbors](https://youtu.be/UqYde-LULfs)
- [Decision Tree](https://www.cs.cmu.edu/afs/cs/academic/class/15381-s07/www/slides/041007decisionTrees1.pdf)

## Imports

In [None]:
import (
    "io/ioutil"
    "fmt"
    "os"
    "math"
    "math/rand"
    
    "github.com/kniren/gota/dataframe"
    "github.com/kniren/gota/series"
    "gonum.org/v1/plot"
    "gonum.org/v1/plot/plotter"
    "gonum.org/v1/plot/plotutil"
    "gonum.org/v1/plot/vg"
    "gonum.org/v1/gonum/stat"
    "gonum.org/v1/gonum/floats"
    "github.com/sjwhitworth/golearn/knn"
    "github.com/sjwhitworth/golearn/base"
    "github.com/sjwhitworth/golearn/evaluation"
    "github.com/sjwhitworth/golearn/trees"
)

## Convenience Functions

In [None]:
// GetGraph returns the bytes corresponding to a
// saved plot.
func GetGraph(graphName string) ([]byte, error) {
    
    // Open the file.
    infile, err := os.Open(graphName)
    if err != nil {
        return nil, err
    }
    
    // Read in the contents of the file.
    bytes, err := ioutil.ReadAll(infile)
    if err != nil {
        return nil, err
    }
    
    // Close the file.
    infile.Close()
    
    return bytes, err
}

## Import the Data

In [None]:
// Open the data file.
f, err := os.Open("../data/iris.csv")
if err != nil {
    fmt.Println(err)
}

// Read in the contents to a dataframe.
irisDF := dataframe.ReadCSV(f)

// Close the file.
f.Close()

In [None]:
// Output a summary of the dataset to stdout.
fmt.Println(irisDF)

## Profile our data set

### Count of each species label

In [None]:
// Define our unique species.
species := []string{"Iris-setosa", "Iris-virginica", "Iris-versicolor"}

// Count instances of the unique species.
for _, sp := range species {
        
    // Create a filter for the dataframe.
    filter := dataframe.F{
        Colname:    "species",
        Comparator: series.Eq,
        Comparando: sp,
    }
    
    // Filter the dataframe to see only the rows where
    // the species is equal to sp.
    filteredDF := irisDF.Filter(filter)
    if filteredDF.Err != nil {
        fmt.Println(filteredDF.Err)
    }
    
    // Output the count.
    fmt.Printf("%s count: %d\n", sp, filteredDF.Nrow())
}

### Distribution of numerical features

In [None]:
// Create a histogram for each of the float columns in the dataset and
// output summary statistics.
for _, colName := range irisDF.Names() {

    if colName != "species" {

        // Create a plotter.Values value and fill it with the
        // values from the respective column of the dataframe.
        plotVals := make(plotter.Values, irisDF.Nrow())
        summaryVals := make([]float64, irisDF.Nrow())
        for i, floatVal := range irisDF.Col(colName).Float() {
            plotVals[i] = floatVal
            summaryVals[i] = floatVal
        }

        // Make a plot and set its title.
        p, err := plot.New()
        if err != nil {
            fmt.Println(err)
        }
        p.Title.Text = fmt.Sprintf("Histogram of a %s", colName)

        // Create a histogram of our values drawn
        // from the standard normal.
        h, err := plotter.NewHist(plotVals, 16)
        if err != nil {
            fmt.Println(err)
        }

        // Normalize the histogram.
        h.Normalize(1)

        // Add the histogram to the plot.
        p.Add(h)

        // Save the plot to a PNG file.
        if err := p.Save(4*vg.Inch, 4*vg.Inch, colName+"_hist.png"); err != nil {
            fmt.Println(err)
        }

        // Calculate the summary statistics.
        meanVal := stat.Mean(summaryVals, nil)
        maxVal := floats.Max(summaryVals)
        minVal := floats.Min(summaryVals)
        stdDevVal := stat.StdDev(summaryVals, nil)

        // Output the summary statistics.
        fmt.Printf("\n%s\n", colName)
        fmt.Printf("Mean: %0.2f\n", meanVal)
        fmt.Printf("Min: %0.2f\n", minVal)
        fmt.Printf("Max: %0.2f\n", maxVal)
        fmt.Printf("StdDev: %0.2f\n\n", stdDevVal)
    }
}

In [None]:
// Read the plot data from the first histogram.
plotBytes, err := GetGraph("sepal_width_hist.png")
if err != nil {
    fmt.Println(err)
}
    
// Display the plot.
display.PNG(plotBytes)

In [None]:
// Read the plot data from the second histogram.
plotBytes, err := GetGraph("sepal_length_hist.png")
if err != nil {
    fmt.Println(err)
}
    
// Display the plot.
display.PNG(plotBytes)

In [None]:
// Read the plot data from the third histogram.
plotBytes, err := GetGraph("petal_width_hist.png")
if err != nil {
    fmt.Println(err)
}
    
// Display the plot.
display.PNG(plotBytes)

In [None]:
// Read the plot data from the fourth histogram.
plotBytes, err := GetGraph("petal_length_hist.png")
if err != nil {
    fmt.Println(err)
}
    
// Display the plot.
display.PNG(plotBytes)

## Defining our models 

![ml workflow](https://docs.google.com/drawings/d/e/2PACX-1vSbzVQ-fJeOxZvAzbbE3yjRdB8A5WyBmHC2jz2AJTKvCcyOvZghkMVRAOvLgoGdF0mbcNPxCqRCrdIZ/pub?w=770&h=344)

We will now define a kNN and decision tree model.  The kNN algorithm calculates a "distance" between the input features and known observations in the feature space. It then chooses the *k* nearest of of these observations based on the distance. The majority class of those k nearest neighbors is then taken to be the class corresponding to the input features.

In a decision tree algorithm, a tree of if/then statements is created based on the features and labeled points of a training set. The parameters of the model that are determined during training are the ranges and ordering that determine how the if/then splits happen.

**Note** - We haven't split into training and test set yet, because we are going to utilize cross validation to evaluate/validate out model. 

In [None]:
// Define our kNN model.
knnModel := knn.NewKnnClassifier("euclidean", "linear", 2)

// This is to seed the random processes involved in building the
// decision tree.
rand.Seed(44111342)

// We will use the ID3 algorithm to build our decision tree.  Also, we
// will start with a parameter of 0.6 that controls the train-prune split.
tree := trees.NewID3DecisionTree(0.6)

## Using cross validation to train/evaluate/validate our models

In [None]:
// Read in the iris data set into golearn "instances".
irisData, err := base.ParseCSVToInstances("../data/iris.csv", true)
if err != nil {
    fmt.Println(err)
}

In [None]:
// Use cross-fold validation to evaluate the kNN model
// on 5 folds of the data set.
cv, err := evaluation.GenerateCrossFoldValidationConfusionMatrices(irisData, knnModel, 5)
if err != nil {
    fmt.Println(err)
}

// Get the mean, variance and standard deviation of the accuracy for the
// cross validation.
mean, variance := evaluation.GetCrossValidatedMetric(cv, evaluation.GetAccuracy)
stdev := math.Sqrt(variance)

// Output the cross metrics to standard out.
fmt.Printf("\n\nkNN Accuracy:\n%.2f (+/- %.2f)\n\n", mean, stdev*2)

// Use cross-fold validation to train evaluate the tree model
// on 5 folds of the data set.
cv, err = evaluation.GenerateCrossFoldValidationConfusionMatrices(irisData, tree, 5)
if err != nil {
    fmt.Println(err)
}

mean, variance = evaluation.GetCrossValidatedMetric(cv, evaluation.GetAccuracy)
stdev = math.Sqrt(variance)

fmt.Printf("Decision Tree Accuracy:\n%.2f (+/- %.2f)\n\n", mean, stdev*2)