```
Copyright 2021 IBM Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```

# Logistic Regression on MNIST8M Dataset

## Background 

The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting.

## Source

We use an inflated version of the dataset (`mnist8m`) from the paper:

Gaëlle Loosli, Stéphane Canu and Léon Bottou: *Training Invariant Support Vector Machines using Selective Sampling*, in [Large Scale Kernel Machines](https://leon.bottou.org/papers/lskm-2007), Léon Bottou, Olivier Chapelle, Dennis DeCoste, and Jason Weston editors, 301–320, MIT Press, Cambridge, MA., 2007.

We download the pre-processed dataset from the [LIBSVM dataset repository](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/).
## Goal
The goal of this notebook is to illustrate how Snap ML can accelerate training of a logistic regression model on this dataset.

## Code

In [1]:
cd ../../

/Users/tpa/Code/snapml-examples/examples


In [2]:
CACHE_DIR='cache-dir'

In [3]:
import numpy as np
import time
from datasets import Mnist8m
from sklearn.linear_model import LogisticRegression
from snapml import LogisticRegression as SnapLogisticRegression
from sklearn.metrics import accuracy_score

In [4]:
X_train, X_test, y_train, y_test = Mnist8m(cache_dir=CACHE_DIR).get_train_test_split()

Downloading dataset...


  0%|          | 0.00/3.12G [00:00<?, ?iB/s]

Writing cached, preprocessed data.


In [5]:
print("Number of examples: %d" % (X_train.shape[0]))
print("Number of features: %d" % (X_train.shape[1]))
print("Number of classes:  %d" % (len(np.unique(y_train))))

Number of examples: 6075000
Number of features: 784
Number of classes:  10


In [6]:
lr = LogisticRegression(fit_intercept=False, n_jobs=4, multi_class='ovr')
t0 = time.time()
lr.fit(X_train, y_train)
t_fit_sklearn = time.time()-t0
score_sklearn = accuracy_score(y_test, lr.predict(X_test))
print("Training time  (sklearn): %6.2f seconds" % (t_fit_sklearn))
print("Accuracy score (sklearn): %.4f" % (score_sklearn))

Training time  (sklearn): 3274.61 seconds
Accuracy score (sklearn): 0.8452


In [7]:
lr = SnapLogisticRegression(fit_intercept=False, n_jobs=4)
t0 = time.time()
lr.fit(X_train, y_train)
t_fit_snapml = time.time()-t0
score_snapml = accuracy_score(y_test, lr.predict(X_test))
print("Training time  (snapml): %6.2f seconds" % (t_fit_snapml))
print("Accuracy score (snapml): %.4f" % (score_snapml))

Training time  (snapml): 183.10 seconds
Accuracy score (snapml): 0.8452


In [8]:
speed_up = t_fit_sklearn/t_fit_snapml
score_diff = (score_snapml-score_sklearn)/score_sklearn
print("Speed-up:                %.1f x" % (speed_up))
print("Relative diff. in score: %.4f" % (score_diff))

Speed-up:                17.9 x
Relative diff. in score: 0.0000


## Disclaimer

Performance results always depend on the hardware and software environment. 

This notebook was run on the following machine:
* OS: MacOS 11.1 (Big Sur)
* CPU: 2.3 GHz Quad-Core Intel Core i7
* Memory: 32GB

The versions of the relevant software packages are given below:

In [9]:
import snapml
import sklearn
print("scikit-learn version: %s" % (sklearn.__version__))
print("      snapml version: %s" % (snapml.__version__))

scikit-learn version: 0.23.2
      snapml version: 1.7.0
