In [2]:
import os
import sys
sys.path.append(os.path.join(".."))

# Import teaching utils
import numpy as np
import utils.classifier_utils as clf_util

# Import sklearn metrics
from sklearn import metrics
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

In [3]:
#Fetch data
X, y = fetch_openml("mnist_784", version=1, return_X_y=True)

In [4]:
#Convert to numpy arrays
X = np.array(X)
y = np.array(y)

X

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [5]:
#Create training data and test dataset
X_train, X_test, y_train, y_test = train_test_split(X, 
                                                    y,
                                                    train_size=7500, # Should be a parameter
                                                    test_size=2500) # Should be a parameter

In [6]:
#re-scaling the features from 0-255 to between 0 and 1
#This is both easier from a computational perspective and required because of logistic regression
X_train_scaled = X_train/255.0
X_test_scaled = X_test/255.0

In [12]:
#Create classifier
clf = LogisticRegression(penalty='none', 
                         tol=0.1, 
                         solver='saga',
                         multi_class='multinomial').fit(X_train_scaled, y_train) #Fit classifier to our test data

In [24]:
#Predict test data 
y_pred = clf.predict(X_test_scaled)

In [26]:
#Create a classification report by comparing test data with predictions
cm = metrics.classification_report(y_test, y_pred)
print(cm) #Print to terminal

              precision    recall  f1-score   support

           0       0.95      0.98      0.96       244
           1       0.90      0.97      0.93       287
           2       0.89      0.90      0.90       235
           3       0.90      0.87      0.89       281
           4       0.90      0.93      0.92       213
           5       0.87      0.84      0.86       215
           6       0.95      0.92      0.94       225
           7       0.95      0.91      0.93       257
           8       0.85      0.87      0.86       253
           9       0.91      0.88      0.90       290

    accuracy                           0.91      2500
   macro avg       0.91      0.91      0.91      2500
weighted avg       0.91      0.91      0.91      2500

