Skip to content

Commit cde7260

Browse files
committed
Get rid of os module
1 parent 8c0c089 commit cde7260

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

main.ipynb

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
"outputs": [],
88
"source": [
99
"import matplotlib.pyplot as plt\n",
10-
"import numpy as np\n",
11-
"import os"
10+
"import numpy as np"
1211
]
1312
},
1413
{
@@ -29,15 +28,13 @@
2928
"metadata": {},
3029
"outputs": [],
3130
"source": [
32-
"def get_x(folder_name):\n",
33-
" sorted_file_names = sorted([x for x in os.listdir(folder_name) if x.endswith('.jpg')], key=lambda x: int(x[:-4]))\n",
34-
"\n",
35-
" image_paths = [os.path.join(folder_name, x) for x in sorted_file_names]\n",
31+
"def get_x(folder_name, number_images):\n",
3632
" x_input_points = np.zeros((0, 784))\n",
37-
" for img in image_paths:\n",
38-
" x_input_points = np.append(x_input_points, plt.imread(img).reshape(1, 784), axis=0)\n",
33+
" for i in range(1, number_images + 1): # +1 Since it's exclusive\n",
34+
" img_path = '{}/{}.jpg'.format(folder_name, i)\n",
35+
" x_input_points = np.append(x_input_points, plt.imread(img_path).reshape(1, 784), axis=0)\n",
3936
"\n",
40-
" return x_input_points "
37+
" return x_input_points"
4138
]
4239
},
4340
{
@@ -104,9 +101,9 @@
104101
"outputs": [],
105102
"source": [
106103
"# Loading in a separate cell to avoid multiple loads.\n",
107-
"x_train = get_x('Train')\n",
104+
"x_train = get_x('Train', 2400)\n",
108105
"x_delta = get_x_delta(x_train)\n",
109-
"training_true_labels = get_t(os.path.join('Train', 'Training Labels.txt'))"
106+
"training_true_labels = get_t('Train/Training Labels.txt')"
110107
]
111108
},
112109
{
@@ -129,8 +126,8 @@
129126
"metadata": {},
130127
"outputs": [],
131128
"source": [
132-
"x_delta_test = get_x_delta(get_x('Test'))\n",
133-
"test_true_labels = get_t(os.path.join('Test', 'Test Labels.txt'))\n",
129+
"x_delta_test = get_x_delta(get_x('Test', 200))\n",
130+
"test_true_labels = get_t('Test/Test Labels.txt')\n",
134131
"conf_matrix = get_confusion_matrix(x_delta_test, weights_matrix, test_true_labels)\n",
135132
"plt.imshow(conf_matrix)\n",
136133
"plt.savefig('Confusion.jpg')\n",

0 commit comments

Comments
 (0)