In [1]:
import sys
assert sys.version_info >= (3, 5)
# Python ≥3.5 is required

# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"

# Common imports
import numpy as np
import os
import tarfile
import urllib
import pandas as pd

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

# Task 3 modules
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

In [2]:
DOWNLOAD_ROOT = "http://www.macs.hw.ac.uk/%7Eek19/data/"
CURRENT_PATH = os.path.join(os.getcwd(), "datasets")
X_FILE = "x_train_gr_smpl.csv"
Y_FILE ="y_train_smpl.csv"

def fetch_data(download_root=DOWNLOAD_ROOT, current_path=CURRENT_PATH):
    if os.path.isfile(os.path.join(current_path, X_FILE)) or os.path.isfile(os.path.join(current_path, Y_FILE)):
        return
    
    os.makedirs(current_path, exist_ok=True)
    
    urllib.request.urlretrieve(download_root + X_FILE, os.path.join(current_path, X_FILE))
    urllib.request.urlretrieve(download_root + Y_FILE, os.path.join(current_path, Y_FILE))

In [3]:
fetch_data()

In [10]:
def load_features_data(current_path=CURRENT_PATH):
    return pd.read_csv(os.path.join(current_path, X_FILE))

def load_labels_data(current_path=CURRENT_PATH):
    return pd.read_csv(os.path.join(current_path, Y_FILE))

In [11]:
signs = load_features_data()
# signs.head() # Display the top five rows of the dataframe

In [12]:
signs_y = load_labels_data()
# signs_test.head() # Display the top five rows of the dataframe

In [13]:
model = GaussianNB()
model.fit(signs, signs_y.values.ravel())

GaussianNB()

In [14]:
expected = signs_y
predicted = model.predict(signs)

In [15]:
print(classification_report(expected, predicted))
print(confusion_matrix(expected, predicted))

              precision    recall  f1-score   support

           0       0.15      0.73      0.25       210
           1       0.41      0.21      0.28      2220
           2       0.43      0.17      0.24      2250
           3       0.44      0.22      0.29      1410
           4       0.39      0.15      0.22      1980
           5       0.13      0.21      0.16       210
           6       0.08      0.67      0.15       360
           7       0.09      0.54      0.15       240
           8       0.67      0.25      0.36       540
           9       0.25      0.27      0.26       270

    accuracy                           0.23      9690
   macro avg       0.30      0.34      0.24      9690
weighted avg       0.39      0.23      0.25      9690

[[154   5  18  12   4   0   0  17   0   0]
 [560 473 252  78 133  53 342 289  17  23]
 [138 373 378 178 259  99 378 387   0  60]
 [ 59  58  16 310  46  44 757  67  26  27]
 [ 63 205 178 113 293  75 883 151   0  19]
 [ 10   2   2   0   0  44 