# Train and visualize a model in Tensorflow - Part 2: Scikit Learn

Along this tutorial we will explain the **multilayer perceptron** algorithm, which is the simplest possible form of an *artificial feed-forward neural network*. For this we will use the 20 newsgroup dataset obtained in the previous part of the tutorial.

We will see the same algorithm in three different ways: using *scikit-learn*'s `MLPClassifier`, using *TensorFlow*'s API to write the neural network from scratch, and finally using *TensorFlow*'s `DNNClassifier`.

The idea is to compare how the different ways serve different purposes. This notebook deals with the simplest form possible using Scikit Learn.

In [1]:
import numpy as np
import warnings
warnings.filterwarnings('ignore')

from sklearn.metrics import accuracy_score, classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier

## Training

Scikit Learn offers a simple API to do machine learning, specially in comparison to TensorFlow. The main problem with scikit learn is that most of the models are shallow ones (e.g. Logistic Regression, SVM, etc). The `MLPClassifier` exists and offers the possibility for a Neural Network classifier, however it is considerably slow to train a classifier and doesn't provide GPU optimization. The [documentation](http://scikit-learn.org/stable/modules/neural_networks_supervised.html) itself says the implementation of `MLPClassifier` is not intended for large-scale applications.

To keep things simple we will create a simple multilayer perceptron with only one hidden layer with size 5000 (half the size of the input) and see how it goes.

In [2]:
# Load the dataset
newsgroups = np.load('./resources/newsgroup.npz')

# Define the model
model = MLPClassifier(
    activation='relu',  # Rectifier Linear Unit activation
    hidden_layer_sizes=(5000,),  # 1 hidden layer of size 5000
    max_iter=5,  # Each epochs takes a lot of time so we keep it to 5
    batch_size=100,  # The batch size is set to 100 elements
    solver='adam')  # We use the adam solver

model.fit(newsgroups['train_data'],
          newsgroups['train_target'])

MLPClassifier(activation='relu', alpha=0.0001, batch_size=100, beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=(5000,), learning_rate='constant',
       learning_rate_init=0.001, max_iter=5, momentum=0.9,
       nesterovs_momentum=True, power_t=0.5, random_state=None,
       shuffle=True, solver='adam', tol=0.0001, validation_fraction=0.1,
       verbose=False, warm_start=False)

## Evaluation

Once the model is trained we check both the accuracy of the model and print the classification report of it (which shows the precision, recall, f1-score and support for each of the categories).

In [3]:
accuracy = accuracy_score(
    newsgroups['test_target'],
    model.predict(newsgroups['test_data']))

print("Accuracy: %.2f\n" % accuracy)

print("Classification Report\n=====================")
print(classification_report(
    newsgroups['test_target'],
    model.predict(newsgroups['test_data'])))

Accuracy: 0.91

Classification Report
             precision    recall  f1-score   support

          0       0.93      0.96      0.94       160
          1       0.86      0.84      0.85       195
          2       0.79      0.89      0.84       197
          3       0.77      0.78      0.77       196
          4       0.91      0.83      0.87       192
          5       0.91      0.90      0.90       196
          6       0.86      0.85      0.85       194
          7       0.91      0.92      0.92       198
          8       0.96      0.96      0.96       199
          9       0.98      0.97      0.98       199
         10       0.98      0.96      0.97       200
         11       0.96      0.94      0.95       198
         12       0.86      0.89      0.87       196
         13       0.89      0.97      0.93       198
         14       1.00      0.92      0.96       197
         15       0.95      0.95      0.95       200
         16       0.96      0.90      0.93       182
       