|
44 | 44 | "outputs": [], |
45 | 45 | "source": [ |
46 | 46 | "def initialize_weights_matrix():\n", |
| 47 | + " # 10 weight vectors (classes), each one is all zeros except for the first feature.\n", |
47 | 48 | " weight_matrix = np.zeros((785, 10))\n", |
48 | 49 | " weight_matrix[0] = np.ones(10)\n", |
49 | 50 | " return weight_matrix.T" |
|
90 | 91 | "source": [ |
91 | 92 | "# Loading in a separate cell to avoid multiple loads.\n", |
92 | 93 | "x_delta = get_x_delta('Train', 2400)\n", |
93 | | - "training_true_labels = get_t('Train/Training Labels.txt')\n", |
94 | | - "T = get_T_matrix(training_true_labels)" |
| 94 | + "T = get_T_matrix(get_t('Train/Training Labels.txt'))" |
95 | 95 | ] |
96 | 96 | }, |
97 | 97 | { |
|
100 | 100 | "metadata": {}, |
101 | 101 | "outputs": [], |
102 | 102 | "source": [ |
103 | | - "learning_rates = [1, 10 ** -1, 10 ** -2, 10 ** -3, 10 ** -4, 10 ** -5, 10 ** -6, 10 ** -7, 10 ** -8, 10 ** -9]\n", |
| 103 | + "# Train all learning rates for 500 epochs.\n", |
| 104 | + "learning_rates_powers = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", |
104 | 105 | "resulting_weights = []\n", |
105 | | - "for eta in learning_rates:\n", |
106 | | - " weights_matrix = initialize_weights_matrix()\n", |
107 | | - " for class_number in range(0,10):\n", |
108 | | - " for epoch in range(0,500):\n", |
109 | | - " for img in range(0, 2400):\n", |
110 | | - " r = 1 if np.dot(weights_matrix[class_number], x_delta[img]) >=0 else -1\n", |
111 | | - " if r != T[:,img][class_number]:\n", |
112 | | - " weights_matrix[class_number] += eta * (np.multiply(x_delta[img], T[:,img][class_number]))\n", |
113 | | - " resulting_weights += [(eta, weights_matrix)]" |
| 106 | + "\n", |
| 107 | + "for e_p in learning_rates_powers: # Eta Power \"e_p\"\n", |
| 108 | + " w_m = initialize_weights_matrix()\n", |
| 109 | + " for c in range(0,10): # Class \"c\"\n", |
| 110 | + " for _ in range(0,500): # Iteration/Epoch\n", |
| 111 | + " for i in range(0, 2400): # Image \"i\"\n", |
| 112 | + " r = 1 if np.dot(w_m[c], x_delta[i]) >=0 else -1\n", |
| 113 | + " if r != T[:,i][c]:\n", |
| 114 | + " # Update weight-vector if image \"i\" is misclassified\n", |
| 115 | + " w_m[c] += (10 ** (-1 * e_p)) * (np.multiply(x_delta[i], T[:,i][c]))\n", |
| 116 | + " resulting_weights += [(e_p, w_m)]" |
114 | 117 | ] |
115 | 118 | }, |
116 | 119 | { |
|
121 | 124 | "source": [ |
122 | 125 | "x_delta_test = get_x_delta('Test', 200)\n", |
123 | 126 | "test_true_labels = get_t('Test/Test Labels.txt')\n", |
124 | | - "for eta, weights_matrix in resulting_weights:\n", |
125 | | - " conf_matrix = get_confusion_matrix(x_delta_test, weights_matrix.T, test_true_labels)\n", |
126 | | - " plt.imshow(conf_matrix)\n", |
127 | | - " plt.savefig('Confusion{}.jpg'.format(eta))" |
| 127 | + "for eta_power, weights_matrix in resulting_weights:\n", |
| 128 | + " confusion_matrix = get_confusion_matrix(x_delta_test, weights_matrix.T, test_true_labels)\n", |
| 129 | + " plt.imshow(confusion_matrix)\n", |
| 130 | + " plt.savefig('Confusion-{}.jpg'.format(eta_power))" |
128 | 131 | ] |
129 | 132 | } |
130 | 133 | ], |
|
0 commit comments