# Multiclass Classification

In this notebook we will use the flower iris dataset which is often called the "Hello world!" of Machine Learning.

Our aim is to train a model which predicts the iris species, e.g, setosa, versicolor', virginica based on its petal and sepal measures. In our case we are not dealing with a binary classification problem but with a multiclass classification problem. Our model should be able to predict one of the three iris species. 

Writing the code for training a multiclass logistic regression model works exactly as when dealing with a binary problem. If you are interested in how logistic regression handles multiclass problems have a look at the sklearn [documentation](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html).

After got the data and defined our features and the target variable, we will split the dataset into 70% trainings data and 30% testing data. Then, 
1. we make a short EDA on the train data
2. we will train (fit) our logistic regression model on the training data and evaluate the performance on the remaining 30%. 


At this point of the course the structure of the notebook should be familiar to you. At the end you should:
* know one of the most popular data sets in machine learning ;) 
* know how to apply logistic regression to a multiclass problem

## Setup and Data

We will begin with importing the required libraries and data. This time we will import the dataset directly from sklearn. Sklearn provide some of the most commonly used [toy data sets](https://scikit-learn.org/stable/datasets/toy_dataset.html) for learning and practicing machine learning. 

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['font.size'] = 14
plt.rcParams['figure.figsize'] = (11, 7)

In [None]:
# Import data from sklearn.datasets
data = load_iris()
data

## Define features and Target

In [None]:
# Define features and target
X = data.data
y = data.target
target_names = data.target_names
target_names

## Train-Test split

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=15, shuffle=True, stratify=y)

# Check the shape of the data sets
print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
print("X_test shape:", X_test.shape)
print("y_test shape:", y_test.shape)

## Exploratory Data Analysis

In [None]:
# Check data type
print("X_train data type:", type(X_train))
print("y_train data type:", type(y_train))

In [None]:
# Convert X_train to dataframe
df_X_train = pd.DataFrame(data=X_train, columns=data.feature_names)
df_X_train

In [None]:
# Convert y_train to dataframe
df_y_train = pd.DataFrame(data=y_train, columns=['species'])
df_y_train

In [None]:
# Combine df_X_train and df_y_train into df_train
df_train = pd.concat([df_X_train,df_y_train], axis=1,ignore_index=False)
df_train

In [None]:
# Replace 0 with setosa, 1 with versicolar and 2 with virginica in the columns species
df_train.species.replace({0: target_names[0], 1: target_names[1], 2: target_names[2]}, inplace=True)

In [None]:
# Check result
df_train.head(2)

### Getting a feel for the data

Before we dive into the modelling part we will examine the data. 

In [None]:
# Checking the columns and data types 
df_train.info()

In [None]:
# Descriptive statistics 
df_train.describe()

In [None]:
# Checking for number of unique values
df_train.nunique()

In [None]:
# Checking for missing values
df_train.isnull().sum()

In [None]:
# Let's check wether the target variable is balanced
df_train['species'].value_counts()

The information above look promising: 
* We have four features which seem to be all numerical ones.
* Our target variable *species* is categorical and is balanced
* None of our columns has missing data. 
* Based on the table with the descriptive statistics we can assume that there are no strong outliers. 

It's always a good idea to visualize the data to get more insight and to confirm our first assumptions.

### Visualisation 

We will start by plotting the target variable. We can see that our dataset is very balanced with 50 specimens of each species. That's good, so we won't have to deal with an imbalanced data set when it comes to the modelling part.

In [None]:
# Plotting the target variable
plt.title('Species Count')
sns.countplot(x=df_train.species);

A pair plot is good to see all the column relationships at once. It's perfect to get an overview of the data but not really useful when it comes to presenting our work to someone else.

We can see from the pair plot that the three species can be pretty well separated. There seems to be only minor overlaps between the species versicolor and virginica. This is good and will hopefully result in a good performance of our model. 

In [None]:
sns.pairplot(df_train, hue="species", height=3,corner=True);

To confirm our first impression from the pairplot we can pick some of the feature combinations and visualize them again with bigger plots. 

In [None]:
plt.title('Comparison between sepal width and length on the basis of species')
sns.scatterplot(x=df_train['sepal length (cm)'], y=df_train['sepal width (cm)'], hue = df_train['species'], s= 50);

In [None]:
plt.title('Comparison between petal width and length on the basis of species')
sns.scatterplot(x=df_train['petal length (cm)'], y=df_train['petal width (cm)'], hue = df_train['species'], s= 50);

It is clearly visible that the species iris setosa is separable from the other two species, while there is some overlapping regarding iris versicolor and iris virginica.

It's always worth to have a look at the correlations of our numerical features. 
From the heatmap we can see that petal length and petal width are strongly correlated. Sepal length shows also a strong correlation to petal length and petal width. 

In [None]:
# Correlation heatmap 
correlations = df_train.corr(numeric_only=True)
correlations

In [None]:
mask = np.triu(correlations)
sns.heatmap(correlations , vmax=1, vmin=-1, annot=True, mask=mask, cmap="YlGnBu",);

## Featuring Engineering

Not done yet

## Train Model(s)
Since our dataset is completely balanced we decide that accuracy is a good metric to evaluate the performance of our model. 

In [None]:
# Train Logistic Regression
log_reg = LogisticRegression(max_iter=1000)
log_reg.fit(X_train, y_train)

## Evaluate Model(s)

In [None]:
# Make predictions
y_pred_train = log_reg.predict(X_train)
y_pred_test = log_reg.predict(X_test)

In [None]:
# Print accuracy of our model
print("Accuracy on train set:", round(accuracy_score(y_train, y_pred_train), 2))
print("Accuracy on test set:", round(accuracy_score(y_test, y_pred_test), 2))
print("--------"*10)

In [None]:
# Print classification report of our model
print(classification_report(y_test, y_pred_test))
print("--------"*10)

In [None]:
# Evaluate the model with a confusion matrix
cm = confusion_matrix(y_test, y_pred_test)
cm

In [None]:
fig, ax = plt.subplots()

sns.heatmap(cm, cmap='YlGnBu', 
            annot=True, fmt='d', 
            linewidths=.5, xticklabels=data.target_names, 
            yticklabels=data.target_names, ax=ax)
fig.supxlabel("predicted")
fig.supylabel("actual")

Our model is working pretty well. We reached an accuracy of 0.98 on the train set and 0.93 on our test set. That's really good. 
If we have a look at the confusion matrix we can see that our model perfectly classified all instances of iris setosa really as iris setosa. 
When it comes to the observations for versicolar and virginica species, the model gets a bit confused:
+ 1 out of the 15 versicolar observation gets misclassified as virginica, and the remaining 14 are correctly classified
+ 2 out of the 15 virginica observation gets misclassified as versicolar, and the remaining 13 are correctly classified

If we recall our observations from the EDA regarding the separability of the three species this is not a big surprise. 