## Imports

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from sklearn.feature_selection import VarianceThreshold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report, accuracy_score
from sklearn.svm import SVC
from scipy import stats
import pickle

## Config

In [2]:
RETRAIN_SVM = False
SVM_BEST_PARAMS = {
    "C": 100,
    "kernel": "poly"
}
SVM_MODEL_FILE = "refitted_svm.pickle"

RETRAIN_LOGREG = True
LOGREG_BEST_PARAMS = {
    "C": 0.1,
    "penalty": "l1",
    "solver": "saga",
    "max_iter": 10_000
}
LOGREG_MODEL_FILE = "refitted_log_reg.pickle"
IMG_SIZE = 28

## Data processing

In [3]:
mnist = pd.read_csv("mnist.csv").values

In [4]:
digits = mnist[:, 1:]
labels = mnist[: , 0]

In [5]:
var_thr = VarianceThreshold(threshold = 0)
clean_digits = var_thr.fit_transform(digits)
scaler = StandardScaler()  
scaled_digits = scaler.fit_transform(clean_digits)
scaled_with_label = np.column_stack((labels, scaled_digits))

In [6]:
np.random.seed(123)
test_indices = np.random.choice(scaled_with_label.shape[0], 5000, replace=False)
test_with_label = scaled_with_label[test_indices]
train_with_label = np.delete(scaled_with_label, test_indices, axis=0)

In [7]:
train_x = train_with_label[:, 1:]
train_y = train_with_label[:, 0]
test_x = test_with_label[:, 1:]
test_y = test_with_label[:, 0]

## Train models

In [8]:
if RETRAIN_SVM:
    svm = SVC(**SVM_BEST_PARAMS)
    svm.fit(train_x, train_y)
    with open(SVM_MODEL_FILE, "wb") as f:
        pickle.dump(svm, f)

In [None]:
if RETRAIN_LOGREG:
    log_reg = LogisticRegression(**LOGREG_BEST_PARAMS, verbose=1, n_jobs=-1)
    log_reg.fit(train_x, train_y)
    with open(LOGREG_MODEL_FILE, "wb") as f:
        pickle.dump(log_reg, f)

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.


Epoch 1, change: 1.00000000
Epoch 2, change: 0.25303147
Epoch 3, change: 0.14071018
Epoch 4, change: 0.09650606
Epoch 5, change: 0.07573142
Epoch 6, change: 0.06272633
Epoch 7, change: 0.05241695
Epoch 8, change: 0.04516563
Epoch 9, change: 0.03971793
Epoch 10, change: 0.03589769
Epoch 11, change: 0.03256611
Epoch 12, change: 0.02957907
Epoch 13, change: 0.02710460
Epoch 14, change: 0.02506183
Epoch 15, change: 0.02316398
Epoch 16, change: 0.02149445
Epoch 17, change: 0.02013654
Epoch 18, change: 0.01883913
Epoch 19, change: 0.01765415
Epoch 20, change: 0.01663536
Epoch 21, change: 0.01575306
Epoch 22, change: 0.01495789
Epoch 23, change: 0.01427611
Epoch 24, change: 0.01352995
Epoch 25, change: 0.01318266
Epoch 26, change: 0.01271818
Epoch 27, change: 0.01260028
Epoch 28, change: 0.01222355
Epoch 29, change: 0.01192004
Epoch 30, change: 0.01167464
Epoch 31, change: 0.01139874
Epoch 32, change: 0.01125245
Epoch 33, change: 0.01092049
Epoch 34, change: 0.01081547
Epoch 35, change: 0.010

Epoch 414, change: 0.00105571
Epoch 415, change: 0.00105238
Epoch 416, change: 0.00105000
Epoch 417, change: 0.00104775
Epoch 418, change: 0.00104534
Epoch 419, change: 0.00103685
Epoch 420, change: 0.00103597
Epoch 421, change: 0.00103452
Epoch 422, change: 0.00103081
Epoch 423, change: 0.00102733
Epoch 424, change: 0.00102551
Epoch 425, change: 0.00102205
Epoch 426, change: 0.00101980
Epoch 427, change: 0.00101491
Epoch 428, change: 0.00101282
Epoch 429, change: 0.00101073
Epoch 430, change: 0.00100777
Epoch 431, change: 0.00100420
Epoch 432, change: 0.00100217
Epoch 433, change: 0.00099990
Epoch 434, change: 0.00099760
Epoch 435, change: 0.00099434
Epoch 436, change: 0.00099200
Epoch 437, change: 0.00098820
Epoch 438, change: 0.00098712
Epoch 439, change: 0.00098399
Epoch 440, change: 0.00098139
Epoch 441, change: 0.00097991
Epoch 442, change: 0.00097695
Epoch 443, change: 0.00097519
Epoch 444, change: 0.00097349
Epoch 445, change: 0.00097005
Epoch 446, change: 0.00096803
Epoch 447,

Epoch 687, change: 0.00059196
Epoch 688, change: 0.00059107
Epoch 689, change: 0.00059043
Epoch 690, change: 0.00058976
Epoch 691, change: 0.00058848
Epoch 692, change: 0.00058773
Epoch 693, change: 0.00058683
Epoch 694, change: 0.00058632
Epoch 695, change: 0.00058520
Epoch 696, change: 0.00058378
Epoch 697, change: 0.00058293
Epoch 698, change: 0.00058206
Epoch 699, change: 0.00058126
Epoch 700, change: 0.00058064
Epoch 701, change: 0.00057896
Epoch 702, change: 0.00057834
Epoch 703, change: 0.00057772
Epoch 704, change: 0.00057623
Epoch 705, change: 0.00057527
Epoch 706, change: 0.00057461
Epoch 707, change: 0.00057349
Epoch 708, change: 0.00057280
Epoch 709, change: 0.00057166
Epoch 710, change: 0.00057076
Epoch 711, change: 0.00056995
Epoch 712, change: 0.00056841
Epoch 713, change: 0.00056804
Epoch 714, change: 0.00056710
Epoch 715, change: 0.00056604
Epoch 716, change: 0.00056469
Epoch 717, change: 0.00056379
Epoch 718, change: 0.00056326
Epoch 719, change: 0.00056226
Epoch 720,

Epoch 960, change: 0.00032793
Epoch 961, change: 0.00032713
Epoch 962, change: 0.00032639
Epoch 963, change: 0.00032565
Epoch 964, change: 0.00032443
Epoch 965, change: 0.00032394
Epoch 966, change: 0.00032325
Epoch 967, change: 0.00032259
Epoch 968, change: 0.00032198
Epoch 969, change: 0.00032109
Epoch 970, change: 0.00032026
Epoch 971, change: 0.00031987
Epoch 972, change: 0.00031892
Epoch 973, change: 0.00031806
Epoch 974, change: 0.00031737
Epoch 975, change: 0.00031671
Epoch 976, change: 0.00031580
Epoch 977, change: 0.00031498
Epoch 978, change: 0.00031421
Epoch 979, change: 0.00031360
Epoch 980, change: 0.00031281
Epoch 981, change: 0.00031227
Epoch 982, change: 0.00031126
Epoch 983, change: 0.00031088
Epoch 984, change: 0.00030991
Epoch 985, change: 0.00030940
Epoch 986, change: 0.00030850
Epoch 987, change: 0.00030790
Epoch 988, change: 0.00030670
Epoch 989, change: 0.00030524
Epoch 990, change: 0.00030430
Epoch 991, change: 0.00030335
Epoch 992, change: 0.00030274
Epoch 993,

Epoch 1225, change: 0.00019843
Epoch 1226, change: 0.00019836
Epoch 1227, change: 0.00019817
Epoch 1228, change: 0.00019810
Epoch 1229, change: 0.00019811
Epoch 1230, change: 0.00019800
Epoch 1231, change: 0.00019785
Epoch 1232, change: 0.00019773
Epoch 1233, change: 0.00019763
Epoch 1234, change: 0.00019745
Epoch 1235, change: 0.00019742
Epoch 1236, change: 0.00019745
Epoch 1237, change: 0.00019727
Epoch 1238, change: 0.00019707
Epoch 1239, change: 0.00019704
Epoch 1240, change: 0.00019690
Epoch 1241, change: 0.00019680
Epoch 1242, change: 0.00019672
Epoch 1243, change: 0.00019663
Epoch 1244, change: 0.00019651
Epoch 1245, change: 0.00019647
Epoch 1246, change: 0.00019635
Epoch 1247, change: 0.00019621
Epoch 1248, change: 0.00019620
Epoch 1249, change: 0.00019601
Epoch 1250, change: 0.00019586
Epoch 1251, change: 0.00019588
Epoch 1252, change: 0.00019574
Epoch 1253, change: 0.00019563
Epoch 1254, change: 0.00019563
Epoch 1255, change: 0.00019555
Epoch 1256, change: 0.00019540
Epoch 12

Epoch 1490, change: 0.00017723
Epoch 1491, change: 0.00017699
Epoch 1492, change: 0.00017694
Epoch 1493, change: 0.00017702
Epoch 1494, change: 0.00017683
Epoch 1495, change: 0.00017673
Epoch 1496, change: 0.00017659
Epoch 1497, change: 0.00017656
Epoch 1498, change: 0.00017646
Epoch 1499, change: 0.00017638
Epoch 1500, change: 0.00017626
Epoch 1501, change: 0.00017637
Epoch 1502, change: 0.00017619
Epoch 1503, change: 0.00017609
Epoch 1504, change: 0.00017605
Epoch 1505, change: 0.00017598
Epoch 1506, change: 0.00017608
Epoch 1507, change: 0.00017598
Epoch 1508, change: 0.00017589
Epoch 1509, change: 0.00017577
Epoch 1510, change: 0.00017578
Epoch 1511, change: 0.00017572
Epoch 1512, change: 0.00017566
Epoch 1513, change: 0.00017575
Epoch 1514, change: 0.00017558
Epoch 1515, change: 0.00017543
Epoch 1516, change: 0.00017536
Epoch 1517, change: 0.00017538
Epoch 1518, change: 0.00017527
Epoch 1519, change: 0.00017523
Epoch 1520, change: 0.00017514
Epoch 1521, change: 0.00017500
Epoch 15

### Load models

In [None]:
with open(SVM_MODEL_FILE, "rb") as f:
    svm = pickle.load(f)

In [None]:
with open(LOGREG_MODEL_FILE, "rb") as f:
    log_reg = pickle.load(f)

## Predict

In [None]:
svm_pred = svm.predict(test_x)

In [None]:
log_reg_pred = log_reg.predict(test_x)

## Analyses

In [None]:
print("Logistic regression")
print(classification_report(test_y, log_reg_pred, digits=4))
print(f"Accuracy: {accuracy_score(test_y, log_reg_pred)}")

In [None]:
print("Support Vector machine")
print(classification_report(test_y, svm_pred, digits=4))
print(f"Accuracy: {accuracy_score(test_y, svm_pred)}")

In [None]:
confusion_matrix(test_y, log_reg_pred)

In [None]:
confusion_matrix(test_y, svm_pred)

In [None]:
def mcnemar_test(pred_y_a, pred_y_b, correct):
    def create_comparison_table(pred_y_a, pred_y_b, correct):
        correct_a = pred_y_a == correct
        correct_b = pred_y_b == correct
        return confusion_matrix(correct_a, correct_b)
    table = create_comparison_table(pred_y_a, pred_y_b, correct)
    print(table)
    b = table[0][1]
    c = table[1][0]
    if b == 0 and c == 0:
        x2 = 0
    else:
        x2 = ((b - c) ** 2) / (b + c)
    p = chi2.sf(x2, df=1)
    return {"chi": x2, "df": 1, "p": p}


In [None]:
def show_n_misclassified(true_y, pred_y, cols=5, savefig=False, filename="plot.png"):
    wrongly_classified = np.arange(0, len(true_y))[true_y != pred_y]
    rows = len(wrongly_classified) // cols + 1
    fig = plt.figure(figsize=(20, 20))
    gs1 = gridspec.GridSpec(rows, cols)

    for idx, error in enumerate(wrongly_classified):
        ax = plt.subplot(gs1[idx])
        ax.imshow(digits[test_indices[error]].reshape(IMG_SIZE, IMG_SIZE))

        ax.axis('off')
        ax.text(0.5,-0.1, f"True Label: {true_y[error]}, Pred Label: {pred_y[error]}", size=10, ha="center", 
             transform=ax.transAxes)

    if savefig:
        plt.savefig(filename, dpi=100)
    plt.show()


In [None]:
show_n_misclassified(test_y, svm_pred, savefig=True, filename="misclassified_svm.png")

In [None]:
show_n_misclassified(test_y, log_reg_pred, 20, savefig=True, filename="miscoded_logreg.png")