Skip to content

Commit

Permalink
Merge pull request #457 from autonomio/TA444_fix_broken_example_notebook
Browse files Browse the repository at this point in the history
fixed recover best model example notebook
  • Loading branch information
mikkokotila authored Jan 26, 2020
2 parents 034c8c2 + 80f26fe commit 38b992c
Showing 1 changed file with 17 additions and 271 deletions.
288 changes: 17 additions & 271 deletions examples/Recover Best Models from Experiment Log.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,10 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(0, '/Users/mikko/Documents/GitHub/talos/')\n",
"\n",
"import sys\n",
"sys.path.insert(0, '/Users/mikko/Documents/GitHub/wrangle/')\n",
"\n",
"import talos\n",
"import wrangle\n",
"from keras.models import Sequential\n",
Expand All @@ -51,17 +37,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10/10 [00:34<00:00, 3.28s/it]\n"
]
}
],
"outputs": [],
"source": [
"# load the data\n",
"x, y = talos.templates.datasets.iris()\n",
Expand Down Expand Up @@ -103,7 +81,7 @@
" x_val=x_val,\n",
" y_val=y_val,\n",
" model=iris_model,\n",
" experiment_name='reactivate',\n",
" experiment_name='minimal_iris',\n",
" params=p,\n",
" round_limit=10)"
]
Expand All @@ -117,50 +95,34 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total 96\r\n",
"-rw-r--r-- 1 mikko staff 16K Sep 26 20:17 092619200831.csv\r\n",
"-rw-r--r-- 1 mikko staff 1.3K Sep 26 20:18 092619201824.csv\r\n",
"-rw-r--r-- 1 mikko staff 1.3K Sep 26 22:13 092619221236.csv\r\n",
"-rw-r--r-- 1 mikko staff 1.3K Sep 26 22:18 092619221803.csv\r\n",
"-rw-r--r-- 1 mikko staff 1.3K Sep 26 22:31 092619223042.csv\r\n",
"-rw-r--r-- 1 mikko staff 1.3K Sep 26 22:35 092619223459.csv\r\n",
"-rw-r--r-- 1 mikko staff 1.3K Sep 26 22:35 092619223524.csv\r\n",
"-rw-r--r-- 1 mikko staff 1.3K Sep 26 22:56 092619225556.csv\r\n",
"-rw-r--r-- 1 mikko staff 1.3K Sep 26 23:04 092619230425.csv\r\n"
]
}
],
"outputs": [],
"source": [
"# get the name of the experiment log\n",
"!ls -lhtr reactivate"
"!ls -lhtr minimal_iris"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this case it will be the most recent one `092619223042.csv` so let's go ahead and recover the best models."
"What you want to do, is get the name of the `.csv` file you want to use, and use it as part of the input for `experiment_log` in the next step."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from talos.utils.recover_best_model import recover_best_model\n",
"\n",
"results, models = recover_best_model(x_train=x_train,\n",
" y_train=y_train,\n",
" x_val=x_val,\n",
" y_val=y_val,\n",
" experiment_log='reactivate/092619221803.csv',\n",
" experiment_log='minimal_iris/012620102735.csv',\n",
" input_model=iris_model,\n",
" n_models=5,\n",
" task='multi_label')"
Expand All @@ -175,170 +137,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>round_epochs</th>\n",
" <th>val_loss</th>\n",
" <th>val_acc</th>\n",
" <th>loss</th>\n",
" <th>acc</th>\n",
" <th>activation</th>\n",
" <th>batch_size</th>\n",
" <th>dropout</th>\n",
" <th>epochs</th>\n",
" <th>first_neuron</th>\n",
" <th>hidden_layers</th>\n",
" <th>losses</th>\n",
" <th>optimizer</th>\n",
" <th>shapes</th>\n",
" <th>crossval_mean_f1score</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>10</td>\n",
" <td>0.028760</td>\n",
" <td>0.955556</td>\n",
" <td>0.031489</td>\n",
" <td>0.885714</td>\n",
" <td>relu</td>\n",
" <td>20</td>\n",
" <td>0.3</td>\n",
" <td>10</td>\n",
" <td>128</td>\n",
" <td>2</td>\n",
" <td>logcosh</td>\n",
" <td>Adam</td>\n",
" <td>brick</td>\n",
" <td>0.930236</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>10</td>\n",
" <td>0.028902</td>\n",
" <td>0.933333</td>\n",
" <td>0.029437</td>\n",
" <td>0.866667</td>\n",
" <td>elu</td>\n",
" <td>50</td>\n",
" <td>0.3</td>\n",
" <td>10</td>\n",
" <td>128</td>\n",
" <td>2</td>\n",
" <td>logcosh</td>\n",
" <td>Adam</td>\n",
" <td>brick</td>\n",
" <td>0.900427</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>0.023247</td>\n",
" <td>0.888889</td>\n",
" <td>0.020156</td>\n",
" <td>0.942857</td>\n",
" <td>elu</td>\n",
" <td>30</td>\n",
" <td>0.3</td>\n",
" <td>10</td>\n",
" <td>128</td>\n",
" <td>2</td>\n",
" <td>logcosh</td>\n",
" <td>Adam</td>\n",
" <td>brick</td>\n",
" <td>0.980606</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>10</td>\n",
" <td>0.044717</td>\n",
" <td>0.866667</td>\n",
" <td>0.048760</td>\n",
" <td>0.780952</td>\n",
" <td>relu</td>\n",
" <td>30</td>\n",
" <td>0.4</td>\n",
" <td>10</td>\n",
" <td>128</td>\n",
" <td>2</td>\n",
" <td>logcosh</td>\n",
" <td>Adam</td>\n",
" <td>brick</td>\n",
" <td>0.529231</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>10</td>\n",
" <td>0.047718</td>\n",
" <td>0.866667</td>\n",
" <td>0.052877</td>\n",
" <td>0.761905</td>\n",
" <td>elu</td>\n",
" <td>30</td>\n",
" <td>0.3</td>\n",
" <td>10</td>\n",
" <td>64</td>\n",
" <td>1</td>\n",
" <td>logcosh</td>\n",
" <td>Adam</td>\n",
" <td>brick</td>\n",
" <td>0.860539</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" round_epochs val_loss val_acc loss acc activation \\\n",
"0 10 0.028760 0.955556 0.031489 0.885714 relu \n",
"9 10 0.028902 0.933333 0.029437 0.866667 elu \n",
"2 10 0.023247 0.888889 0.020156 0.942857 elu \n",
"4 10 0.044717 0.866667 0.048760 0.780952 relu \n",
"8 10 0.047718 0.866667 0.052877 0.761905 elu \n",
"\n",
" batch_size dropout epochs first_neuron hidden_layers losses \\\n",
"0 20 0.3 10 128 2 logcosh \n",
"9 50 0.3 10 128 2 logcosh \n",
"2 30 0.3 10 128 2 logcosh \n",
"4 30 0.4 10 128 2 logcosh \n",
"8 30 0.3 10 64 1 logcosh \n",
"\n",
" optimizer shapes crossval_mean_f1score \n",
"0 Adam brick 0.930236 \n",
"9 Adam brick 0.900427 \n",
"2 Adam brick 0.980606 \n",
"4 Adam brick 0.529231 \n",
"8 Adam brick 0.860539 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"results"
]
Expand All @@ -352,64 +153,9 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[4.0995691e-02, 7.6433158e-01, 1.9467270e-01],\n",
" [5.5384491e-02, 8.0346000e-01, 1.4115542e-01],\n",
" [1.4972138e-01, 7.5501865e-01, 9.5259979e-02],\n",
" [1.1326348e-01, 7.1604651e-01, 1.7069000e-01],\n",
" [3.7939314e-02, 6.4578909e-01, 3.1627160e-01],\n",
" [7.1953669e-02, 7.3772454e-01, 1.9032180e-01],\n",
" [1.9050457e-02, 5.2020442e-01, 4.6074522e-01],\n",
" [3.7893724e-02, 6.0179770e-01, 3.6030859e-01],\n",
" [9.7751850e-03, 3.3405378e-01, 6.5617102e-01],\n",
" [1.1282027e-02, 3.9736888e-01, 5.9134912e-01],\n",
" [3.4884610e-03, 2.8260693e-01, 7.1390456e-01],\n",
" [1.4066804e-02, 4.4939899e-01, 5.3653419e-01],\n",
" [9.8388308e-01, 1.5294192e-02, 8.2262064e-04],\n",
" [9.4405776e-03, 2.6316279e-01, 7.2739667e-01],\n",
" [5.3356937e-03, 2.5790238e-01, 7.3676193e-01],\n",
" [6.6629532e-03, 2.9248431e-01, 7.0085275e-01],\n",
" [1.0408297e-02, 2.8865287e-01, 7.0093888e-01],\n",
" [9.2165405e-03, 4.4728053e-01, 5.4350299e-01],\n",
" [6.1816044e-02, 7.1130496e-01, 2.2687899e-01],\n",
" [5.8197163e-02, 7.0744330e-01, 2.3435953e-01],\n",
" [9.7906607e-01, 1.9747239e-02, 1.1866244e-03],\n",
" [4.9119215e-02, 6.6068804e-01, 2.9019269e-01],\n",
" [9.7612959e-01, 2.2444952e-02, 1.4254292e-03],\n",
" [4.5029860e-02, 6.6044801e-01, 2.9452211e-01],\n",
" [9.9142039e-01, 8.2095861e-03, 3.7001175e-04],\n",
" [9.7515500e-01, 2.3289582e-02, 1.5554430e-03],\n",
" [5.3465478e-02, 6.7216325e-01, 2.7437130e-01],\n",
" [3.2891510e-03, 1.4500402e-01, 8.5170686e-01],\n",
" [4.0943369e-02, 7.4083865e-01, 2.1821795e-01],\n",
" [9.7216946e-01, 2.5830602e-02, 1.9999193e-03],\n",
" [5.3125862e-02, 7.0550537e-01, 2.4136871e-01],\n",
" [1.1228154e-01, 7.2705197e-01, 1.6066651e-01],\n",
" [2.1662652e-02, 5.3180271e-01, 4.4653463e-01],\n",
" [4.6057135e-02, 7.8316230e-01, 1.7078057e-01],\n",
" [3.8668580e-02, 5.8559459e-01, 3.7573683e-01],\n",
" [2.0084916e-02, 4.8307988e-01, 4.9683511e-01],\n",
" [2.2133207e-02, 5.7616937e-01, 4.0169743e-01],\n",
" [9.6933258e-01, 2.8047977e-02, 2.6194439e-03],\n",
" [6.2396564e-02, 6.9619435e-01, 2.4140903e-01],\n",
" [9.8845774e-01, 1.0924198e-02, 6.1809190e-04],\n",
" [1.1867868e-02, 2.8613701e-01, 7.0199513e-01],\n",
" [1.3917084e-02, 4.5086253e-01, 5.3522038e-01],\n",
" [3.7958041e-02, 7.3275220e-01, 2.2928977e-01],\n",
" [1.5825152e-02, 5.4144788e-01, 4.4272691e-01],\n",
" [9.7893941e-01, 1.9864958e-02, 1.1956602e-03]], dtype=float32)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"models[0].predict(x_val)"
]
Expand All @@ -431,7 +177,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.6.9"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 38b992c

Please sign in to comment.