|
7 | 7 | "outputs": [], |
8 | 8 | "source": [ |
9 | 9 | "import matplotlib.pyplot as plt\n", |
10 | | - "import numpy as np\n", |
11 | | - "import os" |
| 10 | + "import numpy as np" |
12 | 11 | ] |
13 | 12 | }, |
14 | 13 | { |
|
29 | 28 | "metadata": {}, |
30 | 29 | "outputs": [], |
31 | 30 | "source": [ |
32 | | - "def get_x(folder_name):\n", |
33 | | - " sorted_file_names = sorted([x for x in os.listdir(folder_name) if x.endswith('.jpg')], key=lambda x: int(x[:-4]))\n", |
34 | | - "\n", |
35 | | - " image_paths = [os.path.join(folder_name, x) for x in sorted_file_names]\n", |
| 31 | + "def get_x(folder_name, number_images):\n", |
36 | 32 | " x_input_points = np.zeros((0, 784))\n", |
37 | | - " for img in image_paths:\n", |
38 | | - " x_input_points = np.append(x_input_points, plt.imread(img).reshape(1, 784), axis=0)\n", |
| 33 | + " for i in range(1, number_images + 1): # +1 Since it's exclusive\n", |
| 34 | + " img_path = '{}/{}.jpg'.format(folder_name, i)\n", |
| 35 | + " x_input_points = np.append(x_input_points, plt.imread(img_path).reshape(1, 784), axis=0)\n", |
39 | 36 | "\n", |
40 | | - " return x_input_points " |
| 37 | + " return x_input_points" |
41 | 38 | ] |
42 | 39 | }, |
43 | 40 | { |
|
104 | 101 | "outputs": [], |
105 | 102 | "source": [ |
106 | 103 | "# Loading in a separate cell to avoid multiple loads.\n", |
107 | | - "x_train = get_x('Train')\n", |
| 104 | + "x_train = get_x('Train', 2400)\n", |
108 | 105 | "x_delta = get_x_delta(x_train)\n", |
109 | | - "training_true_labels = get_t(os.path.join('Train', 'Training Labels.txt'))" |
| 106 | + "training_true_labels = get_t('Train/Training Labels.txt')" |
110 | 107 | ] |
111 | 108 | }, |
112 | 109 | { |
|
129 | 126 | "metadata": {}, |
130 | 127 | "outputs": [], |
131 | 128 | "source": [ |
132 | | - "x_delta_test = get_x_delta(get_x('Test'))\n", |
133 | | - "test_true_labels = get_t(os.path.join('Test', 'Test Labels.txt'))\n", |
| 129 | + "x_delta_test = get_x_delta(get_x('Test', 200))\n", |
| 130 | + "test_true_labels = get_t('Test/Test Labels.txt')\n", |
134 | 131 | "conf_matrix = get_confusion_matrix(x_delta_test, weights_matrix, test_true_labels)\n", |
135 | 132 | "plt.imshow(conf_matrix)\n", |
136 | 133 | "plt.savefig('Confusion.jpg')\n", |
|
0 commit comments