diff --git a/docs/tutorials/runtime_comparison.ipynb b/docs/tutorials/runtime_comparison.ipynb new file mode 100644 index 0000000..38d1021 --- /dev/null +++ b/docs/tutorials/runtime_comparison.ipynb @@ -0,0 +1,404 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "import time\n", + "\n", + "from sklearn.datasets import load_digits, fetch_openml\n", + "from sklearn.utils import check_random_state\n", + "\n", + "from sklearn.ensemble import RandomForestClassifier as RF\n", + "from oblique_forests.sporf import ObliqueForestClassifier as SPORF\n", + "from oblique_forests.morf import Conv2DObliqueForestClassifier as MORF\n", + "\n", + "sns.set_palette('Set1')\n", + "mpl.rcParams.update({\n", + " \"axes.titlesize\": \"xx-large\",\n", + " \"axes.spines.top\": False,\n", + " \"axes.spines.right\": False,\n", + " \"xtick.bottom\": False,\n", + " \"ytick.left\": False,\n", + " \"image.cmap\": \"inferno\",\n", + "})\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Digits Dataset\n", + "These are 8x8 images of handwritten digits from `sklearn.datasets`" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(100, 64) (100,)\n" + ] + } + ], + "source": [ + "from sklearn.datasets import load_digits\n", + "\n", + "images, labels = load_digits(return_X_y=True)\n", + "\n", + "# Get 100 samples of 3s and 5s\n", + "n = 100\n", + "threes = np.where(labels == 3)[0][:(n // 2)]\n", + "fives = np.where(labels == 5)[0][:(n // 2)]\n", + "idx = np.concatenate((threes, fives))\n", + "\n", + "# Subset train data\n", + "X = images[idx]\n", + "y = labels[idx]\n", + "\n", + "# Apply random shuffling\n", + "permuted_idx = np.random.permutation(len(idx))\n", + "X = X[permuted_idx]\n", + "y = y[permuted_idx]\n", + "\n", + "print(X.shape, y.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-05-04T17:54:42.347403\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "\n" + }, + "metadata": {} + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", + "avg_3 = images[threes].mean(axis=0).reshape(8, 8)\n", + "avg_5 = images[fives].mean(axis=0).reshape(8, 8)\n", + "diff = np.abs(avg_3 - avg_5)\n", + "\n", + "axs[0].imshow(avg_3)\n", + "axs[0].set_title(\"Average 3\")\n", + "\n", + "axs[1].imshow(avg_5)\n", + "axs[1].set_title(\"Average 5\")\n", + "\n", + "axs[2].imshow(diff)\n", + "axs[2].set_title(\"Absolute Difference\")\n", + "for ax in axs:\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + "fig.tight_layout()" + ] + }, + { + "source": [ + "## `RF` vs `SPORF` runtimes" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def time_clf_digits(clf, ns, random_state=None):\n", + " runtimes = np.empty(len(ns))\n", + " rng = check_random_state(random_state)\n", + " images, labels = load_digits(return_X_y=True)\n", + " for i, n in enumerate(ns):\n", + " # Get only 3s and 5s\n", + " threes = np.where(labels == 3)[0][:(n // 2)]\n", + " fives = np.where(labels == 5)[0][:(n // 2)]\n", + " idx = np.concatenate((threes, fives))\n", + " X = images[idx]\n", + " y = labels[idx]\n", + "\n", + " # Shuffle samples\n", + " permuted_idx = rng.permutation(len(idx))\n", + " X = X[permuted_idx].reshape(n, -1)\n", + " y = y[permuted_idx].reshape(n)\n", + "\n", + " # Begin timing\n", + " start = time.time()\n", + " clf.fit(X, y)\n", + " runtimes[i] = time.time() - start\n", + " return runtimes" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def rename_clf(clf):\n", + " if isinstance(clf, RF):\n", + " return \"RF\"\n", + " elif isinstance(clf, SPORF):\n", + " return \"SPORF\"\n", + " elif isinstance(clf, MORF):\n", + " return \"MORF\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "ns = [10, 20, 40, 50, 100, 150, 200, 300]\n", + "\n", + "clfs = [\n", + " RF(random_state=0, n_jobs=1),\n", + " SPORF(random_state=0, n_jobs=1),\n", + " # MORF(random_state=0, image_height=8, image_width=8, n_jobs=1) # Too slow\n", + "]\n", + "\n", + "runtimes = []\n", + "for clf in clfs:\n", + " runtimes.append(time_clf_digits(clf, ns, random_state=0))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-05-04T17:54:56.738210\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "for runtime, clf in zip(runtimes, clfs):\n", + " ax.plot(ns, runtime, label=rename_clf(clf))\n", + "ax.legend()\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MNIST Dataset\n", + "These are 28x28 images of handwritten digits from `keras.datasets`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(100, 784) (100,)\n" + ] + } + ], + "source": [ + "from sklearn.datasets import fetch_openml\n", + "\n", + "images, labels = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)\n", + "\n", + "n = 100\n", + "threes = np.where(labels == '3')[0][:(n // 2)]\n", + "fives = np.where(labels == '5')[0][:(n // 2)]\n", + "idx = np.concatenate((threes, fives))\n", + "\n", + "# Subset train data\n", + "X = images[idx]\n", + "y = labels[idx]\n", + "\n", + "# Apply random shuffling\n", + "permuted_idx = np.random.permutation(len(idx))\n", + "X = X[permuted_idx]\n", + "y = y[permuted_idx]\n", + "\n", + "print(X.shape, y.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-05-04T17:55:20.302682\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "\n" + }, + "metadata": {} + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", + "avg_3 = images[threes].mean(axis=0).reshape(28, 28)\n", + "avg_5 = images[fives].mean(axis=0).reshape(28, 28)\n", + "diff = np.abs(avg_3 - avg_5)\n", + "\n", + "axs[0].imshow(avg_3)\n", + "axs[0].set_title(\"Average 3\")\n", + "\n", + "axs[1].imshow(avg_5)\n", + "axs[1].set_title(\"Average 5\")\n", + "\n", + "axs[2].imshow(diff)\n", + "axs[2].set_title(\"Absolute Difference\")\n", + "for ax in axs:\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def time_clf_mnist(clf, ns, random_state=None):\n", + " runtimes = np.empty(len(ns))\n", + " rng = check_random_state(random_state)\n", + " images, labels = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)\n", + " for i, n in enumerate(ns):\n", + " # Get only 3s and 5s\n", + " threes = np.where(labels == '3')[0][:(n // 2)]\n", + " fives = np.where(labels == '5')[0][:(n // 2)]\n", + " idx = np.concatenate((threes, fives))\n", + " X = images[idx]\n", + " y = labels[idx]\n", + "\n", + " # Shuffle samples\n", + " permuted_idx = rng.permutation(len(idx))\n", + " X = X[permuted_idx]\n", + " y = y[permuted_idx]\n", + "\n", + " # Begin timing\n", + " start = time.time()\n", + " clf.fit(X, y)\n", + " runtimes[i] = time.time() - start\n", + " return runtimes" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "ns = [10, 20, 40, 50, 100, 150, 200, 300, 500]\n", + "\n", + "clfs = [\n", + " RF(random_state=0, n_jobs=1),\n", + " SPORF(random_state=0, n_jobs=1),\n", + " # MORF(random_state=0, image_height=8, image_width=8, n_jobs=1) # Too slow\n", + "]\n", + "\n", + "runtimes = []\n", + "for clf in clfs:\n", + " runtimes.append(time_clf_mnist(clf, ns, random_state=0))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-05-04T18:07:35.981486\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "for runtime, clf in zip(runtimes, clfs):\n", + " ax.plot(ns, runtime, label=rename_clf(clf))\n", + "ax.legend()\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "name": "python380jvsc74a57bd039ca1c7a169e56d6a333ccd59f8c6786beb2b8f5c3cc68b80d4610822621472b", + "display_name": "Python 3.8.0 64-bit ('ProgLearn': conda)" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/docs/tutorials/test_feature_importance.ipynb b/docs/tutorials/test_feature_importance.ipynb new file mode 100644 index 0000000..db90ff5 --- /dev/null +++ b/docs/tutorials/test_feature_importance.ipynb @@ -0,0 +1,529 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Feature Importance\n", + "One of the benefits to decision trees is that their results are fairly interpretable in that they allow for estimation of the relative importance of each feature.\n", + "\n", + "There are many approaches that have been suggested for quantifying importances. The `scikit-learn` implementation of the random forest quantifies these importances using the Gini impurity. For `SPORF` and `MORF`, we use a projection forest specific metric to quantify feature importance by computing the normalized count of the number of times a feature $k$ was used in projections across the ensemble of decision trees.\n", + "\n", + "Specifically, consider our forest $\\mathcal{T}$ to be a collection of decision trees $\\{T_i\\}_{i=1}^n$, where each decision tree $T_i \\in \\mathcal{T}$ is composed of many nodes $j$. Let $\\mathcal{A}$ be the set of unique atoms across all nodes and all trees in our forest. For each feature $k$, its importance $\\pi_k$ is computed as the number of times an atom assigns it a nonzero weighting, followed by a normalization of $|\\mathcal{A}|$.\n", + "$$\n", + "\\pi_k = \\frac{1}{|\\mathcal{A}|} \\sum_{T_i \\in \\mathcal{T}} \\sum_{j \\in T_i} \\mathbb{I}(a_{jk} \\not= 0)\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.metrics import accuracy_score\n", + "from oblique_forests.tree.oblique_tree import ObliqueTreeClassifier as OTC\n", + "from oblique_forests.sporf import ObliqueForestClassifier as SPORF\n", + "from oblique_forests.morf import Conv2DObliqueForestClassifier as MORF\n", + "\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "\n", + "mpl.rcParams.update({\n", + " \"axes.titlesize\": \"xx-large\",\n", + " \"axes.spines.bottom\": False,\n", + " \"axes.spines.left\": False,\n", + " \"xtick.bottom\": False,\n", + " \"ytick.left\": False,\n", + " \"image.cmap\": \"inferno\",\n", + " \"image.aspect\": 1\n", + "})\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Digits Dataset\n", + "We visualize feature importances identified by `RF`, `SPORF`, and `MORF` on a subset of the MNIST dataset. We only consider threes and fives and use 100 8x8 images from each class." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(1797, 64) (1797,)\n" + ] + } + ], + "source": [ + "from sklearn.datasets import load_digits\n", + "\n", + "images, labels = load_digits(return_X_y=True)\n", + "print(images.shape, labels.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(200, 64) (200,)\n" + ] + } + ], + "source": [ + "# Get 100 samples of 3s and 5s\n", + "num = 100\n", + "threes = np.where(labels == 3)[0][:num]\n", + "fives = np.where(labels == 5)[0][:num]\n", + "idx = np.concatenate((threes, fives))\n", + "\n", + "# Subset train data\n", + "X = images[idx]\n", + "y = labels[idx]\n", + "\n", + "# Apply random shuffling\n", + "permuted_idx = np.random.permutation(len(idx))\n", + "X = X[permuted_idx]\n", + "y = y[permuted_idx]\n", + "\n", + "print(X.shape, y.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-04-21T03:14:56.804734\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "\n" + }, + "metadata": {} + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", + "avg_3 = images[threes].mean(axis=0).reshape(8, 8)\n", + "avg_5 = images[fives].mean(axis=0).reshape(8, 8)\n", + "diff = np.abs(avg_3 - avg_5)\n", + "\n", + "axs[0].imshow(avg_3)\n", + "axs[0].set_title(\"Average 3\")\n", + "\n", + "axs[1].imshow(avg_5)\n", + "axs[1].set_title(\"Average 5\")\n", + "\n", + "axs[2].imshow(diff)\n", + "axs[2].set_title(\"Absolute Difference\")\n", + "for ax in axs:\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "clfs = [\n", + " RandomForestClassifier(random_state=0),\n", + " SPORF(random_state=0),\n", + " MORF(random_state=0, image_height=8, image_width=8)\n", + "]\n", + "for clf in clfs:\n", + " clf.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def rename_clf(clf):\n", + " if isinstance(clf, RandomForestClassifier):\n", + " return \"RF\"\n", + " elif isinstance(clf, SPORF):\n", + " return \"SPORF\"\n", + " elif isinstance(clf, MORF):\n", + " return \"MORF\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "output_type": "error", + "ename": "AttributeError", + "evalue": "'Conv2DObliqueTreeClassifier' object has no attribute 'tree_'", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0max\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclfs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mimportances\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_importances_\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0msns\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheatmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimportances\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcmap\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'inferno'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msquare\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0max\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m ax.tick_params(\n", + "\u001b[0;32m/opt/anaconda3/envs/ProgLearn/lib/python3.8/site-packages/oblique_forests/morf.py\u001b[0m in \u001b[0;36mfeature_importances_\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;31m# 1. Find all unique atoms in the forest\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;31m# 2. Compute number of times each atom appears across all trees\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 210\u001b[0;31m forest_projections = [\n\u001b[0m\u001b[1;32m 211\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_vec\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtree\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators_\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/anaconda3/envs/ProgLearn/lib/python3.8/site-packages/oblique_forests/morf.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_vec\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtree\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators_\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 213\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mtree\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnode_count\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 214\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mnode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtree\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnodes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_vec\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'Conv2DObliqueTreeClassifier' object has no attribute 'tree_'" + ] + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(19, 4))\n", + "\n", + "for clf, ax in zip(clfs, axs):\n", + " importances = clf.feature_importances_\n", + " sns.heatmap(importances.reshape(8, 8), cmap='inferno', square=True, ax=ax)\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + " ax.set_title(f\"{rename_clf(clf)} Importance\")\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MNIST Dataset\n", + "We visualize feature importances identified by `RF`, `SPORF`, and `MORF` on a subset of the MNIST dataset. We only consider threes and fives and use 100 28x28 images from each class." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(60000, 28, 28)\n(60000,)\n(10000, 28, 28)\n(10000,)\n" + ] + } + ], + "source": [ + "# from sklearn.datasets import fetch_openml\n", + "from keras.datasets import mnist\n", + "\n", + "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n", + "print(X_train.shape, y_train.shape, X_test.shape, y_test.shape, sep='\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(200, 28, 28)\n(200,)\n(1010, 28, 28)\n(1010,)\n" + ] + } + ], + "source": [ + "# Get 100 samples of 3s and 5s\n", + "num = 100\n", + "threes = np.where(y_train == 3)[0][:num]\n", + "fives = np.where(y_train == 5)[0][:num]\n", + "train_idx = np.concatenate((threes, fives))\n", + "\n", + "# Subset train data\n", + "Xtrain = X_train[train_idx]\n", + "ytrain = y_train[train_idx]\n", + "\n", + "# Apply random shuffling\n", + "permuted_idx = np.random.permutation(len(train_idx))\n", + "Xtrain = Xtrain[permuted_idx]\n", + "ytrain = ytrain[permuted_idx]\n", + "\n", + "# Subset test data\n", + "test_idx = np.where(y_test == 3)[0]\n", + "Xtest = X_test[test_idx]\n", + "ytest = y_test[test_idx]\n", + "\n", + "print(Xtrain.shape, ytrain.shape, Xtest.shape, ytest.shape, sep='\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-04-21T03:26:15.950727\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA0cAAAEjCAYAAAD5QHrmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAwoklEQVR4nO3deZhcV33m8fdXvXdrae2LtdmSF1nGYGyMMYQYA3FiEh4YSCAkEEgymYGEhMmwJCSTQBIymQwBMhDIHpNglmCzhSwQTAwx4A2DV9mytVmyrKUldav3ruXMH7caF+WW3iPRbrWk7+d5+rF1++1bp6ruPXV/dW/VL1JKAgAAAIAzXelkDwAAAAAAZgOKIwAAAAAQxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAACcdiLiuohIEdF6Em57R0RcN9O3+1SpP47valq2IiJuiIi+xt9HxNyI+OuIeLy+/LqTMGT8ACiOTrKIuKq+80xExOKTPZ7ZKiLeEBE31Seb8Yh4LCL+OSKed7LHBpwqmG/yRMS76o/TVD/MOTjpTud9OSIW1vfBq56Cda9r2p8rEXEwIr4dER+MiKcfx+reJ+nH6/99raTP1Je/U9IvSPq7+vK/mM77gKfejL+bgCd5naTdkpZJ+mlJHzy5w5m1nilpj6QvSeqTtETSz0r6ekS8MqX0mWP9MQBJzDfH662S9jUte+hkDARocjrvywsl/W79/29+im7j85JukBSS5ku6SMXj+MsR8Ycppd9uyndJqjQte4Gkr6SU/nCK5fenlN45/cPGTKA4OokiolvSKyV9QNIlKia7kzLBRUR3SmnkZNx2jpTSm5uXRcSHJG2V9DY98Y4NgCkw35yQz6eUHjnZgwAazaZ9+RR2X0rpY40LIuJtkj4p6bciYmtK6e8mf5dSGptiHUsl9R9l+Z5pHKsiol1SNaVUnc71YmpcVndyvVzSXEnX138ui4gLJ38ZEe+vn/Jd3vyHEfG6+inhH2lY1hMRfxgRW+uXnu2JiA9HxIKmv705InZHxPkR8S8RcUTSv9R/97SI+JuIeDgiRiKiPyL+LSKeNdUdiIi31q8tHquflv7R+nXOO6bI/nBEfDkiBiJiNCJujYgfP7GHTkopDUs6KKn3RNcBnEGYb05gvomIeRHRcrx/BzyFjrkvT2F5RHy6vi8MRMT1EbG0MRARSyPiLyJiZ31/3hcRX42Iq5tyF0TEjRFxqL5f3RURr3UDjicuA7xqit997/NJ9d8/XP/V7zZc/nZdQz5r7jleKaVBFWeP+uu3HQ232fiZondFRFJx1ulnGsb4+vrysyU9t2H5VQ3reXlEfCMihuo/N0XElU2Px+vrf3dtRPzviNgtaUzS6vrvl0TEn0XEriguq9xRz3VM8bjeUp9nv1qfY/dGxB803reG/PPqc/Tkc/tQRLy/KdMaEe+IiAfq83BfRHwsIlad8AM/C1EcnVyvk3RnSukhFad4B+vLJl0vqUXSq6b429eouNzjJkmq7xQ3SXqzpM/V//txSW+QdFPzTiOpW9JXVLy78VZJ/1Bffo2kp0v6lKRfk/ReSedL+lpEXNC4goj4HUn/V9IOFWdvvqziXZdnNg82Il5RH1+XpHdLeoekJOkLETHV/ZtSRCyoTwybIuJ9ki5UcakdgGNjvjnO+UbSXZIGJI3WC60n3RZwErh9udkXJXWq+CzMxyS9WtKXozgbMenTKgqDj0t6k4p98ZAa9q+I2CDpW5JeKOnDkn5D0rikv4+It07LPZM2q5gjJOmzKj6z873P7ZzA3HNcUkpHVFyJslbSBUeJfaY+Jkn6ZsMY76j/t0/Sloblm+tjf0v9bw9J+k0Vlw6ukPQfEfHcKW7nj1U81u9VMYcNRcQiSbdK+klJH5X0KyqOgd4m6cYp1rGi/vu7Jf16fYy/JennG0P1efFmSZsk/Zmkt6h4fF/WkAkV28nvSfpPFXP2hyVdK+mb9bGdHlJK/JyEH0krJVUlvaVh2UdVXENcalj2kKTbmv52iaSypA80LHu7pAlJz2zKvlTFQcEvNiy7ub7sN6YYV88UyxZL2i/pz5uWjUv6uqSWhuUvrK97R8OybhWTxQ1N621RsaM+KikyH7cd9fUnSaMqJsyuk/188sPPbP5hvjm++UbFgcFHVBxwvkzFQcyApBFJl53s55OfM/cnd1+uL7+uvn98qmn5m+vL31j/9/z6v99mbvsfJdUa9wFJ7ZJuq78eL25YvkPSdQ3/vqp+G1dNsd7m7IZ69l1TZLPnnqPch3X13B8cI/M/6pmXNix70njqyz52lPtzS9OyVfVxv7dp+Zz6nHRLw7LX19d9t6T2pvyHVRRXq5uW/2r9b17UNI4k6WVN2bsl3d7w77mSDqso6Hqbso2vDz9VX9+PN2UuUfF5rKM+pqfaD2eOTp6fVbGRfbJh2fWSzlLxgt+47PKIWN+w7FUqPi92fcOyV0u6U9KjEbF48kfFuzyjTeuc9JHmBam4VE1ScV1zwzsBt0tqvNTlRSomxT9LDdfAppRuknR/02pfJGmRpI81jW2BpH9Vcar4vCnGN5WfUfFu83+tj6lbxTtiAI6O+eY45puU0gdSSm9MKf19SulzKaV3S7pSRYH13mP9LfAUy92XG32g6d9/KWlI0k/U/z2q4sD9qqO9+x/FpaXXSvpqSunOyeUppYn6+jsl/chUfzvNTmTuOV6D9f/OnYZ1Tfovktokfbxp3J0qzoQ9J4rPkjX6m/rjK+l7Z25epeKs+WjTev69Hmu+/4+nlD7XtOxmSY1z/ItVfDzh/6SU+huDKaVawz9fraIIv7XptndJ2jbFbZ+y+EKGk+d1Kt7F7IyIdfVlW1VU7z+nJzb0j6u4LOQ1kn6/vuw1kh5OKd3RsL4LVFxCcuAot7e06d+HUkoDzaGImCfpD1Scsm3+7MH2hv+fHPPDerIt+v5LXSZPTX/2KGObHJ/9FqiU0jcaxnqditPLX5D0Q+5vgTMY882Tx3dc3zqXUro/Ir4o6WUR0ZVSGj2evwemSe6+3Oj7tvWU0ngUn9M7u/7viYh4u4rCf29E3KHiMtiPp5QerP/ZEkk9ql8i1uSB+n/PPsH7dDyOd+45EZNF0eAxU8dncl769jEyi1ScnZ60ven3S1R8k9+rNPXlz9KT7//OKTKH6+uZdG79v/ccY2xScR9W6eiP/Wnz2UyKo5MgIi5VcV2n9OSNX5JeHhFzU0qDKaVHIuJ21Q9W6pPhc1QcwDQqqXjn5HeOcrOHm/59tBf2T6io/t+vJ663r6m4PrbxnYbJD/OlKdbR/EG/yTOUb5R0tG9+uu8oy48qpVSJiE9L+qOIODelNNWBE3BGY76Z0nHPN3U76+tfoKPfJ+ApcTz7ctPyqfab7w+k9KcR8VkVl6e9UMWlZb8ZEb+UGr617SjrOtb+mTOG4zmoPt6550Q8rf7f6TymmJyXXq7irN1UmouO5jlmch2fVXF53VSavyXveL7dzm0nJRVniP7bUX5/2syJFEcnx8+puH7+NSoOBBqtULHRv1JFAzGpOGX+pxFxiaQfqy/7eNPfPaLiWtGvnOigIqJXxWnzd6eU3tX0u99vik9OzOdJ+k7T785t+vfkAcqhH2R8RzF5Sd0P9C01wGmM+Wb6bFBxbf2haV4vkON49+VJF6goKCR970sN1qn4UP33pJQelfQhSR+K4pvfviXpPfX1HZA0rOJLkJpNnhXZcYyxTxYtzd9m2Vkf+/cN5Rjr+YHnnmOJiPkqCpidkh408eMxOS/tSSndfoLrOCDpiKTOab7/k0XgxSouWTyaRyQ9X9LNKaXmnk+nFT5zNMMiok3FdZs3pZQ+U7+evfHnIyo+nNf4zTOfUlH9v0bFt8nckVLa0rTqT0jaGBGvmeI2WyJiYfPyKVRVTErft11E8TWUz27KfkXFNcq/HA1fcxsRL9QT72xN+pKKifGdEdE1xfiOeRq8/tWRTxp/RMxR8RmkET35cwfAGY/55vjnm3pmqvnmCj3xmYupep4AT5kT3JcnvaXp37+k4osAvlhfd3fzvpJSOqyi2FlQ/3dVxWf2rm781sb6uH5NRdH25WPchR0q3li4umn5r+rJZ44mz6xM9abndMw9U4qIuSreCOqV9Hup/m0D0+QGFff/XRHxpBMTOfNS/Tn4R0nXRMTzplhHZ/0+HK9/V/H15e+oF4eN62w8M/8JFZccvn2K2476549OC5w5mnnXqrhu9PPHyHxBxUHA2pTSzpTSvoj4iopTmXP15IlOkv6kvu6PRdHLY/Jdog2SXiHpt1V8c81RpZQGI+ImSW+vT5RbVJxefoOK4mNuQ7YvIv5IxantmyLiRhXv/rxR0r1N2cGI+EUVB133R8Q/qPgA3wpJV6h416nxEppmcyTtiogb6uM4pOJdr59Tcf3rrzR+sBvA9zDfHP98I0nb65cYPaDiUr+LJf2CioO2Xzd/CzwVjntfblh+fkT8k4ri5kI9sd/8Tf3356n4OukbVGzzQyrOEFwj6W8b1vNOFV94clMUTdgPqCjYrlDxTXd9RxtYSulIRFwv6U31A+77VFyy+0Mqvl2yMbsvIh6V9OqI2KKin+H2lNJtmoa5p+6iiPhZFZcEzlUx9/ykis/ivCel9LfH+uPjlVLaERH/U9KfSvp2RPyjim/lXKXim/xqkl6QsarfUPHcfLX+uevvqLiC5vz6+F+h4gsXjmdsgxHxJhVf8353REx+++E6Fc/v5Hz5cRXf3vmeKHozfVXFpXRn15d/XNK7jue2Z62T/XV5Z9qPiu+hr0lacYzMi1S8o/rbDcteW19WkbT8KH/XqeL76+9T0TCsX8VXNv6xpDUNuZsl7T7KOpaq6EGyX8Up9G+quP74OjV8XW49GyreQdhZv727JP2oindINk+x7stVXCvbp+JdpkdVTOavNo9Zu6T31dd/WMXXCu+r/+2LT/Zzyg8/s/WH+eb455v63/6lioPHfhVnrHapuLTonJP9nPJzZv6cyL6sJ77Ke5WK/jQDKi7L+oSkZQ1/t0jS/6vvy0dUFEf3qug31PxV0htV9Oo5XN8PvyPpdVOMZYcavp67vqxXxQH4gIovO/i8in5CU2Wfr+ISr7H6fbiu4XdZc89RHqN1eqIdSFJxBvtwfT75oKRnHOXvfqCv8m743TUqzrD118e+XcU3D17TkHm9mr6We4rH8Y9VXA43Xp/jblfRcmChG4eKAiZNsfwFKs4iTbYteFDSnzRlSpJ+WcUXS4zUn8cHVFyOeeHJ3k+m6yfqdxaYNhFxt6R9KaWZ+FpPAGcw5hsAwHTiM0c4YUe5nv9FKi5B+erMjwjA6Yr5BgAwEzhzhBMWEa9Wcf39F1RcFnORig967pd0cWpqJgYAJ4r5BgAwE/hCBvwg7lfxgcw3SVqs4hraGyS9kwMVANOM+QYA8JTjzBEAAAAAiM8cAQAAAIAkc1ldRHBaCTiz9aWUlszUjTHnAGe2lFL41PRgvgHOeFMe45jPHDU3LQZwZqnu9JnpxJwDnLmqM3x7zDfAmW3qYxwuqwMAAAAAURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSp9WQP4NQXmTlfh0b2un5wSSkjVZ3GW5y5+za9Nf90PgbA6aolI1PLyEzfvpszn+bNgzly7ltxi05kvCzP7Lin67ZwasrZt6WItqd4HE9IaTwjlTOX5O23kfEYzOwxVd7xVM64Z9LMH3eeOM4cAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkEQTWMM30MptfFaKDpvp7TzHZubHcptZVltqM3MzxrOwPW/zWNnlM+0ZvciqGf3BBsu++dmjwxW/IkmPaI/NbB+/1WbKlQNZtwdMj7zGfi2lOTbT1jrXZiLjPbTu1kVZY3IqWc0dpZ6Sv71uzbOZWkYTyNaMl8nu5B9rSVqQemxmSP4x2NvymM0cmNhiM+Pl/TaT0oTNYDbyr5XT2bg1pbLNzOvcYDNLW3xmQW2hzaxs9fvkhrl5zVRXdvljinLNryun5exjI36++dehbRlrkvZWNtvMZaUX2kxZ/v5v1rdt5sj4Tpup1QZtZiZw5ggAAAAARHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAICkU7IJbE7TLl/ztbb4BoFz2s/ymZbFGeORNtbOs5lN831Dtot7h23mwiV7bWb9+Y/YTPeyQzYjSbWy34yqY+02M7jPN3YcPOKbVt6/e43NSNI9h1fZzNf6Xmoz30yfs5lqdcBmUkajNZy6Ivw+EPJzQGuL3wckaU77MpuZV/JNpddU/X6yqsM3lT6r20Z03txRH5K0omfIZrpafVPKCN95enC802b2jvqMJFVr/vYOT/gHatfIRpvZVvPNNDe3+9eB3UNftxmJ+Wv6ZDSfzzgOasmYJxZ1nm8zXTHfZiRpk862mcszekWfO9fv22cv7POZ9XfYzPBA3lx63/ZzbOaxYd909r4BP0+MV/145qVeH5K0Y8I3qJ/b419zrlrqt8lHR37YZu4c8M/trWOfshlJSmksK3eiOHMEAAAAAKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABA0qxrApvR/Cx8w6r21oU2s7TjAptZmdH8cGNHXhOxixdM2MyVK7f627vsHpvp3rTPZtK61TZTXXiuzUiSOntspL3vMZvpefA+m1ly0DfSXLLVN8GVpHUP+8ZuNa23mYMHr7KZB8e+YjPVmm8Ui1NXKbpsprXF70s5c5cknV31TRlXtPjmrRsW+Mal63p889aLlvr9cuMz/fwmSZ0rD/pQ1b/3Vz7iH+/auH/NGevPex3YvcPPu9v6lvoxyTcxH634cR8aXWkzj5X8ditJqTaYlTuT5TSCbin5bXJ+xzqbWSnfBHhd+K6sZ3X74zJJWj/HN12+et02mzln0xab6VjkXytH9/r79vCuvIbxB8b8PtBf9ofTw/4hUt+E7wI7HnkNUJd0P81m7taDNrN2eJPNrOjy4z573G/bd5bz5tJyhSawAAAAAPCUozgCAAAAAFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEmzrAlsa0uvzbRlNEnsaVtiM8uqvvndoozmd4s6fINESVrdPWwzSxb7xoblYT+m8q45NhN7/G2V5vrGrZJUqmXFrFrJ1+q1Md9Er7XTN9yVpGXL9tvMWXt808a1sdhmHsnYbmkCeyoLnwi/ffe2+6aEG6q+ebEkrej0+8qqbr+es7r8/vS0pY/bTE4D6/aM5o65WhaP20x11DfBHdzlm7Ie6vMNJyVpIqMxa47uFj/pzm/3zTs7Rv14ejvztrfDow/ZTEpPbePG2a673b+ezGnNOH5Ja23mnJaFNjNR89vR4fG8Y5zuXt8ENCU/T97/Hd+4dMtB/xhdv9Nv2+PJj1mShsI3ud4fW21mU9poM2u6/bhXl/zxqyT1T6ywmU/2f8RmHqqdazMbMnq39rT6OWltx+V+RZK2126xmR/kmIozRwAAAAAgiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACTNsiawOVpLvmlfjr6WAzbTUfXNuPaOZnRRlPTgkXk207Jtvc2071xnMxNV32hruOwbRLZEXnfXtQsO2cz6DdtspjujAWR13D8nY0d8w1VJGhr0zXIHK/6xHKpVbKYUp9yuhuMyPe8ztUWnzXSV/DYpSe0l33CxM6OZ6IbewzZz1krfBLbUUbaZ6rC//5I0vNc3Xt6xdZ3N7D3SazPVjMaV4SOSpJJ8Q82c15SBst8GxjL6W7bLr2dZZDaB1eas3JmslNEIujP861JKfr/dUe23mbJ8g+eBqm8YL0l9u32j0DsOnmczNw5/2WZWtfjjl2t7fQPUavLrkaS/67/VZl7V8wKbecmqvowx+cnksWG/jUjSdw/74+W2Vj+X1pKftw5mHJt1Zbx0bQzfKFmStinv+PREceYIAAAAAERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJAktZ7sATSq1cZsplwbtZnIqPkOh19PqeTXMzHmuwtL0kRtrs3sGFpkMzldz4cqvnNwe8l3YT6rO6/t+7yOcZtp6/CduEvtPjM+0GMzlfG8rte7+5bazPYh3/X5cBz2Y8rYbnHqCuXtK349fs7pasloMy6pN2M3WN7p97n1q3bZTFuXnwNU84/RwO6Vfj2SvnXfRTazO7OLvFPOGPeRct5zMljxz+/+Md+Nvm+8YjMDadhm+ksDNrOn+oDNSFJK5azcmWys0m8zwy3zbKZa8s//ULXPZyYet5nutrxjnFuqW2xmXTzLZv5+46U2c9Ga7TYzd+Fmm/nirVfYjCQtOrLWZt70DL+fLFnqn5MD+/3jvXahX48kTdTOsZllY5ts5taJL9tM6nuxzXSXfMnxXeXNN7XaSFbuRHHmCAAAAABEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgKRZ1gQ2x0TFN62bSL4pZ6nkOySOxAGbOdDaazOStLOywGZ6K8tspiLf/G1O8k3kntnuG409Y8GQzUjSFZfdZTPzrnjUZmp9vlYf2b/QZrbvWm0zkvTVPctt5u5B/xjskm82V6v5ZpuYraanwWst+X13pObnrtGU0Qla0oJ2P+6Ll+61mWUXbLOZjrX9fkAdfjy77zvXr0fS/tFum9kx7Of4nIarhyb8491XG7QZSepr8a8pQ8k3eByvHbGZiapvAlvNaE5drfn1FPxjeaZLyTdoHxj3r5UHq3nbm9PV7l8Dn9dyVda6fnilf/5fct6DNrPh3b7Ba3nFBTbT9skv2Mz7P+UboErSP1zq567z3uebIKeMx3vp3d+0mcNfzmuWrZ2+CeyKqj9e2lu922a+Vr7eZjoyGgpPlP1rYCHvdfBEceYIAAAAAERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACApFnWBLaWRnwoq8+cD9UyGuTlNH+sVH1TWkmqtI3bTFe7b956sTbYzHOX+vt/zXrfjG3TT3zNZiSpduXTfGZsvc3EwFabObxvic3csTevQdodh3zTti1xj82MVfptppb884/ZKuM9pPCZUvjptiXabKarpcWPR9JZ3X4+veiy79pM+zW+8fLE0qf79XzdzycDg3NtRpJaSr6ZZlvJz4MDZb+e++IhmxnUfpuRpMHhnTaTMuaKNG0NV/39p7nr9KlU+zNS/jkplebYTFuL35fWtl5iM5ctynsP/YVrH7GZ897mG9yOrf4hm2n/6I0284r3vN5m/mSTb8osSc/8kG9M2rbpv9vMRMUfL6Y9/jjosR1rbEaS9oz615yBUr/N9LSvsJkjY75Z+NjE4zbzVDd3zcWZIwAAAAAQxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQRHEEAAAAAJJmWRPY2ddszteObS2+QaIkndP6LJt5yfxlNvPy9dttZuMzfePSzp9eZDOVS95nM5LU3jrfZsr3/rnNpH7fIGz3/qU2c/+Ab6QpSbvCN6QrV30jzUpGQ+GUxrLGhFNTyG9zOfvJPPkmx22Zb2ltXOwbk3b+jN+fSpe+xWYqh++0mXb5JrArV+y1GUnq6vT704aBXps5q8s/3r0HN9nMvTU/d0vSo53+yRsY2+JXlHIaJdLgdfbxz1tkHJblNJSem9G4c0nNH7/0T+RNOC0tfnuLR3bZzKG/uM9m/tdX3mAzF8732/bVr/uczUhSreKb5aa/+nmbGblrnc1843Z/rPixbf74TZK+VX3AZsryc+lo+ZDN5DSvPpXmG84cAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkDTrmsDOnJxGa+1tvkHi8o4Ls27vwhbfJPDH1uy2mUuuvdlmai+51GbG1z7fZnoymlZKUjmjCerEsgtspnvtHTazeH6/zSzvXGkzkrRo2DfA21UbtZlabSLj1iIjc+o0SDuzZDTcDP8+Uy1VbKY7zbGZntacbUnqaC3bTFR8Ztqs9Pvb8qvuzVrVsuEWm6kOddnM+VvW2sxZ911kMx17FtuMJKWRjTYz2uobLo6XfcNs5pNTU5KfJyKjCexI5aDNHG45YjNbh/Kaqt++e43N3P/B19jM53fNtZmvVfyxwjsXPt1m7r7xhTYjSR1f9A1O+4d+zGY2H/RNp2/r67aZhzOaskrSUOqzmSNj222mVGrPur3TCWeOAAAAAEAURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkiiOAAAAAEASxREAAAAASDqTm8BGh810tPhmZEuqy7Nu7+yMfqqVmm9sONHXazNd2zbbTNse3/grVXwzOkkqzeu1mfY5PpNWr7KZ8190q81c0eeb90rS1kHfuHFbxTeUPVD2zfZC/rnNaf6H2SllNAuW/CQwHmM2M1rNa+45OOaboOq2h2wkPfyLNuPbFkrlbT5VHV2UsSZJJf8YtC0YtJmlP3S/zTy76t9DHKlebDOSNLjbN4Hsi7Nt5kDG9lapHs4aE049larftnM81rLFZvanjHlEUvlRvw+s6/FzwA1H/spmrul+g830l32z7H/f6fc1SVrU4Ru9j1b84fQDA76Z6oqums3URnxGksaq/TaTkm9wW6n616WcY+qU/HpmC84cAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSJN/S9zSV0xU4yXchHg3fqVySHhuZbzNf37PcZnb+8zU2s/Y/B2xmPKOb81hGRpKW9AzZzNMvv8tmep6212ZKGc26z1nq1yNJKx733eqXDqy2mSNte2xmrJwxptTiM6pmZJArpmkKLJV6bKZWq9jMvtpWm9kztjhrTLc8vtJmxq9/qc1Uav49tNFym19P8uupJt/VXpLmtfv5e82iAzZz/vPvtJkF6/z+ffbuVTYjSasPzbGZNf3rbWaobb/PVAdtJslvk5g+Ee0ZKb+ftLcutJmU/PHLoZH7bKa1ZZ7NSNJdnWWb2T98ns2kNGEzfbVhm/lEn9/+x2PMZiRpYXWRzbxooT84OTiebGZgws+BC5OfR3LlzQF+TCn55z9n+895/mcCZ44AAAAAQBRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIOk2bwOY0moro8JmM2nGgdDBrTLdN+GZj9xzs9ivKuLlaxrhzquI5KadhnXTRnDU2U6n6BqfPXf4lmyn1TF+DsAXtvqFqd/KN3ToymuSNlX3TxshoOuxbyOG4hN8T5nSstZnuFt8kcKTqd972kp8DDqR+m5Gkbx3wzWI3DyyzmbGq3+rGqn7bbSv5x3pZZ977devn+oaDPW1+rhjv8825S+3+trozmtJK0pw2/1i2ys+VXS0LbGY4fPPalGgCO5NCvlnysu5n2EyX/GvOnvK9fjzhx1PL3EZaMu7bkei3maU9l9vMluodNtOW/Fw6X8ttRpK26R6b+Wi/b8z63JZLMm7NN1w9FEMZ68k7NpmIjAbmaSTj1nIa1PttZLbgzBEAAAAAiOIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJs64JbEbzu/azbKanbYnNDE48bjMRfjwjtcM2I0mDyTcBHav020y15psN1pJvfpjT4La91TdIlKSFo9fazNCEb7pbHvBN1No7+22mlnwTNUkar/rHoCujaVl3yTdkHMxoOpzbbA/TJyXfuC6n4eZ8+TlndWywmbaan3OGwzeUlqSt2msz5Qk/V4zJNxzMmQe75R/HjcMX2IwkrZ/r9/HVy/z9b18waDNjfb02c2Q0o4G3pImaH/f8km++3aqMJuaR8fJOV+kZldNM8+D41mm5rYvarraZoTa//R9Mu7Nub7zm54mWkn89zTk2yTk26wh/PHGefAN7SRrSCpt5IH3bZr5de9hmXtBxvs0sK+cdm+3IeH3L2SYj4/glpbxG2KcKzhwBAAAAgCiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQNMuawLa2zLOZRe3n2Mzy6iqb2dXum5El1WxmYCKvQVq56puttZZ6bKatxWfKvu+XInxdvKBtrV+RpNXdviHb2Ut9Q8bWbt/ccmJfr80MjvjHSJJGqr4hYyWjS2IpoyFdzraEk8HvLKNV3+B0U1xsM1cv89NtZ8lvJzfv77IZSeqr+uZ+B0oHbCZn281phHxO7VybeXqvn5cl6TkrHrWZxWt8o+/KkG/eOnbEzycD43nPyXDFzzkt4TOljIbpzDmzT0+HP37pzDgOOjRyr8381Oq5NtPb3mkzN+9bbTOSdCTjwOOeeMhmDk1st5n5bf4Y72JdaDNrunMPgX3utsN7bObitstsZrDsjzkqafr27ZwGrxEZx8s0gQUAAACA0w/FEQAAAACI4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABImmVNYNtb59vMpbHRZs5b6JvoHRi71GbuHzliM7szmslK0mhtwGZymvaVa76x45z2FTazJjbZzBXdi21Gkn7qHN8Id8OV37GZ1qX+vo18Z4nNjE2024wkTWT0USurYjMjNd8kNKVyzpAwCw2N7/Ih30tRVyzbZzPrlvtGght3r/E3JmnnoG8m+djI2TYzUPbz6fw237hwWdeEzayf55vSStKG9dtsJtUymjyP+bli5Mgcm9k/mtcE9oh/CDRS8800h1Jf1u1hdlnQ5huq/sqS823m3v4rbWbHsB/POy6902aevbrXr0jSrkOLbGZz/3k2M1zxx3g9rX4fOTzhGyVvOZLXTPXr1dtt5iVdr7SZdXP8+YjtQ35M81vzjjt7av4Y7kjyjXlzmsCebudaTq97AwAAAAAniOIIAAAAAERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJKl1Zm7GdyqXpPaS70S+rNPXc5csGLCZrraKzWwc8B3m7zz4dJuRpL4J3xq9LN8ZuVTyj+WGnk6bed4S3z77mmd+3WYkaemLN9tMbf16m0l3bbWZg3uW2sy9B5fYjCQdGEs2s7dlj82MTBzMuDX/3Cb58WC6+S7qpeiwmZx9d+/QXJvZkDEvPefZd9iMJF1Z8tvT6GE/xw0d8eMul/1LycLFfj9p7fTzpCS1tJdtJjLuf6r5+fSRx1bbzJYjXTYjSQNlv51sKfl5cHhsn82kNJ4xopzXZuYlJzIPpTpTj82snztoM09f5J/bX9/st7Xq166wmbdfeafNSNLzL3rQZp6VMd8c3Odfv2/aep7NbB7w2+0t1VttRpLevOTZNnPxgsM285XHF9nMvDZ/jLt9dNRmJGmg9pjNlEp+flfy21LWMU7WnDQ7cOYIAAAAAERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACApBlqAhvRlpUrhW/IWM3oR7eg0zfIOn/tTpu5cr5vxvbaibyHMFp8g6ycpoW1sn+M5p51wGY6z/ENGWsXbrAZSSovfK7NtD18j83s/dYmm/niA0+zmW/sz9ve7q76bWB/eYvNlCt+O8lrolbNyGCmpYzmdvtLfp/bcmSVzZyzb5nNLF6/y2YkqfM8v4/PzWg4uHQ0o7lfTu/WjKkyjea9X5cy5t2xvQtt5qHvXGQzdxzwTSn7xvIanT9a6beZ/uSf32p1xN9Y1pyD6RClvCbAOSrJ7wPP2PSAzfx119k2c+Mj/pjj/C/cZzOS1NvljxdK4ffbWhqwmaGJL9rMRW1X28ybFvsmuJK0qtsfU24f9A1uD437x/tbFX/MsXP0P2xGyjv2TmksYz3tGbeW0yzaH7/OluMgzhwBAAAAgCiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQNENNYFPKa+o0Xh2ymQeHffO7O/t8076ejnGbOT+jCeyiZzxsM5JUepofU3mVb6JWa/fN5kpjC2ymOjZsMy2H9tuMJMVd37CZ737WN2T7p23rbeb2g76x4d2622Ykqb/smy2Wq/5xqiXfIC7NksZmaOafl1Tzz++O8Tts5psHl9tMxFqb2TPQazOS9OwdvlHkonMftZn2c47YTMzPaBI44OfciT7fSFGS9j7g54qHd6+2me8cXGwzu4Z948Kto34bkaQdcb/NjE4cspmU/GOZspoyYjrkPB+SNB5+O7ll/zqb2bjHzyVLF/om0G+83GfeFL45uyTt2ufH1D/WbTNbBvwc8Piob27a2eK3/zltFZuRpNGKP1T+/B7/WvLNiRttpqO112aW9TzLZiTpwKifb7Lmkqwu3zkNXk+dxtScOQIAAAAAURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACBphprA5jRalKSxcp/NPNq902a+ceBcf1tV3yDw8EiPzVwy7JuyStLS2oM203bQN2TTYNlGasO+5h3ducxmdjxwnh+PpK9uv9Zmbjvgm7bdU3ncZh6v+aa741XftFLKa/Cakm8Sl5J/TkRDxlNWTjPNStU3jP6u7rSZUp9v7rd31Dd5lqStg8+2mTX3X+QzvYdtprPNNwk8NDzXZgbGOm1GkrYNzbGZR4f9nLN7xDcl3FLb49eT7rMZSRqd8K9x1YxtKa/Ba07DReal6ZDXJFM6MLHFZu6tnm0zn37wApu5qNdvR0u7h2zmgg1bbUaSrrj2P2ymNu73yVW3XGYzu/oX2syhcT+XfG2fP8aTpC9lNPnuG99sM4u7NtrMeM0/J/uGb7MZSYqMxqwtJd90t5oxptNtvuHMEQAAAACI4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAElSpHT0pkwRrTPasakUviHXvE7fIG1hyTd4nZN846sl8pneVt/UTJLmt4fN9ExTS9628E/bUMWP57HRvOa929I+m9mXttnM4IRvtpjT1CxXtTZqM7XkM7kNAE9N1W+nlHxXvmky03NOHr+v5GyXER0209G2yGbmti23GUmapyU+U5tvM93yzRQ7wt//2jFeayZNZDYMP1DyzVT75eelnIaLOU2lxzMamEuScppKy2dOX1WllPwON01mer5pa/X75NqOy23m7HSWzcxt9ftkT6t/f7wl89nIyVUzHu3Bsg/lNEHuK/vX5YdLD/oBSTow9oDNVKq+WXbO81+uHLCZjraVNiNJ1dq4zdRqYz6ThrNu79Q09TEOZ44AAAAAQBRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABImmVNYHNE+IaEEb4xaynabaal5Js2tpT8eqTpbV7q1FLZZiYqAzZTrfnmh4WZ20wi43lTqmWt68xutpiLJrCnqqx9JUvGe2gZzU1z5DR3LPh9PKsxb6nLZlJGs2jmkulyejeBzdHassBmSjkNpVvn2kx7aY5fT/iMJI0lf7wQGXNJyti3a8k3ix6aeMxmchq3FnKafPvjzpZSj81Uqv3TcluSlJJv8AqawAIAAADAUVEcAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkCS1nuwBHK+cjr85Gd+DWar4JsyYYSlNnOwhAKeE03tf8R3rc6TaqM+oMi23BeSoVA9Py3omKnunZT2QIqZnvqnWhnNSNpESB6dPNc4cAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkHQKNoEFAJzpMholzsAoAJz+Tu+G2pgKZ44AAAAAQBRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIMk1gU6rETA0EwOwTEf82k7fHnAOcuZhvAMyko805kRJ9xAEAAACAy+oAAAAAQBRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkCT9f+fwy/DnOu16AAAAAElFTkSuQmCC\n" + }, + "metadata": {} + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", + "avg_3 = X_train[threes].mean(axis=0)\n", + "avg_5 = X_train[fives].mean(axis=0)\n", + "diff = np.abs(avg_3 - avg_5)\n", + "\n", + "axs[0].imshow(avg_3)\n", + "axs[0].set_title(\"Average 3\")\n", + "\n", + "axs[1].imshow(avg_5)\n", + "axs[1].set_title(\"Average 5\")\n", + "\n", + "axs[2].imshow(diff)\n", + "axs[2].set_title(\"Absolute Difference\")\n", + "for ax in axs:\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "RandomForestClassifier(random_state=0)" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ], + "source": [ + "clf = RandomForestClassifier(random_state=0)\n", + "clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "def rename_clf(clf):\n", + " if isinstance(clf, RandomForestClassifier):\n", + " return \"RF\"\n", + " elif isinstance(clf, SPORF):\n", + " return \"SPORF\"\n", + " elif isinstance(clf, MORF):\n", + " return \"MORF\"" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-04-21T03:26:17.992359\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(5, 4))\n", + "importances = clf.feature_importances_\n", + "sns.heatmap(importances.reshape(28, 28), cmap='inferno', square=True)\n", + "ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + ")\n", + "ax.set_title(f\"{rename_clf(clf)} Importance\")\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "ObliqueForestClassifier(random_state=0)" + ] + }, + "metadata": {}, + "execution_count": 25 + } + ], + "source": [ + "clf = SPORF(n_estimators=100, random_state=0)\n", + "clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-04-21T03:28:42.502439\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(5, 4))\n", + "importances = clf.feature_importances_\n", + "sns.heatmap(importances.reshape(28, 28), cmap='inferno', square=True)\n", + "ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + ")\n", + "ax.set_title(f\"{rename_clf(clf)} Importance\")\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "# morf = MORF(random_state=0, image_height=28, image_width=28)\n", + "# morf.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# XXX: MORF fits too slowly\n", + "clfs = [\n", + " RandomForestClassifier(random_state=0),\n", + " SPORF(random_state=0),\n", + " MORF(random_state=0, image_height=28, image_width=28)\n", + "]\n", + "for clf in clfs:\n", + " clf.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(19, 4))\n", + "\n", + "for clf, ax in zip(clfs, ax):\n", + " importances = clf.feature_importances_\n", + " sns.heatmap(importances.reshape(28, 28), cmap='inferno', square=True, ax=ax)\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + " ax.set_title(f\"{rename_clf(clf)} Importance\")\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "name": "python380jvsc74a57bd039ca1c7a169e56d6a333ccd59f8c6786beb2b8f5c3cc68b80d4610822621472b", + "display_name": "Python 3.8.0 64-bit ('ProgLearn': conda)" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/oblique_forests/morf.py b/oblique_forests/morf.py index a32a121..b2f8fdd 100644 --- a/oblique_forests/morf.py +++ b/oblique_forests/morf.py @@ -1,4 +1,8 @@ +import numpy as np +from joblib import Parallel, delayed from sklearn.ensemble._forest import ForestClassifier +from sklearn.utils.validation import check_is_fitted +from sklearn.utils.fixes import _joblib_parallel_args from .tree.morf_tree import Conv2DObliqueTreeClassifier @@ -175,11 +179,55 @@ def __init__( # self.min_impurity_split = min_impurity_split # s-rerf params - # self.discontiguous_height = discontiguous_height - # self.discontiguous_width = discontiguous_width - # self.image_height = image_height - # self.image_width = image_width - # self.patch_height_max = patch_height_max - # self.patch_height_min = patch_height_min - # self.patch_width_max = patch_width_max - # self.patch_width_min = patch_width_min + self.discontiguous_height = discontiguous_height + self.discontiguous_width = discontiguous_width + self.image_height = image_height + self.image_width = image_width + self.patch_height_max = patch_height_max + self.patch_height_min = patch_height_min + self.patch_width_max = patch_width_max + self.patch_width_min = patch_width_min + + @property + def feature_importances_(self): + """ + Computes the importance of every unique feature used to make a split + in each tree of the forest. + + Parameters + ---------- + normalize : bool, default=True + A boolean to indicate whether to normalize feature importances. + + Returns + ------- + importances : array of shape [n_features] + Array of count-based feature importances. + """ + # TODO: Parallelize this and see if there is an equivalent way to express this better + # 1. Find all unique atoms in the forest + # 2. Compute number of times each atom appears across all trees + forest_projections = [ + node.proj_vec + for tree in self.estimators_ + if tree.tree_.node_count > 0 + for node in tree.tree_.nodes + if node.proj_vec is not None + ] + unique_projections, counts = np.unique( + forest_projections, axis=0, return_counts=True + ) + + if counts.sum() == 0: + return np.zeros(self.n_features_, dtype=np.float64) + + # 3. Count how many times each feature gets nonzero weight in unique projections + importances = np.zeros(self.n_features_) + for proj_vec, count in zip(unique_projections, counts): + importances[np.nonzero(proj_vec)] += count + + # 4. Normalize by number of unique projections + if len(unique_projections) > 0: + importances /= len(unique_projections) + + return importances diff --git a/oblique_forests/sporf.py b/oblique_forests/sporf.py index 557ae6d..515ad4e 100644 --- a/oblique_forests/sporf.py +++ b/oblique_forests/sporf.py @@ -1,6 +1,10 @@ +import numpy as np +from joblib import Parallel, delayed from sklearn.ensemble._forest import ForestClassifier +from sklearn.utils.validation import check_is_fitted +from sklearn.utils.fixes import _joblib_parallel_args -from .tree.oblique_tree import ObliqueTreeClassifier +from oblique_forests.tree.oblique_tree import ObliqueTreeClassifier class ObliqueForestClassifier(ForestClassifier): @@ -95,3 +99,47 @@ def __init__( # self.max_leaf_nodes = max_leaf_nodes # self.min_impurity_decrease = min_impurity_decrease # self.min_impurity_split = min_impurity_split + + @property + def feature_importances_(self): + """ + Computes the importance of every unique feature used to make a split + in each tree of the forest. + + Parameters + ---------- + normalize : bool, default=True + A boolean to indicate whether to normalize feature importances. + + Returns + ------- + importances : array of shape [n_features] + Array of count-based feature importances. + """ + # TODO: Parallelize this and see if there is an equivalent way to express this better + # 1. Find all unique atoms in the forest + # 2. Compute number of times each atom appears across all trees + forest_projections = [ + node.proj_vec + for tree in self.estimators_ + if tree.tree_.node_count > 0 + for node in tree.tree_.nodes + if node.proj_vec is not None + ] + unique_projections, counts = np.unique( + forest_projections, axis=0, return_counts=True + ) + + if counts.sum() == 0: + return np.zeros((self.n_features_), dtype=np.float64) + + # 3. Count how many times each feature gets nonzero weight in unique projections + importances = np.zeros((self.n_features_), dtype=np.float64) + for proj_vec, count in zip(unique_projections, counts): + importances[np.nonzero(proj_vec)] += count + + # 4. Normalize by number of unique projections + if len(unique_projections) > 0: + importances /= len(unique_projections) + + return importances \ No newline at end of file diff --git a/oblique_forests/tree/morf_tree.py b/oblique_forests/tree/morf_tree.py index 936cb9a..4dc31aa 100644 --- a/oblique_forests/tree/morf_tree.py +++ b/oblique_forests/tree/morf_tree.py @@ -265,7 +265,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): splitter = self._set_splitter(X, y) # create the Oblique tree - self.tree = ObliqueTree( + self.tree_ = ObliqueTree( splitter, self.min_samples_split, self.min_samples_leaf, @@ -273,7 +273,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): self.min_impurity_split, self.min_impurity_decrease, ) - self.tree.build() + self.tree_.build() return self diff --git a/oblique_forests/tree/oblique_tree.py b/oblique_forests/tree/oblique_tree.py index 16bf622..d439cab 100644 --- a/oblique_forests/tree/oblique_tree.py +++ b/oblique_forests/tree/oblique_tree.py @@ -3,6 +3,7 @@ from joblib import Parallel, delayed from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array, check_is_fitted, check_X_y +from sklearn.utils.fixes import _joblib_parallel_args from ._split import BaseObliqueSplitter from .oblique_base import BaseManifoldSplitter, Node, SplitInfo, StackRecord @@ -528,6 +529,37 @@ def predict(self, X, check_input=True): return predictions + def compute_feature_importances(self): + """ + Computes the importance of each feature (aka variable). + + Parameters + ---------- + unique_projections : ndarray of shape (n_proj, n_features) + Array of unique sampling projection vectors. + + Returns + ------- + importances : ndarray of shape (n_features,) + Normalized importance of each feature of the data matrix. + """ + projections = [ + node.proj_vec for node in self.nodes if node.proj_vec is not None + ] + unique_projections, counts = np.unique(projections, axis=0, return_counts=True) + + if counts.sum() == 0: + return np.zeros((self.splitter.n_features,)) + + importances = np.zeros((self.splitter.n_features,)) + for proj_vec, count in zip(unique_projections, counts): + importances[np.nonzero(proj_vec)] += count + + if len(unique_projections) > 0: + importances /= len(unique_projections) + + return importances + class ObliqueTreeClassifier(BaseEstimator): """ @@ -600,6 +632,7 @@ def __init__( # Max features self.max_features = max_features + self.n_jobs = n_jobs self.n_classes = None self.n_jobs = n_jobs @@ -640,7 +673,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): tree_func = self._tree_class() # instantiate the tree and build it - self.tree = tree_func( + self.tree_ = tree_func( splitter, self.min_samples_split, self.min_samples_leaf, @@ -648,7 +681,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): self.min_impurity_split, self.min_impurity_decrease, ) - self.tree.build() + self.tree_.build() return self @@ -666,7 +699,7 @@ def apply(self, X): pred_nodes : array of shape[n_samples] The indices for each test sample's final node in the oblique tree. """ - pred_nodes = self.tree.predict(X).astype(int) + pred_nodes = self.tree_.predict(X).astype(int) return pred_nodes def predict(self, X, check_input=True): @@ -689,7 +722,7 @@ def predict(self, X, check_input=True): pred_nodes = self.apply(X) for k in range(len(pred_nodes)): id = pred_nodes[k] - preds[k] = self.tree.nodes[id].label + preds[k] = self.tree_.nodes[id].label return preds @@ -713,7 +746,7 @@ def predict_proba(self, X, check_input=True): pred_nodes = self.apply(X) for k in range(len(preds)): id = pred_nodes[k] - preds[k] = self.tree.nodes[id].proba + preds[k] = self.tree_.nodes[id].proba return preds @@ -737,3 +770,53 @@ def predict_log_proba(self, X, check_input=True): # TODO: Actually do this function def _validate_X_predict(self, X, check_input=True): return X + + @property + def feature_importances_(self): + """ + Return the feature importances. + The importance of a feature is computed as the number of times it + is used in a projection across all split nodes + + Returns + ------- + feature_importances_ : ndarray of shape (n_features,) + Array of count-based feature importances. + """ + check_is_fitted(self) + + return self.tree_.compute_feature_importances() + + def compute_projection_counts(self, unique_projections=None): + """ + Counts the number of times each unique projection in the tree appears. + + Parameters + ---------- + unique_projections : ndarray of shape (n_proj,), optional + Array of unique projections to count, by default None + + Returns + ------- + projection_counts : ndarray of shape (n_proj,) + Counts of each unique projection used in this tree. + """ + check_is_fitted(self) + + if unique_projections is None: + projections = [ + node.proj_vec + for node in self.tree_.nodes + if node.proj_vec is not None + ] + unique_projections, counts = np.unique(projections, axis=0, return_counts=True) + return counts, unique_projections + + # TODO: see if joblib will speed up at all for this for loop + n_proj = len(unique_projections) + counts = np.zeros(n_proj) + for node in self.tree_.nodes: + projection_idx = np.where((unique_projections == node.proj_vec).all(axis=1)) + counts[projection_idx] += 1 + + return counts, unique_projections \ No newline at end of file diff --git a/oblique_forests/tree/tests/test_morf_tree.py b/oblique_forests/tree/tests/test_morf_tree.py index 9cbd12c..32c1bf0 100644 --- a/oblique_forests/tree/tests/test_morf_tree.py +++ b/oblique_forests/tree/tests/test_morf_tree.py @@ -10,6 +10,10 @@ from sklearn.utils.validation import check_random_state from oblique_forests.tree.morf_split import Conv2DSplitter +from oblique_forests.sporf import ObliqueForestClassifier as SPORF +from oblique_forests.tree.oblique_tree import ObliqueTreeClassifier as OTC +from oblique_forests.morf import Conv2DObliqueForestClassifier as MORF +from oblique_forests.tree.morf_tree import Conv2DObliqueTreeClassifier # toy sample X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] @@ -43,7 +47,7 @@ def test_convolutional_splitter(): y[:25] = 0 splitter = Conv2DSplitter( - X, + X.reshape(n, -1), y, max_features=1, feature_combinations=1.5, @@ -52,4 +56,60 @@ def test_convolutional_splitter(): image_width=d, patch_height_max=2, patch_height_min=2, + patch_width_max=3, + patch_width_min=3, ) + + splitter.sample_proj_mat(splitter.indices) + + +if __name__ == "__main__": + + test_convolutional_splitter() + + # from sklearn.datasets import fetch_openml + from keras.datasets import mnist + import time + + (X_train, y_train), (X_test, y_test) = mnist.load_data() + + # Get 100 samples of 3s and 5s + num = 100 + threes = np.where(y_train == 3)[0][:num] + fives = np.where(y_train == 5)[0][:num] + train_idx = np.concatenate((threes, fives)) + + # Subset train data + Xtrain = X_train[train_idx] + ytrain = y_train[train_idx] + + # Apply random shuffling + permuted_idx = np.random.permutation(len(train_idx)) + Xtrain = Xtrain[permuted_idx] + ytrain = ytrain[permuted_idx] + + # Subset test data + test_idx = np.where(y_test == 3)[0] + Xtest = X_test[test_idx] + ytest = y_test[test_idx] + + print(f"-----{2 * num} samples") + + clf = OTC(random_state=0) + start = time.time() + clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain) + elapsed = time.time() - start + print(elapsed) + print(f"SPORF Tree: {elapsed} sec") + + clf = Conv2DObliqueTreeClassifier(image_height=28, image_width=28, random_state=0) + start = time.time() + clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain) + elapsed = time.time() - start + print(f"MORF Tree: {elapsed} sec") + + clf = SPORF(n_estimators=100, random_state=0) + start = time.time() + clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain) + elapsed = time.time() - start + print(f"SPORF: {elapsed} sec") \ No newline at end of file diff --git a/oblique_forests/tree/tests/test_splitter.py b/oblique_forests/tree/tests/test_splitter.py index 4ef97e9..30b2767 100644 --- a/oblique_forests/tree/tests/test_splitter.py +++ b/oblique_forests/tree/tests/test_splitter.py @@ -34,16 +34,16 @@ def test_argmin(self): assert 4 == j def test_matmul(self): - + b = BOS() A = np.zeros((3, 3), dtype=np.float64) B = np.ones((3, 3), dtype=np.float64) - + for i in range(3): for j in range(3): - A[i, j] = 3*i + j + 1 - + A[i, j] = 3 * i + j + 1 + res = b.test_matmul(A, B) C = np.ones((3, 3), dtype=np.float64) @@ -53,7 +53,6 @@ def test_matmul(self): assert_allclose(C, res) - def test_impurity(self): """ diff --git a/oblique_forests/tree/tests/test_sporf.py b/oblique_forests/tree/tests/test_sporf.py new file mode 100644 index 0000000..921fecd --- /dev/null +++ b/oblique_forests/tree/tests/test_sporf.py @@ -0,0 +1,561 @@ +import numpy as np +from numpy.testing import ( + assert_almost_equal, + assert_allclose, + assert_array_equal, + assert_array_almost_equal, +) + +import pytest +from oblique_forests.tree.oblique_tree import ObliqueTreeClassifier as OTC +from oblique_forests.sporf import ObliqueForestClassifier as OFC + +from sklearn import datasets +from sklearn.metrics import accuracy_score + +""" +Sklearn test_tree.py stuff +""" +X_small = np.array( + [ + [ + 0, + 0, + 4, + 0, + 0, + 0, + 1, + -14, + 0, + -4, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 5, + 3, + 0, + -4, + 0, + 0, + 1, + -5, + 0.2, + 0, + 4, + 1, + ], + [ + -1, + -1, + 0, + 0, + -4.5, + 0, + 0, + 2.1, + 1, + 0, + 0, + -4.5, + 0, + 1, + ], + [ + -1, + -1, + 0, + -1.2, + 0, + 0, + 0, + 0, + 0, + 0, + 0.2, + 0, + 0, + 1, + ], + [ + -1, + -1, + 0, + 0, + 0, + 0, + 0, + 3, + 0, + 0, + 0, + 0, + 0, + 1, + ], + [ + -1, + -2, + 0, + 4, + -3, + 10, + 4, + 0, + -3.2, + 0, + 4, + 3, + -4, + 1, + ], + [ + 2.11, + 0, + -6, + -0.5, + 0, + 11, + 0, + 0, + -3.2, + 6, + 0.5, + 0, + -3, + 1, + ], + [ + 2.11, + 0, + -6, + -0.5, + 0, + 11, + 0, + 0, + -3.2, + 6, + 0, + 0, + -2, + 1, + ], + [ + 2.11, + 8, + -6, + -0.5, + 0, + 11, + 0, + 0, + -3.2, + 6, + 0, + 0, + -2, + 1, + ], + [ + 2.11, + 8, + -6, + -0.5, + 0, + 11, + 0, + 0, + -3.2, + 6, + 0.5, + 0, + -1, + 0, + ], + [ + 2, + 8, + 5, + 1, + 0.5, + -4, + 10, + 0, + 1, + -5, + 3, + 0, + 2, + 0, + ], + [ + 2, + 0, + 1, + 1, + 1, + -1, + 1, + 0, + 0, + -2, + 3, + 0, + 1, + 0, + ], + [ + 2, + 0, + 1, + 2, + 3, + -1, + 10, + 2, + 0, + -1, + 1, + 2, + 2, + 0, + ], + [ + 1, + 1, + 0, + 2, + 2, + -1, + 1, + 2, + 0, + -5, + 1, + 2, + 3, + 0, + ], + [ + 3, + 1, + 0, + 3, + 0, + -4, + 10, + 0, + 1, + -5, + 3, + 0, + 3, + 1, + ], + [ + 2.11, + 8, + -6, + -0.5, + 0, + 1, + 0, + 0, + -3.2, + 6, + 0.5, + 0, + -3, + 1, + ], + [ + 2.11, + 8, + -6, + -0.5, + 0, + 1, + 0, + 0, + -3.2, + 6, + 1.5, + 1, + -1, + -1, + ], + [ + 2.11, + 8, + -6, + -0.5, + 0, + 10, + 0, + 0, + -3.2, + 6, + 0.5, + 0, + -1, + -1, + ], + [ + 2, + 0, + 5, + 1, + 0.5, + -2, + 10, + 0, + 1, + -5, + 3, + 1, + 0, + -1, + ], + [ + 2, + 0, + 1, + 1, + 1, + -2, + 1, + 0, + 0, + -2, + 0, + 0, + 0, + 1, + ], + [ + 2, + 1, + 1, + 1, + 2, + -1, + 10, + 2, + 0, + -1, + 0, + 2, + 1, + 1, + ], + [ + 1, + 1, + 0, + 0, + 1, + -3, + 1, + 2, + 0, + -5, + 1, + 2, + 1, + 1, + ], + [ + 3, + 1, + 0, + 1, + 0, + -4, + 1, + 0, + 1, + -2, + 0, + 0, + 1, + 0, + ], + ] +) + +y_small = [1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0] +y_small_reg = [ + 1.0, + 2.1, + 1.2, + 0.05, + 10, + 2.4, + 3.1, + 1.01, + 0.01, + 2.98, + 3.1, + 1.1, + 0.0, + 1.2, + 2, + 11, + 0, + 0, + 4.5, + 0.201, + 1.06, + 0.9, + 0, +] + +# toy sample +X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] +y = [-1, -1, -1, 1, 1, 1] +T = [[-1, -1], [2, 2], [3, 2]] +true_result = [-1, 1, 1] + +# also load the iris dataset +# and randomly permute it +iris = datasets.load_iris() +rng = np.random.RandomState(1) +perm = rng.permutation(iris.target.size) +iris.data = iris.data[perm] +iris.target = iris.target[perm] + +# also load the diabetes dataset +# and randomly permute it +diabetes = datasets.load_diabetes() +perm = rng.permutation(diabetes.target.size) +diabetes.data = diabetes.data[perm] +diabetes.target = diabetes.target[perm] + +# Ignoring digits dataset cause it takes a minute + + +def test_classification_toy(): + # Check classification on a toy dataset. + clf = OTC(random_state=0) + clf.fit(X, y) + assert_array_equal(clf.predict(T), true_result) + + """ + # Ignoring because max_features implemented differently + clf = OTC(max_features=1, random_state=0) + clf.fit(X, y) + assert_array_equal(clf.predict(T), true_result) + """ + + +def test_xor(): + + # Check on a XOR problem + y = np.zeros((10, 10)) + y[:5, :5] = 1 + y[5:, 5:] = 1 + + gridx, gridy = np.indices(y.shape) + + X = np.vstack([gridx.ravel(), gridy.ravel()]).T + y = y.ravel() + + # Changing feature parameters from default 1.5 to 2 makes this test pass. + clf = OTC(random_state=0, feature_combinations=2) + clf.fit(X, y) + + assert accuracy_score(clf.predict(X), y) == 1 + + +def test_iris(): + + clf = OTC(random_state=0) + + clf.fit(iris.data, iris.target) + score = accuracy_score(clf.predict(iris.data), iris.target) + assert score > 0.9 + + +def test_diabetes(): + + """ + Diabetes should overfit with MSE = 0 for normal trees. + idk if this applies to sporf, so this is just a placeholder + to check consistency like iris. + """ + + clf = OTC(random_state=0) + + clf.fit(diabetes.data, diabetes.target) + score = accuracy_score(clf.predict(diabetes.data), diabetes.target) + assert score > 0.9 + + +def test_probability(): + + clf = OTC(random_state=0) + + clf.fit(iris.data, iris.target) + p = clf.predict_proba(iris.data) + + assert_array_almost_equal(np.sum(p, 1), np.ones(iris.data.shape[0])) + + assert_array_equal(np.argmax(p, 1), clf.predict(iris.data)) + + assert_almost_equal( + clf.predict_proba(iris.data), np.exp(clf.predict_log_proba(iris.data)) + ) + + +def test_pure_set(): + + clf = OTC(random_state=0) + + X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] + y = [1, 1, 1, 1, 1, 1] + + clf.fit(X, y) + assert_array_equal(clf.predict(X), y) + + +def test_importances(): + # Check variable importances. + X, y = datasets.make_classification( + n_samples=5000, + n_features=10, + n_informative=3, + n_redundant=0, + n_repeated=0, + shuffle=False, + random_state=0, + ) + + clf = OTC(random_state=0) + + clf.fit(X, y) + importances = clf.feature_importances_ + n_important = np.sum(importances > 0.4) + + assert importances.shape[0] == 10, "Failed with SPORF" + assert n_important == 3, "Failed with SPORF" + + # Check on iris that importances are the same for all builders + clf = OTC(random_state=0) + clf.fit(iris.data, iris.target) + clf2 = OTC(random_state=0, max_depth=len(iris.data)) + clf2.fit(iris.data, iris.target) + + assert_array_equal(clf.feature_importances_, clf2.feature_importances_) + + +def test_importances_raises(): + # Check if variable importance before fit raises ValueError. + clf = OTC(random_state=0) + with pytest.raises(ValueError): + getattr(clf, "feature_importances_") \ No newline at end of file