|
16 | 16 | },
|
17 | 17 | {
|
18 | 18 | "cell_type": "code",
|
19 |
| - "execution_count": 2, |
| 19 | + "execution_count": 1, |
20 | 20 | "metadata": {},
|
21 | 21 | "outputs": [],
|
22 | 22 | "source": [
|
23 | 23 | "from sklearn.datasets import load_diabetes\n",
|
24 | 24 | "from sklearn.linear_model import Ridge\n",
|
25 | 25 | "from sklearn.metrics import mean_squared_error\n",
|
26 | 26 | "from sklearn.model_selection import train_test_split\n",
|
27 |
| - "import joblib" |
| 27 | + "import joblib\n", |
| 28 | + "import pandas as pd" |
28 | 29 | ]
|
29 | 30 | },
|
30 | 31 | {
|
|
36 | 37 | },
|
37 | 38 | {
|
38 | 39 | "cell_type": "code",
|
39 |
| - "execution_count": 3, |
| 40 | + "execution_count": 6, |
40 | 41 | "metadata": {},
|
41 | 42 | "outputs": [],
|
42 | 43 | "source": [
|
43 |
| - "X, y = load_diabetes(return_X_y=True)" |
| 44 | + "sample_data = load_diabetes()\n", |
| 45 | + "\n", |
| 46 | + "df = pd.DataFrame(\n", |
| 47 | + " data=sample_data.data,\n", |
| 48 | + " columns=sample_data.feature_names)\n", |
| 49 | + "df['Y'] = sample_data.target" |
44 | 50 | ]
|
45 | 51 | },
|
46 | 52 | {
|
47 | 53 | "cell_type": "code",
|
48 |
| - "execution_count": 4, |
| 54 | + "execution_count": 7, |
49 | 55 | "metadata": {},
|
50 | 56 | "outputs": [
|
51 | 57 | {
|
|
57 | 63 | }
|
58 | 64 | ],
|
59 | 65 | "source": [
|
60 |
| - "print(X.shape)" |
61 |
| - ] |
62 |
| - }, |
63 |
| - { |
64 |
| - "cell_type": "code", |
65 |
| - "execution_count": 5, |
66 |
| - "metadata": {}, |
67 |
| - "outputs": [ |
68 |
| - { |
69 |
| - "name": "stdout", |
70 |
| - "output_type": "stream", |
71 |
| - "text": [ |
72 |
| - "(442,)\n" |
73 |
| - ] |
74 |
| - } |
75 |
| - ], |
76 |
| - "source": [ |
77 |
| - "print(y.shape)" |
| 66 | + "print(df.shape)" |
78 | 67 | ]
|
79 | 68 | },
|
80 | 69 | {
|
81 | 70 | "cell_type": "code",
|
82 |
| - "execution_count": 8, |
| 71 | + "execution_count": 11, |
83 | 72 | "metadata": {},
|
84 | 73 | "outputs": [
|
85 | 74 | {
|
|
103 | 92 | " <thead>\n",
|
104 | 93 | " <tr style=\"text-align: right;\">\n",
|
105 | 94 | " <th></th>\n",
|
106 |
| - " <th>0</th>\n", |
107 |
| - " <th>1</th>\n", |
108 |
| - " <th>2</th>\n", |
109 |
| - " <th>3</th>\n", |
110 |
| - " <th>4</th>\n", |
111 |
| - " <th>5</th>\n", |
112 |
| - " <th>6</th>\n", |
113 |
| - " <th>7</th>\n", |
114 |
| - " <th>8</th>\n", |
115 |
| - " <th>9</th>\n", |
| 95 | + " <th>age</th>\n", |
| 96 | + " <th>sex</th>\n", |
| 97 | + " <th>bmi</th>\n", |
| 98 | + " <th>bp</th>\n", |
| 99 | + " <th>s1</th>\n", |
| 100 | + " <th>s2</th>\n", |
| 101 | + " <th>s3</th>\n", |
| 102 | + " <th>s4</th>\n", |
| 103 | + " <th>s5</th>\n", |
| 104 | + " <th>s6</th>\n", |
| 105 | + " <th>Y</th>\n", |
116 | 106 | " </tr>\n",
|
117 | 107 | " </thead>\n",
|
118 | 108 | " <tbody>\n",
|
|
128 | 118 | " <td>4.420000e+02</td>\n",
|
129 | 119 | " <td>4.420000e+02</td>\n",
|
130 | 120 | " <td>4.420000e+02</td>\n",
|
| 121 | + " <td>442.000000</td>\n", |
131 | 122 | " </tr>\n",
|
132 | 123 | " <tr>\n",
|
133 | 124 | " <td>mean</td>\n",
|
134 |
| - " <td>-3.639623e-16</td>\n", |
135 |
| - " <td>1.309912e-16</td>\n", |
136 |
| - " <td>-8.013951e-16</td>\n", |
137 |
| - " <td>1.289818e-16</td>\n", |
138 |
| - " <td>-9.042540e-17</td>\n", |
139 |
| - " <td>1.301121e-16</td>\n", |
140 |
| - " <td>-4.563971e-16</td>\n", |
141 |
| - " <td>3.863174e-16</td>\n", |
142 |
| - " <td>-3.848103e-16</td>\n", |
143 |
| - " <td>-3.398488e-16</td>\n", |
| 125 | + " <td>-3.634285e-16</td>\n", |
| 126 | + " <td>1.308343e-16</td>\n", |
| 127 | + " <td>-8.045349e-16</td>\n", |
| 128 | + " <td>1.281655e-16</td>\n", |
| 129 | + " <td>-8.835316e-17</td>\n", |
| 130 | + " <td>1.327024e-16</td>\n", |
| 131 | + " <td>-4.574646e-16</td>\n", |
| 132 | + " <td>3.777301e-16</td>\n", |
| 133 | + " <td>-3.830854e-16</td>\n", |
| 134 | + " <td>-3.412882e-16</td>\n", |
| 135 | + " <td>152.133484</td>\n", |
144 | 136 | " </tr>\n",
|
145 | 137 | " <tr>\n",
|
146 | 138 | " <td>std</td>\n",
|
|
154 | 146 | " <td>4.761905e-02</td>\n",
|
155 | 147 | " <td>4.761905e-02</td>\n",
|
156 | 148 | " <td>4.761905e-02</td>\n",
|
| 149 | + " <td>77.093005</td>\n", |
157 | 150 | " </tr>\n",
|
158 | 151 | " <tr>\n",
|
159 | 152 | " <td>min</td>\n",
|
|
167 | 160 | " <td>-7.639450e-02</td>\n",
|
168 | 161 | " <td>-1.260974e-01</td>\n",
|
169 | 162 | " <td>-1.377672e-01</td>\n",
|
| 163 | + " <td>25.000000</td>\n", |
170 | 164 | " </tr>\n",
|
171 | 165 | " <tr>\n",
|
172 | 166 | " <td>25%</td>\n",
|
|
180 | 174 | " <td>-3.949338e-02</td>\n",
|
181 | 175 | " <td>-3.324879e-02</td>\n",
|
182 | 176 | " <td>-3.317903e-02</td>\n",
|
| 177 | + " <td>87.000000</td>\n", |
183 | 178 | " </tr>\n",
|
184 | 179 | " <tr>\n",
|
185 | 180 | " <td>50%</td>\n",
|
|
193 | 188 | " <td>-2.592262e-03</td>\n",
|
194 | 189 | " <td>-1.947634e-03</td>\n",
|
195 | 190 | " <td>-1.077698e-03</td>\n",
|
| 191 | + " <td>140.500000</td>\n", |
196 | 192 | " </tr>\n",
|
197 | 193 | " <tr>\n",
|
198 | 194 | " <td>75%</td>\n",
|
|
206 | 202 | " <td>3.430886e-02</td>\n",
|
207 | 203 | " <td>3.243323e-02</td>\n",
|
208 | 204 | " <td>2.791705e-02</td>\n",
|
| 205 | + " <td>211.500000</td>\n", |
209 | 206 | " </tr>\n",
|
210 | 207 | " <tr>\n",
|
211 | 208 | " <td>max</td>\n",
|
|
219 | 216 | " <td>1.852344e-01</td>\n",
|
220 | 217 | " <td>1.335990e-01</td>\n",
|
221 | 218 | " <td>1.356118e-01</td>\n",
|
| 219 | + " <td>346.000000</td>\n", |
222 | 220 | " </tr>\n",
|
223 | 221 | " </tbody>\n",
|
224 | 222 | "</table>\n",
|
225 | 223 | "</div>"
|
226 | 224 | ],
|
227 | 225 | "text/plain": [
|
228 |
| - " 0 1 2 3 4 \\\n", |
| 226 | + " age sex bmi bp s1 \\\n", |
229 | 227 | "count 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 \n",
|
230 |
| - "mean -3.639623e-16 1.309912e-16 -8.013951e-16 1.289818e-16 -9.042540e-17 \n", |
| 228 | + "mean -3.634285e-16 1.308343e-16 -8.045349e-16 1.281655e-16 -8.835316e-17 \n", |
231 | 229 | "std 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 \n",
|
232 | 230 | "min -1.072256e-01 -4.464164e-02 -9.027530e-02 -1.123996e-01 -1.267807e-01 \n",
|
233 | 231 | "25% -3.729927e-02 -4.464164e-02 -3.422907e-02 -3.665645e-02 -3.424784e-02 \n",
|
234 | 232 | "50% 5.383060e-03 -4.464164e-02 -7.283766e-03 -5.670611e-03 -4.320866e-03 \n",
|
235 | 233 | "75% 3.807591e-02 5.068012e-02 3.124802e-02 3.564384e-02 2.835801e-02 \n",
|
236 | 234 | "max 1.107267e-01 5.068012e-02 1.705552e-01 1.320442e-01 1.539137e-01 \n",
|
237 | 235 | "\n",
|
238 |
| - " 5 6 7 8 9 \n", |
239 |
| - "count 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 \n", |
240 |
| - "mean 1.301121e-16 -4.563971e-16 3.863174e-16 -3.848103e-16 -3.398488e-16 \n", |
241 |
| - "std 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 \n", |
242 |
| - "min -1.156131e-01 -1.023071e-01 -7.639450e-02 -1.260974e-01 -1.377672e-01 \n", |
243 |
| - "25% -3.035840e-02 -3.511716e-02 -3.949338e-02 -3.324879e-02 -3.317903e-02 \n", |
244 |
| - "50% -3.819065e-03 -6.584468e-03 -2.592262e-03 -1.947634e-03 -1.077698e-03 \n", |
245 |
| - "75% 2.984439e-02 2.931150e-02 3.430886e-02 3.243323e-02 2.791705e-02 \n", |
246 |
| - "max 1.987880e-01 1.811791e-01 1.852344e-01 1.335990e-01 1.356118e-01 " |
| 236 | + " s2 s3 s4 s5 s6 \\\n", |
| 237 | + "count 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 \n", |
| 238 | + "mean 1.327024e-16 -4.574646e-16 3.777301e-16 -3.830854e-16 -3.412882e-16 \n", |
| 239 | + "std 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 \n", |
| 240 | + "min -1.156131e-01 -1.023071e-01 -7.639450e-02 -1.260974e-01 -1.377672e-01 \n", |
| 241 | + "25% -3.035840e-02 -3.511716e-02 -3.949338e-02 -3.324879e-02 -3.317903e-02 \n", |
| 242 | + "50% -3.819065e-03 -6.584468e-03 -2.592262e-03 -1.947634e-03 -1.077698e-03 \n", |
| 243 | + "75% 2.984439e-02 2.931150e-02 3.430886e-02 3.243323e-02 2.791705e-02 \n", |
| 244 | + "max 1.987880e-01 1.811791e-01 1.852344e-01 1.335990e-01 1.356118e-01 \n", |
| 245 | + "\n", |
| 246 | + " Y \n", |
| 247 | + "count 442.000000 \n", |
| 248 | + "mean 152.133484 \n", |
| 249 | + "std 77.093005 \n", |
| 250 | + "min 25.000000 \n", |
| 251 | + "25% 87.000000 \n", |
| 252 | + "50% 140.500000 \n", |
| 253 | + "75% 211.500000 \n", |
| 254 | + "max 346.000000 " |
247 | 255 | ]
|
248 | 256 | },
|
249 |
| - "execution_count": 8, |
| 257 | + "execution_count": 11, |
250 | 258 | "metadata": {},
|
251 | 259 | "output_type": "execute_result"
|
252 | 260 | }
|
253 | 261 | ],
|
254 | 262 | "source": [
|
255 |
| - "import pandas as pd\n", |
256 |
| - "features = pd.DataFrame(X)\n", |
257 |
| - "features.describe()" |
| 263 | + "# All data in a single dataframe\n", |
| 264 | + "df.describe()" |
258 | 265 | ]
|
259 | 266 | },
|
260 | 267 | {
|
|
266 | 273 | },
|
267 | 274 | {
|
268 | 275 | "cell_type": "code",
|
269 |
| - "execution_count": 3, |
| 276 | + "execution_count": 12, |
270 | 277 | "metadata": {},
|
271 | 278 | "outputs": [],
|
272 | 279 | "source": [
|
273 |
| - "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)\n", |
| 280 | + "X = df.drop('Y', axis=1).values\n", |
| 281 | + "y = df['Y'].values\n", |
| 282 | + "\n", |
| 283 | + "X_train, X_test, y_train, y_test = train_test_split(\n", |
| 284 | + " X, y, test_size=0.2, random_state=0)\n", |
274 | 285 | "data = {\"train\": {\"X\": X_train, \"y\": y_train},\n",
|
275 | 286 | " \"test\": {\"X\": X_test, \"y\": y_test}}"
|
276 | 287 | ]
|
|
284 | 295 | },
|
285 | 296 | {
|
286 | 297 | "cell_type": "code",
|
287 |
| - "execution_count": 4, |
| 298 | + "execution_count": 16, |
288 | 299 | "metadata": {},
|
289 | 300 | "outputs": [
|
290 | 301 | {
|
|
294 | 305 | " normalize=False, random_state=None, solver='auto', tol=0.001)"
|
295 | 306 | ]
|
296 | 307 | },
|
297 |
| - "execution_count": 4, |
| 308 | + "execution_count": 16, |
298 | 309 | "metadata": {},
|
299 | 310 | "output_type": "execute_result"
|
300 | 311 | }
|
301 | 312 | ],
|
302 | 313 | "source": [
|
303 |
| - "alpha = 0.5\n", |
| 314 | + "# experiment parameters\n", |
| 315 | + "args = {\n", |
| 316 | + " \"alpha\": 0.5\n", |
| 317 | + "}\n", |
304 | 318 | "\n",
|
305 |
| - "reg = Ridge(alpha=alpha)\n", |
306 |
| - "reg.fit(data[\"train\"][\"X\"], data[\"train\"][\"y\"])" |
| 319 | + "reg_model = Ridge(**args)\n", |
| 320 | + "reg_model.fit(data[\"train\"][\"X\"], data[\"train\"][\"y\"])" |
307 | 321 | ]
|
308 | 322 | },
|
309 | 323 | {
|
|
315 | 329 | },
|
316 | 330 | {
|
317 | 331 | "cell_type": "code",
|
318 |
| - "execution_count": 6, |
| 332 | + "execution_count": 18, |
319 | 333 | "metadata": {},
|
320 | 334 | "outputs": [
|
321 | 335 | {
|
322 | 336 | "name": "stdout",
|
323 | 337 | "output_type": "stream",
|
324 | 338 | "text": [
|
325 |
| - "mse: 3298.9096058070622\n" |
| 339 | + "{'mse': 3298.9096058070622}\n" |
326 | 340 | ]
|
327 | 341 | }
|
328 | 342 | ],
|
329 | 343 | "source": [
|
330 |
| - "preds = reg.predict(data[\"test\"][\"X\"])\n", |
331 |
| - "print(\"mse: \", mean_squared_error(preds, y_test))" |
| 344 | + "preds = reg_model.predict(data[\"test\"][\"X\"])\n", |
| 345 | + "mse = mean_squared_error(preds, y_test)\n", |
| 346 | + "metrics = {\"mse\": mse}\n", |
| 347 | + "print(metrics)" |
332 | 348 | ]
|
333 | 349 | },
|
334 | 350 | {
|
|
363 | 379 | ],
|
364 | 380 | "metadata": {
|
365 | 381 | "kernelspec": {
|
366 |
| - "display_name": "Python (storedna)", |
| 382 | + "display_name": "Python 3", |
367 | 383 | "language": "python",
|
368 |
| - "name": "storedna" |
| 384 | + "name": "python3" |
369 | 385 | },
|
370 | 386 | "language_info": {
|
371 | 387 | "codemirror_mode": {
|
|
377 | 393 | "name": "python",
|
378 | 394 | "nbconvert_exporter": "python",
|
379 | 395 | "pygments_lexer": "ipython3",
|
380 |
| - "version": "3.6.9" |
| 396 | + "version": "3.7.4" |
381 | 397 | }
|
382 | 398 | },
|
383 | 399 | "nbformat": 4,
|
|
0 commit comments