Skip to content

Commit

Permalink
RFC
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinJohannesNilsen committed May 18, 2021
1 parent 2573c7a commit a4cfef5
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 1 deletion.
Binary file removed model/src/RFC_2_plot.png
Binary file not shown.
Binary file modified model/src/RFC_3_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
112 changes: 112 additions & 0 deletions model/src/rfc_visualisation.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion model/utils/Precision.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1998,7 +1998,7 @@
"output_cleared": false,
"deepnote_cell_type": "code"
},
"source": "model_type: \"ANN\" or \"CNN\" or \"LSTM\" or \"RFC\" = \"CNN\"\nids = [1]\nPRINT_CLASSIFICATION_MATRIX = False\n\nfor run_i in ids:\n if model_type != \"RFC\":\n model = load_model(f\"../models/{model_type}_{AMOUNT_OF_SENSORS}_sensor{'er' if AMOUNT_OF_SENSORS > 1 else ''}_{run_i}.h5\")\n else:\n model = load(f\"../models/RFC_model_{AMOUNT_OF_SENSORS}.joblib\")\n print(\"Classification accuracy:\")\n classification_dict = dict()\n accuracy_list = list() \n\n for key in x_test_dict:\n x_test_numpy = x_test_dict[key].drop([' TimeStamp (s)', 'Pose'], axis=1).values\n if model_type == \"RFC\": \n x_test_numpy = x_test_dict[key].drop([' TimeStamp (s)', 'Pose'], axis=1)\n if model_type == \"CNN\":\n x_test_numpy = x_test_numpy.reshape(x_test_numpy.shape[0],x_test_numpy.shape[1], 1)\n classify = model.predict(x_test_numpy)\n classifications = [i.argmax() for i in classify]\n if model_type == 'RFC': classifications = classify\n annotated_positions = y_test_dict[key].to_numpy()\n correct_classifications = (classifications == annotated_positions).sum()\n accuracy_list.append(round(correct_classifications/len(classifications)*100,2))\n print(f\"{key}: {accuracy_list[-1]}%\")\n classification_dict[key] = classifications\n print(f\"Average accuracy: {round(sum(accuracy_list)/len(accuracy_list),2)}%\")\n\n if PRINT_CLASSIFICATION_MATRIX:\n width=16\n height=4*len(x_test_dict)\n fig, axes = plt.subplots(len(x_test_dict), 4, figsize=(width, height))\n for i, key in enumerate(x_test_dict):\n cols = [f\"{key}: Model classifications\", f\"{key}: Model heatmap\", f\"{key}: Annotated classifications\", f\"{key}: Annotated heatmap\"]\n df_predict = pd.DataFrame({' Timestamp (s)': x_test_dict[key][' TimeStamp (s)'],'Pose':classification_dict[key]})\n sns.lineplot(ax=axes[i, 0], data=df_predict,x=' Timestamp (s)',y='Pose')\n sns.heatmap(ax=axes[i, 1], data=confusion_matrix(y_test_dict[key], classification_dict[key]), cmap=\"YlGnBu\", annot=True, fmt=\"d\")\n sns.lineplot(ax=axes[i, 2], data=x_test_dict[key], x=\" TimeStamp (s)\", y='Pose')\n sns.heatmap(ax=axes[i, 3], data=confusion_matrix(y_test_dict[key], y_test_dict[key].to_numpy()), cmap=\"YlGnBu\", annot=True, fmt=\"d\")\n for ax, col in zip(axes[i], cols): ax.set_title(col)\n fig.tight_layout()\n plt.show()\n fig.savefig(f\"{model_type}_plot.png\")",
"source": "model_type: \"ANN\" or \"CNN\" or \"LSTM\" or \"RFC\" = \"CNN\"\nids = [1]\nPRINT_CLASSIFICATION_MATRIX = False\n\nfor run_i in ids:\n if model_type != \"RFC\":\n model = load_model(f\"../models/{model_type}_{AMOUNT_OF_SENSORS}_sensor{'er' if AMOUNT_OF_SENSORS > 1 else ''}_{run_i}.h5\")\n else:\n model = load(f\"../models/RFC_model_{AMOUNT_OF_SENSORS}.joblib\")\n print(\"Classification accuracy:\")\n classification_dict = dict()\n accuracy_list = list() \n\n for key in x_test_dict:\n x_test_numpy = x_test_dict[key].drop([' TimeStamp (s)', 'Pose'], axis=1).values\n if model_type == \"RFC\": \n x_test_numpy = x_test_dict[key].drop([' TimeStamp (s)', 'Pose'], axis=1)\n if model_type == \"CNN\":\n x_test_numpy = x_test_numpy.reshape(x_test_numpy.shape[0],x_test_numpy.shape[1], 1)\n classify = model.predict(x_test_numpy)\n classifications = [i.argmax() for i in classify]\n if model_type == 'RFC': classifications = classify\n annotated_positions = y_test_dict[key].to_numpy()\n correct_classifications = (classifications == annotated_positions).sum()\n accuracy_list.append(round(correct_classifications/len(classifications)*100,2))\n print(f\"{key}: {accuracy_list[-1]}%\")\n classification_dict[key] = classifications\n print(f\"Average accuracy: {round(sum(accuracy_list)/len(accuracy_list),2)}%\")\n \n\n if PRINT_CLASSIFICATION_MATRIX:\n width=16\n height=4*len(x_test_dict)\n fig, axes = plt.subplots(len(x_test_dict), 4, figsize=(width, height))\n for i, key in enumerate(x_test_dict):\n cols = [f\"{key}: Model classifications\", f\"{key}: Model heatmap\", f\"{key}: Annotated classifications\", f\"{key}: Annotated heatmap\"]\n df_predict = pd.DataFrame({' Timestamp (s)': x_test_dict[key][' TimeStamp (s)'],'Pose':classification_dict[key]})\n sns.lineplot(ax=axes[i, 0], data=df_predict,x=' Timestamp (s)',y='Pose')\n sns.heatmap(ax=axes[i, 1], data=confusion_matrix(y_test_dict[key], classification_dict[key]), cmap=\"YlGnBu\", annot=True, fmt=\"d\")\n sns.lineplot(ax=axes[i, 2], data=x_test_dict[key], x=\" TimeStamp (s)\", y='Pose')\n sns.heatmap(ax=axes[i, 3], data=confusion_matrix(y_test_dict[key], y_test_dict[key].to_numpy()), cmap=\"YlGnBu\", annot=True, fmt=\"d\")\n for ax, col in zip(axes[i], cols): ax.set_title(col)\n fig.tight_layout()\n plt.show()\n fig.savefig(f\"{model_type}_plot.png\")",
"execution_count": 5,
"outputs": [
{
Expand Down

0 comments on commit a4cfef5

Please sign in to comment.