In [1]:
import os
import numpy as np

from PIL import Image, ImageChops
import operator
from collections import Counter
from matplotlib import pyplot as plt
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
from tqdm import tqdm
tqdm.pandas()
from skimage.transform import resize

from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report

import pickle as pk

## Train ML model

In [3]:
train_path = 'dataset'
l = []

for subpath in tqdm(os.listdir(train_path)):
    for file in os.listdir(train_path+'/'+subpath):
        path = train_path+'/'+subpath+'/'+file
        im = Image.open(path)
        im_arr = (np.array(ImageChops.invert(Image.open(path)).resize((28,28)))/255).flatten()
        im_arr = np.append(im_arr, int(subpath))
        l.append(im_arr)

dataset = pd.DataFrame(l)
dataset = dataset.rename(columns={784:'label'})
dataset.to_csv('dataset.csv',index=False)

100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.12it/s]


In [4]:
dataset = pd.read_csv('dataset.csv')
X = np.array(dataset.drop('label',axis=1))
y = np.array(dataset.label)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print('training set population',Counter(y_train))

Counter({7.0: 818,
         8.0: 816,
         9.0: 820,
         6.0: 813,
         4.0: 815,
         1.0: 810,
         5.0: 826,
         2.0: 816,
         3.0: 799,
         0.0: 795})

In [5]:
%%time

# cross-validator : Stratified Shuffle Split 
sss = StratifiedShuffleSplit(n_splits = 2, test_size = 0.2, random_state = 42) 

# Functions to be used in the pipeline
skb = SelectKBest(f_classif)

### Define classifier ###
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
clf = LinearDiscriminantAnalysis()

# definition of the pipeline
pipeline = Pipeline(steps = [
    ('SKB',skb),
    ('LDA',clf)
])   

# parameters to tune 
param_grid = {
    'SKB__k':['all'],
    'LDA__solver':['svd'],
} 

# exhaustive search over specified parameter
grid = GridSearchCV(pipeline, param_grid, verbose = 1, cv = sss)

# training classifier
print (" > training classifier on training set:")
grid.fit(X_train, y_train)

# best classifier using the cross-validator and the Stratified Shuffle Split 
clf = grid.best_estimator_

# predicition with the classifier
print (" > testing classifier on testing set:")
y_pred = clf.predict(X_test)

# print grid parameters
print ("\n > Best grid search:")
print (grid.best_params_)

# dump classifier in a pickle file
print ("\n > Classifier dumped")
with open("digit_classifier.pkl", 'wb') as fid:
    pk.dump(clf, fid)
    
print(classification_report(y_test, y_pred))

 > training classifier on training set:
Fitting 2 folds for each of 1 candidates, totalling 2 fits


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    8.3s finished


 > testing classifier on testing set:

 > Best grid search:
{'LDA__solver': 'svd', 'SKB__k': 'all'}

 > Classifier dumped
              precision    recall  f1-score   support

         0.0       0.94      0.98      0.96       221
         1.0       0.80      0.97      0.88       206
         2.0       0.98      0.94      0.96       200
         3.0       0.96      0.95      0.96       217
         4.0       0.98      0.93      0.95       201
         5.0       0.95      0.93      0.94       190
         6.0       0.96      0.93      0.95       203
         7.0       0.98      0.93      0.96       198
         8.0       0.95      0.95      0.95       200
         9.0       0.97      0.91      0.94       196

    accuracy                           0.94      2032
   macro avg       0.95      0.94      0.94      2032
weighted avg       0.95      0.94      0.94      2032

Wall time: 13.7 s
