Skip to content

Commit 9700e63

Browse files
committed
Add naive bayes classifier notebook
1 parent cca9911 commit 9700e63

File tree

1 file changed

+187
-0
lines changed

1 file changed

+187
-0
lines changed

naive-bayes.ipynb

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import matplotlib.pyplot as plt\n",
10+
"import numpy as np\n",
11+
"import math"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 2,
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"def get_t(labels_path):\n",
21+
" with open(labels_path) as f:\n",
22+
" training_labels = [int(x) for x in f.read().splitlines()]\n",
23+
" return np.array(training_labels)"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": 3,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"def get_confusion_matrix(true_labels, predictions): \n",
33+
" conf_matrix = np.zeros((10, 10))\n",
34+
" for i, predicted_class in enumerate(predictions):\n",
35+
" conf_matrix[true_labels[i]][predicted_class] += 1\n",
36+
"\n",
37+
" return conf_matrix"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": 4,
43+
"metadata": {},
44+
"outputs": [],
45+
"source": [
46+
"def get_x(folder_name, number_images):\n",
47+
" x_input_points = np.zeros((0, 784))\n",
48+
" for i in range(1, number_images + 1): # +1 Since it's exclusive\n",
49+
" img_path = '{}/{}.jpg'.format(folder_name, i)\n",
50+
" x_input_points = np.append(x_input_points, plt.imread(img_path).reshape(1, 784), axis=0)\n",
51+
"\n",
52+
" return x_input_points"
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": null,
58+
"metadata": {},
59+
"outputs": [],
60+
"source": [
61+
"def gaussian(x, mean, variance):\n",
62+
" deno = (2 * math.pi * variance) ** 0.5\n",
63+
" exp = -1 * ( (x-mean)**2 / (2*variance) )\n",
64+
" return (1/deno) * math.exp(exp)"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": 5,
70+
"metadata": {},
71+
"outputs": [],
72+
"source": [
73+
"# Loading in a separate cell to avoid multiple loads.\n",
74+
"x_delta = get_x('Train', 2400)\n",
75+
"training_true_labels = get_t('Train/Training Labels.txt')"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": 6,
81+
"metadata": {},
82+
"outputs": [
83+
{
84+
"data": {
85+
"text/plain": [
86+
"(10, 784)"
87+
]
88+
},
89+
"execution_count": 6,
90+
"metadata": {},
91+
"output_type": "execute_result"
92+
}
93+
],
94+
"source": [
95+
"means = np.zeros((10, 784))\n",
96+
"variances = np.zeros((10, 784))\n",
97+
"\n",
98+
"classes_inputs = np.split(x_delta/255, 10)\n",
99+
"\n",
100+
"for i, c in enumerate(classes_inputs):\n",
101+
" means[i] = np.mean(c, axis=0)\n",
102+
" \n",
103+
"for i, c in enumerate(classes_inputs):\n",
104+
" variances[i] = np.var(c, axis=0)\n",
105+
" variances[i][variances[i] < 0.01] = 0.01\n",
106+
"means.shape"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": 9,
112+
"metadata": {},
113+
"outputs": [
114+
{
115+
"name": "stderr",
116+
"output_type": "stream",
117+
"text": [
118+
"/home/abdullah/.virtualenvs/ml/lib/python3.5/site-packages/ipykernel_launcher.py:8: RuntimeWarning: overflow encountered in double_scalars\n",
119+
" \n"
120+
]
121+
}
122+
],
123+
"source": [
124+
"x_delta_test = get_x('Test', 200)/255\n",
125+
"test_true_labels = get_t('Test/Test Labels.txt')\n",
126+
"\n",
127+
"probabilities = np.ones((200, 10))\n",
128+
"for i in range(200):\n",
129+
" for c in range(10):\n",
130+
" for f in range(784):\n",
131+
" probabilities[i][c] *= gaussian(x_delta_test[i][f], means[c][f], variances[c][f])"
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": 10,
137+
"metadata": {},
138+
"outputs": [
139+
{
140+
"data": {
141+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAD8CAYAAABaQGkdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAACwZJREFUeJzt3VuIXeUZxvHnceegiW0caYtkkpqkWNsobaJTTykWEqG1ilIoJYpCvWhuqkYRREvBy4KI6IUIaVQKRqXEXIiIWjxcCG3q5EA1GYV0THPGFMcoiub09mKmEMXMXpP5Ptfsl/8PhMx25fVlmL9r7z1r1jgiBCCn09peAEA9BA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYtNqDJ1zdifO6Z9efO7+t2cVn4m63OlUmRvHjhWf2Uu7fqZPdDg+d7fjqgR+Tv90PfrcucXn/ul7Pyo+E3V15vRVmXtsZKT4zF7adWO80ug4nqIDiRE4kBiBA4kROJAYgQOJETiQWKPAbf/C9ru2d9i+p/ZSAMroGrjtjqRHJF0tabGkG2wvrr0YgMlrcga/RNKOiBiOiMOSnpF0fd21AJTQJPB+SbtP+HjP2GNfYHuV7UHbgx9+UP7SPAATV+xNtohYExEDETFw1tl1rukFMDFNAt8raf4JH88bewzAFNck8DclnWd7oe0ZklZKeq7uWgBK6PrTZBFx1Patkl6S1JH0eERsq74ZgElr9OOiEfGCpBcq7wKgMK5kAxIjcCAxAgcSI3AgMQIHEnON3w/+TZ8dl3pF8bkv7dtafKYk/XzukipzgVo2xiv6KD7oeldVzuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGKNfjfZVFHr7qc/2Xqs+Mw3l/TW70jv9PW1vcKEHBsZKT6zlz4HPtTs64szOJAYgQOJETiQGIEDiRE4kBiBA4l1Ddz2fNuv2d5ue5vt1V/HYgAmr8n3wY9KuisiNtv+hqRNtv8WEdsr7wZgkrqewSNif0RsHvvzx5KGJPXXXgzA5E3oNbjtBZKWStpYYxkAZTW+VNX2mZKelXRHRHz0Ff9+laRVknS6ZhVbEMCpa3QGtz1do3Gvi4gNX3VMRKyJiIGIGJiumSV3BHCKmryLbkmPSRqKiAfrrwSglCZn8GWSbpa03PbWsX9+WXkvAAV0fQ0eEW9I8tewC4DCuJINSIzAgcQIHEiMwIHECBxIrKduulhLjRsk/mboQPGZkvTXH55TZW4tNW6OWIv75lSZe3R4Z/GZEc1uFMoZHEiMwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIjLuqVlLr7qfDTy2pMnfRjVurzPXFF1SZG5u2lZ85cqj4TEk6uvzi4jPjn39vdBxncCAxAgcSI3AgMQIHEiNwIDECBxIjcCCxxoHb7tjeYvv5mgsBKGciZ/DVkoZqLQKgvEaB254n6RpJa+uuA6CkpmfwhyTdLen4yQ6wvcr2oO3BI/q8yHIAJqdr4LavlfR+RGwa77iIWBMRAxExMF0ziy0I4NQ1OYMvk3Sd7Z2SnpG03PaTVbcCUETXwCPi3oiYFxELJK2U9GpE3FR9MwCTxvfBgcQm9PPgEfG6pNerbAKgOM7gQGIEDiRG4EBiBA4kRuBAYtxVtcd8/48fVpl79JV5VebGivJ3P63l+KK5VebO3DJcfOZpnza7HJwzOJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGHdV7TFHh3fWGbyiztg/73qjytzfffenxWeeNryv+ExJOjYyUnxmxLFGx3EGBxIjcCAxAgcSI3AgMQIHEiNwILFGgds+y/Z62+/YHrJ9ee3FAExe0++DPyzpxYj4te0ZkmZV3AlAIV0Dtz1H0pWSfitJEXFY0uG6awEooclT9IWSDkp6wvYW22ttz668F4ACmgQ+TdJFkh6NiKWSPpF0z5cPsr3K9qDtwSNq9svJAdTVJPA9kvZExMaxj9drNPgviIg1ETEQEQPTNbPkjgBOUdfAI+KApN22zx97aIWk7VW3AlBE03fRb5O0buwd9GFJt9RbCUApjQKPiK2SBirvAqAwrmQDEiNwIDECBxIjcCAxAgcSI3AgMe6q2mN88QVV5sambVXm1rj7qSQNP7Wk+MxFN24tPlOSPv3VpcVnHn/1H42O4wwOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGJVbrroTkedOX3F5x4bGSk+s9fUujlir6lxg8R7//2v4jMl6f6B8i10Pv6s0XGcwYHECBxIjMCBxAgcSIzAgcQIHEiMwIHEGgVu+07b22y/bftp26fXXgzA5HUN3Ha/pNslDUTEhZI6klbWXgzA5DV9ij5N0hm2p0maJWlfvZUAlNI18IjYK+kBSbsk7Zd0KCJe/vJxtlfZHrQ9eDiaXUYHoK4mT9H7JF0vaaGkuZJm277py8dFxJqIGIiIgRm8RAemhCZP0a+S9F5EHIyII5I2SLqi7loASmgS+C5Jl9meZduSVkgaqrsWgBKavAbfKGm9pM2S3hr7O2sq7wWggEY/Dx4R90m6r/IuAArjSjYgMQIHEiNwIDECBxIjcCCxKndVhdTpK38nTUk6vmhulbm9drfWGp/f+wd+VnymJB34y3eKzzxyZ7N0OYMDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4k5IsoPtQ9K+k+DQ78l6b/FF6inl/btpV2l3tp3Kux6bkR8u9tBVQJvyvZgRAy0tsAE9dK+vbSr1Fv79tKuPEUHEiNwILG2A1/T8n9/onpp317aVeqtfXtm11ZfgwOoq+0zOICKWgvc9i9sv2t7h+172tqjG9vzbb9me7vtbbZXt71TE7Y7trfYfr7tXcZj+yzb622/Y3vI9uVt7zQe23eOfR28bftp26e3vdN4WgncdkfSI5KulrRY0g22F7exSwNHJd0VEYslXSbp91N41xOtljTU9hINPCzpxYj4gaQfawrvbLtf0u2SBiLiQkkdSSvb3Wp8bZ3BL5G0IyKGI+KwpGckXd/SLuOKiP0RsXnszx9r9Auwv92txmd7nqRrJK1te5fx2J4j6UpJj0lSRByOiA/b3aqraZLOsD1N0ixJ+1reZ1xtBd4vafcJH+/RFI9GkmwvkLRU0sZ2N+nqIUl3Szre9iJdLJR0UNITYy8n1tqe3fZSJxMReyU9IGmXpP2SDkXEy+1uNT7eZGvI9pmSnpV0R0R81PY+J2P7WknvR8SmtndpYJqkiyQ9GhFLJX0iaSq/H9On0WeaCyXNlTTb9k3tbjW+tgLfK2n+CR/PG3tsSrI9XaNxr4uIDW3v08UySdfZ3qnRlz7LbT/Z7kontUfSnoj4/zOi9RoNfqq6StJ7EXEwIo5I2iDpipZ3Gldbgb8p6TzbC23P0OgbFc+1tMu4bFujrxGHIuLBtvfpJiLujYh5EbFAo5/XVyNiSp5lIuKApN22zx97aIWk7S2u1M0uSZfZnjX2dbFCU/hNQWn0KdLXLiKO2r5V0ksafSfy8YjY1sYuDSyTdLOkt2xvHXvsDxHxQos7ZXKbpHVj/6MflnRLy/ucVERstL1e0maNfndli6b4VW1cyQYkxptsQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiT2PwxUclT/PpitAAAAAElFTkSuQmCC\n",
142+
"text/plain": [
143+
"<Figure size 432x288 with 1 Axes>"
144+
]
145+
},
146+
"metadata": {
147+
"needs_background": "light"
148+
},
149+
"output_type": "display_data"
150+
}
151+
],
152+
"source": [
153+
"predictions = np.argmax(probabilities, axis=1)\n",
154+
"conf_matrix = get_confusion_matrix(test_true_labels, predictions)\n",
155+
"plt.imshow(conf_matrix)\n",
156+
"plt.savefig('Confusion-naive.jpg')\n"
157+
]
158+
}
159+
],
160+
"metadata": {
161+
"file_extension": ".py",
162+
"kernelspec": {
163+
"display_name": "Python 3",
164+
"language": "python",
165+
"name": "python3"
166+
},
167+
"language_info": {
168+
"codemirror_mode": {
169+
"name": "ipython",
170+
"version": 3
171+
},
172+
"file_extension": ".py",
173+
"mimetype": "text/x-python",
174+
"name": "python",
175+
"nbconvert_exporter": "python",
176+
"pygments_lexer": "ipython3",
177+
"version": "3.5.2"
178+
},
179+
"mimetype": "text/x-python",
180+
"name": "python",
181+
"npconvert_exporter": "python",
182+
"pygments_lexer": "ipython3",
183+
"version": 3
184+
},
185+
"nbformat": 4,
186+
"nbformat_minor": 2
187+
}

0 commit comments

Comments
 (0)