Skip to content

Commit eef85a8

Browse files
committed
Add confusion matrix calculation
1 parent 743b0a6 commit eef85a8

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

main.ipynb

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@
8080
"metadata": {},
8181
"outputs": [],
8282
"source": [
83-
"# Labels matrix (), each row representing one class, and the columns represent the true label of each image \n",
83+
"# Labels matrix, each row representing one class, and the columns represent whether\n",
84+
"# the image belongs to that class (row index) or not (-1, 1) \n",
8485
"T = np.zeros((10, 2400))\n",
8586
"for i in range(len(T)):\n",
8687
" T[i] = np.where(t == i, 1, -1)\n",
@@ -117,8 +118,14 @@
117118
"metadata": {},
118119
"outputs": [],
119120
"source": [
120-
"out = np.dot(x_delta, WC)\n",
121-
"WC.shape"
121+
"predictions = np.dot(x_delta, weights_matrix)\n",
122+
"conf_matrix = np.zeros((10, 10))\n",
123+
"for i, row in enumerate(predictions):\n",
124+
" prediction = np.where(row == row.max())[0][0] # The index (class) of the maximum value (most probable class)\n",
125+
" true_value = t[i]\n",
126+
" conf_matrix[true_value][prediction] += 1\n",
127+
" \n",
128+
"conf_matrix"
122129
]
123130
},
124131
{
@@ -127,21 +134,10 @@
127134
"metadata": {},
128135
"outputs": [],
129136
"source": [
130-
"tested = out[1000]\n",
131-
"pred = np.where(zero == zero.max())\n",
137+
"to_be_tested = predictions[2030]\n",
138+
"pred = np.where(to_be_tested == to_be_tested.max())\n",
132139
"pred, pred[0][0]"
133140
]
134-
},
135-
{
136-
"cell_type": "code",
137-
"execution_count": null,
138-
"metadata": {
139-
"scrolled": false
140-
},
141-
"outputs": [],
142-
"source": [
143-
"out.shape"
144-
]
145141
}
146142
],
147143
"metadata": {

0 commit comments

Comments
 (0)