diff --git a/docs/tutorials/runtime_comparison.ipynb b/docs/tutorials/runtime_comparison.ipynb new file mode 100644 index 0000000..38d1021 --- /dev/null +++ b/docs/tutorials/runtime_comparison.ipynb @@ -0,0 +1,404 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "import time\n", + "\n", + "from sklearn.datasets import load_digits, fetch_openml\n", + "from sklearn.utils import check_random_state\n", + "\n", + "from sklearn.ensemble import RandomForestClassifier as RF\n", + "from oblique_forests.sporf import ObliqueForestClassifier as SPORF\n", + "from oblique_forests.morf import Conv2DObliqueForestClassifier as MORF\n", + "\n", + "sns.set_palette('Set1')\n", + "mpl.rcParams.update({\n", + " \"axes.titlesize\": \"xx-large\",\n", + " \"axes.spines.top\": False,\n", + " \"axes.spines.right\": False,\n", + " \"xtick.bottom\": False,\n", + " \"ytick.left\": False,\n", + " \"image.cmap\": \"inferno\",\n", + "})\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Digits Dataset\n", + "These are 8x8 images of handwritten digits from `sklearn.datasets`" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(100, 64) (100,)\n" + ] + } + ], + "source": [ + "from sklearn.datasets import load_digits\n", + "\n", + "images, labels = load_digits(return_X_y=True)\n", + "\n", + "# Get 100 samples of 3s and 5s\n", + "n = 100\n", + "threes = np.where(labels == 3)[0][:(n // 2)]\n", + "fives = np.where(labels == 5)[0][:(n // 2)]\n", + "idx = np.concatenate((threes, fives))\n", + "\n", + "# Subset train data\n", + "X = images[idx]\n", + "y = labels[idx]\n", + "\n", + "# Apply random shuffling\n", + "permuted_idx = np.random.permutation(len(idx))\n", + "X = X[permuted_idx]\n", + "y = y[permuted_idx]\n", + "\n", + "print(X.shape, y.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-05-04T17:54:42.347403\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA0cAAAEjCAYAAAD5QHrmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZ0klEQVR4nO3debwsZ13n8c+TBLghBAg7iAjKJoIKCoMLGhZBeaEDIwKDbCrKsKiMA4jIKLgg4wCCAi4zQlQIWwDBZRQJg7JJWBQFlADDDSBLiIEQCElIUvNH9SFN55yb3MvNqZC8369Xv849T1fX7+k+Xc/tb9VT1WOapgAAAC7rDlm6AwAAAJcEwhEAAEDCEQAAQCUcAQAAVMIRAABAJRwBAABUwhEAwKXOGOOYMcY0xjhsgdp7xxjH7Hbdi8vqdXzSRtt1xxjHjTFOWb9/jHHkGON/jzE+vmo/ZoEu8xUQjhY2xjh6tfGcPca4xtL9uaQaY/zYGOP41WBz1hjj38YYfzHG+O6l+wZfLYw3F80Y40mr12m7mzGHxV2at+UxxtVW2+DRF8O6b7ixPZ8zxvj3McY7xhi/M8b4lv1Y3TOqe6x+PrB6xar9CdVPVM9ftf/+wXwOXPx2fW8CF/Cg6qPVtav/XP3Ost25xLpN9bHqr6tTqmtWD6j+boxx72maXrGvBwOV8WZ/Pab65Ebb+5boCGy4NG/LV6t+efXv119MNV5VHVeN6irVLZtfx0eOMZ4yTdMTN5Y/vDpno+2O1WunaXrKNu3vmabpCQe/2+wG4WhBY4wrVveunlndunmwW2SAG2NccZqmM5aofVFM0/TTm21jjGdXH6we2/l7bIBtGG8OyKumafrA0p2AdZekbfmr2LunaXrBesMY47HVi6tfHGN8cJqm52/dN03Tmdus41rVZ3Zo/9hB7GtjjMtX507TdO7BXC/bM61uWfeqjqxeuLp9+xjjFlt3jjF+a3XI9zqbDxxjPGh1SPiua21HjDGeMsb44Grq2cfGGM8dYxy18djXjzE+Osa42RjjL8cYn63+cnXfrcYYfzjGeP8Y44wxxmfGGH81xrjtdk9gjPGY1dziM1eHpb9/Nc957zbLfu8Y4zVjjNPGGF8YY/z9GOMeB/bS1TRNn6/+vbrqga4DLkOMNwcw3owxrjzGOHR/HwcXo31uy9u4zhjjZatt4bQxxgvHGNdaX2CMca0xxu+PMU5abc+fHGO8boxxp43lbj7GePkY49TVdvXOMcYDL6zD4/xpgEdvc9+Xzk9a3f/+1V2/vDb97Zi15S/S2LO/pmk6vfno0WdWtcdazfVzip40xpiajzr96FofH7Jqv1H1XWvtR6+t515jjDeNMT63uh0/xvjOjdfjIavH3X2M8RtjjI9WZ1Zfu7r/mmOM54wxPjLmaZV7V8tdYZvX9Y2rcfZ1qzH2E2OMX1t/bmvLf/dqjN76275vjPFbG8scNsb4+THGe1fj8CljjBeMMa5/wC/8JZBwtKwHVW+fpul9zYd4T1+1bXlhdWh1320ee//m6R7HV602iuOrn67+dPXz2OrHquM3N5rqitVrm/duPKb6k1X73apvqV5S/Wz1tOpm1d+OMW6+voIxxi9V/7Pa23z05jXNe11us9nZMcYPr/p3ePXk6uerqXr1GGO757etMcZRq4Hhm8YYz6hu0TzVDtg3481+jjfVO6vTqi+sgtYFasECLmxb3vTn1Z7mc2FeUN2ves2Yj0ZseVlzMDi2ekTztnhqa9vXGOPG1VuqO1fPrR5fnVX98RjjMQflmdW/NI8RVa9sPmfnS+ftHMDYs1+mafps80yUr6tuvsNir1j1qerNa3182+rnKdWJa+3/sur7o1ePPbX6heapg9et/u8Y47u2qfObza/105rHsM+NMa5e/X31I9UfVY9q/gz02Orl26zjuqv731X93KqPv1j9+PpCq3Hx9dU3Vc+pHt38+t5zbZnR/D75leoNzWP2c6u7V29e9e3SYZomtwVu1fWqc6tHr7X9UfMc4kPW2t5XvXXjsdesvlg9c63tcdXZ1W02lv2h5g8FD11re/2q7fHb9OuIbdquUZ1c/d5G21nV31WHrrXfebXuvWttV2weLI7bWO+hzRvqh6txEV+3vav1T9UXmgfMw5f+e7q5XZJvxpv9G2+aPxj8bvMHzns2f4g5rTqj+val/55ul93bRd2WV+3HrLaPl2y0//Sq/eGr36+y+v2xF1L7pdV569tAdfnqrav/j6+x1r63Ombt96NXNY7eZr2by954teyTtln2Io89OzyHG66W+7V9LPNfV8v80FrbBfqzanvBDs/njRtt11/1+2kb7VdajUlvXGt7yGrd76ouv7H8c5vD1ddutP/M6jF32ejHVN1zY9l3VSes/X5k9enmQHfVjWXX/3+4z2p999hY5tbN52Pt+Jp+td0cOVrOA5rfZC9ea3th9TXN/+Gvt91ujPENa233bT5f7IVrbfer3l59eIxxja1b816eL2ysc8vvbjZM81S1ap7XvLYn4IRqfarLXZoHxedMa3Ngp2k6vnrPxmrvUl29esFG346q/k/zoeKbbtO/7fxo897mn1z16YrNe8SAnRlv9mO8mabpmdM0PXyapj+epulPp2l6cvWdzQHraft6LFzMLuq2vO6ZG7//QfW56gdXv3+h+YP70Tvt/R/z1NK7V6+bpuntW+3TNJ29Wv+e6q7bPfYgO5CxZ3+dvvp55EFY15b/VF2uOnaj33uaj4R9x5jPJVv3h6vXt/rSkZv7Nh81/8LGev5mtdjm8//4NE1/utH2+mp9jP++5tMT/sc0TZ9ZX3CapvPWfr1fcwj/+43aH6n+3za1v2q5IMNyHtS8F3PPGOOGq7YPNqf3B3f+G/3Y5mkh969+ddV2/+r90zS9bW19N2+eQvKpHepda+P3U6dpOm1zoTHGlatfaz5ku3nuwYfW/r3V5/d3QSf25VNdtg5Nv3KHvm3170KvAjVN05vW+npM8+HlV1d3uLDHwmWY8eaC/duvq85N0/SeMcafV/ccYxw+TdMX9ufxcJBc1G153Ze916dpOmvM5+ndaPX72WOMxzUH/0+MMd7WPA322Gma/nX1sGtWR7SaIrbhvaufNzrA57Q/9nfsORBboej0fS61f7bGpXfsY5mrNx+d3vKhjfuv2Xwlv/u2/fTnuuDzP2mbZT69Ws+Wm6x+/tM++lbzc7h+O7/2l5pzM4WjBYwxvq15Xmdd8M1fda8xxpHTNJ0+TdMHxhgntPqwshoMv6P5A8y6Q5r3nPzSDmU/vfH7Tv+xv6g5/f9W58+3P695fuz6noatk/mmbdaxeaLf1hHKh1c7Xfnp3Tu072iapnPGGC+rnjrGuMk0Tdt9cILLNOPNtvZ7vFk5abX+o9r5OcHFYn+25Y327babL19gmp41xnhl8/S0OzdPLfuFMcZPTWtXbdthXfvaPi9KH/bnQ/X+jj0H4larnwfzM8XWuHSv5qN229kMHZtjzNY6Xtk8vW47m1fJ25+r213Y++SQ5iNED9vh/kvNmCgcLePBzfPn79/8QWDddZvf9Pdu/gKxmg+ZP2uMcevqB1Ztx2487gPNc0Vfe6CdGmNctfmw+ZOnaXrSxn2/urH41sB80+ofNu67ycbvWx9QTv1K+reDrSl1X9FVauBSzHhz8Ny4eW79qQd5vXBR7O+2vOXmzYGi+tJFDW7YfFL9l0zT9OHq2dWzx3zlt7dUv75a36eqzzdfBGnT1lGRvfvo+1Zo2bya5Z5V37+sK/tYz1c89uzLGOMqzQHmpOpfL2Tx/bE1Ln1smqYTDnAdn6o+W+05yM9/KwR+c/OUxZ18oPqe6vXTNG1+59OlinOOdtkY43LN8zaPn6bpFav57Ou3320+OW/9yjMvaU7/92++mszbpmk6cWPVL6q+cYxx/21qHjrGuNpm+zbObR6Uvux9MebLUP6HjWVf2zxH+ZFj7TK3Y4w7d/6erS1/3TwwPmGMcfg2/dvnYfDVpSMv0P8xxpWaz0E6owuedwCXecab/R9vVstsN97cvvPPudjuO0/gYnOA2/KWR2/8/lPNFwL489W6r7i5rUzT9OnmsHPU6vdzm8/Zu9P6VRtX/frZ5tD2mn08hb3NOxbutNH+M13wyNHWkZXtdnoejLFnW2OMI5t3BF21+pVpdbWBg+S45uf/pDHGBQ5MXJRxafU3eGl1tzHGd2+zjj2r57C//qb58uU/vwqH6+tcPzL/ouYph4/bpvZYnX90qeDI0e67e/O80VftY5lXN38I+Lppmk6apumTY4zXNh/KPLILDnRVT1+t+wVj/i6Prb1EN65+uHpi85VrdjRN0+ljjOOrx60GyhObDy//WHP4OHJt2VPGGE9tPrR9/Bjj5c17fx5e/fPGsqePMR7a/KHrPWOMP2k+ge+61e2b9zqtT6HZdKXqI2OM41b9OLV5r9eDm+e/Pmr9xG7gS4w3+z/eVH1oNcXovc1T/b65+onmD20/dyGPhYvDfm/La+03G2P8WXO4uUXnbzd/uLr/ps2Xkz6u+T3/ueYjBHernre2nic0X/Dk+DF/CfunmgPb7ZuvdHfKTh2bpumzY4wXVo9YfeB+d/OU3Ts0X11yfdlPjjE+XN1vjHFi8/cZfmiaprd2EMaelVuOMR7QPCXwyOax50eaz8X59WmanrevB++vaZr2jjH+W/Ws6h1jjJc2X5Xz+s1X8juvuuNFWNXjm/82r1udd/0PzTNobrbq/w83X3Bhf/p2+hjjEc2XeX/XGGPr6oc3bP77bo2XxzZfvfPXx/zdTK9rnkp3o1X7sdWT9qf2JdbSl8u7rN2ar0N/XnXdfSxzl+Y9qk9ca3vgqu2c6jo7PG5P8/Xr3938hWGfab5k429WN1hb7vXVR3dYx7Wav4Pk5OZD6G9unn98TGuXy10tO5r3IJy0qvfO6vub95D8yzbrvl3zXNlTmvcyfbh5ML/fhbxml6+esVr/p5svK/zJ1WO/b+m/qZvbJfVmvNn/8Wb12D9o/vD4meYjVh9pnlr09Uv/Td0um7cD2ZY7/1Le12/+fprTmqdlvai69trjrl799mpb/mxzOPrn5u8b2ryU9Dc2f1fPp1fb4T9UD9qmL3tbuzz3qu2qzR/AT2u+2MGrmr9PaLtlv6d5iteZq+dwzNp9F2ns2eE1umHnfx3I1HwE+9Or8eR3qm/d4XFf0aW81+67W/MRts+s+v6h5isP3m1tmYe0cVnubV7H32yeDnfWaow7ofkrB652Yf1oDjDTNu13bD6KtPW1Bf9aPX1jmUOqRzZfWOKM1d/xvc3TMW+x9HZysG5j9WThoBljvKv65DRNu3FZT+AyzHgDwMHknCMO2A7z+e/SPAXldbvfI+DSyngDwG5w5IgDNsa4X/P8+1c3T4u5ZfOJnidX3zxtfJkYwIEy3gCwG1yQga/Ee5pPyHxEdY3mObTHVU/wQQU4yIw3AFzsHDkCAADIOUcAAADVhUyrG2NMF/xurt1z5J7NLz7fXSc99OxF61/pWb+0aP2zPvKXi9avetkPHMj3mR08/+X9f7Vo/bPP+cSi9afpnHHhSx08S4851zridovVrvrgo/YuWv8KN/z3ReuPw3f17batM9+33fdO7p6f/V/3WbT+8055zqL1d3PMWXq8OezQZd9rR1z+uovWv1W3X7T+uzth0fpVj77WBb7LdVd96PTl3v9Vf3TqcxetP01f3Ha8ceQIAAAg4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAACqOmzpDuzLdx16p0XrH3m75y1a/wM/tGz9b3jMxxetX/XAx3540frPfuR9Fq3/znOOXbT+Zc1/vOJtF61/+G88ctH65/3tkxet/9mXLFq+qo/vvf6i9T/w+TMXrc/uudqemyxa/95XWna8u/aecxetf9qnvnHR+lWvPfnsReu/9Yt/tmj9SypHjgAAABKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoKrDlu7AvrzhnNcuWv+uD3vwovWvdOiyf56XP+7kRetXffyvb7lo/b3jPYvWZ3cddflp0fpnfP6Di9Z/6SO+Z9H6zz7prEXrV5147lsWrX/GWW9YtD67575H3nbR+k/92xMXrf+g295k0fpv/5VXLVq/6iNv/NZF69/uNddbtP6p53xq0fo7ceQIAAAg4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAACqOmzpDuzL58/64KL133y5Lyxa/23fc5NF6x9yhycuWr/qTY/6m0Xrn3rGCxetz+766BnL1j/kEycsWv/Bv/x3i9a/7fNut2j9qoe+4Q6L1n9bexetz+45d1q2/hVu8oBF63/dEW9dtP6Ln3+/RetXvenkIxetf+oZz1m0/iWVI0cAAAAJRwAAAJVwBAAAUAlHAAAAlXAEAABQCUcAAACVcAQAAFAJRwAAAJVwBAAAUAlHAAAAlXAEAABQCUcAAACVcAQAAFAJRwAAAJVwBAAAUAlHAAAAlXAEAABQCUcAAACVcAQAAFAJRwAAAJVwBAAAUAlHAAAAlXAEAABQCUcAAACVcAQAAFAJRwAAAJVwBAAAUAlHAAAAlXAEAABQCUcAAACVcAQAAFDVmKZp5zvHmOrQXezOJcthhx61aP2HX/P+i9Z/+ktes2j9qvGZUxetf4P733zR+h///JsWrT9N54zdrLf0mHOFy11vsdpVT7nBPRat/51f85FF61/nWicvWr/q+PfcatH6D3vfKxetf+55py1afzfHnKXHmz2Xv/5itaueccO7LVr/HafuWbT+yWeeu2j9qh+43hcXrf/4D79x0fqfPfN9i9bfabxx5AgAACDhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoKoxTdPOd44x1aG72J3N+nsWq111lT03XrT+aWd+YNH6j73ejy9av+q/v+mTi9b/7aOvsWj9X9z7+4vWn6Zzxm7WW3zM6bDFaldd+fCbLlr/Xle846L1//v3nrBo/aq9//Y1i9b/wX/8x0Xrn3HW3kXr7+aYs/R4s2ztqnMXrf6Iaz9y0fpHXf68RetX3eCIMxetf+JnD1+0/tM/9txF6+803jhyBAAAkHAEAABQCUcAAACVcAQAAFAJRwAAAJVwBAAAUAlHAAAAlXAEAABQCUcAAACVcAQAAFAJRwAAAJVwBAAAUAlHAAAAlXAEAABQCUcAAACVcAQAAFAJRwAAAJVwBAAAUAlHAAAAlXAEAABQCUcAAACVcAQAAFAJRwAAAJVwBAAAUAlHAAAAlXAEAABQCUcAAACVcAQAAFAJRwAAAJVwBAAAUAlHAAAAVR22dAf25bBDj1y0/ou+6WaL1t9zua9ftP5t7/iSRetXnXvlBy1a/5Nnnr1ofXbX4Ve4/qL1f+/Gt1m0/l3v8BeL1r/y0Z9atH7VcY+/9aL1v3ju5xetz+45ZOxZtP550xmL1t/7uXMWrX/Cectvay/+thMXrf/2N9120fqXVI4cAQAAJBwBAABUwhEAAEAlHAEAAFTCEQAAQCUcAQAAVMIRAABAJRwBAABUwhEAAEAlHAEAAFTCEQAAQCUcAQAAVMIRAABAJRwBAABUwhEAAEAlHAEAAFTCEQAAQCUcAQAAVMIRAABAJRwBAABUwhEAAEAlHAEAAFTCEQAAQCUcAQAAVMIRAABAJRwBAABUwhEAAEAlHAEAAFTCEQAAQCUcAQAAVMIRAABAVWOapp3vHGOqQ3exO1/uiCt8w2K1qz79zL2L1j/kYc9btP65b3nqovWrXv6wb1u0/o++57hF60/TmQvXP2fsZr2lx5xrHLHs++3fnvzPi9Y/+z73XLT+2++77Pu96r7/9IlF65/8+RMWrb+03Rxzlh5vDjv0qMVqV5133tmL1p+msxat/8Sv/clF61cdcdh5i9b/hb3PX7T+NC39Htx+vHHkCAAAIOEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAKuEIAACgEo4AAAAq4QgAAKASjgAAACrhCAAAoBKOAAAAqhrTNO185xifqk7ave4AlzCnTNP0/btVzJgDl2nGG2A3bTvm7DMcAQAAXFaYVgcAAJBwBAAAUAlHAAAAlXAEAABQCUcAAABV/X8n7tbU+lNDGgAAAABJRU5ErkJggg==\n" + }, + "metadata": {} + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", + "avg_3 = images[threes].mean(axis=0).reshape(8, 8)\n", + "avg_5 = images[fives].mean(axis=0).reshape(8, 8)\n", + "diff = np.abs(avg_3 - avg_5)\n", + "\n", + "axs[0].imshow(avg_3)\n", + "axs[0].set_title(\"Average 3\")\n", + "\n", + "axs[1].imshow(avg_5)\n", + "axs[1].set_title(\"Average 5\")\n", + "\n", + "axs[2].imshow(diff)\n", + "axs[2].set_title(\"Absolute Difference\")\n", + "for ax in axs:\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + "fig.tight_layout()" + ] + }, + { + "source": [ + "## `RF` vs `SPORF` runtimes" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def time_clf_digits(clf, ns, random_state=None):\n", + " runtimes = np.empty(len(ns))\n", + " rng = check_random_state(random_state)\n", + " images, labels = load_digits(return_X_y=True)\n", + " for i, n in enumerate(ns):\n", + " # Get only 3s and 5s\n", + " threes = np.where(labels == 3)[0][:(n // 2)]\n", + " fives = np.where(labels == 5)[0][:(n // 2)]\n", + " idx = np.concatenate((threes, fives))\n", + " X = images[idx]\n", + " y = labels[idx]\n", + "\n", + " # Shuffle samples\n", + " permuted_idx = rng.permutation(len(idx))\n", + " X = X[permuted_idx].reshape(n, -1)\n", + " y = y[permuted_idx].reshape(n)\n", + "\n", + " # Begin timing\n", + " start = time.time()\n", + " clf.fit(X, y)\n", + " runtimes[i] = time.time() - start\n", + " return runtimes" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def rename_clf(clf):\n", + " if isinstance(clf, RF):\n", + " return \"RF\"\n", + " elif isinstance(clf, SPORF):\n", + " return \"SPORF\"\n", + " elif isinstance(clf, MORF):\n", + " return \"MORF\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "ns = [10, 20, 40, 50, 100, 150, 200, 300]\n", + "\n", + "clfs = [\n", + " RF(random_state=0, n_jobs=1),\n", + " SPORF(random_state=0, n_jobs=1),\n", + " # MORF(random_state=0, image_height=8, image_width=8, n_jobs=1) # Too slow\n", + "]\n", + "\n", + "runtimes = []\n", + "for clf in clfs:\n", + " runtimes.append(time_clf_digits(clf, ns, random_state=0))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-05-04T17:54:56.738210\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAn+ElEQVR4nO3deXwU9f3H8dd3Nyc5SUIIdwKEI1zhUIsHIqDifVa0rVpbe9i7/qxa7a/2slprD61YtVWr/jxrq4JWK5c3oiABEo4cnAGSEAK5r939/v5IiAkECJDNbHbfz8djH7s7Ozv7+e7szHtn9rszxlqLiIhIoHE5XYCIiEhnFFAiIhKQFFAiIhKQFFAiIhKQFFAiIhKQwvwx0blz59q33nrLH5MWEZHgYzob6JctqPLycn9MVkREQoh28YmISEBSQImISEBSQImISEDySyeJzjQ3N1NcXExDQ0NPvWTAioqKYvDgwYSHhztdiohIwOqxgCouLiYuLo709HSM6bTDRkiw1rJ3716Ki4vJyMhwuhwRkYDVY7v4GhoaSE5ODulwAjDGkJycrC1JEZGj6NHfoEI9nA7Q+yAicnTqJCEiIgGpx36DCgRut5sJEybg8XjIyMjgmWeeITExka1btzJ27FhGjx7dNu4nn3xCRESEg9WKiASu8upGkmMj/LpHKKS2oKKjo8nJySE3N5ekpCTmz5/f9tiIESPIyclpuyicREQOVVnXxANvbeSyP73LiqK9fn2tkNqCam/69OmsXbvW6TJERHqFhmYv/1yxnaff30xNo4cLsgcxvF+sX1/TkYDa//Nf0Lw+r1unGZ41jsRf/aJL43q9XpYsWcLXv/71tmFFRUVkZ2cDcNppp3XYuhIRCVVen+Wttbt4bGkhpZUNnJqZwnfOHsXI/nF+f+2Q2oKqr68nOzubrVu3MnXqVM4+++y2xw7s4hMRkZb/bH5cWM5Di/IpKq1h7MB4fn7ZBKZmJPVYDY4EVFe3dLrbgd+gKisrufDCC5k/fz4/+MEPHKlFRCRQbdxVyUNv57NySwWD+kbz6y9OZHZWGi5Xz/5FJqS2oA5ISEjgwQcf5JJLLuGmm25yuhwRkYCwa18djywp4O11JST2CefH543h8mlDCA9zpj9dSAYUwOTJk5k0aRIvvPACZ5xxhtPliIg4prKuiSff3czLn27H7TJcf8Zwrj09ndgoZ48XGlIBVVNT0+H+woUL227n5ub2dDkiIo5qaPby0sfbeOr9LdQ3ebhg8iC+cdZIUuOjnC4NCLGAEhGRlp55b67ZxaNLC9hT1cjpo/vxnTmjGJ7q327jx0oBJSISIqy1LC8oZ/6ifIrKasgalMCvrpjI5PSe65l3LBRQIiIhYP3OSh56exOfbd3H4KQ+/PaqSZyV1T+gD16tgBIRCWLFFS098xbnltA3JoJbzh/LpdMGE+YO/CPdKaBERILQvtomnny3iH+v3EGYy8XXzhzOl0/NICaq96z2e0+lIiJyVA1NXp5fvpVnPtxCY7OPiyYP4sazRpISF+l0accspALq7rvv5rnnnsPtduNyuXj00Ue57bbb2L17N1FRUcTGxvLEE08wevRompqauPXWW1m4cCEul4usrCzmz5/P4MGDAZ26Q0QCi8fr442cXfx9WSF7qhuZMSaVm+ZkkuHnA7r6U8gE1PLly3n99df57LPPiIyMpLy8nKamJgCeffZZpk2bxmOPPcZPfvITFixYwB133EF1dTX5+fm43W6efPJJLr/8clasWIExpu2wSQDXX3898+fP58477wR0XD8R6TnWWj7I38PDi/LZsqeW8YMT+PUXJ5E9rK/TpZ2wwP+VrJvs3r2blJQUIiNbNnNTUlIYOHBgh3FmzJhBYWEhdXV1PPnkk/zpT3/C7XYDcMMNNxAZGcnSpUsPmfb06dPZuXOn/xshItJObvF+bnryU37y3Gq8Pss987L5242nBEU4gUNbUH96cwP5JdXdOs1RaXH8+Lyxh338nHPO4Ve/+hWjRo1izpw5zJs3jzPPPLPDOAsXLmTChAkUFhYydOhQ4uPjOzw+bdo08vLymD17dtswnbpDRHra9r21PLK4gKXrS+kbE8FPLhjLJVN7R8+8YxEyu/hiY2NZtWoV77//PsuWLWPevHnce++9AHz5y18mOjqa9PR0/vKXv1BRUdHpfwOstW3DdeoOEelpFTWNPPFuEa+sLCYizMXXZ47gS6emExMZnKtyR1p1pC0df3K73cycOZOZM2cyYcIEnnrqKeDz36AOSEpKYtu2bVRXVxMX9/lJuT777DMuuugiQKfuEJGeU9/k4fmPtvF/H26h0ePj4imDuXHmCJJ7Yc+8YxFc24NHsGnTJgoKCtru5+TkMGzYsE7HjYmJ4frrr+fmm2/G6/UC8PTTT1NXV8esWbM6jHvg1B33338/zc3N/muAiIQcj9fHqyt3cOUD7/PYskJOGpHMc989jdsuygr6cIIQ2sVXU1PD97//ffbv309YWBgjR47kscce48orr+x0/HvuuYdbbrmFUaNG4XK5GDNmDK+88kqnu/506g4R6U7WWt7ftIf5i/LZVl7LxKGJ3DMvm4lDg6PzQ1cZa23XRjTGDawEdlprLzzSuNOmTbMrV67sMGzDhg2MHevMrr1ApPdDRDqzbsd+/vL2JtZu38+wlBi+MyeTGWNSA/qYed2g08YdyxbUD4ENQPzRRhQRkWOzvbyWhxfn886GMpJjI7j9oiwunDwo6HrmHYsuBZQxZjBwAXA3cLNfKxIRCSF7axp5/J0iXltVTGSYi2+eNZJrTh1GdETI/AJzWF19B/4M3ArEHWW8I2rfTTuUdXW3qogEr7pGD899tJVnP9pKk8fHZdMG87UzR5AUG/ydH7rqqAFljLkQKLPWrjLGzDzeF4qKimLv3r0kJyeHdEhZa9m7dy9RUYFxSmUR6Vker48Fn+3k7+8UUlHTxKys/nx7TiZDk2OcLi3gdGUL6jTgYmPM+UAUEG+M+T9r7VeO5YUGDx5McXExe/bsOZ46g0pUVFTbQWdFJDRYa3l3YxkPL8pn+946sof15b6rJzN+SKLTpQWsLvfiA2jdgrrleHrxiYiEqjXb9/HQ2/ms27Gf9H4xfPfsUZw+ql9I7006yAn34hMRkWOwdU8Nf11cwLsby0iJi+SnF4/jguyBId0z71gcU0BZa98B3vFLJSIiQaK8upHH3ylkwWc7iQx38a1ZI7l6unrmHSu9WyIi3aS20cOzH27huY+20ez1cflJQ/jamSPoG6MTlh4PBZSIyAlqOWZeMY+/W8S+2iZmj0vj27NHMkQ9806IAkpE5DhZa1m2vpSHFxdQXFHH5PS+3P+lyYwbnOh0aUFBASUichxWb63goUX55BVXMjw1lj98eQqnZqaoZ143UkCJiByDLWU1zF+czweb9tAvPpI7LxnH+dmDcLsUTN1NASUi0gV7qhr427JCXl+9k+iIMG6ancm8LwwjKsLtdGlBSwElInIEtQ0envlwC88v34rXZ7nqlGF8dcZwEtUzz+8UUCIinWj2+Hhl5Q6eeLeI/XXNnDMhjW/NymRQUh+nSwsZCigRkXastSzJK+WRJfkUV9QzLSOJ750zijEDE5wuLeQooEREWq3aUsH8RZtYv7OKEf1j+dNXpvCFkeqZ5xQFlIiEvKLSah5eXMCH+XtIjY/ify8bz9yJA9Uzz2EKKBEJWWWVLT3z3sjZSZ/IML579ii+eMpQosLVMy8QKKBEJOTUNDTzzAdbeGH5NnzWMu8LLT3zEvqoZ14gUUCJSMho8vh45dMdPPFeEZV1zZw7cQDfmpXJwL7RTpcmnVBAiUjQ8/ksS/JK+OuSAnbtq+ek4cl89+xRjBkY73RpcgQKKBEJais37+WhRfls3FVFZlocD1w7lVNGpjhdlnSBAkpEglJhaTXzF+WzvKCctIQo7rp8AudOGIBLPfN6DQWUiASV0sp6Hl1ayJtrdhEbGcb3zxnFlScPJVI983odBZSIBIXq+maeen8zL63YDsCXpqdz3RkZ6pnXiymgRKRXa/L4ePmT7fzjvSKqGzzMnTiQb84ayYBE9czr7RRQItIr+XyWt3N388iSAkr2N/CFkS098zLT1DMvWCigRKTX+aRoLw8t2kT+7mpGDYjjjovHc/KIZKfLkm6mgBKRXiN/dxXzF+WzomgvAxKj+eUVEzh7vHrmBSsFlIgEvN3763lsaSFvrd1FXFQ4Pzx3NFecPJSIMJfTpYkfKaBEJGBV1Tfz1Hub+ecn2zHAV07L4LrTM4iLDne6NOkBCigRCTiNzd7WnnmbqWn0cP6klp55/RPUMy+UKKBEJGD4fJb/rtvNo0sKKKlsYHpmCt+ZM4rMtDinSxMHKKBEJCCsKCznoUX5FJRUM2ZgPD+7dDzThqtnXihTQImIozbtruKht/P5dPNeBvaN5tdXTmT2uDT1zBMFlIg4Y9e+eh5dWsB/1+4moU84P547hstOGqKeedJGASUiPaqyrol/vLeZlz/ZjssYrj8jg2tPzyA2Sj3zpCMFlIj0iIZmL/9csZ2n3t9MXaOHC7IH8Y2zRpKaEOV0aRKgFFAi4lden+XNNbt4bGkhZVUNnDaqH9+Zk8mI/uqZJ0emgBIRv7DWsrywnPmL8ikqrSFrUDy/uGICU9KTnC5NegkFlIh0u427Knno7XxWbqlgcFI0v/niJGaP648x6pknXaeAEpFus7OijkeWFLAot4TEPuH8z/ljuHTqEMLVM0+OgwJKRE5YZV0TT7xbxL8+3UGYy8UNM4bzldMyiInSKkaOnz49InJC6ps8fO+plRSVVnPRlMHcOHME/eLVM09OnAJKRI6btZZ7FuRRWFrNH748hVMz+zldkgQR7RgWkeP2/PJtvL2uhG/PylQ4SbdTQInIcVm5eS/zF+VzVlZ/rjsjw+lyJAgpoETkmO3eX8/P/rmGocl9+Nml49V9XPxCASUix6Sh2cvtL+TQ7LX87urJxETqp2zxj6MGlDEmyhjziTFmjTEmzxjzy54oTEQCj7WW+xauZ9PuKn55xQSGpsQ4XZIEsa589WkEZllra4wx4cAHxpg3rbUf+7k2EQkwL3+ynf+s2cWNM0dw+uhUp8uRIHfUgLLWWqCm9W5468X6sygRCTw52/bx57c2cfrofnztzBFOlyMhoEu/QRlj3MaYHKAMWGStXeHXqkQkoJRVNvDTF3MY1DeaX1w+QWe7lR7RpYCy1nqttdnAYOBkY8x4v1YlIgGjyePj9hdzaGz28rtrJuvEgtJjjqkXn7V2P/AOMNcfxYhI4Ln/jfWs31nJ/142gYx+sU6XIyGkK734+hljEltvRwNzgI1+rktEAsCrK3ew4LOdfHXGcM7K6u90ORJiutKLbwDwlDHGTUugvWStfd2/ZYmI09bt2M/9/9nA9MwUvnHWSKfLkRDUlV58a4HJPVCLiASI8upGfvpiDmkJUfzyiom41SlCHKAjSYhIB80eH3e8lENNg4d7r55MfLQ6RYgzFFAi0sGf39rI2u37+dml4xnZP87pciSEKaBEpM3rq3fyr0938OXT0pkzPs3pciTEKaBEBID1Oyu57/X1nDQ8mZtmZzpdjogCSkSgoqaR21/IITk2kt98cSJhbq0axHn6FIqEOI/Xx53/XENlXRP3Xp1NQp8Ip0sSARRQIiHvobfzWb11H7dfPI7RA+KdLkekjQJKJIS9uWYXL3y8jau/MIzzJg10uhyRDhRQIiFq0+4q7l2Qx+T0vnzvnFFOlyNyCAWUSAjaX9vEbS+sJqFPBHd/cZI6RUhA0qdSJMR4vD5+9vIaKmpaOkUkxUY6XZJIpxRQIiHmr0sKWLm5glsvzCJrUILT5YgclgJKJIQsyt3Nsx9u5YqThnDh5EFOlyNyRAookRBRWFrN3a/mMXFoIj+aO8bpckSOSgElEgIq65q47fnVxEaFcc9V2YSHadGXwKdPqUiQ8/osd/1rHaVVDdwzL5vkOHWKkN5BASUS5P62rJCPC8u55fyxTBiS6HQ5Il2mgBIJYsvWl/KP9zZz8ZRBXDptiNPliBwTBZRIkNpSVsOvX1nHuMEJ3HJBltPliBwzBZRIEKppaOa2F1YTFeHmnnnZRKhThPRC+tSKBBmfz/KLf69j5756fntVNqnxUU6XJHJcFFAiQeaJd4v4YNMefjR3NNnD+jpdjshxU0CJBJEPNpXx93eKOD97IFeePNTpckROiAJKJEhsL6/lrn+tY/SAeG69MAtjjNMliZwQBZRIEKht9HDrC6sJdxvuvTqbqHC30yWJnLAwpwsQkRNjreXXr6xjx946HrxuKgMSo50uSaRbaAtKpJd7+v0tvLOhjO+ePYqpGclOlyPSbRRQIr3Y8oI9PLK0gHMmpHHN9GFOlyPSrRRQIr1UcUUdP395LSP7x3HHxePVKUKCjgJKpBeqb/Jw2wurMQbunZdNVIQ6RUjwUUCJ9DLWWu5+La/lWHtXTmJQUh+nSxLxCwWUSC/z3EdbWZxbwrdnZ3LKyBSnyxHxGwWUSC/ySdFe5i/KZ1ZWf649PcPpckT8SgEl0kvs2lfP/768hvR+sfzsUnWKkOCngBLpBRqavdz+4mq8Psu9V2fTJ1L/sZfgp4ASCXDWWu5dkEdBSTW/vGIiQ5NjnC5JpEcooEQC3EsrtvPW2t18Y+ZIThvVz+lyRHqMAkokgH22tYIH/7uJGWNS+eqM4U6XI9KjFFAiAaq0sp47X1rD4KQ+3HXZBFwudYqQ0KKAEglAjc1efvpiDo0eL7+7OpuYKHWKkNCjgBIJMNZa7n9jA+t3VnHXZRNI7xfrdEkijlBAiQSYV1YWs3D1Tm6YMZwzx/Z3uhwRxyigRALImu37+OObGzg1M4UbzxrpdDkijlJAiQSIPVUN3PFiDmkJ0fzyiom41SlCQtxRA8oYM8QYs8wYs8EYk2eM+WFPFCYSSpo9Pu54aQ11TS2dIuKiw50uScRxXeka5AH+x1r7mTEmDlhljFlkrV3v59pEQsYf39zIuh37ufuqSYzoH+d0OSIB4ahbUNba3dbaz1pvVwMbgEH+LkwkVCxYVcwrK3dw7ekZzB6X5nQ5IgHjmH6DMsakA5OBFX6pRiTE5BXv5/dvrOfkEcl8e3am0+WIBJQuB5QxJhb4F/Aja22V/0oSCQ17axq5/cUcUuKi+PWV6hQhcrAuBZQxJpyWcHrWWvtv/5YkEvw8Xh93vrSGqvpmfnd1Ngl9IpwuSSTgdKUXnwEeBzZYa//o/5JEgt8D/91EzrZ93HHxOEYNiHe6HJGA1JUtqNOAa4FZxpic1sv5fq5LJGj9J2cn/1yxnaunD+PciQOdLkckYB21m7m19gNAO8dFusHGXZX8buF6pqT35Xtnj3K6HJGApiNJiPSQfbVN3P5CDokxEdx9VTZhbi1+IkeiJUSkB3i8Pv73n2uoqG3i3nnZ9I1RpwiRo1FAifSAhxcXsHJLBbdemMXYQQlOlyPSKyigRPzs7XW7ee6jrVx58hAunKyDsIh0lQJKxI8KSqq4+7VcJg1N5IfnjnG6HJFeRQEl4ieVdU3c9kIO8VHh/PaqbMLDtLiJHIuuHM1cRI6R12f5+ctrKatq4JEbTiY5LtLpkkR6HX2lE/GDx5YWsKJoL7ecP5bxQxKdLkekV1JAiXSzpXklPPX+Fi6dOphLpw1xuhyRXksBJdKNNpfV8OtXcxk/OIGbzx/rdDkivZp+gxI5QR6vj9Xb9rE4t4Rl60uIjnDz23nZRKhThMgJUUCJHAevz7Jm+4FQKmVfbRPREW7OGN2Pa08fTmp8lNMlivR6CiiRLvL5LOuK97eFUnl1I5HhLk4f1Y/Z49M4dWQ/oiLcTpcpEjQUUCJHYK0lr7iSxXklLM0rpayqgYgwF9MzU5gzPo3TR/UjOkKLkYg/aMkSOYi1lo27qlicV8KSvBJK9jcQ7jacMjKF75ydyRmjU4mJ1KIj4m9aykRoCaWCkmoW57aE0s599bhdhlNGJPONs0YyY3QqcdHhTpcpElIUUBLSiko/D6Xte+twuwzTMpK4/ozhnDk2lYQ+Oi2GiFMUUBJytu6paQulLXtqcRmYkp7El05N58yx/XWuJpEAoYCSkLBjb21bKBWW1mAMZA/tyy0XDOWssf11rDyRAKSAkqC1a18di3NLWJxXQv7uagAmDEnkx+eNYVZWf/rpv0oiAU0BJUGlZH89S/JKWZK3m/U7qwDIGpTAD84dzexx/emfEO1whSLSVQoo6fXKqhpYllfK4rwS1u3YD8CYgfF89+xRzB6XxsC+CiWR3kgBJb3S3upGlq4vYUleKWu278NayEyL49uzM5k9rj9DkmOcLlFETpACSnqNfbVNLFtfypK8ElZvrcBnYXhqLN+YOZLZ49MYlqJQEgkmCigJaJV1TbyzoYwleSWs2lKB12cZlhLDDTNGMHt8GsNTY50uUUT8RAElAae6vpl3N5axOLeETzfvxeuzDE6K5iunZTBnfBoj+8dijHG6TBHxMwWUBITaBg/vbWoJpRVF5Xi8lgGJ0VwzfRhzxg9g9IA4hZJIiFFAiWPqGj18mL+HxbklLC8sp8njIzU+ii+ePJQ549PIGpSgUBIJYQoo6VENTV4+KmgJpQ8L9tDY7CMlLpJLpw5mzvg0xg9OxOVSKImIAkp6QGOzl+WF5SzJLeGD/D3UN3npGxPBhdmDmDM+jUlD+yqUROQQCijxiyaPjxVFLaH03qYy6hq9JPYJZ+7EAcwel8bk9CTcCiUROQIFlHQray1vrtnFg//dxP66ZuKjw5g9Lo0549OYmp5EmNvldIki0ksooKTblFbWc+/C9SwvKGfi0ER+PmM4J2UkEx6mUBKRY6eAkhNmreW1VcU8+PYmfD64+bwxXHnyUP2uJCInRAElJ2RnRR33LMhj5ZYKpmUk8dOLxzEoqY/TZYlIEFBAyXHx+Swvf7KdhxcX4HLB7RdlccnUwfrfkoh0GwWUHLPt5bX85rVc1m7fz/TMFG6/KEvnWRKRbqeAki7zeH08v3wbf1tWSFS4i7sun8DciQO01SQifqGAki4pLK3m7ldz2bCrijPHpnLrBVkkx0U6XZaIBDEFlBxRs8fH0x9s5sn3NhMXFc7dV01iVlZ/bTWJiN8poOSwNu6q5Dev5lJYWsM5EwZw83ljSIyJcLosEQkRCig5RGOzlyfeLeL/PtxK35gI7rtmMjPGpDpdloiEGAWUdLBux35+82ou28pruWjyIH5w7mjiosOdLktEQtBRA8oY8wRwIVBmrR3v/5LECQ1NXh5ZWsCLH2+jf3wUD1w7lVNGpjhdloiEsK5sQf0DeAh42r+liFNWbangngW5FFfUc8VJQ/jO2aOIidTGtYg466hrIWvte8aY9B6oRXpYbaOH+Yvy+fenOxicFM3DN5zElPQkp8sSEQH0G1TI+riwnHsX5FFa1cA104fxrVmZREW4nS5LRKSNAirEVNU38+B/N/H66p2k94vhb18/hfFDEp0uS0TkEAqoEPLexjLue309+2qbuP6M4XztzOFEhmurSUQCkwIqBOyrbeJPb27g7XUlZKbFcf+XpjBmYLzTZYmIHFFXupk/D8wEUowxxcBd1trH/V2YnDhrLYvzSvjDGxuoafTwzVkjue70DJ12XUR6ha704rumJwqR7lVe3cjv31jPuxvKyBoUz52XjGdE/zinyxIR6TLt4gsy1lreXLOLP7+1kcZmH98/ZxTzvjBMW00i0usooIJIaWU99y5cz/KCciYNTeTOS8YzNCXG6bJERI6LAioIWGt5bVUxD769CWvhf84fwxUnDcXl0ikxRKT3UkD1cjsr6vjtgjxWbalg2vAk7rh4HAP79nG6LBGRE6aA6qW8PsvLK7bz1yUFuF2Gn148jounDNKJBEUkaCigeqGte2q4+7U81u3Yz6mZKdx+0ThSE6KcLktEpFspoHoRj9fHcx9t5e/vFBEV7uKuyycwd+IAbTWJSFBSQPUShaXV/ObVXDbuqmLm2FR+ckEWyXGRTpclIuI3CqgA1+zx8dT7m/nH+5uJiwrnt1dNYta4NKfLEhHxOwVUANu4q5LfvJpLYWkN504cwI/njiExJsLpskREeoQCKgA1Nnt5/J0inv1oK0kxEfz+S5M5Y3Sq02WJiPQoBVSAWbt9H3e/lse28loumjKIH5wzmrjocKfLEhHpcQqoAFHf5OHRJYW8uGIb/ROieOC6qZwyIsXpskREHKOACgCrtuzlt6/lsXNfPVeePISb5owiJlKzRkRCm9aCDqpt8PDQonxeWbmDwUl9+OsNJzE5PcnpskREAoICyiHLC/Zw78L17Klq4EunpvPNs0YSFaHTr4uIHKCA6mFV9c088NZG3sjZRUa/GB678RTGD050uiwRkYCjgOpB720s43cL89hf18wNM4Zzw5kjiAjTiQRFRDqjgOoB+2qb+ON/NrAot4TMtDj++JWpjB4Q73RZIiIBLSgCavveWh54axP5JVWkxkfRPz6K1PgoUhOi6J/w+f3kuEjcPXgSP4/Xx7L1pfzhPxuoafTwzVkjue70DJ1+XUSkC3p1QDU0e3nm/S08/cFmwsNcnDE6lYqaJorKaviooJyGZm+H8d0uQ0pcZEuIJUR9HmYJn99Piok45jPR1jd52FZey5Y9tWwrr2Xrnhq2lteyY28dXp8la1A8P7t0AsNTY7uz+SIiQa3XBtRHBXv44382UFxRzzkTBvCDc0eT0u7o3tZaqhs8lFY2UFbV0HZdVtlAaVUDG3dV8f7GMho9vg7TDXMb+sUdCKzPw6x/QjSp8ZE0NPtaAmhPLVvLW65LKhvanu92GQb1jSa9XywzRqcyakA8M8emaqtJROQYGWttt0902rRpduXKld0+XYDSynr+/NYmlq0vZVhKDD+5YCzThicf17SstVTWNVPaSYAduF1W1UCz99D3KDLcxbCUGNJTYknvF0N6vxgyUmIZnNSHcHV8EBE5Fp3utuo1W1Aer48XP97G398pwmctN83O5JpT00+oF5wxhsSYCBJjIg7bacHns+yra2JPa4iFh7lIT4klLSHqmHcFiohI1/WKgMrZto/7Xl/P5rIaTh/dj5vPG8vAvtE98toulyE5NpLk2EjGDEzokdcUEZEAD6iKmkYeWpTPf3J2kZYYxX3XTGbGGJ12QkQkFARkQHl9ltdWFfPXxfnUN3u5/owMvjpjONERAVmuiIj4QUCu8TeXVfP7N9YzJT2JWy4YS0Y/dc8WEQk1ARlQmWnx/P3GU8galIAx6oggIhKKAjKgAMbpAKoiIiFNf9gREZGApIASEZGApIASEZGApIASEZGApIASEZGApIASEZGAFLDdzEVE5PhZa+Hgi8/Xdtu2jNRhWPtLp89vG79l+u6kvpho/x0XVQEl0g1s+wXd5wOf7/NhXm/LfZ8F62t7vGXYwfdbx2l9Dj6L7fAci/V6203HtpuOt+N925XnHFSLbfd423PsEert2Oa2dltL6xqw05WgtbTePnTl2PIetF44eMVo21aObY8fsoI98HgnK17fQSvfDvXxeQ1tz/F1mO7nNfo61mbpuPLvUF/nK3prfYdM90Attq02Dhsi9gjTxg+nUepM0pOPE33OOX6bfkAGlHfXbsqvuw4wLUeS6HDh89u03qblvnG7IcyNcYdBmBvcYZj212FhreOEgct96DC3G9P+utPHXGBcGJcLXC5wmZbrToYZl7ulvvYf6LYFuv0C2f4x2sb5fPihC+khjx0Y3rZQ0e612j+Hjh/4DiuUg75hHbxQdHgOHRfgttewnQ5vWdEeNPzAN7HDvDdtK8z2K5C2ettP5wjv4yHjH/x++Q4d1tl72+H9tYe8tyGp7fPe8pk3HFgWDlpmXS3L6CHLssv1+fJ84PFDnt/6uHF1si5onbbpZPqug8bDtCyfna1HjKtlvdFufGPoWMuBdUxn0+5QHwfV52p7btt0O2vngekfZtoHpmsOnvZB68LO38OWx02n86Z12u3bd8g02j1+0DTCs7L8+hELyIAizE3YsGGH+VbU7lsPdBzH68P6vNjGRmytB7we8HixB197PC3fUD2elpWRx9P2GB6Po03vUQd/GFs/rB1WNG3XnSzkbSsYV8cFo93wDh/qA8NdLow5dFiHBf1A0LdfgbgOTK/96x8Y/+AFq93wdsMOeX6H8Tm03YfUf2ClQ8fhB25Dy5ea9l9WjKvdsNb7HVbuh/nC43Yf4QuQ++jPOfA+H7jvdrdOxxzmtV3tXqvdc1zt50m754j4WUAGlDs1leTH/+7Ia7ftlvF4WnZxeDxYj7cl7Jo92EN20bR+O2+/C6Ztd0vrt/ODVp5tK+eDV4htAdBueIfHOguUw3yrOvixg6anYxyKSKALyIBykjGm5Vuj2935OYhFRKRHaDtdREQCUpcCyhgz1xizyRhTaIy53d9FiYiIHDWgjDFuYD5wHpAFXGOM8W/XDRERCXld2YI6GSi01m621jYBLwCX+LcsEREJdV0JqEHAjnb3i1uHiYiI+E1XAqqzzmwh+s9EERHpKV0JqGJgSLv7g4Fd/ilHRESkRVcC6lMg0xiTYYyJAK4GFvi3LBERCXVH/aOutdZjjPke8F/ADTxhrc3ze2UiIhLSjPXDgS6NMXuAbd0+YRERCUbl1tq5Bw/0S0CJiIicKB3qSEREApICSkREApICSkREApICSkREApICSkREApICSkREAlKPBFQonU/KGLPVGLPOGJNjjFnZOizJGLPIGFPQet3X6TpPlDHmCWNMmTEmt92ww7bTGPPT1vm/yRhzrjNVn7jDtPsXxpidrfM8xxhzfrvHgqXdQ4wxy4wxG4wxecaYH7YOD4V5fri2B/V8N8ZEGWM+McasaW33L1uH99w8t9b69ULL0SeKgOFABLAGyPL36zp1AbYCKQcNuw+4vfX27cDvnK6zG9o5A5gC5B6tnbScR2wNEAlktH4e3E63oRvb/Qvglk7GDaZ2DwCmtN6OA/Jb2xcK8/xwbQ/q+U7LgcJjW2+HAyuAL/TkPO+JLSidT6qlvU+13n4KuNS5UrqHtfY9oOKgwYdr5yXAC9baRmvtFqCQls9Fr3OYdh9OMLV7t7X2s9bb1cAGWk67Ewrz/HBtP5ygaLttUdN6N7z1YunBed4TARVq55OywNvGmFXGmG+2Dutvrd0NLR92INWx6vzrcO0Mhc/A94wxa1t3AR7Y5RGU7TbGpAOTaflGHVLz/KC2Q5DPd2OM2xiTA5QBi6y1PTrPeyKgQu18UqdZa6cA5wHfNcbMcLqgABDsn4G/AiOAbGA38IfW4UHXbmNMLPAv4EfW2qojjdrJsGBre9DPd2ut11qbTctplk42xow/wujd3u6eCKiQOp+UtXZX63UZ8Aotm7ilxpgBAK3XZc5V6FeHa2dQfwastaWtC7IP+Buf79YIqnYbY8JpWUE/a639d+vgkJjnnbU9VOY7gLV2P/AOMJcenOc9EVAhcz4pY0yMMSbuwG3gHCCXlvZe3zra9cBrzlTod4dr5wLgamNMpDEmA8gEPnGgPr84sLC2uoyWeQ5B1G5jjAEeBzZYa//Y7qGgn+eHa3uwz3djTD9jTGLr7WhgDrCRnpznPdQb5Hxaer4UAXc63TvFj+0cTksvljVA3oG2AsnAEqCg9TrJ6Vq7oa3P07Jbo5mWb05fP1I7gTtb5/8m4Dyn6+/mdj8DrAPWti6kA4Kw3afTsrtmLZDTejk/ROb54doe1PMdmAisbm1fLvDz1uE9Ns91ug0REQlIOpKEiIgEJAWUiIgEJAWUiIgEJAWUiIgEJAWUiIgEJAWUiIgEJAWUiIgEpP8HBdlb2srwwHoAAAAASUVORK5CYII=\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "for runtime, clf in zip(runtimes, clfs):\n", + " ax.plot(ns, runtime, label=rename_clf(clf))\n", + "ax.legend()\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MNIST Dataset\n", + "These are 28x28 images of handwritten digits from `keras.datasets`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(100, 784) (100,)\n" + ] + } + ], + "source": [ + "from sklearn.datasets import fetch_openml\n", + "\n", + "images, labels = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)\n", + "\n", + "n = 100\n", + "threes = np.where(labels == '3')[0][:(n // 2)]\n", + "fives = np.where(labels == '5')[0][:(n // 2)]\n", + "idx = np.concatenate((threes, fives))\n", + "\n", + "# Subset train data\n", + "X = images[idx]\n", + "y = labels[idx]\n", + "\n", + "# Apply random shuffling\n", + "permuted_idx = np.random.permutation(len(idx))\n", + "X = X[permuted_idx]\n", + "y = y[permuted_idx]\n", + "\n", + "print(X.shape, y.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-05-04T17:55:20.302682\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA0cAAAEjCAYAAAD5QHrmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAvyUlEQVR4nO3deZwcd33m8efbPfeh2zpsSbZs2ZJt2VjG4YyDOQJZAiwm7EIgQLLJHgkJyWYTkrDZhGyOzbLZQDYENuc6CZiQEK5lw0Js4yXG+L5APnRYh2XdI819dvdv/6ge3LRn9Pwkj0fX5/16zUuammeqqrurflPfrur6RkpJAAAAAHCuK53qFQAAAACA0wHFEQAAAACI4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAADOOhFxU0SkiGg5BcveFRE3zfdyny/15/GDTdNWRcRnIuJI488jojci/iwi9ten33QKVhnPAcXRKRYRN9R3nsmIWHaq1+d0FRE/FhG31gebiYh4OiL+T0R876leN+BMwXiTJyI+WH+eZvpizMEpdzbvyxGxpL4P3vA8zPuipv25EhF9EXF/RPxhRLzgBGb3+5LeUP/3XZI+W5/+AUk/Lul/1af/8Vw+Bjz/5v3dBDzLuyXtlbRC0g9L+sNTuzqnrWsl7ZP0FUlHJJ0n6UckfT0i3ppS+uzxfhmAJMabE/ULkg42TXviVKwI0ORs3peXSPr1+v9vf56W8QVJn5EUkhZK2qTieXxvRPxOSulXm/KdkipN014p6ZaU0u/MMH1LSukDc7/amA8UR6dQRHRJequkj0jarGKwOyUDXER0pZRGT8Wyc6SUfqZ5WkR8VNIOSb+oZ96xATADxpuT8oWU0vZTvRJAo9NpXz6DfTul9InGCRHxi5L+RtJ/jIgdKaX/Nf2zlNL4DPNYLql/lun75nBdFRFtkqoppepczhcz47K6U+tGSb2SPln/ui4irpj+YUR8uH7Kd2XzL0bEu+unhF/bMK07In4nInbULz3bFxEfi4jFTb97e0TsjYgNEfEPETEo6R/qP7sqIv48IrZFxGhE9EfE/42I75npAUTEL9SvLR6vn5b+gfp1zrtmyL4iIr4aEQMRMRYRd0XEG07uqZNSSiOS+iQtOtl5AOcQxpuTGG8iYkFElE/094Dn0XH35RmsjIi/q+8LAxHxyYhY3hiIiOUR8ccRsbu+Px+MiNsi4lVNuY0R8fcRcbS+Xz0QEe9yKxzPXAZ4www/+87nk+o/31b/0a83XP52U0M+a+w5USmlIRVnj/rry46GZTZ+puiDEZFUnHV6Z8M6/mh9+jpJL2+YfkPDfG6MiG9ExHD969aIeFnT8/Gj9d97fUT8l4jYK2lc0pr6z8+LiD+KiKeiuKxyVz3XPsPzekd9nL2tPsYeiIjfanxsDfnvrY/R06/tExHx4aZMS0T8UkQ8Wh+Hj0TEJyJi9Uk/8achiqNT692S7kspPaHiFO9Qfdq0T0oqS3rbDL/7DhWXe9wqSfWd4lZJPyPp8/V/b5b0Y5Jubd5pJHVJukXFuxu/IOmv69NfJ+kFkj4t6Wcl/Z6kDZL+X0RsbJxBRPyapP8maZeKszdfVfGuy7XNKxsRP1Rfv05JvyHplyQlSV+MiJke34wiYnF9YLgyIn5f0hUqLrUDcHyMNyc43kh6QNKApLF6ofWsZQGngNuXm31JUoeKz8J8QtLbJX01irMR0/5ORWFws6SfUrEvHlXD/hUR6yV9U9KrJX1M0i9LmpD0VxHxC3PyyKTHVIwRkvQ5FZ/Z+c7ndk5i7DkhKaVBFVeiXChp4yyxz9bXSZLubFjHe+v/HpG0tWH6Y/V1/7n67x6V9CsqLh1cJelrEfHyGZbzIRXP9e+pGMOGI2KppLsk/QtJfynpp1UcA/2ipL+fYR6r6j9/WNLP19fxP0r6V42h+rh4u6QrJf2RpJ9T8fy+uSETKraT/yzpn1SM2R+T9HpJd9bX7eyQUuLrFHxJOl9SVdLPNUz7SxXXEJcapj0h6e6m3z1P0pSkjzRMe7+kSUnXNmXfpOKg4Ccapt1en/bLM6xX9wzTlkk6JOl/Nk2bkPR1SeWG6a+uz3tXw7QuFYPFZ5rmW1axo+6RFJnP2676/JOkMRUDZuepfj354ut0/mK8ObHxRsWBwcdVHHC+WcVBzICkUUnXnerXk69z9yt3X65Pv6m+f3y6afrP1Kf/ZP37hfXvf9Es+28l1Rr3AUltku6u/z1e1jB9l6SbGr6/ob6MG2aYb3N2fT37wRmy2WPPLI/honrut46T+ff1zJsapj1rferTPjHL47mjadrq+nr/XtP0nvqYdEfDtB+tz/thSW1N+Y+pKK7WNE1/X/13XtO0HknSm5uyD0u6p+H7XknHVBR0i5qyjX8f/mV9fm9oymxW8XmsWZ/TM+2LM0enzo+o2Mj+pmHaJyVdoOIPfuO0F0XEJQ3T3qbi82KfbJj2dkn3SdoTEcumv1S8yzPWNM9pH2+ekIpL1SQV1zU3vBNwj6TGS11eo2JQ/KPUcA1sSulWSVuaZvsaSUslfaJp3RZL+rKKU8WXzbB+M3mnineb/3V9nbpUvCMGYHaMNycw3qSUPpJS+smU0l+llD6fUvoNSS9TUWD93vF+F3ie5e7LjT7S9P2fSBqW9Mb692MqDtxvmO3d/yguLX29pNtSSvdNT08pTdbn3yHptTP97hw7mbHnRA3V/+2dg3lNe4ukVkk3N613h4ozYS+N4rNkjf68/vxK+s6Zm7epOGs+1jSff6zHmh///pTS55um3S6pcYz/fhUfT/ivKaX+xmBKqdbw7dtVFOF3NS37KUlPzrDsMxY3ZDh13q3iXcyOiLioPm2Hiur9PXpmQ79ZxWUh75D0m/Vp75C0LaV0b8P8Nqq4hOTwLMtb3vT90ZTSQHMoIhZI+i0Vp2ybP3uws+H/0+u8Tc+2Vd99qcv0qenPzbJu0+tn7wKVUvpGw7repOL08hclXe9+FziHMd48e/1O6K5zKaUtEfElSW+OiM6U0tiJ/D4wR3L35Ubfta2nlCai+Jzeuvr3kxHxfhWF/4GIuFfFZbA3p5Qer//aeZK6Vb9ErMmj9X/XneRjOhEnOvacjOmiaOi4qRMzPS7df5zMUhVnp6ftbPr5eSru5Pc2zXz5s/Tsx797hsyx+nymXVr/95HjrJtUPIbVmv25P2s+m0lxdApExAtVXNcpPXvjl6QbI6I3pTSUUtoeEfeofrBSHwxfquIAplFJxTsnvzbLYo81fT/bH/ZPqaj+P6xnrrevqbg+tvGdhukP86UZ5tH8Qb/pM5Q/KWm2Oz99e5bps0opVSLi7yT9bkRcmlKa6cAJOKcx3szohMebut31+S/W7I8JeF6cyL7cNH2m/ea7Ayn9QUR8TsXlaa9WcWnZr0TEv0kNd22bZV7H2z9z1uFEDqpPdOw5GVfV/53LY4rpcelGFWftZtJcdDSPMdPz+JyKy+tm0nyXvBO5u53bTkoqzhD921l+ftaMiRRHp8Z7VFw//w4VBwKNVqnY6N+qooGYVJwy/4OI2Czpn9Wn3dz0e9tVXCt6y8muVEQsUnHa/DdSSh9s+tlvNsWnB+bLJD3Y9LNLm76fPkA5+lzWbxbTl9Q9p7vUAGcxxpu5s17FtfVH53i+QI4T3ZenbVRRUEj6zk0NLlLxofrvSCntkfRRSR+N4s5v35T02/X5HZY0ouImSM2mz4rsOs66TxctzXez7Kiv+3etynHm85zHnuOJiIUqCpjdkh438RMxPS7tSyndc5LzOCxpUFLHHD/+6SLwahWXLM5mu6Tvk3R7Sqm559NZhc8czbOIaFVx3eatKaXP1q9nb/z6uIoP5zXeeebTKqr/d6i4m8y9KaWtTbP+lKTLI+IdMyyzHBFLmqfPoKpiUPqu7SKK21C+uCl7i4prlN8bDbe5jYhX65l3tqZ9RcXA+IGI6Jxh/Y57Grx+68hnrX9E9Kj4DNKonv25A+Ccx3hz4uNNPTPTePMSPfOZi5l6ngDPm5Pcl6f9XNP3/0bFjQC+VJ93V/O+klI6pqLYWVz/vqriM3uvarxrY329flZF0fbV4zyEXSreWHhV0/T36dlnjqbPrMz0pudcjD0zioheFW8ELZL0n1P9bgNz5DMqHv8HI+JZJyZyxqX6a/C3kl4XEd87wzw66o/hRP2jituX/1K9OGycZ+OZ+U+puOTw/TMsO+qfPzorcOZo/r1exXWjXzhO5osqDgIuTCntTikdjIhbVJzK7NWzBzpJ+u/1eX8iil4e0+8SrZf0Q5J+VcWda2aVUhqKiFslvb8+UG5VcXr5x1QUH70N2SMR8bsqTm3fGhF/r+Ldn5+U9K2m7FBE/ISKg64tEfHXKj7At0rSS1S869R4CU2zHklPRcRn6utxVMW7Xu9Rcf3rTzd+sBvAdzDenPh4I0k765cYPariUr+rJf24ioO2nze/CzwfTnhfbpi+ISL+t4ri5go9s9/8ef3nl6m4nfRnVGzzwyrOELxO0l80zOcDKm54cmsUTdgPqyjYXqLiTndHZluxlNJgRHxS0k/VD7i/reKS3etV3F2yMXswIvZIentEbFXRz3BnSuluzcHYU7cpIn5ExSWBvSrGnn+h4rM4v51S+ovj/fKJSintioj/IOkPJN0fEX+r4q6cq1Xcya8m6ZUZs/plFa/NbfXPXT+o4gqaDfX1/yEVN1w4kXUbioifUnGb94cjYvruhxepeH2nx8ubVdy987ej6M10m4pL6dbVp98s6YMnsuzT1qm+Xd659qXiPvQ1SauOk3mNindUf7Vh2rvq0yqSVs7yex0q7l//bRUNw/pV3LLxQ5LWNuRul7R3lnksV9GD5JCKU+h3qrj++CY13C63ng0V7yDsri/vAUk/oOIdksdmmPeLVFwre0TFu0x7VAzmbzfPWZuk36/P/5iK2wofrP/u95/q15Qvvk7XL8abEx9v6r/7JyoOHvtVnLF6SsWlRRef6teUr3Pz62T2ZT1zK+/VKvrTDKi4LOtTklY0/N5SSf+jvi8PqiiOvqWi31DzraQvV9Gr51h9P3xQ0rtnWJddarg9d33aIhUH4AMqbnbwBRX9hGbKfp+KS7zG64/hpoafZY09szxHF+mZdiBJxRnsY/Xx5A8lXTPL7z2nW3k3/Ox1Ks6w9dfXfaeKOw++riHzo2q6LfcMz+OHVFwON1Ef4+5R0XJgiVsPFQVMmmH6K1WcRZpuW/C4pP/elClJeq+KG0uM1l/HR1VcjnnFqd5P5uor6g8WmDMR8bCkgyml+bitJ4BzGOMNAGAu8ZkjnLRZrud/jYpLUG6b/zUCcLZivAEAzAfOHOGkRcTbVVx//0UVl8VsUvFBz0OSrk5NzcQA4GQx3gAA5gM3ZMBzsUXFBzJ/StIyFdfQfkbSBzhQATDHGG8AAM87zhwBAAAAgPjMEQAAAABIMpfVRUR6dm8uAOeKlCrhU3OHMQc4t83nmMN4A5zbZhtvOHMEAAAAAKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIElqOdUrAAA4F4RPRGvmvObqfb3a3Mwm5c0nqTI3ywNglG2iFB15swo/3kTGmJTS3Oz/SVN5uTQ5J8s7F3HmCAAAAABEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSawM6biDabWdp1lc0s0kqfqS22mSXRZTOLW/M2j/ULks2Uff9HDU75Wn00o4faluFRH5L0ROkRm+kbedhmaOyI+RSZjQvLpW6baSl3PtfVkSR1tSy1mfbosZmcRoq5RmvHbGZ4cr/N5DxHPa0rstapLN/kdqzq13uyNmIzE1OHbIYmkXBK4ccRSVmNUnvazreZVeWNNtOS/LHJeWmRzaxoa7cZSXrZcr+frOwcs5mS/LHSgTF/bPaXe8dtRpLuH/sbm1nY6Z/vWvJNZ4cn9mbMx49bpwvOHAEAAACAKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJAkRUqzN6WKiCSV53F1cszN+pRLviFhb/sam7kgNmQt76o23yTwhhW+sdeL1+y2mQs3breZ3mufthlVfcMySRp9zD+2kSOLbGb3rgttZiCjQdq3+nxDSknaMewbMn5txDc2e2LkSxlLq2ZkTj8pVTLa986d03PMmRulUq/N5DRBbW3x85HyGpMu0HKbuUxrbebiHt+UcVl7zWYWtubtJ8MV/zz1TfjtaCxjcVN+tXV4Im+9R6u+YfSofMPJwdKQzRxKu/x8Jv34Vqn4prTS3DXDns8x53Qcb0IZzdczGq7mWNJ5pc1cmC7PmtcLe3zz+bdceMBmLlu7x2aWrvHNm7s3+2VVns5rAnvkW+ttZnKyzWYe2XmJX9aYbzr9SL8/DpKkbw9M2Mzt439lM9d3vMtm9pb22cz+qS02Mz7lXzdp7hpYzzbecOYIAAAAAERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJOV0G5s/Oc3PInzTro6282xmTcsLbObC5JsoXrck7ym8dqlvpPeKF91rMz0/7R/b5Lp32sx4i2+Cm/ofsxlJar/W57oeeNAvr+qb8Q32L7CZlYuO2owkXXHUN4ut7VltM326xmaOjNyfs0o4i+U0eG0pd9vM2tbNWcu7KK2ymYt7fCPkS3unbGZB66jNLOscs5nzF+Q1HO3OmNdUxT+2tlbfSDAl35N0aNS/bpL09MAim9k74sfmQ+M+s2Vgpc080PGEzewf82O3JFWqea/duSyiw2bKJb8tLe241GYurvkG9Zd1+O3o4t68BscbF/bbzHWbH7aZjl/xjWlL5/+wzYwf+IbNtH/hCzYjSeUW3+B42w7fKPZPty60md3piM20atxmJGms5MflBR1+va9b5I+7X1JaZzO3HPWNgrek221GkiamfNPZ54IzRwAAAAAgiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACSdZk1g21qX20xnyxKbWZ7RjOoS+QZ5i9v907OswzcRlKQ1GY1J2xeO2EzLlr0+s2+XX6FazWf2+mZkkjS1P6OhbKtvftfW5Rs7tgx32czklG/+KEkrFgzYzBULfdOybcO+KefXS9ttplrz64MzV5Lf5xa1rbWZDXFB1vIu6PZNldd0+eaGPa2+CezFC30D0PMW+cyChYM2I0ntnb4J4rJXbPMzWtBpI2N3+zFn5JD/uyRJayv+b8qRw7459a7DvkF5e9mPXWOHfQPIaod//SXpwNhDNlOrDWXN62y1uNM3Zl0SvvH4hljjM4v8e98DGS/tsUk/jkhSR9mPJWOD/lih55O32czwni02s2+3fx7v2PUem5Gk+4/645fPj95hMxPVh2xmZGKXzUTkHeMo+ddkU9eNNtORsQlctdiP3YNTfkyaGn65X5ikx2p+O6lU+7LmNRPOHAEAAACAKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJB0ujWBLXfbTGdpoc10JN+0ry/5hqNtFb8+e0babEaSFhxaZTNHb/cN0qpf8/VsteYz41X/0i9o840WJWlxl29eu3r1Ppsplao209ntX7fu3mGbkaRjfb5xY0arXCUlm+lp9407B8ZoAntWS35rags/dnWU897Tymnc11H2225vi+8U2Z7RKHbx4n6/Pr1+LJGk8SE/Nu//yiabGRxYYDPt7RM209UzajOSFOG3gQj/mrS3+OaOHWW/rM6y30jWVy6zGUnan+7Nyp3Lukq+CebSqm8C3Nvux4B9/k+lBqf8dvTgSF7j3v2jfr3/dpdv8PnI/77UZh6buMVmVndeYjNJu21GknYP++Vd3fU2m1nX5o85+kp+vOkr+YbakvTk5N0205p8Q9meVn9sdmjcN8pd0OrHtms7ltuMJG0Zn8zKnSzOHAEAAACAKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkqSWU70CjSarvjv6VHncZgZL/TZzOA3bzHB1pc0c6V9oM5K0dbDdZiYyOhUf1qDN9KYumzm/rdNmVnYushlJ2rTId4fv6vTtujs7/Gu7YGm/zUyO+k7NkjQ00m0zxyb8LlKT79QcvA9xziuV2mwmqWYzbaXIWt7Sdt/VfEHrlM2sW9JnM8MTfp8bH/OZo0cX24wk7T7iu6hXk9/netp8N/qU/PM9NuXHbkk6PO7H3T0jPrN31D+2fWMVm3k6HbGZHbV7bUaSUvLj97luuOqf79E4z2b6Jv2+tLX0pM0M1A7YzETVH3NI0khls81sbr3AZj50hR8Dtw/8kM30Zoxt3zjcYzOS9KfDX7OZF3Qsy5qXc+2SnL8TK7Lm9dTIjTbzd8O32Mzv7vfLurh0rc0szzgOuk93+4VJqtWGsnIniyM2AAAAABDFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkk6zJrC15Jt2DU36blT9Nd/8rFbzjTsPl3yjtfaWRTYjSS3yTWBzGoXmNIlcVD7fZnorl9jMunLZZiRp4xLf2G7V2qdtplrxm2OU/OM/dmyRzUjSXfvW2MxDx5LNHCz5xz9Z8U2HcXbLGXNGa8dsZrji9wFJqvpNV5csOmozF63bZTNtnb6ZaveFvuHk7juvsRkpr3nrvYd8o8QdQ7556/4x30x3ezpoM5I0oH02k7MNTNb8eDIx5edTS745d0p+u0WeavLb7dOx1Wa2V/xrOzb5lM10tvm/ga9o/Wc2I0nX+961uvGKR2zm0n/+TZu54ZBvdD92cInNfOPTb7AZSbqo5zU2876rt9nMmnV7bGbxlf74tTLoG9hL0u3/59U2M/akzzyYcdz9+NQ/2cwjWWPSiM3MB84cAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkHSaNYGdqviGhHl8074ctapvRlWtDmTNq71tpc20lXxjr1XljTazqXSBzbz2fN/87/XXPGAzkrTy1d/OyjnVI75R7tEt62zmif3+8UvS/UfbbGZ3zTd4HYycJrB52wnOVOET4YfbsYzmjn0ZjSQlaWmbf+9r4/rtNrPsX/pxeeKyzTZT+gffAPHgUd+4UZK29fvcI8d8g9f7JvfazJMTd9pMtTZuM5JUqw1l5XB2Gp7Y7UPJN3lO8h2e21p8E+T28gKb6c5sBn/FYj9OrP9B3+B17I3vsJm2px62mZve+WKbqWU0ypakh358i820/PQmm+m45FdsZmz8ab9Cv/lhn5F0ZNw3y815DhbUem0mMs61nC4NXnNw5ggAAAAARHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAICk06wJ7Fw1b507vrFjubwwa05dLUtt5kXxEpt55Qrfseu68/bZzEu+/+s2U37japuRpLErf91mWnZ+1Wduu8tm9uxZYzO3H1xsM5L0+NigzRwo7bSZwSn/fFeq/TmrhDNUKK9R4lwYVV7D0bayb6q87MWP++W9/D/YTPeCK20mLbzFZpb0+n1SklaO+obZL1zq/7ylPt8wujO9xmaekn8eJWkwo+nsVOVw1rxw5klpck7mE9Hhl5VxPFVLUzazc6o/Z5U0MuWbLk8d9E1nW/70Czaz464X2MwH9nzZZrbeuMxmJKnlZ67JyjmV+z9iMyMfG7OZv7j9LVnL+/ohn3k6+fFmX/hm4WOTGc1rzyCcOQIAAAAAURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACDptGsCO39K4ZsItrb4Bq9tZT8fSeou+Saw12f0I3vT5VtsZv2bfDPV6nWbbaZ21bv9CknqynieRte91mZaL9vhMy0Vm5mq2YgkaaDU7zNZTRsHMpbmm/fi7JbkN8zWUqfNdMk3gJSksYpvTDt12DdlLLf7gWl0xO+7/pFJazdty0hJS1f4xoVXDvTazJon19tMz/4lNlMavtxmJOnpdt+Y92jGfGgUe67zY0nONlKtjtrMSNfGrDXaMbTKZv760zfaTIT/W/n+3XfbzPvOe5PNlFv+yWYkqfyp/+dDJf+a1Kb8IfeD3/JNp+85knde487qbTYzPOGPcUqltoyl+abDZxLOHAEAAACAKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJB0DjeBVfi6sCWjIWNveWXW4tZXL7aZqRQ2094x4Rd2vm9aWOvwzWur+zIan0kaWbLBZiL8plZZudZmNr3Br9PLDlxgM5L0wE7fvDYytpNa8o30cHZL8s2JlaZspJr8fPrKx3JWSQfGfVPGHXdfbTMX/acP20zbcr9OI09njJUZDSAlqWdln830XnDIZl5YzmjcKD++dR1ebDOSdF//FTbzeLvfTo7Wxm2mVhvKWieceVKatJkI37gzJX88sXX0K1nrdPOR19nMJfJjwJdH/txmfnbVv7WZH1jjm5vu3Xu+zUjSwDF/rJAyjt/a2vzrdvGq/TbTtWOFzUjS4NBWm2lt8U2+azW/3mcbzhwBAAAAgCiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQRHEEAAAAAJKkllO9AqdKTvfwMd+oXN0tS7OW1xd+ed865uf1yfuvs5lXHPBdqNvKFZtpKQ/ajCQtO+/LNnP+Kx+xmdqLNttMabmv5y9e3GczkrR+31qbOVBZbzMHa77L+FTlcNY6YX6VotuHYm7eQ0q1MZupZGT69HTW8h45ttxmyo9eYTMb915oM92tvoN6T0Z3+K42vy9JUi2jG/2SXj9+LVnmx4rLV+21mZ1DPTYjSWs7OmxmeOIym5lo849teNxvS0n+7wDmTrm00GYi5uawrJpxjBPRPifLkqSdE3fbzIHWJTYTpU6beah/3Gbu7ffLGgk/H0m6ptMfB7zu/AGbuXTZQZs5NurHkgu6bESS1D68ymaqGccvKflxIjLKiTNpvOHMEQAAAACI4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAElnaRPYUqnXZ+ao0Vrf2BNZuaGWAzZzIHzz1jsPt9rMHx/xmRb55m8rqr6BmCRt7vVNC39k3/k28+IVd/iFtZRtpLs9r7HbJb3JZrYe8es92LLfZipV3yAuJd8kE3Mrp+Hi+s4bbGZJbZHN3Df1DzaTMy6dl9bYjCSNparN3HnEZ2454pdVkx9zLuvwTa6vWpTXJPDyRf02U6lmjBU9IzbTUvbP0eruUZuRpG1DC/zykl/vrrJ/Lkdin83kNHfE3CmV2mzm5a1vspmQb4L8tbE/8/Mp+eOATR2vtxlJ6k6+wfG+sm9gXWubspm7ql+xmctaXm4zF5d9o1hJ2jbq9+9vPumb7n6oxS+vK6NZ9tEJ//pL0obW623m2+P+71KSf02S/Dh5JuHMEQAAAACI4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEmnWRPYyFidpd0vsJlF8s1U9049bDMTU777YUp5DUfHJn0T0LHJp2wm5BsEJuU09vNNxPZEV8Z8pInwTes2H1luM9fc51//9k0TNrN0WZ/NSFL3zkv8vKLHZvaXltnMsPbaTBJNYOdbLfntaVXNv74r2nwzxQ79c5sZzdgGask3L5akY+EbnO6PnX6dqn5/Gq/0+2WVr7SZytErbEaSLujyz/fla3bbzIJlx2ym7wnfuHGkkventJzRu3Fxxrh7MCOT0+BYeZsS5shU5ajN7Ozw22176rSZ7+v8cZvpCH880duSt22P12o2czj5/bYUvqF0KXKakvr1uagn7/zAVS2+we1n+/zz9CfbFtvMf7p2l8289Dw/tkvStl3++KWl1G0zk5W8JtdnE84cAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkHSaNYHtaDvfZtbWNtjMQmU00StttZmJlNNMNZfv/tdS9s0GldHYrFL1jQ0jo9FaW6tvWCZJ7cnPq7clo2lb8rV61fdS1dhIXvPaHBk9G9WtRTaT5B8/5l9OE+f9Jd8MemP7apt5w2rf4LVvwjdJ/NrBvGG7r+Ib3E6kYZup1Px8lrT7hsobapfazLpFNiJJunTpIZtZtna/zaSa38NT8pnJqm+mKeX1XC2HX141zU0TTMyvZd3X2MzRim8CWw4/Tnz0Kv93ubvdj3+ff/Iim5Gkp0b83++9ow/bzOSUP37Z0PVam7mq7TybWdmRd4y3fsGgzdzdt9RmJjMa5e7u9/MZnvKvrSSVM0acpKrNdLWvsZmxiZxG93N5TP384swRAAAAAIjiCAAAAAAkURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASadZE9hFrb7R1DWdvjHpotaMxlf932cz97T7RnuTlSGbkaRyyTdt62zxTWAHJ3yDuIg2m+lpv9BmNsZLbEaSXr/cN129ftM3baZtRb/NpMm8Zos5yuG3k2pGE7VjaZ/NpOQbgOL0tHX0yzazrPZOm3nzhSM2c8NVvkniS/evtBlJOjbaYzPfOHC9zewY9k1JL1/omxuu6/GPf3G7bwApSevW77SZ1u4xmxk97P+eTFR8w8X9Y37MzTVe840Sh5NvgqtEE9j5k9MuXFpXu8Jm3nuxf8/6q/u7beaBw76558+//89s5qonl9mMJB3esdZmLr33jTYzUvHP5Q2rfGPuzlbfBPrxo77hqiTdc8SPE7dNfMpm3rf8h22mo+z3//UL+21Gklr3+fUuhR/fpqp+7C6V/DZZrQ3YzOmCM0cAAAAAIIojAAAAAJBEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSpZT4WEpmLaQ/f0X11l+/6vXmp77J+9eJ2m7nm2A/azEP94zYjScOasJma/GMb7lxnM8tr59nMixd32MyNF++2GUnadN1tNtN9w6ifUesSG6k+MmQz2/Zf4Jcl6ciE3y53lHbYTP/ok1nLw5kppUmb2VbaYjPfOPBCm1m99LDNXPXKu2xGklrX+q7mr/K7kyYPLrSZoaeX20xL25TNlNt9JtdYX8Z6H/WZnceW2szAVGSt0+BkspntZT/mjI4dsplayhhzMSci/PGEJC0I/3f32jVbbeY11/hjnJ/4wktt5slf+3c289/e9TmbkaS17/F/B3/tzQ/ZTBryx0G7vvI9NvPlxzbZzK/vzRtLN+ham/n4pW+2mcuWPGUzK5f22cynHrnaZiRpIGMMSMk/31MV/3dJKmdkzhycOQIAAAAAURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACBpvprAljqzcqPJNzar+h56umT5fptZktEo9nWtviHhyECvXyFJ7R2+CWy15mvV0eFum+np9Y3GllziM+0bB21GkipX+QZpU22++V3rXXfYzC2feYvN3HFwmc1I0jeP+ga+B6q+uWet5ptt4syV08R6ojZsM1894reTxds32MzbMsYuSVq+5lEfuuZiG2mr+HFw2dPbbSaN+GaDtf68ZpqTGQ1eBw74cWDPgVU28/igH3OPZTR3laRHp3wzxWNVPzZXqjljc9464blrLS/Kyh1LfgwYHfPHS5e/5QGb+eLL/d+uf/qrN9rMJX/s9xFJqnzcH+Ns1o02s6S1zWbuqPrHP16702be1PlKm5GkG9f6xqzXXOLHwKGhHpv56L2bbeaegYzu3ZJ2ZDQnn6wcyJhTToPXakbmzMGZIwAAAAAQxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQRHEEAAAAAJKkSGn2RnERkfKaPzl58+huv8hmXlJ+lc28ZbVvNviKi7fZzOqNO2ymY+VRm5Gk8lULbGZiw/fYTHXJpX5hVd/ctDzgGw0qVXxGUkv/QZuZ/LRvNPal//tqm/nyPt/88aHxQzYjSU9mNJKbqPjXt1Ltz1jamdmQMaVKzOfy5m7MmTsRvilhW4tvOLq07RKbWVNbZzMvXJDXePqNa/w+t2m9b1y44vpv+4X5IVfV/i6b6d++2s9I0rceu9xmhif96/b0iG/KuHXIN6b9xkBeY97t6X6bGZrYaTMpTWYt70w0n2POXI03OY2iJWlV90tt5q29V9vMe7/nQZu5+E1320ztistspvSEP1aSpMH7/L47dGSxzTy4zTfCznGcQ9vvuPyCjOMgSRe97CGb+ce//0Gb+eBjfjt5YOzTNhPK20VSRmPWcskfm+Y0uk/KO1483cw23nDmCAAAAABEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgKR5awKbyy8rp1Hs1aXvtZlXLO22mZctP2Iz112xxWYkafHlu2ymZUNGrdrrG3Zl2eMbRA7ef2HWrO65b7PN3Hlwuc/0+SZij8ajNnNsarfNSFKt5pdXSxM2U6nmNYA8E9EENq8JbGt5ic10tvrM4vIam7mger7NSNKa9k6buXyh79561eJ+m+lqnbKZiYpvgDg46RuuSlJ/ToPXMZ/ZNeyX9fDUfpvZPnGHn5GkqYym0spo3Hg2OxObwGYvL6NZ7OKuK23m+vLLbOata33jzpdteMxm1ny/b5YuSbHWN8JWzY83U/f7BseTx3wj7L6nVtnMIzt9Y25J+svtvnntHbV7bObo6ON+YeGPAyPzvEZ7q1/vlPxrMja5J2t5ZyKawAIAAADAcVAcAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAICk064JbA7fH65U6rGZ9hbfHGtF2+U+U11hM5K0utWv08pO/9hmf7VOzKjvf6qxSt7S9k/6RqlPlnfYTN/kkzZTqY3ZTGRus3kNXgcz5nT2Nm2kCWxe48Yo+Yar86217BslLmpbazMrkm8GvUILbaa77F/XoWrGwCTpWPINLg+UnrKZkVqfzQxPPG0z1dqAzSDP2dwENkfOeJMyjgQ62nwT1E3lG2zmkra8xvMXdPlMd4tf710j/j37p8bGbWZ7xjFHbjPVoeohmxmc8M3nU87xS/hG2D3tq21GkqYyljc+uc9mkvLG5TMRTWABAAAA4DgojgAAAABAFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABA0hnZBHY+ZTScjYzOZ5JKpbacVNa8nAg/n2rNN1Gr1UYzlzhXTVD9853T4PVsblg232gCe+YqRbfNpIxGyDlOz33Ob0elkh+/88bBs7cR9Hw715vAzi//VLe2LMuaU7Xq95OWjMbUOdpb/HwmKkM2M1U5krW8vPHNb0flUo/NVGs5jed9M13koQksAAAAABwHxREAAAAAiOIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJUqQ0e6ddukcD57b57FYvMeYA57r5HHMYb4Bz22zjDWeOAAAAAEAURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASVKklGb/YcRhSbvnb3UAnGaOpJR+YL4WxpgDnNMYbwDMpxnHnOMWRwAAAABwruCyOgAAAAAQxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkSf8fcwltMYdwPGwAAAAASUVORK5CYII=\n" + }, + "metadata": {} + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", + "avg_3 = images[threes].mean(axis=0).reshape(28, 28)\n", + "avg_5 = images[fives].mean(axis=0).reshape(28, 28)\n", + "diff = np.abs(avg_3 - avg_5)\n", + "\n", + "axs[0].imshow(avg_3)\n", + "axs[0].set_title(\"Average 3\")\n", + "\n", + "axs[1].imshow(avg_5)\n", + "axs[1].set_title(\"Average 5\")\n", + "\n", + "axs[2].imshow(diff)\n", + "axs[2].set_title(\"Absolute Difference\")\n", + "for ax in axs:\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def time_clf_mnist(clf, ns, random_state=None):\n", + " runtimes = np.empty(len(ns))\n", + " rng = check_random_state(random_state)\n", + " images, labels = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)\n", + " for i, n in enumerate(ns):\n", + " # Get only 3s and 5s\n", + " threes = np.where(labels == '3')[0][:(n // 2)]\n", + " fives = np.where(labels == '5')[0][:(n // 2)]\n", + " idx = np.concatenate((threes, fives))\n", + " X = images[idx]\n", + " y = labels[idx]\n", + "\n", + " # Shuffle samples\n", + " permuted_idx = rng.permutation(len(idx))\n", + " X = X[permuted_idx]\n", + " y = y[permuted_idx]\n", + "\n", + " # Begin timing\n", + " start = time.time()\n", + " clf.fit(X, y)\n", + " runtimes[i] = time.time() - start\n", + " return runtimes" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "ns = [10, 20, 40, 50, 100, 150, 200, 300, 500]\n", + "\n", + "clfs = [\n", + " RF(random_state=0, n_jobs=1),\n", + " SPORF(random_state=0, n_jobs=1),\n", + " # MORF(random_state=0, image_height=8, image_width=8, n_jobs=1) # Too slow\n", + "]\n", + "\n", + "runtimes = []\n", + "for clf in clfs:\n", + " runtimes.append(time_clf_mnist(clf, ns, random_state=0))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-05-04T18:07:35.981486\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqw0lEQVR4nO3deXiU5b3G8e+PEJIQEiDsJOyGJWwBAi644IKigrhWWqtobU/rUo92c+sR1INaa7VVqba2Kh61VEUFcamAuCOrLAkhJECAQCAhQEgg+zznj4w22ggBMnlnMvfnunLNzJNZ7nlIcvPOPPO+5pxDREQk2LTwOoCIiEh9VFAiIhKUVFAiIhKUVFAiIhKUVFAiIhKUWnodAGDChAnuvffe8zqGiIh4w+obDIotqD179ngdQUREgkxQFJSIiMi3qaBERCQoHbGgzCzazJaZ2RozyzCze/3jCWa2wMyy/aft69zmTjPLMbMsMzsvkE9ARESap4YskqgAznLOlZpZJPCpmb0LXAoscs49ZGZ3AHcAt5tZCjAFGAx0BxaaWX/nXM3RBKuqqiIvL4/y8vKjekLNVXR0NElJSURGRnodRUSkSRyxoFztzvpK/Rcj/V8OmAyM84/PAj4EbvePz3bOVQBbzCwHGAMsOZpgeXl5xMXF0bt3b8zqXeARNpxzFBUVkZeXR58+fbyOIyLSJBr0HpSZRZjZaqAAWOCcWwp0cc7lA/hPO/uvnghsr3PzPP/YUSkvL6dDhw5hX04AZkaHDh20NSkiYaVBBeWcq3HOpQJJwBgzG3KYq9fXKMe0y3SV079pLkQk3BzVKj7n3H5qX8qbAOw2s24A/tMC/9XygB51bpYE7DzeoCIiEl4asoqvk5m185+PAc4BNgDzgKn+q00F5vrPzwOmmFmUmfUBkoFljZy7SURERJCamsqQIUOYNGkS+/fvByA3N5eYmBhSU1O//qqsrPQ2rIhIE/L5HHtKKgL6GA3ZguoGLDaztcByat+Dmg88BIw3s2xgvP8yzrkM4BVgPfAecNPRruALFjExMaxevZr09HQSEhKYOXPm19/r168fq1ev/vqrVatWHiYVEWk6BQfKueX/VnDjc8sorwrcn/eGrOJbC4yoZ7wIOPs7bjMDmHHc6YLIySefzNq1a72OISLiqQ8zd/PA3Awqq33cOmEAUS0Dt7+HoNhZ7JHsv2c6VeszGvU+I1MG0+6+6Q26bk1NDYsWLeL666//emzTpk2kpqYCMHbs2G9sXYmINDeHKqr543sbmLdqBwO7x3PvZcPo1TE2oI8ZEgXllbKyMlJTU8nNzWXUqFGMHz/+6+999RKfiEhzt35HMdPmrCVv7yGuObUPPznzBCIDuOX0lZAoqIZu6TS2r96DKi4uZuLEicycOZNbbrnFkywiIk2txud48dMt/HVxDh3aRPHk1NGM6pPQZI8fEgXltbZt2/L4448zefJkbrjhBq/jiIgE3K79Zdz7+jq+3LqPswd34fZJg4mPadpdramgGmjEiBEMHz6c2bNnc9ppp3kdR0QkYBak5/O7t9bj8zn+55IhXDC8uyc7C1BBHUZpaek3Lr/11ltfn09PT2/qOCIiAXWwvJo/vJPJO2t2MjipLfdeNoykhNae5VFBiYgI67bvZ/qcteTvL+P6M/px3Rl9aRnh7SEDVVAiImGsusbHrE828+xHm+kcH81TPxrD8J7tj3zDJqCCEhEJUzv3HWLanHWs276fCcO68asLB9EmOniOOaeCEhEJM8453lubz+/fXo9h3HvZUM4b1t3rWP9BBSUiEkZKyqp4eP56FqTvYnjPdky7dBjd28d4HateKigRkTDxZe5e7n19HYUlFfz0rBO45rS+RLQI3mPNebtEIwTMmDGDwYMHM2zYMFJTU1m6dCnjxo1jwIABDB8+nLFjx5KVlQVAZWUlt956K/369SM5OZnJkyeTl5f39X3p8B0i4oXqGh9PLczmxueX0zLC+Mv1Y7jujH5BXU6ggjqsJUuWMH/+fFatWsXatWtZuHAhPXrUHovxpZdeYs2aNUydOpVf//rXANx1112UlJSwceNGsrOzufjii7n00ktxrvaAwjp8h4g0tW1FB/mvvy9l1iebmZiayKyfncKQpHZex2oQFdRh5Ofn07FjR6KiogDo2LEj3bt/843E008/nZycHA4dOsRzzz3HY489RkREBADXXXcdUVFRfPDBB/9x3yeffDI7duwI/JMQkbDknGPeyjymPr2EvL2HeOB7w7n74iHERoXOOzshkfSxdzPZuKukUe+zf9c4bjt/0GGvc+6553LffffRv39/zjnnHK688krOOOOMb1znrbfeYujQoeTk5NCzZ0/i4+O/8f20tDQyMjI4++x/HzpLh+8QkUAqPlTJg/My+DCzgFF9Eph2yVA6t432OtZRC4mC8kqbNm1YuXIln3zyCYsXL+bKK6/koYceAuCqq64iJiaG3r1788QTT7B3795691XlnPt6XIfvEJFAW765iPteX8e+Q5XcPL4/PzilNy2C/L2m7xISBXWkLZ1AioiIYNy4cYwbN46hQ4cya9YsoPY9qLS0tK+vl5CQwNatWykpKSEuLu7r8VWrVjFp0iRAh+8QkcCprPbxlw+yefnzXHp2iOX3PxjJwO7xR75hENN7UIeRlZVFdnb215dXr15Nr1696r1ubGwsU6dO5Re/+AU1NTUAvPDCCxw6dIizzjrrG9f96vAdjzzyCFVVVYF7AiISFrYUlvLjv33BS5/lcvGoHsz66ckhX06ggjqs0tJSpk6dSkpKCsOGDWP9+vVMnz79O6//4IMPEh0dTf/+/UlOTubVV1/ljTfeqPelv7qH7xARORbOOV5fvp1r/7KEguJyHv7+CG6flEJ0qwivozUK+2oJtJfS0tLcihUrvjGWmZnJoEHevbQXjDQnIvKVfQcrmTE3nU+zCjmxXwf+55KhdIyL8jrWsar3TbKQeA9KRET+bUl2Ife/mU5JWRW3TRjIFSf2DNmFEIejghIRCREVVTXMXLCRV5Zuo2/nNvzp6jSSu8Yd+YYhKqgLqu4S7XAXDC/Fioh3cnaXMO21tWwqKOV7J/bkxvH9iY5sHu81fZegLajo6GiKioro0KFD2JeUc46ioiKio0Pvg3Yicnx8PserS7cxc+FG2kS35NEfjuSU5E5ex2oSQVtQSUlJ5OXlUVhY6HWUoBAdHU1SUpLXMUSkCe0pqeB/31zHFzlFjO3fibsnDyahTcguhDhqQVtQkZGR9OnTx+sYIiKe+HhDAQ/MTaessoZfXziIS0f3CLtXk4K2oEREwlF5ZQ1/+lcWb6zYTnLXOO67bBh9OrfxOpYnVFAiIkEiK/8A97y2lq17DvKDU3rzs7OTadUyfPencMRnbmY9zGyxmWWaWYaZ/bd/fLqZ7TCz1f6vC+rc5k4zyzGzLDM7L5BPQEQk1Pl8jpc+28L1z3zBwYpqHr8mjVvOGxDW5QQN24KqBn7pnFtlZnHASjNb4P/eY865R+pe2cxSgCnAYKA7sNDM+jvnahozuIhIc1BwoJz7Xl/Hii17OWNQZ+66aDBtW+uApdCAgnLO5QP5/vMlZpYJJB7mJpOB2c65CmCLmeUAY4AljZBXRKTZ+GD9Lh6al0FlteOuiwYzaWRi2C2EOJyj2n40s97ACGCpf+hmM1trZs+aWXv/WCKwvc7N8jh8oYmIhJVDFdXMmJvOXf9cQ2L71rzws5O5aFSSyulbGlxQZtYGmAPc6pw7ADwF9ANSqd3C+sNXV63n5toNgogIsH5HMdc8vYT5X+5g6ml9eebHJ9KzY6zXsYJSg1bxmVkkteX0knPudQDn3O46338GmO+/mAf0qHPzJGBno6QVEQlRNT7H/326hWcW59AxLoqZ145mZO8Er2MFtSMWlNVuc/4dyHTOPVpnvJv//SmAS4B0//l5wMtm9ii1iySSgWWNmlpEJITk7y/j3tfXsXrrPs4Z0pXfTEwhPibS61hBryFbUGOBq4F1ZrbaP3YX8H0zS6X25btc4KcAzrkMM3sFWE/tCsCbtIJPRMLV++vyeXj+enzOMe3SoUwY1k3vNTVQ0B6wUEQklJWWV/HI25m8tzafoT3aMf3SoSQmtPY6VrDSAQtFRJrCmm37mD5nHbuLy7h+XD+uO70vLSPC+0O3x0IFJSLSSKprfDz30Wae+3gTXdrG8PSPxjCsZ/sj31DqpYISEWkEO/YeYtqctaTnFXP+8O786oJBxEbrT+zx0OyJiBwH5xzvrtnJI29n0qKFcd/lwzh3aDevYzULKigRkWN0oKyKh+evZ2H6LlJ7tWfapUPp1i7G61jNhgpKROQYrMrdy72vr2NPSQU/OzuZq0/tQ0QLLR9vTCooEZGjUFXt428f5vDCp1tIbN+av14/hsFJ7byO1SypoEREGmjbnoNMm7OWzJ0HmDQykdsmDKR1lP6MBopmVkTkCJxzzFu1g8fe3UCrli148MpUzkzp4nWsZk8FJSJyGMWHKnlgXgYfZRaQ1jeBey4ZSuf4aK9jhQUVlIjId1i2qYj731jHvkOV/PzcAXz/5F600EKIJqOCEhH5lspqH08vyublz3Pp1TGWR64ayYBu8V7HCjsqKBGROrYUlHLPnLVk7yrh0tE9uOXcAUS3ivA6VlhSQYmIULsQYs7y7TzxryxiWkXw+x+M4LQBnb2OFdZUUCIS9vaWVjBjbgafbSzkpBM68D8XD6VDXJTXscKeCkpEwtrn2YXc/0Y6Byuque38gVwxpqcWQgQJFZSIhKXyqhpmLtjIq0u30a9zG56YmsYJXeK8jiV1qKBEJOxk7yph2py1bC4o5cqTenLjOf2JitRCiGCjghKRsOHzOV5ZupWZCzYSFxPJYz8cycnJnbyOJd9BBSUiYWFPSQX3v7GOpZuKOHVAJ+6ePIT2sa28jiWHoYISkWbv4w0FzJibTnlVDb+ZmMIlaUmYaSFEsFNBiUizVVZZzeP/yuKNFXn07xrHvZcPo0+nNl7HkgZSQYlIs7Rh5wGmzVnLtqKDXDW2Nz89K5lWLVt4HUuOggpKRJoVn8/x0ue5/OWDbNrHtuKJa9JI69vB61hyDFRQItJs7C4u47430lm5ZS9npnThjkkptG2thRChSgUlIs3Cooxd/O6tDKpqHHdPHszEEYlaCBHiVFAiEtIOVlTz2LsbmP/lDlIS45l+2TB6doj1OpY0AhWUiISs9Lz9TJ+zlh37yrj29L78eFw/WkZoIURzoYISkZBT43PM+mQzf/9wE53iovjztaMZ0TvB61jSyFRQIhJSdu4r497X17Jm237GD+nKbyamEBcT6XUsCYAjbgubWQ8zW2xmmWaWYWb/7R9PMLMFZpbtP21f5zZ3mlmOmWWZ2XmBfAIiEj7+tXYnVz/1Odm7S5h26VDuu3yYyqkZa8gWVDXwS+fcKjOLA1aa2QLgWmCRc+4hM7sDuAO43cxSgCnAYKA7sNDM+jvnagLzFESkuSstr+L3b2fyr7X5DO3RjnsvG0r39q29jiUBdsSCcs7lA/n+8yVmlgkkApOBcf6rzQI+BG73j892zlUAW8wsBxgDLGns8CLS/K3euo/pr6+l8EAFPzmzH1NP66uFEGHiqN6DMrPewAhgKdDFX1445/LNrLP/aonAF3VulucfExFpsOoaH89+tInnP95M13YxPP2jMQzt0c7rWNKEGlxQZtYGmAPc6pw7cJgPwNX3DXcM2UQkTOXtPcS0OWvJyCvmgtTu/PL8QcRGa01XuGnQv7iZRVJbTi855173D+82s27+raduQIF/PA/oUefmScDOxgosIs2Xc463V+/k0XcyadHCuP+KYYwf0s3rWOKRhqziM+DvQKZz7tE635oHTPWfnwrMrTM+xcyizKwPkAwsa7zIItIcHSir4revruF/30xnQLd4XrzhFJVTmGvIFtRY4GpgnZmt9o/dBTwEvGJm1wPbgCsAnHMZZvYKsJ7aFYA3aQWfiBzOyi1F3Pt6OkWlFdx4TjJXje1DRAvtRy/cmXPevz2UlpbmVqxY4XUMEWliVdU+/ro4hxc/20JSQmvuu2wYgxLbeh1Lml69/xvRu44i4omtew5yz2tryco/wORRSdw6YQAxrfQnSf5NPw0i0qScc8xdmcdj720gqmUED01JZdygLl7HkiCkghKRJrP/YCUPzMvg4w0FjO7bgXsuGUKn+GivY0mQUkGJSJNYmrOH+95Yx4GyKm45bwBTTupFCy2EkMNQQYlIQFVU1fDUomxmL9lK706xPPbDUfTvFu91LAkBKigRCZjNBaXc89oacnaXcvmYHtw8fgDRrSK8jiUhQgUlIo3OOcdry7bx5PsbaR3Vkkd+MIJTB3Q+8g1F6lBBiUijKiqtYMab6XyevYeTkzvy24uH0KFNlNexJASpoESk0Xy2sZD/fTOdgxXV/PKCgVw+pieH2bG0yGGpoETkuJVX1fDk+1m8tmw7/bq04cmpafTrEud1LAlxKigROS7Zuw5wz2tr2VJ4kCkn9+KGs5OJitRCCDl+KigROSY+n2P2F1t5auFG4mMi+ePVozjphI5ex5JmRAUlIket8EA5972RzvLNRZw2oBN3TR5C+9hWXseSZkYFJSJH5cPM3TwwN4OK6hpun5jCxWlJWgghAaGCEpEGKaus5o/vZTF3ZR4DusVz72VD6d2pjdexpBlTQYnIEWXuKGbanLVs33uIq0/tw3+deQKRLY94QG6R46KCEpHvVONzvPTZFv7yQQ4JbVrx5NQ0RvXp4HUsCRMqKBGp1+7iMqa/vo4vc/dxVkoXbp+UQtvWWgghTUcFJSL/YWH6Ln73VgbVPsdvLx7ChandtRBCmpwKSkS+drCimj+8k8k7q3eSktiWey8bSo8OsV7HkjClghIRANK372fanLXk7y/jR2f05Udn9KNlhBZCiHdUUCJhrrrGx6xPNvPsR5vpFB/Fn68bQ2qv9l7HElFBiYSzbUUH+d8301m7bT/nDu3GbyYOok10pNexRAAVlEhYqq7xMXvJVp5ZnENkyxZMv2woE4Z19zqWyDeooETCTPauEh6Ym07mzgOcNqATv5mYQqf4aK9jifwHFZRImKis9jHr4808/8lm4mMiuf+KYZwzuKuWj0vQUkGJhIH0vP08MDeDzQWlnDesG7dNGEg77X1cgpwKSqQZK6+s4ekPsvnnF1vpFBfNH64aydj+nbyOJdIgKiiRZmrlliIemJvBjn1lXJKWxM3jBxAbrV95CR36aRVpZkrLq3ji/Y3MXZlHUkIMM68dzag+CV7HEjlqR/yYuJk9a2YFZpZeZ2y6me0ws9X+rwvqfO9OM8sxsywzOy9QwUXkP32aVcD3Z37GW6vyuOqU3rx4w1iVk4SshmxBPQ88CbzwrfHHnHOP1B0wsxRgCjAY6A4sNLP+zrmaRsgqIt9h38FKHns3k/fX7aJf5zY8dGUqg5PaeR1L5LgcsaCccx+bWe8G3t9kYLZzrgLYYmY5wBhgybFHFJHv4pxjQfouHn0nk9KKan48rh9TT+urgwlKs3A870HdbGbXACuAXzrn9gGJwBd1rpPnHxORRlZwoJyH56/n06xCUhLjuXvyEPp1ifM6lkijOdaCegq4H3D+0z8APwLq+8SfO8bHEJF6OOeYuzKPJ97fSLXPx8/PHcCUk3sR0UIfuJXm5ZgKyjm3+6vzZvYMMN9/MQ/oUeeqScDOY04nIt+Qt/cQD87LYOWWvYzs3Z47Lxqs4zVJs3VMBWVm3Zxz+f6LlwBfrfCbB7xsZo9Su0giGVh23ClFwlyNz/HKF1t5+oNsIloYt09MYfKoJFpoq0masSMWlJn9AxgHdDSzPGAaMM7MUql9+S4X+CmAcy7DzF4B1gPVwE1awSdyfDYXlDJjbjoZecWM7d+J2yem0Lmtdu4qzZ855/1bRGlpaW7FihVexxAJKlXVPv7v0y08+/EmYqNa8ovzB3Lu0G7auas0R/X+UGtPEiJBKHNHMTPmppOzu5TxQ7py2/kDSWgT5XUskSalghIJIuVVNTyzOId/fJ5LQpsoHv7+CE4f2NnrWCKeUEGJBIkvc/fywLwMthcdYtLIRG45dwBxMTr8uoQvFZSIxw6WVzNz4UZeX76d7u1jeGJqGqP7dvA6lojnVFAiHvo8u5DfvbWeggPlTDmpFz89+wRiWunXUgRUUCKeKD5UyR/fy+LdNTvp3SmWv15/IkN7tPM6lkhQUUGJNCHnHB+s380jb2dyoKyK607vy3Vn9KOVdu4q8h9UUCJNZE9JBb9/ez0fZRYwsHs8j18ziuSu8V7HEglaKiiRAHPO8fbqnfzpvQ1UVPu4aXx/vn9yL1pGaKtJ5HBUUCIBtHNfGQ+9lcGyTUUM79mOuycPoWdH7dxVpCFUUCIB4PM5Xlu2jacWZWPAry4cxKVpPbRzV5GjoIISaWS5haXMmJvBuu37OemEDtw+aTDd2sV4HUsk5KigRBpJdY2PFz/L5e8f5hDTKoJ7LhnC+cO7a+euIsdIBSXSCLLyDzDjzXQ27irhrJQu/PLCQXTQzl1FjosKSuQ4VFTV8OxHm3jxs1zato7kwStTOTOli9exRJoFFZTIMVqzbR8PzM1g656DXJjanVvOG0Db1q28jiXSbKigRI7SoYpqnlqUzWvLttElPpo/Xj2Kk07o6HUskWZHBSVyFJZu2sOD8zLYXVzOZaN7csM5ycRG6ddIJBD0myXSAAfKqnj8X1nM/3IHvTrG8tR1Y0jt1d7rWCLNmgpK5Ag+zNzN7+evZ/+hKqae1ocfndGPqMgIr2OJNHsqKJHvUFRawaPvZLIoYzf9u8bxh6tGMbC7du4q0lRUUCLf4pzjvbX5PPbuBsoqq/nZ2cn8cGxv7dxVpImpoETq2LW/jN/NX8+S7D0M7dGOuyYPpk+nNl7HEglLKigRanfu+saK7cxcsBGfg9vOH8jlY3oSoZ27inhGBSVhb1vRQR6cm8GXW/cxum8H7rwohe7tW3sdSyTsqaAkbFXX+PjHkq38bXEOkS1bcPfkwUwckaidu4oECRWUhKXsXSXMmJvOhp0HOGNgZ3514SA6xUd7HUtE6lBBSViprPbx/MebmPXJFuJjIpnxveGcldJFW00iQUgFJWEjfft+ZsxNZ0vhQSYM68Zt5w/Uzl1FgtgRP9hhZs+aWYGZpdcZSzCzBWaW7T9tX+d7d5pZjpllmdl5gQou0lDllTX88b0N/OTvSzlYUcOjV41k+mXDVE4iQa4hnzx8HpjwrbE7gEXOuWRgkf8yZpYCTAEG+2/zZzPTPmHEM6XlVfz8hRXMXrKVS9J68I+bxnJK/05exxKRBjhiQTnnPgb2fmt4MjDLf34WcHGd8dnOuQrn3BYgBxjTOFFFjk7xoUp+PmsF63cU88D3hvObiSnERutVbZFQcaz7buninMsH8J929o8nAtvrXC/PPybSpPaWVnDT88vZVFDKw98fwVmDu3odSUSOUmP/d7K+pVCukR9D5LAKD5Rz86wV7Cou45EfjGRMvw5eRxKRY3CsW1C7zawbgP+0wD+eB/Soc70kYOexxxM5Ovn7y7jhuWUUHijnT1enqZxEQtixFtQ8YKr//FRgbp3xKWYWZWZ9gGRg2fFFFGmY7UUHueHZZRQfquLxqWk6oKBIiDviS3xm9g9gHNDRzPKAacBDwCtmdj2wDbgCwDmXYWavAOuBauAm51xNgLKLfG1LYSk/n7WCqhofT147mgHddNwmkVBnznn/FlFaWppbsWKF1zEkRGXvOsAtL6zEDJ6cOpq+nXV4DJEQU++uXHQENglpmTuKuen55URGtODp68aonESaEX0oRELWmm37+MWLq4iPiWTmtWk6RIZIM6OCkpC0cksRv3r5SzrFRfHk1NF0bqs9kYs0NyooCTlLsgu5Y/ZqEhNa88Q1aXSIi/I6kogEgApKQspHmbu5+9U19O3chsevTqNdrHb4KtJcqaAkZCxIz2f6nHUM7B7PYz8cRXxMpNeRRCSAVFASEt5evYMZb6YzvGd7HrlqJLFR+tEVae70Wy5B7/Xl23l4/nrG9OvAw1NGEN1KR3ARCQcqKAlqs5fk8sf3shjbvxMPfG84UZEqJ5FwoYKSoPX8x5t5elE2Z6Z04b7LhhHZUp8rFwknKigJOs45/vpBDs99vJnzhnXjfy4eQssIlZNIuFFBSVBxzvHE+xt5+fNcLhqZyO2TBhPRot7ddIlIM6eCkqDh8zn+8E4mc5Zv5/IxPfnF+QNpoXISCVsqKAkKNT7Hg/MymP/lDq4a25ubx/fHTOUkEs5UUOK56hof972xjvfX7eL6cf348bh+KicRUUGJt6qqffz2tTV8lFnAjeckc81pfb2OJCJBQgUlnimvquGuf67m8+w93Hb+QK48qZfXkUQkiKigxBNlldX8+uUvWZm7lzsmpXBxWg+vI4lIkFFBSZM7WF7NL15aybrt+7nnkqGcP7y715FEJAipoKRJFR+q5LYXV5KVX8L9Vwzn7MFdvY4kIkFKBSVNZt/BSm55YQW5haU8NCWV0wZ09jqSiAQxFZQ0iT0lFdw8azn5+8t45AcjOfGEjl5HEpEgp4KSgNu1v4ybZ62gqLSCx344ipG9E7yOJCIhQAUlAbVj7yFumrWc0vJqHr8mjaE92nkdSURChApKAia3sJSfz1pBRbWPJ6eOZmD3eK8jiUgIUUFJQOTsLuHns1YA8OfrRnNClziPE4lIqFFBSaPaVnSQt1bt4M0V24luFcGTU0fTq2Os17FEJASpoOS4lVfVsHj9buatyuPL3H1EtDDG9u/Ef583gMSE1l7HE5EQpYKSY5aVf4B5K/P417p8SsurSUpozY3nJHNBaiId46K8jiciIU4FJUelpKyK99flM2/VDrLyDxDVsgVnpnTholFJjOjVXofJEJFGc1wFZWa5QAlQA1Q759LMLAH4J9AbyAW+55zbd3wxxUvOOVZv3ce8VTv4IGMXFdU++neN41cXDuK8od2Ii4n0OqKINEONsQV1pnNuT53LdwCLnHMPmdkd/su3N8LjSBMrKqngnTU7eWtVHtuKDhEb1ZILRyRy0cgkLRkXkYALxEt8k4Fx/vOzgA9RQYWM6hofSzcVMW9lHp9uLKTG50jt1Z5rT+/LWSldiW4V4XVEEQkTx1tQDnjfzBzwF+fcX4Euzrl8AOdcvplpj6AhYOe+Q7y1agfzV++g8EAF7WNb8f2TezFpZJKWiYuIJ463oMY653b6S2iBmW1ojFDSNCqqavh4QwFzV+WxYvNeWhicdEJHfnnBIE7t34mWES28jigiYey4Cso5t9N/WmBmbwBjgN1m1s2/9dQNKGiEnNKIcnaX8NaqPN5dk8+Bsiq6tYvhv846gYmpiXRuG+11PBER4DgKysxigRbOuRL/+XOB+4B5wFTgIf/p3MYIKsfnYEU1C/zLw9fvKCYywjhjUBcuGplEWp8EWrTQ8nARCS7HswXVBXjD/7mXlsDLzrn3zGw58IqZXQ9sA644/phyLJxzpOcVM29lHgszdlFWWUPfzm24dcIAJgzrTrvYVl5HFBH5TsdcUM65zcDwesaLgLOPJ5Qcn30HK3nXvzx8S+FBYlpFMH5IVy4alcTgxLb6MK2IhATtSaKZ8PkcyzcXMW9VHh9tKKC6xjEkqS13XTSYs4d0JTZK/9QiElr0VyvEOef4MLOApxdls3XPQdq2juTyMT25aGQSfTu38TqeiMgxU0GFsFW5e5m5YCMZecX06RTL/ZcP44xBXWjVUsvDRST0qaBCUM7uEp5amM1nGwvpHB/Nby8ewvnDuxOhlXgi0oyooELIrv1l/HVxDu+u2UmbqJbcPL4/l5/Yk+hI7X5IRJofFVQIKD5UyaxPtvDasm0AXHVKb645rS/x2ou4iDRjKqggVl5ZwytLt/LCp1s4VFHNhamJ/PjMfnRpG+N1NBGRgFNBBaHqGh9vr97J3xbnUFhSwakDOnHD2cn06xLndTQRkSajggoizjk+ySrkzws3klt4kKE92nH/FcNJ7dXe62giIk1OBRUkVm/dx8wFG1m3fT+9OsbyuympnD6ws/b6ICJhSwXlsS0Fpfx54UY+ySqkU1wUd140mAtTu+tQFyIS9lRQHikoLueZxTm8vXoHraNacuM5yXzvxF46Yq2IiJ8KqokdKKvihU828+rSbfic48qTenHt6X1p21p7FhcRqUsF1UQqqmp4bdk2Zn2ymZLyas4f1p2fnHUC3dppybiISH1UUAFW43O8u2YnzyzOYXdxOackd+SGc/qT3FVLxkVEDkcFFSDOOT7bWMifF2azuaCUlMS23HPJUEb1SfA6mohISFBBBcC67fuZuWAjq7fuo2eH1jxw5XDOHNRFS8ZFRI6CCqoR5RaW8tSibD7KLKBDm1bcPjGFSSMTtWRcROQYqKAaQeGBcv724Sbmf7mDqMgW/PSsE5hyci9iWml6RUSOlf6CHofS8ir+79NcZn+RS43PccWYnkw9vS/tY7VkXETkeKmgjkFltY85y7bx3MebOVBWxXnDuvHTs06ge/vWXkcTEWk2VFBHocbneH9dPn/5IJtd+8s5sV8HbhzfnwHd4r2OJiLS7KigGsA5x5KcPfx5wUZydpcysHs8d08ewui+HbyOJiLSbKmgjmD9jmKefD+LVbn7SEqI4f4rhnF2SldatNCScRGRQFJB1WPfwUo+XL+bhRm7WLllL+1jW/GrCwYxeVQSkS21ZFxEpCmooPyKD1XyUWbB16VU43P06hjLz85O5ooTexIbpakSEWlKYf1Xt7S8io83FLAwfRdLNxVR43MkJcTww7F9OGdIF07oEqe9P4iIeCTsCupgRTWfZtWW0hc5e6iqcXRtF82Uk3sxfkhXBnSLVymJiASBsCiosspqPtu4h4Xp+SzJ3kNFtY9O8VFcNqYn5wzuyuCktiolEZEgE7CCMrMJwJ+ACOBvzrmHAvVYdTnn2L73EBl5xaRv3096XjE5u0uo8TkS2rRi0sgkzhnSlWE92mklnohIEAtIQZlZBDATGA/kAcvNbJ5zbn0gHq+y2seLn22pLaW8/RQfqgKgdVQEKYltufrUPozum0BqrwQiVEoiIiEhUFtQY4Ac59xmADObDUwGAlJQkRHGq5/k0DYSxsYbKYkRpMQbvVpDhJUAJbB5O1Wboepo79y5Omfdd37Pe8GUhfrnpsFjDby/o7jP//i3O5Kj/rc9iusHOstRR2+kn50gu5+j/jf/7jtqnPtprN/RxsjTSFGsdQytJ09unDurR6AKKhHYXudyHnBigB4LM+OpF24j8mDJN8b3B+oBRUSEiO7dQ7Kg6nsdLaD/xe/+6j/A56v/m+47EjVU3QUU315MEUyLK4IpC3xHnv8cq3eBSr1jR/E4DR07nIBeP7BZjnrRT2P96DTWz6Dupwnu5/jvwwJ8rLtAFVQe0KPO5SRgZ4AeC4BWw4cH8u5FRKSJBar+lgPJZtbHzFoBU4B5AXosERFphgKyBeWcqzazm4F/UbvM/FnnXEYgHktERJqngH0Oyjn3DvBOoO5fRESaN+2aW0REgpIKSkREgpIKSkREgpIKSkREgpIKSkREgpIKSkREgpI12g4VjyeEWSGw1escIiLiiT3OuQnfHgyKghIREfk2vcQnIiJBSQUlIiJBSQUlIiJBSQUlIiJBSQUlIiJBSQUlIiJBKWQKyswmmFmWmeWY2R1e5/GCmT1rZgVmll5nLMHMFphZtv+0fZ3v3emfrywzO8+b1IFnZj3MbLGZZZpZhpn9t3887OcGwMyizWyZma3xz8+9/nHNj5+ZRZjZl2Y2339ZcwOYWa6ZrTOz1Wa2wj/WdHPjnAv6L2oPergJ6Au0AtYAKV7n8mAeTgdGAul1xh4G7vCfvwP4nf98in+eooA+/vmL8Po5BGheugEj/efjgI3+5x/2c+N/vga08Z+PBJYCJ2l+vjFHvwBeBub7L2tuap9vLtDxW2NNNjehsgU1Bshxzm12zlUCs4HJHmdqcs65j4G93xqeDMzyn58FXFxnfLZzrsI5twXIoXYemx3nXL5zbpX/fAmQCSSiuQHA1Sr1X4z0fzk0PwCYWRJwIfC3OsOam+/WZHMTKgWVCGyvcznPPybQxTmXD7V/qIHO/vGwnDMz6w2MoHYrQXPj538JazVQACxwzml+/u2PwG8AX50xzU0tB7xvZivN7L/8Y002NwE75Hsjs3rGtI+mwwu7OTOzNsAc4Fbn3AGz+qag9qr1jDXruXHO1QCpZtYOeMPMhhzm6mEzP2Y2EShwzq00s3ENuUk9Y81ybvzGOud2mllnYIGZbTjMdRt9bkJlCyoP6FHnchKw06MswWa3mXUD8J8W+MfDas7MLJLacnrJOfe6f1hz8y3Ouf3Ah8AEND8AY4GLzCyX2rcOzjKzF9HcAOCc2+k/LQDeoPYluyabm1ApqOVAspn1MbNWwBRgnseZgsU8YKr//FRgbp3xKWYWZWZ9gGRgmQf5As5qN5X+DmQ65x6t862wnxsAM+vk33LCzGKAc4ANaH5wzt3pnEtyzvWm9u/KB865H6K5wcxizSzuq/PAuUA6TTk3Xq8SOYrVJBdQuzprE3C313k8moN/APlAFbX/W7ke6AAsArL9pwl1rn+3f76ygPO9zh/AeTmV2pcS1gKr/V8XaG6+fq7DgC/985MO3OMf1/x8c57G8e9VfGE/N9Suml7j/8r46u9uU86NDrchIiJBKVRe4hMRkTCjghIRkaCkghIRkaCkghIRkaCkghIRkaCkghIRkaCkghIRkaD0/yDwdC5TKvacAAAAAElFTkSuQmCC\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "for runtime, clf in zip(runtimes, clfs):\n", + " ax.plot(ns, runtime, label=rename_clf(clf))\n", + "ax.legend()\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "name": "python380jvsc74a57bd039ca1c7a169e56d6a333ccd59f8c6786beb2b8f5c3cc68b80d4610822621472b", + "display_name": "Python 3.8.0 64-bit ('ProgLearn': conda)" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/docs/tutorials/test_feature_importance.ipynb b/docs/tutorials/test_feature_importance.ipynb new file mode 100644 index 0000000..db90ff5 --- /dev/null +++ b/docs/tutorials/test_feature_importance.ipynb @@ -0,0 +1,529 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Feature Importance\n", + "One of the benefits to decision trees is that their results are fairly interpretable in that they allow for estimation of the relative importance of each feature.\n", + "\n", + "There are many approaches that have been suggested for quantifying importances. The `scikit-learn` implementation of the random forest quantifies these importances using the Gini impurity. For `SPORF` and `MORF`, we use a projection forest specific metric to quantify feature importance by computing the normalized count of the number of times a feature $k$ was used in projections across the ensemble of decision trees.\n", + "\n", + "Specifically, consider our forest $\\mathcal{T}$ to be a collection of decision trees $\\{T_i\\}_{i=1}^n$, where each decision tree $T_i \\in \\mathcal{T}$ is composed of many nodes $j$. Let $\\mathcal{A}$ be the set of unique atoms across all nodes and all trees in our forest. For each feature $k$, its importance $\\pi_k$ is computed as the number of times an atom assigns it a nonzero weighting, followed by a normalization of $|\\mathcal{A}|$.\n", + "$$\n", + "\\pi_k = \\frac{1}{|\\mathcal{A}|} \\sum_{T_i \\in \\mathcal{T}} \\sum_{j \\in T_i} \\mathbb{I}(a_{jk} \\not= 0)\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.metrics import accuracy_score\n", + "from oblique_forests.tree.oblique_tree import ObliqueTreeClassifier as OTC\n", + "from oblique_forests.sporf import ObliqueForestClassifier as SPORF\n", + "from oblique_forests.morf import Conv2DObliqueForestClassifier as MORF\n", + "\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "\n", + "mpl.rcParams.update({\n", + " \"axes.titlesize\": \"xx-large\",\n", + " \"axes.spines.bottom\": False,\n", + " \"axes.spines.left\": False,\n", + " \"xtick.bottom\": False,\n", + " \"ytick.left\": False,\n", + " \"image.cmap\": \"inferno\",\n", + " \"image.aspect\": 1\n", + "})\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Digits Dataset\n", + "We visualize feature importances identified by `RF`, `SPORF`, and `MORF` on a subset of the MNIST dataset. We only consider threes and fives and use 100 8x8 images from each class." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(1797, 64) (1797,)\n" + ] + } + ], + "source": [ + "from sklearn.datasets import load_digits\n", + "\n", + "images, labels = load_digits(return_X_y=True)\n", + "print(images.shape, labels.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(200, 64) (200,)\n" + ] + } + ], + "source": [ + "# Get 100 samples of 3s and 5s\n", + "num = 100\n", + "threes = np.where(labels == 3)[0][:num]\n", + "fives = np.where(labels == 5)[0][:num]\n", + "idx = np.concatenate((threes, fives))\n", + "\n", + "# Subset train data\n", + "X = images[idx]\n", + "y = labels[idx]\n", + "\n", + "# Apply random shuffling\n", + "permuted_idx = np.random.permutation(len(idx))\n", + "X = X[permuted_idx]\n", + "y = y[permuted_idx]\n", + "\n", + "print(X.shape, y.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-04-21T03:14:56.804734\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA0cAAAEjCAYAAAD5QHrmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaUUlEQVR4nO3debRlVX0n8O+vqsACLGVGFDtgUAlOrW3UGAcUFNvYNgSNiIrY0WShcWjbObZiHJLYOMXZpQYVcUJRW40iEGI7omgbAQGlqQKDIhWgAJmL3X+c8+B6672qekXV2yX1+ax1V/H2Pff89r3vnc35nrPPudVaCwAAwJZuUe8OAAAAbA6EIwAAgAhHAAAASYQjAACAJMIRAABAEuEIAAAgiXAEAHCbU1XHVFWrqiUdai+vqmMWuu6mMn6OR0217V5Vx1fVysnnq2pZVX2wqn45th/TocvcCsJRZ1W137jxXF9VO/fuz+aqqp5VVSePg811VfVvVfXlqnpY777B7wrjzfqpqqPGz2m2hzGH7m7L23JV7Thug/ttgnXvObU931hV/15Vp1fVO6vqfvNY3VuTPGH89xlJPje2vyrJnyf5x7H9/RvzPbDpLfjRBNZweJJfJNktyVOTvLNvdzZbD0hyUZKvJVmZZJckT0/yjap6Umvtc2t7MZDEeDNfL0ly8VTbOT06AlNuy9vyjkleO/73qZuoxheSHJ+kktwxyb0zfI7Pq6o3tdZePbX8NklunGp7VJKTWmtvmqX9zNbaqzZ+t1kIwlFHVbVtkicleXuS+2cY7LoMcFW1bWvt6h6110dr7fnTbVX1riTnJXlpbjliA8zCeLNBvtBa+3nvTsCkzWlb/h12Rmvt2MmGqnppkk8m+euqOq+19o8zz7XWrp1lHbsmuXyO9os2Yl9TVVsnWd1aW70x18vsTKvr6+Aky5J8fHw8sKr2nXmyqt42nvK90/QLq+rw8ZTwYyfatquqN1XVeePUs4uq6j1VtcPUa0+tql9U1T2r6itVdUWSr4zP3aeqPlRVP6uqq6vq8qr6alX94WxvoKpeMs4tvnY8Lf24cZ7z8lmWfWRVnVhVq6rqmqr6blU9YcM+uqS19psk/55k+w1dB2xBjDcbMN5U1R2qavF8Xweb0Fq35Vncqao+M24Lq6rq41W16+QCVbVrVb2/qlaM2/PFVXVKVT16arl9quqzVXXpuF39sKqesa4O1y3TAPeb5bmbr08an//Z+NRrJ6a/HTOx/HqNPfPVWrsyw9mjy8faNVFz8pqio6qqZTjr9LSJPh4xtu+V5I8n2vebWM/BVfWtqrpqfJxcVQ+d+jyOGF/3+Kr626r6RZJrk9x1fH6Xqnp3VV1Yw7TK5eNyt5vlc/3mOM6eMo6xv6qqN0y+t4nlHzaO0TO/23Oq6m1TyyypqpdX1VnjOLyyqo6tqj02+IPfDAlHfR2e5AettXMynOK9cmyb8fEki5M8ZZbXHpZhusfJSTJuFCcneX6Sz4//HpfkWUlOnt5okmyb5KQMRzdekuRjY/uBSe6X5FNJXpjk6CT3TPIvVbXP5Aqq6jVJ/leS5RnO3pyY4ajLA6Y7W1WHjP3bJsnrkrw8SUvyxaqa7f3Nqqp2GAeGe1XVW5Psm2GqHbB2xpt5jjdJfphkVZJrxqC1Ri3oYF3b8rQvJVma4VqYY5McmuTEGs5GzPhMhmBwXJLnZtgWL83E9lVVeyf5TpL9k7wnySuSXJfko1X1ko3yzpKfZhgjkuSEDNfs3HzdzgaMPfPSWrsiw0yU30uyzxyLfW7sU5J8e6KP3x//XZnk3In2n459f9H42kuTvDLD1MHdk/xzVf3xLHXenOGzPjrDGHZVVe2U5LtJnpzkI0n+KsM+0EuTfHaWdew+Pv/jJC8e+/jXSf7b5ELjuHhqknsleXeSF2X4fA+aWKYy/J38TZL/k2HMfk+Sxyf59ti324bWmkeHR5I7J1md5EUTbR/JMId40UTbOUm+N/XaXZLckOTtE20vS3J9kgdMLfvEDDsFz55oO3Vse8Us/dpulradk/w6yfum2q5L8o0kiyfa9x/XvXyibdsMg8XxU+tdnGFDvSBJrefntnxcf0tyTYYBc5vev08Pj835YbyZ33iTYcfgvRl2OA/KsBOzKsnVSR7Y+/fpseU+1ndbHtuPGbePT021P39sP3L8+Y7jzy9dR+1PJ7lpchtIsnWS743/P955on15kmMmft5vrLHfLOudXnbvcdmjZll2vceeOd7DnuNyb1jLMv99XOaJE21r9GdsO3aO9/PNqbY9xn4fPdV++3FM+uZE2xHjun+cZOup5d+TIVzddar9BeNrDpjqR0ty0NSyP05y2sTPy5JcliHQbT+17OT/H/5sXN8Tppa5f4brseb8TH/XHs4c9fP0DH9kn5xo+3iSu2T4H/5k24Oq6vcn2p6S4Xqxj0+0HZrkB0kuqKqdZx4ZjvJcM7XOGe+dbmjDVLUkw7zmiSMBpyWZnOpyQIZB8d1tYg5sa+3kJGdOrfaAJDslOXaqbzsk+acMp4rvMUv/ZvO0DEebnzP2adsMR8SAuRlv5jHetNbe3lo7srX20dba51trr0vy0AwB6+i1vRY2sfXdlie9fernDyS5Ksl/GX++JsOO+35zHf2vYWrp45Oc0lr7wUx7a+36cf1Lkzx2ttduZBsy9szXleO/yzbCumb8aZKtkhw31e+lGc6E/VEN15JN+tD4+Sa5+czNUzKcNb9maj1fHxebfv+/bK19fqrt1CSTY/xjMlye8PettcsnF2yt3TTx46EZQvh3p2pfmOT/zVL7d5YbMvRzeIajmEuras+x7bwM6f2ZueUP/bgM00IOS/L6se2wJD9rrX1/Yn37ZJhCcskc9Xad+vnS1tqq6YWq6g5J3pDhlO30tQfnT/z3TJ9/ljWdm9+e6jJzavqEOfo207913gWqtfatib4ek+H08heTPHxdr4UtmPFmzf7N665zrbUzq+pLSQ6qqm1aa9fM5/Wwkazvtjzpt/7WW2vX1XCd3l7jz9dX1csyBP9fVdX3M0yDPa61dvb4sl2SbJdxitiUs8Z/99rA9zQf8x17NsRMKLpyrUvNz8y4dPpaltkpw9npGedPPb9Lhjv5PSWzT39O1nz/K2ZZ5rJxPTPuPv77r2vpWzK8hz0y92d/m7k2UzjqoKr+U4Z5ncmaf/xJcnBVLWutXdla+3lVnZZxZ2UcDP8oww7MpEUZjpy8Zo6yl039PNf/2D+RIf2/LbfMt78pw/zYySMNMxfztVnWMX2h38wZyiOTzHXnpzPmaJ9Ta+3GqvpMkr+rqru31mbbcYItmvFmVvMeb0YrxvXvkLnfE2wS89mWp9pn225+e4HW3lFVJ2SYnrZ/hqllr6yqv2gTd22bY11r2z7Xpw/z2ame79izIe4z/rsx9ylmxqWDM5y1m8106JgeY2bWcUKG6XWzmb5L3nzubreuv5NFGc4Q/eUcz99mxkThqI9nZpg/f1iGHYFJu2f4o39Shi8QS4ZT5u+oqvsn+c9j23FTr/t5hrmiJ21op6pq+wynzV/XWjtq6rnXTy0+MzDfI8mPpp67+9TPMzsol96a/s1hZkrdrbpLDdyGGW82nr0zzK2/dCOvF9bHfLflGftkCBRJbr6pwZ4ZLqq/WWvtgiTvSvKuGu789p0kbxzXd0mS32S4CdK0mbMiy9fS95nQMn03y6Vj33+rK2tZz60ee9amqu6YIcCsSHL2Ohafj5lx6aLW2mkbuI5LklyRZOlGfv8zIfC+GaYszuXnSR6R5NTW2vR3Pt2muOZogVXVVhnmbZ7cWvvcOJ998vHeDBfnTd555lMZ0v9hGe4m8/3W2rlTq/5Ekj+oqsNmqbm4qnacbp/F6gyD0m/9XdRwG8oHTy17UoY5ys+ridvcVtX+ueXI1oyvZRgYX1VV28zSv7WeBh9vHblG/6vq9hmuQbo6a153AFs84838x5txmdnGm4fklmsuZvvOE9hkNnBbnvGiqZ//IsONAL40rnvb6W2ltXZZhrCzw/jz6gzX7D168q6NY79emCG0nbiWt7A8w4GFR0+1vyBrnjmaObMy20HPjTH2zKqqlmU4ELR9kr9p490GNpLjM7z/o6pqjRMT6zMujb+DTyc5sKoeNss6lo7vYb6+nuH25S8fw+HkOifPzH8iw5TDl81Su8brj24TnDlaeI/PMG/0C2tZ5osZdgJ+r7W2orV2cVWdlOFU5rKsOdAlyVvGdR9bw3d5zBwl2jvJIUleneHONXNqrV1ZVScnedk4UJ6b4fTyszKEj2UTy66sqr/LcGr75Kr6bIajP0cm+cnUsldW1bMz7HSdWVUfy3AB3+5JHpLhqNPkFJppt09yYVUdP/bj0gxHvZ6ZYf7rX01e2A3czHgz//EmSc4fpxidlWGq332T/HmGnbYXr+O1sCnMe1ueaL9nVf3vDOFm39yy3XxofP4eGW4nfXyGv/mrMpwhODDJhyfW86oMNzw5uYYvYb8kQ2B7SIY73a2cq2OttSuq6uNJnjvucJ+RYcruwzPcXXJy2Yur6oIkh1bVuRm+z/D81tr3shHGntG9q+rpGaYELssw9jw5w7U4b2ytfXhtL56v1tryqvofSd6R5PSq+nSGu3LukeFOfjcledR6rOoVGX43p4zXXf8owwyae479PyTDDRfm07crq+q5GW7z/uOqmrn74Z4Zfr8z4+VxGe7e+cYavpvplAxT6fYa249LctR8am+2et8ub0t7ZLgP/U1Jdl/LMgdkOKL66om2Z4xtNya50xyvW5rh/vVnZPjCsMsz3LLxzUn+w8Rypyb5xRzr2DXDd5D8OsMp9G9nmH98TCZulzsuWxmOIKwY6/0wyeMyHCH56SzrflCGubIrMxxluiDDYH7oOj6zrZO8dVz/ZRluK3zx+NrH9P6denhsrg/jzfzHm/G1H8iw83h5hjNWF2aYWnS33r9Tjy3zsSHbcm65lfceGb6fZlWGaVmfSLLbxOt2SvIP47Z8RYZw9JMM3zc0fSvpP8jwXT2Xjdvhj5IcPktflmfi9txj2/YZdsBXZbjZwRcyfJ/QbMs+IsMUr2vH93DMxHPrNfbM8RntmVu+DqRlOIN92TievDPJf5zjdbfqVt4Tzx2Y4Qzb5WPfz89w58EDJ5Y5IlO35Z7lc3xzhulw141j3GkZvnJgx3X1I0OAabO0PyrDWaSZry04O8lbppZZlOR5GW4scfX4ezwrw3TMfXtvJxvrUeObhY2mqn6c5OLW2kLc1hPYghlvANiYXHPEBptjPv8BGaagnLLwPQJuq4w3ACwEZ47YYFV1aIb591/MMC3m3hku9Px1kvu2qS8TA9hQxhsAFoIbMnBrnJnhgsznJtk5wxza45O8yo4KsJEZbwDY5Jw5AgAAiGuOAAAAkqxjWl1VdT2tVNmqZ/nss2yN638X1O3utrRr/bTpL+BeeJef2/fM5oprV3Wt39L9S6hXttZ2WahivcecRdV3m7vfbqu71q9lte6FNqXF098F2cE1fb/f9d8uWrDNbVYX3/DrrvVbawv2R9h7vFnzu08X1rJFG/x9qRvFb9oVXevf1K7vWj9J9tqm7+/gmhv7niP51Q2XdK2fOfZx1nHNUd8Nd6sl6/zC4E3quAdPf/H6wrrbp/btWj83XLXuZTaxz+/fd2fx2ef8U9f6N9zYe+BYvWLdy2xMfcec7W63Z9f633hm3zC+9NF9D0it3n6nrvWTZPGZZ3Wt/4pXP6Vr/bdc9N6O1Rf6YFDf8WbRomXrXmgTetDSQ7rW/+7qvjeZvPr6C7vWT5I37f2ErvXPuLzv3+AbL3xf1/rJjbPu45hWBwAAEOEIAAAgiXAEAACQRDgCAABIIhwBAAAkEY4AAACSCEcAAABJhCMAAIAkwhEAAEAS4QgAACCJcAQAAJBEOAIAAEgiHAEAACQRjgAAAJIIRwAAAEmEIwAAgCTCEQAAQBLhCAAAIIlwBAAAkEQ4AgAASCIcAQAAJBGOAAAAkghHAAAASYQjAACAJMIRAABAEuEIAAAgiXAEAACQRDgCAABIIhwBAAAkEY4AAACSJEt6d2BtHrP0oK71933dv3Stv+rlF3Stv8PBV3StnySHPufSrvVf8+pHdq2/4qqTutbf0hyy3WO71t/mbx/UtX775F92rf/rD+/StX6S/HzFn3Stf/pl13Wtz8LZbus7da2/VfU9Pv7IJQd0rX/D4ta1fpJ8/oK+u+En3/i9rvWT/r+D2ThzBAAAEOEIAAAgiXAEAACQRDgCAABIIhwBAAAkEY4AAACSCEcAAABJhCMAAIAkwhEAAEAS4QgAACCJcAQAAJBEOAIAAEgiHAEAACQRjgAAAJIIRwAAAEmEIwAAgCTCEQAAQBLhCAAAIIlwBAAAkEQ4AgAASCIcAQAAJBGOAAAAkghHAAAASYQjAACAJMIRAABAEuEIAAAgiXAEAACQRDgCAABIIhwBAAAkEY4AAACSCEcAAABJkiW9O7A2y2+6pGv9Fx786K71z7vqhq71v/yUr3atnyTnnfjgrvUvvv7srvVZWL+/bHXX+r+56pyu9T/4Pw/vWv9jv7yya/0kOXf1d7rWv/q6FV3rJ61z/S3HQds+tmv9D3/w2K71Dzr8qV3rf/noD3atnyTnf6XvPs4Xv9Z3P3tz5cwRAABAhCMAAIAkwhEAAEAS4QgAACCJcAQAAJBEOAIAAEgiHAEAACQRjgAAAJIIRwAAAEmEIwAAgCTCEQAAQBLhCAAAIIlwBAAAkEQ4AgAASCIcAQAAJBGOAAAAkghHAAAASYQjAACAJMIRAABAEuEIAAAgiXAEAACQRDgCAABIIhwBAAAkEY4AAACSCEcAAABJhCMAAIAkwhEAAEAS4QgAACCJcAQAAJBEOAIAAEgiHAEAACRJlvTuwNr89NoTu9a/uO7Rtf439tuma/1F+7++a/0k+efnfaNr/Wuv/1bX+iysX12zuGv9RZee07X+C171+a71D/jUw7vWT5Kjf/AnXesfd8OxXeuvvmlV1/pbkrts27rWr4P+oWv93Zae1rX+l97/5K71k+SVP7uma/3rbzyza/3NlTNHAAAAEY4AAACSCEcAAABJhCMAAIAkwhEAAEAS4QgAACCJcAQAAJBEOAIAAEgiHAEAACQRjgAAAJIIRwAAAEmEIwAAgCTCEQAAQBLhCAAAIIlwBAAAkEQ4AgAASCIcAQAAJBGOAAAAkghHAAAASYQjAACAJMIRAABAEuEIAAAgiXAEAACQRDgCAABIIhwBAAAkEY4AAACSCEcAAABJhCMAAIAkwhEAAEAS4QgAACCJcAQAAJAkqdba3E/Wkrmf3AJsveROXeu/do//2rX+Sz9yStf6SVLXXt21/n3/9EFd6//0Nyd0rZ+sPr219sCFqtZ7zNlp2/v3LJ/37L1v1/p32/GSrvVvvGlx1/pJct6lO3et/6Lzz+paf+VvTu9YfXVaa7VQ1XqPN3fe7uE9y+fD++7atf77zu5bf987Ltif2pz2XtZ3H+fIn5/Ytf51N1zUtf5c+zjOHAEAAEQ4AgAASCIcAQAAJBGOAAAAkghHAAAASYQjAACAJMIRAABAEuEIAAAgiXAEAACQRDgCAABIIhwBAAAkEY4AAACSCEcAAABJhCMAAIAkwhEAAEAS4QgAACCJcAQAAJBEOAIAAEgiHAEAACQRjgAAAJIIRwAAAEmEIwAAgCTCEQAAQBLhCAAAIIlwBAAAkEQ4AgAASCIcAQAAJBGOAAAAkghHAAAASYQjAACAJMIRAABAEuEIAAAgSVKttbmfrCVzP7kAlizeoWf57Lj07l3rr7z6zK71n7r9EV3rJ8n7v7+ia/3jn7hT1/pHnPWRrvWT1ae31h64UNV6jzlVW/csnx23uVfX+ods99Cu9Z9zr591rZ8kK69a1rX+kWdf1bX+8qu+1rH66rTWaqGqbenjTWvXd63/4t2f27X+Dre7qWv9JLnHHa7sWv9fL7tD1/pvvPC9XevPtY/jzBEAAECEIwAAgCTCEQAAQBLhCAAAIIlwBAAAkEQ4AgAASCIcAQAAJBGOAAAAkghHAAAASYQjAACAJMIRAABAEuEIAAAgiXAEAACQRDgCAABIIhwBAAAkEY4AAACSCEcAAABJhCMAAIAkwhEAAEAS4QgAACCJcAQAAJBEOAIAAEgiHAEAACQRjgAAAJIIRwAAAEmEIwAAgCTCEQAAQBLhCAAAIIlwBAAAkEQ4AgAASCIcAQAAJEmW9O7A2myz1c5d65/wgL71r73+MV3r33W307vWT5K25C5d619+/dZd6yfVuf6WZZut79y1/tv3uk/X+g+952ld6+/xsP/btX6SfPQDT+9a/9Ibz+han4XT2g1d699h6T271v/aqou71r/bol261k+Sx+15Sdf6Z69a1rX+5rqP48wRAABAhCMAAIAkwhEAAEAS4QgAACCJcAQAAJBEOAIAAEgiHAEAACQRjgAAAJIIRwAAAEmEIwAAgCTCEQAAQBLhCAAAIIlwBAAAkEQ4AgAASCIcAQAAJBGOAAAAkghHAAAASYQjAACAJMIRAABAEuEIAAAgiXAEAACQRDgCAABIIhwBAAAkEY4AAACSCEcAAABJhCMAAIAkwhEAAEAS4QgAACCJcAQAAJBEOAIAAEgiHAEAACRJlvTuwNpcc8OlXevfebeLu9bf9dinda1/46qfdq2fJBcccXbX+n//y191re/4xcJaXFt1rf9nz/pk1/qrX3B01/ornrS4a/0kOfrCvv/fufLa87rWZ+EsWbx91/p/uOgRXet/Z/XXu9Y/bNcDu9ZPkpVXLeta/30rf9K1/ua6j7N59goAAGCBCUcAAAARjgAAAJIIRwAAAEmEIwAAgCTCEQAAQBLhCAAAIIlwBAAAkEQ4AgAASCIcAQAAJBGOAAAAkghHAAAASYQjAACAJMIRAABAEuEIAAAgiXAEAACQRDgCAABIIhwBAAAkEY4AAACSCEcAAABJhCMAAIAkwhEAAEAS4QgAACCJcAQAAJBEOAIAAEgiHAEAACQRjgAAAJIIRwAAAEmEIwAAgCTCEQAAQBLhCAAAIElSrbXefQA2U1X11dba43r3A7jtM94AC2muMUc4AgAAiGl1AAAASYQjAACAJMIRAABAEuEIAAAgiXAEAACQJPn/C1rwJ177YtkAAAAASUVORK5CYII=\n" + }, + "metadata": {} + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", + "avg_3 = images[threes].mean(axis=0).reshape(8, 8)\n", + "avg_5 = images[fives].mean(axis=0).reshape(8, 8)\n", + "diff = np.abs(avg_3 - avg_5)\n", + "\n", + "axs[0].imshow(avg_3)\n", + "axs[0].set_title(\"Average 3\")\n", + "\n", + "axs[1].imshow(avg_5)\n", + "axs[1].set_title(\"Average 5\")\n", + "\n", + "axs[2].imshow(diff)\n", + "axs[2].set_title(\"Absolute Difference\")\n", + "for ax in axs:\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "clfs = [\n", + " RandomForestClassifier(random_state=0),\n", + " SPORF(random_state=0),\n", + " MORF(random_state=0, image_height=8, image_width=8)\n", + "]\n", + "for clf in clfs:\n", + " clf.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def rename_clf(clf):\n", + " if isinstance(clf, RandomForestClassifier):\n", + " return \"RF\"\n", + " elif isinstance(clf, SPORF):\n", + " return \"SPORF\"\n", + " elif isinstance(clf, MORF):\n", + " return \"MORF\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "output_type": "error", + "ename": "AttributeError", + "evalue": "'Conv2DObliqueTreeClassifier' object has no attribute 'tree_'", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0max\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclfs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mimportances\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_importances_\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0msns\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheatmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimportances\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcmap\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'inferno'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msquare\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0max\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m ax.tick_params(\n", + "\u001b[0;32m/opt/anaconda3/envs/ProgLearn/lib/python3.8/site-packages/oblique_forests/morf.py\u001b[0m in \u001b[0;36mfeature_importances_\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;31m# 1. Find all unique atoms in the forest\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;31m# 2. Compute number of times each atom appears across all trees\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 210\u001b[0;31m forest_projections = [\n\u001b[0m\u001b[1;32m 211\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_vec\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtree\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators_\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/anaconda3/envs/ProgLearn/lib/python3.8/site-packages/oblique_forests/morf.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_vec\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtree\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimators_\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 213\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mtree\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnode_count\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 214\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mnode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtree\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnodes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_vec\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'Conv2DObliqueTreeClassifier' object has no attribute 'tree_'" + ] + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(19, 4))\n", + "\n", + "for clf, ax in zip(clfs, axs):\n", + " importances = clf.feature_importances_\n", + " sns.heatmap(importances.reshape(8, 8), cmap='inferno', square=True, ax=ax)\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + " ax.set_title(f\"{rename_clf(clf)} Importance\")\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MNIST Dataset\n", + "We visualize feature importances identified by `RF`, `SPORF`, and `MORF` on a subset of the MNIST dataset. We only consider threes and fives and use 100 28x28 images from each class." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(60000, 28, 28)\n(60000,)\n(10000, 28, 28)\n(10000,)\n" + ] + } + ], + "source": [ + "# from sklearn.datasets import fetch_openml\n", + "from keras.datasets import mnist\n", + "\n", + "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n", + "print(X_train.shape, y_train.shape, X_test.shape, y_test.shape, sep='\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(200, 28, 28)\n(200,)\n(1010, 28, 28)\n(1010,)\n" + ] + } + ], + "source": [ + "# Get 100 samples of 3s and 5s\n", + "num = 100\n", + "threes = np.where(y_train == 3)[0][:num]\n", + "fives = np.where(y_train == 5)[0][:num]\n", + "train_idx = np.concatenate((threes, fives))\n", + "\n", + "# Subset train data\n", + "Xtrain = X_train[train_idx]\n", + "ytrain = y_train[train_idx]\n", + "\n", + "# Apply random shuffling\n", + "permuted_idx = np.random.permutation(len(train_idx))\n", + "Xtrain = Xtrain[permuted_idx]\n", + "ytrain = ytrain[permuted_idx]\n", + "\n", + "# Subset test data\n", + "test_idx = np.where(y_test == 3)[0]\n", + "Xtest = X_test[test_idx]\n", + "ytest = y_test[test_idx]\n", + "\n", + "print(Xtrain.shape, ytrain.shape, Xtest.shape, ytest.shape, sep='\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-04-21T03:26:15.950727\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA0cAAAEjCAYAAAD5QHrmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAwoklEQVR4nO3deZhcV33m8fdXvXdrae2LtdmSF1nGYGyMMYQYA3FiEh4YSCAkEEgymYGEhMmwJCSTQBIymQwBMhDIHpNglmCzhSwQTAwx4A2DV9mytVmyrKUldav3ruXMH7caF+WW3iPRbrWk7+d5+rF1++1bp6ruPXV/dW/VL1JKAgAAAIAzXelkDwAAAAAAZgOKIwAAAAAQxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAACcdiLiuohIEdF6Em57R0RcN9O3+1SpP47valq2IiJuiIi+xt9HxNyI+OuIeLy+/LqTMGT8ACiOTrKIuKq+80xExOKTPZ7ZKiLeEBE31Seb8Yh4LCL+OSKed7LHBpwqmG/yRMS76o/TVD/MOTjpTud9OSIW1vfBq56Cda9r2p8rEXEwIr4dER+MiKcfx+reJ+nH6/99raTP1Je/U9IvSPq7+vK/mM77gKfejL+bgCd5naTdkpZJ+mlJHzy5w5m1nilpj6QvSeqTtETSz0r6ekS8MqX0mWP9MQBJzDfH662S9jUte+hkDARocjrvywsl/W79/29+im7j85JukBSS5ku6SMXj+MsR8Ycppd9uyndJqjQte4Gkr6SU/nCK5fenlN45/cPGTKA4OokiolvSKyV9QNIlKia7kzLBRUR3SmnkZNx2jpTSm5uXRcSHJG2V9DY98Y4NgCkw35yQz6eUHjnZgwAazaZ9+RR2X0rpY40LIuJtkj4p6bciYmtK6e8mf5dSGptiHUsl9R9l+Z5pHKsiol1SNaVUnc71YmpcVndyvVzSXEnX138ui4gLJ38ZEe+vn/Jd3vyHEfG6+inhH2lY1hMRfxgRW+uXnu2JiA9HxIKmv705InZHxPkR8S8RcUTSv9R/97SI+JuIeDgiRiKiPyL+LSKeNdUdiIi31q8tHquflv7R+nXOO6bI/nBEfDkiBiJiNCJujYgfP7GHTkopDUs6KKn3RNcBnEGYb05gvomIeRHRcrx/BzyFjrkvT2F5RHy6vi8MRMT1EbG0MRARSyPiLyJiZ31/3hcRX42Iq5tyF0TEjRFxqL5f3RURr3UDjicuA7xqit997/NJ9d8/XP/V7zZc/nZdQz5r7jleKaVBFWeP+uu3HQ232fiZondFRFJx1ulnGsb4+vrysyU9t2H5VQ3reXlEfCMihuo/N0XElU2Px+vrf3dtRPzviNgtaUzS6vrvl0TEn0XEriguq9xRz3VM8bjeUp9nv1qfY/dGxB803reG/PPqc/Tkc/tQRLy/KdMaEe+IiAfq83BfRHwsIlad8AM/C1EcnVyvk3RnSukhFad4B+vLJl0vqUXSq6b429eouNzjJkmq7xQ3SXqzpM/V//txSW+QdFPzTiOpW9JXVLy78VZJ/1Bffo2kp0v6lKRfk/ReSedL+lpEXNC4goj4HUn/V9IOFWdvvqziXZdnNg82Il5RH1+XpHdLeoekJOkLETHV/ZtSRCyoTwybIuJ9ki5UcakdgGNjvjnO+UbSXZIGJI3WC60n3RZwErh9udkXJXWq+CzMxyS9WtKXozgbMenTKgqDj0t6k4p98ZAa9q+I2CDpW5JeKOnDkn5D0rikv4+It07LPZM2q5gjJOmzKj6z873P7ZzA3HNcUkpHVFyJslbSBUeJfaY+Jkn6ZsMY76j/t0/Sloblm+tjf0v9bw9J+k0Vlw6ukPQfEfHcKW7nj1U81u9VMYcNRcQiSbdK+klJH5X0KyqOgd4m6cYp1rGi/vu7Jf16fYy/JennG0P1efFmSZsk/Zmkt6h4fF/WkAkV28nvSfpPFXP2hyVdK+mb9bGdHlJK/JyEH0krJVUlvaVh2UdVXENcalj2kKTbmv52iaSypA80LHu7pAlJz2zKvlTFQcEvNiy7ub7sN6YYV88UyxZL2i/pz5uWjUv6uqSWhuUvrK97R8OybhWTxQ1N621RsaM+KikyH7cd9fUnSaMqJsyuk/188sPPbP5hvjm++UbFgcFHVBxwvkzFQcyApBFJl53s55OfM/cnd1+uL7+uvn98qmn5m+vL31j/9/z6v99mbvsfJdUa9wFJ7ZJuq78eL25YvkPSdQ3/vqp+G1dNsd7m7IZ69l1TZLPnnqPch3X13B8cI/M/6pmXNix70njqyz52lPtzS9OyVfVxv7dp+Zz6nHRLw7LX19d9t6T2pvyHVRRXq5uW/2r9b17UNI4k6WVN2bsl3d7w77mSDqso6Hqbso2vDz9VX9+PN2UuUfF5rKM+pqfaD2eOTp6fVbGRfbJh2fWSzlLxgt+47PKIWN+w7FUqPi92fcOyV0u6U9KjEbF48kfFuzyjTeuc9JHmBam4VE1ScV1zwzsBt0tqvNTlRSomxT9LDdfAppRuknR/02pfJGmRpI81jW2BpH9Vcar4vCnGN5WfUfFu83+tj6lbxTtiAI6O+eY45puU0gdSSm9MKf19SulzKaV3S7pSRYH13mP9LfAUy92XG32g6d9/KWlI0k/U/z2q4sD9qqO9+x/FpaXXSvpqSunOyeUppYn6+jsl/chUfzvNTmTuOV6D9f/OnYZ1Tfovktokfbxp3J0qzoQ9J4rPkjX6m/rjK+l7Z25epeKs+WjTev69Hmu+/4+nlD7XtOxmSY1z/ItVfDzh/6SU+huDKaVawz9fraIIv7XptndJ2jbFbZ+y+EKGk+d1Kt7F7IyIdfVlW1VU7z+nJzb0j6u4LOQ1kn6/vuw1kh5OKd3RsL4LVFxCcuAot7e06d+HUkoDzaGImCfpD1Scsm3+7MH2hv+fHPPDerIt+v5LXSZPTX/2KGObHJ/9FqiU0jcaxnqditPLX5D0Q+5vgTMY882Tx3dc3zqXUro/Ir4o6WUR0ZVSGj2evwemSe6+3Oj7tvWU0ngUn9M7u/7viYh4u4rCf29E3KHiMtiPp5QerP/ZEkk9ql8i1uSB+n/PPsH7dDyOd+45EZNF0eAxU8dncl769jEyi1ScnZ60ven3S1R8k9+rNPXlz9KT7//OKTKH6+uZdG79v/ccY2xScR9W6eiP/Wnz2UyKo5MgIi5VcV2n9OSNX5JeHhFzU0qDKaVHIuJ21Q9W6pPhc1QcwDQqqXjn5HeOcrOHm/59tBf2T6io/t+vJ663r6m4PrbxnYbJD/OlKdbR/EG/yTOUb5R0tG9+uu8oy48qpVSJiE9L+qOIODelNNWBE3BGY76Z0nHPN3U76+tfoKPfJ+ApcTz7ctPyqfab7w+k9KcR8VkVl6e9UMWlZb8ZEb+UGr617SjrOtb+mTOG4zmoPt6550Q8rf7f6TymmJyXXq7irN1UmouO5jlmch2fVXF53VSavyXveL7dzm0nJRVniP7bUX5/2syJFEcnx8+puH7+NSoOBBqtULHRv1JFAzGpOGX+pxFxiaQfqy/7eNPfPaLiWtGvnOigIqJXxWnzd6eU3tX0u99vik9OzOdJ+k7T785t+vfkAcqhH2R8RzF5Sd0P9C01wGmM+Wb6bFBxbf2haV4vkON49+VJF6goKCR970sN1qn4UP33pJQelfQhSR+K4pvfviXpPfX1HZA0rOJLkJpNnhXZcYyxTxYtzd9m2Vkf+/cN5Rjr+YHnnmOJiPkqCpidkh408eMxOS/tSSndfoLrOCDpiKTOab7/k0XgxSouWTyaRyQ9X9LNKaXmnk+nFT5zNMMiok3FdZs3pZQ+U7+evfHnIyo+nNf4zTOfUlH9v0bFt8nckVLa0rTqT0jaGBGvmeI2WyJiYfPyKVRVTErft11E8TWUz27KfkXFNcq/HA1fcxsRL9QT72xN+pKKifGdEdE1xfiOeRq8/tWRTxp/RMxR8RmkET35cwfAGY/55vjnm3pmqvnmCj3xmYupep4AT5kT3JcnvaXp37+k4osAvlhfd3fzvpJSOqyi2FlQ/3dVxWf2rm781sb6uH5NRdH25WPchR0q3li4umn5r+rJZ44mz6xM9abndMw9U4qIuSreCOqV9Hup/m0D0+QGFff/XRHxpBMTOfNS/Tn4R0nXRMTzplhHZ/0+HK9/V/H15e+oF4eN62w8M/8JFZccvn2K2476549OC5w5mnnXqrhu9PPHyHxBxUHA2pTSzpTSvoj4iopTmXP15IlOkv6kvu6PRdHLY/Jdog2SXiHpt1V8c81RpZQGI+ImSW+vT5RbVJxefoOK4mNuQ7YvIv5IxantmyLiRhXv/rxR0r1N2cGI+EUVB133R8Q/qPgA3wpJV6h416nxEppmcyTtiogb6uM4pOJdr59Tcf3rrzR+sBvA9zDfHP98I0nb65cYPaDiUr+LJf2CioO2Xzd/CzwVjntfblh+fkT8k4ri5kI9sd/8Tf3356n4OukbVGzzQyrOEFwj6W8b1vNOFV94clMUTdgPqCjYrlDxTXd9RxtYSulIRFwv6U31A+77VFyy+0Mqvl2yMbsvIh6V9OqI2KKin+H2lNJtmoa5p+6iiPhZFZcEzlUx9/ykis/ivCel9LfH+uPjlVLaERH/U9KfSvp2RPyjim/lXKXim/xqkl6QsarfUPHcfLX+uevvqLiC5vz6+F+h4gsXjmdsgxHxJhVf8353REx+++E6Fc/v5Hz5cRXf3vmeKHozfVXFpXRn15d/XNK7jue2Z62T/XV5Z9qPiu+hr0lacYzMi1S8o/rbDcteW19WkbT8KH/XqeL76+9T0TCsX8VXNv6xpDUNuZsl7T7KOpaq6EGyX8Up9G+quP74OjV8XW49GyreQdhZv727JP2oindINk+x7stVXCvbp+JdpkdVTOavNo9Zu6T31dd/WMXXCu+r/+2LT/Zzyg8/s/WH+eb455v63/6lioPHfhVnrHapuLTonJP9nPJzZv6cyL6sJ77Ke5WK/jQDKi7L+oSkZQ1/t0jS/6vvy0dUFEf3qug31PxV0htV9Oo5XN8PvyPpdVOMZYcavp67vqxXxQH4gIovO/i8in5CU2Wfr+ISr7H6fbiu4XdZc89RHqN1eqIdSFJxBvtwfT75oKRnHOXvfqCv8m743TUqzrD118e+XcU3D17TkHm9mr6We4rH8Y9VXA43Xp/jblfRcmChG4eKAiZNsfwFKs4iTbYteFDSnzRlSpJ+WcUXS4zUn8cHVFyOeeHJ3k+m6yfqdxaYNhFxt6R9KaWZ+FpPAGcw5hsAwHTiM0c4YUe5nv9FKi5B+erMjwjA6Yr5BgAwEzhzhBMWEa9Wcf39F1RcFnORig967pd0cWpqJgYAJ4r5BgAwE/hCBvwg7lfxgcw3SVqs4hraGyS9kwMVANOM+QYA8JTjzBEAAAAAiM8cAQAAAIAkc1ldRHBaCTiz9aWUlszUjTHnAGe2lFL41PRgvgHOeFMe45jPHDU3LQZwZqnu9JnpxJwDnLmqM3x7zDfAmW3qYxwuqwMAAAAAURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSp9WQP4NQXmTlfh0b2un5wSSkjVZ3GW5y5+za9Nf90PgbA6aolI1PLyEzfvpszn+bNgzly7ltxi05kvCzP7Lin67ZwasrZt6WItqd4HE9IaTwjlTOX5O23kfEYzOwxVd7xVM64Z9LMH3eeOM4cAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkEQTWMM30MptfFaKDpvp7TzHZubHcptZVltqM3MzxrOwPW/zWNnlM+0ZvciqGf3BBsu++dmjwxW/IkmPaI/NbB+/1WbKlQNZtwdMj7zGfi2lOTbT1jrXZiLjPbTu1kVZY3IqWc0dpZ6Sv71uzbOZWkYTyNaMl8nu5B9rSVqQemxmSP4x2NvymM0cmNhiM+Pl/TaT0oTNYDbyr5XT2bg1pbLNzOvcYDNLW3xmQW2hzaxs9fvkhrl5zVRXdvljinLNryun5exjI36++dehbRlrkvZWNtvMZaUX2kxZ/v5v1rdt5sj4Tpup1QZtZiZw5ggAAAAARHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAICkU7IJbE7TLl/ztbb4BoFz2s/ymZbFGeORNtbOs5lN831Dtot7h23mwiV7bWb9+Y/YTPeyQzYjSbWy34yqY+02M7jPN3YcPOKbVt6/e43NSNI9h1fZzNf6Xmoz30yfs5lqdcBmUkajNZy6Ivw+EPJzQGuL3wckaU77MpuZV/JNpddU/X6yqsM3lT6r20Z03txRH5K0omfIZrpafVPKCN95enC802b2jvqMJFVr/vYOT/gHatfIRpvZVvPNNDe3+9eB3UNftxmJ+Wv6ZDSfzzgOasmYJxZ1nm8zXTHfZiRpk862mcszekWfO9fv22cv7POZ9XfYzPBA3lx63/ZzbOaxYd909r4BP0+MV/145qVeH5K0Y8I3qJ/b419zrlrqt8lHR37YZu4c8M/trWOfshlJSmksK3eiOHMEAAAAAKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABA0qxrApvR/Cx8w6r21oU2s7TjAptZmdH8cGNHXhOxixdM2MyVK7f627vsHpvp3rTPZtK61TZTXXiuzUiSOntspL3vMZvpefA+m1ly0DfSXLLVN8GVpHUP+8ZuNa23mYMHr7KZB8e+YjPVmm8Ui1NXKbpsprXF70s5c5cknV31TRlXtPjmrRsW+Mal63p889aLlvr9cuMz/fwmSZ0rD/pQ1b/3Vz7iH+/auH/NGevPex3YvcPPu9v6lvoxyTcxH634cR8aXWkzj5X8ditJqTaYlTuT5TSCbin5bXJ+xzqbWSnfBHhd+K6sZ3X74zJJWj/HN12+et02mzln0xab6VjkXytH9/r79vCuvIbxB8b8PtBf9ofTw/4hUt+E7wI7HnkNUJd0P81m7taDNrN2eJPNrOjy4z573G/bd5bz5tJyhSawAAAAAPCUozgCAAAAAFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEmzrAlsa0uvzbRlNEnsaVtiM8uqvvndoozmd4s6fINESVrdPWwzSxb7xoblYT+m8q45NhN7/G2V5vrGrZJUqmXFrFrJ1+q1Md9Er7XTN9yVpGXL9tvMWXt808a1sdhmHsnYbmkCeyoLnwi/ffe2+6aEG6q+ebEkrej0+8qqbr+es7r8/vS0pY/bTE4D6/aM5o65WhaP20x11DfBHdzlm7Ie6vMNJyVpIqMxa47uFj/pzm/3zTs7Rv14ejvztrfDow/ZTEpPbePG2a673b+ezGnNOH5Ja23mnJaFNjNR89vR4fG8Y5zuXt8ENCU/T97/Hd+4dMtB/xhdv9Nv2+PJj1mShsI3ud4fW21mU9poM2u6/bhXl/zxqyT1T6ywmU/2f8RmHqqdazMbMnq39rT6OWltx+V+RZK2126xmR/kmIozRwAAAAAgiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACTNsiawOVpLvmlfjr6WAzbTUfXNuPaOZnRRlPTgkXk207Jtvc2071xnMxNV32hruOwbRLZEXnfXtQsO2cz6DdtspjujAWR13D8nY0d8w1VJGhr0zXIHK/6xHKpVbKYUp9yuhuMyPe8ztUWnzXSV/DYpSe0l33CxM6OZ6IbewzZz1krfBLbUUbaZ6rC//5I0vNc3Xt6xdZ3N7D3SazPVjMaV4SOSpJJ8Q82c15SBst8GxjL6W7bLr2dZZDaB1eas3JmslNEIujP861JKfr/dUe23mbJ8g+eBqm8YL0l9u32j0DsOnmczNw5/2WZWtfjjl2t7fQPUavLrkaS/67/VZl7V8wKbecmqvowx+cnksWG/jUjSdw/74+W2Vj+X1pKftw5mHJt1Zbx0bQzfKFmStinv+PREceYIAAAAAERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJAktZ7sATSq1cZsplwbtZnIqPkOh19PqeTXMzHmuwtL0kRtrs3sGFpkMzldz4cqvnNwe8l3YT6rO6/t+7yOcZtp6/CduEvtPjM+0GMzlfG8rte7+5bazPYh3/X5cBz2Y8rYbnHqCuXtK349fs7pasloMy6pN2M3WN7p97n1q3bZTFuXnwNU84/RwO6Vfj2SvnXfRTazO7OLvFPOGPeRct5zMljxz+/+Md+Nvm+8YjMDadhm+ksDNrOn+oDNSFJK5azcmWys0m8zwy3zbKZa8s//ULXPZyYet5nutrxjnFuqW2xmXTzLZv5+46U2c9Ga7TYzd+Fmm/nirVfYjCQtOrLWZt70DL+fLFnqn5MD+/3jvXahX48kTdTOsZllY5ts5taJL9tM6nuxzXSXfMnxXeXNN7XaSFbuRHHmCAAAAABEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgKRZ1gQ2x0TFN62bSL4pZ6nkOySOxAGbOdDaazOStLOywGZ6K8tspiLf/G1O8k3kntnuG409Y8GQzUjSFZfdZTPzrnjUZmp9vlYf2b/QZrbvWm0zkvTVPctt5u5B/xjskm82V6v5ZpuYraanwWst+X13pObnrtGU0Qla0oJ2P+6Ll+61mWUXbLOZjrX9fkAdfjy77zvXr0fS/tFum9kx7Of4nIarhyb8491XG7QZSepr8a8pQ8k3eByvHbGZiapvAlvNaE5drfn1FPxjeaZLyTdoHxj3r5UHq3nbm9PV7l8Dn9dyVda6fnilf/5fct6DNrPh3b7Ba3nFBTbT9skv2Mz7P+UboErSP1zq567z3uebIKeMx3vp3d+0mcNfzmuWrZ2+CeyKqj9e2lu922a+Vr7eZjoyGgpPlP1rYCHvdfBEceYIAAAAAERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACApFnWBLaWRnwoq8+cD9UyGuTlNH+sVH1TWkmqtI3bTFe7b956sTbYzHOX+vt/zXrfjG3TT3zNZiSpduXTfGZsvc3EwFabObxvic3csTevQdodh3zTti1xj82MVfptppb884/ZKuM9pPCZUvjptiXabKarpcWPR9JZ3X4+veiy79pM+zW+8fLE0qf79XzdzycDg3NtRpJaSr6ZZlvJz4MDZb+e++IhmxnUfpuRpMHhnTaTMuaKNG0NV/39p7nr9KlU+zNS/jkplebYTFuL35fWtl5iM5ctynsP/YVrH7GZ897mG9yOrf4hm2n/6I0284r3vN5m/mSTb8osSc/8kG9M2rbpv9vMRMUfL6Y9/jjosR1rbEaS9oz615yBUr/N9LSvsJkjY75Z+NjE4zbzVDd3zcWZIwAAAAAQxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQRHEEAAAAAJJmWRPY2ddszteObS2+QaIkndP6LJt5yfxlNvPy9dttZuMzfePSzp9eZDOVS95nM5LU3jrfZsr3/rnNpH7fIGz3/qU2c/+Ab6QpSbvCN6QrV30jzUpGQ+GUxrLGhFNTyG9zOfvJPPkmx22Zb2ltXOwbk3b+jN+fSpe+xWYqh++0mXb5JrArV+y1GUnq6vT704aBXps5q8s/3r0HN9nMvTU/d0vSo53+yRsY2+JXlHIaJdLgdfbxz1tkHJblNJSem9G4c0nNH7/0T+RNOC0tfnuLR3bZzKG/uM9m/tdX3mAzF8732/bVr/uczUhSreKb5aa/+nmbGblrnc1843Z/rPixbf74TZK+VX3AZsryc+lo+ZDN5DSvPpXmG84cAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkDTrmsDOnJxGa+1tvkHi8o4Ls27vwhbfJPDH1uy2mUuuvdlmai+51GbG1z7fZnoymlZKUjmjCerEsgtspnvtHTazeH6/zSzvXGkzkrRo2DfA21UbtZlabSLj1iIjc+o0SDuzZDTcDP8+Uy1VbKY7zbGZntacbUnqaC3bTFR8Ztqs9Pvb8qvuzVrVsuEWm6kOddnM+VvW2sxZ911kMx17FtuMJKWRjTYz2uobLo6XfcNs5pNTU5KfJyKjCexI5aDNHG45YjNbh/Kaqt++e43N3P/B19jM53fNtZmvVfyxwjsXPt1m7r7xhTYjSR1f9A1O+4d+zGY2H/RNp2/r67aZhzOaskrSUOqzmSNj222mVGrPur3TCWeOAAAAAEAURwAAAAAgieIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkiiOAAAAAEASxREAAAAASDqTm8BGh810tPhmZEuqy7Nu7+yMfqqVmm9sONHXazNd2zbbTNse3/grVXwzOkkqzeu1mfY5PpNWr7KZ8190q81c0eeb90rS1kHfuHFbxTeUPVD2zfZC/rnNaf6H2SllNAuW/CQwHmM2M1rNa+45OOaboOq2h2wkPfyLNuPbFkrlbT5VHV2UsSZJJf8YtC0YtJmlP3S/zTy76t9DHKlebDOSNLjbN4Hsi7Nt5kDG9lapHs4aE049larftnM81rLFZvanjHlEUvlRvw+s6/FzwA1H/spmrul+g830l32z7H/f6fc1SVrU4Ru9j1b84fQDA76Z6oqums3URnxGksaq/TaTkm9wW6n616WcY+qU/HpmC84cAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSJN/S9zSV0xU4yXchHg3fqVySHhuZbzNf37PcZnb+8zU2s/Y/B2xmPKOb81hGRpKW9AzZzNMvv8tmep6212ZKGc26z1nq1yNJKx733eqXDqy2mSNte2xmrJwxptTiM6pmZJArpmkKLJV6bKZWq9jMvtpWm9kztjhrTLc8vtJmxq9/qc1Uav49tNFym19P8uupJt/VXpLmtfv5e82iAzZz/vPvtJkF6/z+ffbuVTYjSasPzbGZNf3rbWaobb/PVAdtJslvk5g+Ee0ZKb+ftLcutJmU/PHLoZH7bKa1ZZ7NSNJdnWWb2T98ns2kNGEzfbVhm/lEn9/+x2PMZiRpYXWRzbxooT84OTiebGZgws+BC5OfR3LlzQF+TCn55z9n+895/mcCZ44AAAAAQBRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIOk2bwOY0moro8JmM2nGgdDBrTLdN+GZj9xzs9ivKuLlaxrhzquI5KadhnXTRnDU2U6n6BqfPXf4lmyn1TF+DsAXtvqFqd/KN3ToymuSNlX3TxshoOuxbyOG4hN8T5nSstZnuFt8kcKTqd972kp8DDqR+m5Gkbx3wzWI3DyyzmbGq3+rGqn7bbSv5x3pZZ977devn+oaDPW1+rhjv8825S+3+trozmtJK0pw2/1i2ys+VXS0LbGY4fPPalGgCO5NCvlnysu5n2EyX/GvOnvK9fjzhx1PL3EZaMu7bkei3maU9l9vMluodNtOW/Fw6X8ttRpK26R6b+Wi/b8z63JZLMm7NN1w9FEMZ68k7NpmIjAbmaSTj1nIa1PttZLbgzBEAAAAAiOIIAAAAACRRHAEAAACAJIojAAAAAJBEcQQAAAAAkiiOAAAAAEASxREAAAAASKI4AgAAAABJs64JbEbzu/azbKanbYnNDE48bjMRfjwjtcM2I0mDyTcBHav020y15psN1pJvfpjT4La91TdIlKSFo9fazNCEb7pbHvBN1No7+22mlnwTNUkar/rHoCujaVl3yTdkHMxoOpzbbA/TJyXfuC6n4eZ8+TlndWywmbaan3OGwzeUlqSt2msz5Qk/V4zJNxzMmQe75R/HjcMX2IwkrZ/r9/HVy/z9b18waDNjfb02c2Q0o4G3pImaH/f8km++3aqMJuaR8fJOV+kZldNM8+D41mm5rYvarraZoTa//R9Mu7Nub7zm54mWkn89zTk2yTk26wh/PHGefAN7SRrSCpt5IH3bZr5de9hmXtBxvs0sK+cdm+3IeH3L2SYj4/glpbxG2KcKzhwBAAAAgCiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQNMuawLa2zLOZRe3n2Mzy6iqb2dXum5El1WxmYCKvQVq56puttZZ6bKatxWfKvu+XInxdvKBtrV+RpNXdviHb2Ut9Q8bWbt/ccmJfr80MjvjHSJJGqr4hYyWjS2IpoyFdzraEk8HvLKNV3+B0U1xsM1cv89NtZ8lvJzfv77IZSeqr+uZ+B0oHbCZn281phHxO7VybeXqvn5cl6TkrHrWZxWt8o+/KkG/eOnbEzycD43nPyXDFzzkt4TOljIbpzDmzT0+HP37pzDgOOjRyr8381Oq5NtPb3mkzN+9bbTOSdCTjwOOeeMhmDk1st5n5bf4Y72JdaDNrunMPgX3utsN7bObitstsZrDsjzkqafr27ZwGrxEZx8s0gQUAAACA0w/FEQAAAACI4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABImmVNYNtb59vMpbHRZs5b6JvoHRi71GbuHzliM7szmslK0mhtwGZymvaVa76x45z2FTazJjbZzBXdi21Gkn7qHN8Id8OV37GZ1qX+vo18Z4nNjE2024wkTWT0USurYjMjNd8kNKVyzpAwCw2N7/Ih30tRVyzbZzPrlvtGght3r/E3JmnnoG8m+djI2TYzUPbz6fw237hwWdeEzayf55vSStKG9dtsJtUymjyP+bli5Mgcm9k/mtcE9oh/CDRS8800h1Jf1u1hdlnQ5huq/sqS823m3v4rbWbHsB/POy6902aevbrXr0jSrkOLbGZz/3k2M1zxx3g9rX4fOTzhGyVvOZLXTPXr1dtt5iVdr7SZdXP8+YjtQ35M81vzjjt7av4Y7kjyjXlzmsCebudaTq97AwAAAAAniOIIAAAAAERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACAJKl1Zm7GdyqXpPaS70S+rNPXc5csGLCZrraKzWwc8B3m7zz4dJuRpL4J3xq9LN8ZuVTyj+WGnk6bed4S3z77mmd+3WYkaemLN9tMbf16m0l3bbWZg3uW2sy9B5fYjCQdGEs2s7dlj82MTBzMuDX/3Cb58WC6+S7qpeiwmZx9d+/QXJvZkDEvPefZd9iMJF1Z8tvT6GE/xw0d8eMul/1LycLFfj9p7fTzpCS1tJdtJjLuf6r5+fSRx1bbzJYjXTYjSQNlv51sKfl5cHhsn82kNJ4xopzXZuYlJzIPpTpTj82snztoM09f5J/bX9/st7Xq166wmbdfeafNSNLzL3rQZp6VMd8c3Odfv2/aep7NbB7w2+0t1VttRpLevOTZNnPxgsM285XHF9nMvDZ/jLt9dNRmJGmg9pjNlEp+flfy21LWMU7WnDQ7cOYIAAAAAERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAEkURwAAAAAgieIIAAAAACRRHAEAAACApBlqAhvRlpUrhW/IWM3oR7eg0zfIOn/tTpu5cr5vxvbaibyHMFp8g6ycpoW1sn+M5p51wGY6z/ENGWsXbrAZSSovfK7NtD18j83s/dYmm/niA0+zmW/sz9ve7q76bWB/eYvNlCt+O8lrolbNyGCmpYzmdvtLfp/bcmSVzZyzb5nNLF6/y2YkqfM8v4/PzWg4uHQ0o7lfTu/WjKkyjea9X5cy5t2xvQtt5qHvXGQzdxzwTSn7xvIanT9a6beZ/uSf32p1xN9Y1pyD6RClvCbAOSrJ7wPP2PSAzfx119k2c+Mj/pjj/C/cZzOS1NvljxdK4ffbWhqwmaGJL9rMRW1X28ybFvsmuJK0qtsfU24f9A1uD437x/tbFX/MsXP0P2xGyjv2TmksYz3tGbeW0yzaH7/OluMgzhwBAAAAgCiOAAAAAEASxREAAAAASKI4AgAAAABJFEcAAAAAIIniCAAAAAAkURwBAAAAgCSKIwAAAACQNENNYFPKa+o0Xh2ymQeHffO7O/t8076ejnGbOT+jCeyiZzxsM5JUepofU3mVb6JWa/fN5kpjC2ymOjZsMy2H9tuMJMVd37CZ737WN2T7p23rbeb2g76x4d2622Ykqb/smy2Wq/5xqiXfIC7NksZmaOafl1Tzz++O8Tts5psHl9tMxFqb2TPQazOS9OwdvlHkonMftZn2c47YTMzPaBI44OfciT7fSFGS9j7g54qHd6+2me8cXGwzu4Z948Kto34bkaQdcb/NjE4cspmU/GOZspoyYjrkPB+SNB5+O7ll/zqb2bjHzyVLF/om0G+83GfeFL45uyTt2ufH1D/WbTNbBvwc8Piob27a2eK3/zltFZuRpNGKP1T+/B7/WvLNiRttpqO112aW9TzLZiTpwKifb7Lmkqwu3zkNXk+dxtScOQIAAAAAURwBAAAAgCSKIwAAAACQRHEEAAAAAJIojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACBphprA5jRalKSxcp/NPNq902a+ceBcf1tV3yDw8EiPzVwy7JuyStLS2oM203bQN2TTYNlGasO+5h3ducxmdjxwnh+PpK9uv9Zmbjvgm7bdU3ncZh6v+aa741XftFLKa/Cakm8Sl5J/TkRDxlNWTjPNStU3jP6u7rSZUp9v7rd31Dd5lqStg8+2mTX3X+QzvYdtprPNNwk8NDzXZgbGOm1GkrYNzbGZR4f9nLN7xDcl3FLb49eT7rMZSRqd8K9x1YxtKa/Ba07DReal6ZDXJFM6MLHFZu6tnm0zn37wApu5qNdvR0u7h2zmgg1bbUaSrrj2P2ymNu73yVW3XGYzu/oX2syhcT+XfG2fP8aTpC9lNPnuG99sM4u7NtrMeM0/J/uGb7MZSYqMxqwtJd90t5oxptNtvuHMEQAAAACI4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIojgCAAAAAElSpHT0pkwRrTPasakUviHXvE7fIG1hyTd4nZN846sl8pneVt/UTJLmt4fN9ExTS9628E/bUMWP57HRvOa929I+m9mXttnM4IRvtpjT1CxXtTZqM7XkM7kNAE9N1W+nlHxXvmky03NOHr+v5GyXER0209G2yGbmti23GUmapyU+U5tvM93yzRQ7wt//2jFeayZNZDYMP1DyzVT75eelnIaLOU2lxzMamEuScppKy2dOX1WllPwON01mer5pa/X75NqOy23m7HSWzcxt9ftkT6t/f7wl89nIyVUzHu3Bsg/lNEHuK/vX5YdLD/oBSTow9oDNVKq+WXbO81+uHLCZjraVNiNJ1dq4zdRqYz6ThrNu79Q09TEOZ44AAAAAQBRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABImmVNYHNE+IaEEb4xaynabaal5Js2tpT8eqTpbV7q1FLZZiYqAzZTrfnmh4WZ20wi43lTqmWt68xutpiLJrCnqqx9JUvGe2gZzU1z5DR3LPh9PKsxb6nLZlJGs2jmkulyejeBzdHassBmSjkNpVvn2kx7aY5fT/iMJI0lf7wQGXNJyti3a8k3ix6aeMxmchq3FnKafPvjzpZSj81Uqv3TcluSlJJv8AqawAIAAADAUVEcAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkCS1nuwBHK+cjr85Gd+DWar4JsyYYSlNnOwhAKeE03tf8R3rc6TaqM+oMi23BeSoVA9Py3omKnunZT2QIqZnvqnWhnNSNpESB6dPNc4cAQAAAIAojgAAAABAEsURAAAAAEiiOAIAAAAASRRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkHQKNoEFAJzpMholzsAoAJz+Tu+G2pgKZ44AAAAAQBRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkERxBAAAAACSKI4AAAAAQBLFEQAAAABIMk1gU6rETA0EwOwTEf82k7fHnAOcuZhvAMyko805kRJ9xAEAAACAy+oAAAAAQBRHAAAAACCJ4ggAAAAAJFEcAQAAAIAkiiMAAAAAkCT9f+fwy/DnOu16AAAAAElFTkSuQmCC\n" + }, + "metadata": {} + } + ], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n", + "avg_3 = X_train[threes].mean(axis=0)\n", + "avg_5 = X_train[fives].mean(axis=0)\n", + "diff = np.abs(avg_3 - avg_5)\n", + "\n", + "axs[0].imshow(avg_3)\n", + "axs[0].set_title(\"Average 3\")\n", + "\n", + "axs[1].imshow(avg_5)\n", + "axs[1].set_title(\"Average 5\")\n", + "\n", + "axs[2].imshow(diff)\n", + "axs[2].set_title(\"Absolute Difference\")\n", + "for ax in axs:\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "RandomForestClassifier(random_state=0)" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ], + "source": [ + "clf = RandomForestClassifier(random_state=0)\n", + "clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "def rename_clf(clf):\n", + " if isinstance(clf, RandomForestClassifier):\n", + " return \"RF\"\n", + " elif isinstance(clf, SPORF):\n", + " return \"SPORF\"\n", + " elif isinstance(clf, MORF):\n", + " return \"MORF\"" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-04-21T03:26:17.992359\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAT8AAAEYCAYAAAAqD/ElAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAe20lEQVR4nO3df7RdZX3n8ffnntybhJ/htyGJkmKqjbVFtJC1tFNHZUqoNbhaXdCWX8tZEYdUra7VZmzXSJ22g0yFJQ4TGsYo+IuiwhBpnIjUH6VDMMhQJITINQS4yTVBfoSQm+Tec853/jj74uFwzn72Se5Nbs7+vFh7nXv2893Pfk4Svvd59n7OsxURmJmVTd+hboCZ2aHg5GdmpeTkZ2al5ORnZqXk5GdmpeTkZ2al5ORnZqXk5AdIerukaNrqkp6TdJekcwrEN2+/KHC+LZLumZxPc3BJ+pikSw91O8y6Ne1QN2CK+TzwfaACzAc+CPwfSb8bEd/NiW+2dzIbOAV9DBgEvniI22HWFSe/l1sXEV8efyPp68DDwJ8B7ZLfy+LLQpKA6RFRtkRvPcTD3hwRsQH4BfDayT7X+FBY0psk/UDSiKSnJC3Lyk+TdIekFyT9QtKnJfV1qONsSf+a1bFV0n+V9IpfdJIWZ3G7s3rXSjq7Jea0bDj/N5IulbQB2AdcICmAOcDvNA37t2THDUi6UtJ9kp6RtFfSBkkfzpJn8zmuzI59o6RrJe3I2v5tSa9p0+4TJX1W0uOS9kkalnS7pDe0xP2OpO9I2ilpj6R1kt69X39B1nPc88shaRZwHPBYh5CjJJ3Ysm9XROzbz1OeCqwBvgzcClwCfE7SCPBfsrK/AM4H/hz4GbCyTR3fBr6abYuBvwJOBD40HiTp/cAtwCbgSmAAuBz4gaR3RsS/ttS7BDgBWAE8DTwKXARcB2wH/jaLezF7PQZYBnwd+AoQwH8APkvjz/Sv23z+VcBzwKeA2TSG1F8Gfrup3ScB9wGvpjHU/jEwC/j3wJuBDVncHwD/CNybnasKXAislnRhRPxjm/NbmURE6Tfg7TT+5/wojSRxMnA28J1s/4c7xLfbLi1wvi3APW32BbCkad8JwB6gDvxp0/5+YBvw4w51fKhl/zey/Quz99Oy458CZjXFzQV2Afc37TstO3Yv8Jo2n2UI+H6b/RUaQ+PW/V8AXgAGmvZdmZ3j9pbYjzW3O9v3v7J9v9embmWvR9DosX+jTZvWA0+Ox3or7+Zh78tdS6NXsx1YB7yNRo/rcx3irwHOadnWHsD5hyPijvE3EfEMjZ5ZADc27R8DfkT74fguGjdiWtsJMD7kewuNntU/RMTzTfUO0egtvlnSqS11rImIJ4p+kIioRdYDljRN0nFZL/lu4GjgdW0OW9Hy/nvZ6+lZPX3AH9L4xfFPbc45vkTRu2j84vhyNkQ+MTv3cTR6xfOAXy36Waw3edj7ctfQ+J9jBo3e3Z8BRzX9T9VqY7S/C7y/2iWX54Ed8cqbC88Dx0iqREStaf+WiBhtid2Uvc7PXk/LXje2Od8jTbHbmvY/3rnZ7Un6E+DjwBtp9LqaHdfmkNbP/1z2enz2ehJwLPBQ4tSvz15vz4k5mV/+uVgJOfm9XHMyu1PSTuBTku6JiG8dhPPXutwPoJb33SzQ2C5WHcr2dFEvkt4HfInGL5PPAT8HRoEzgU/T/mZbp8/Z7Wccr/tDNKbhtPNwog7rcU5++f47jbl+V0v6p4ioH+oGFTBf0kBL7298iDnee9uSvS4Ebms5/vUtMSmdEtEfZed7d/Ofm6TTC9bbztPATuA3EnHjCe/ZCe6ZWw/xNb8c2VDzGhoJ4Q8PcXOKOhr4QMu+j2Wvd2av9wPDwFJJx4wHZdf5/hh4ICK2UcyLtB/CjvfiXuq1SZoJfLhgva+QJdFvAL8taXFredMUmrU0hsyfyM7ZGnfy/rbBeod7fmn/APwl8JeSvp5z/W+qeBz422zO2yM0prq8G7gxIh4BiIiqpI/SmOqyTtIqfjnVpR/4SBfnWw/8iaRPAj8FXswuEdwO/AGwRtJtNK7bXcovp8Lsr0/QuKGxWtIXgAdoJPx3AF8Dbo6IXZL+I42pLhskfYnGne3ZwCIav8wOpAdqPcDJLyEidkv6HPBJ4PeB1Ye4SSnbaMxnu4ZGD/A54O9otP8lEXGrpBdpJPZP0eip3Qu8PyLWdXG+/0zjzurHaSShJ4BvRcRXJB0P/CmNuX3DNKa5/F/grv39cBGxI5uIPf73cSmNaS330ujRjsfdJumtWfuuyNq2HXiQxme2ktPU78hYUdm3K4Yi4m2Hui1mU52v+ZlZKTn5mVkpOfmZWSn5mp+ZlVLu3V6p35nR7DAWMdb67Ziu1PhKVzmgwh8f0PkOJk91MbOO6vW8b1a+UuUwupDm5GdmHUVUD3UTJo2Tn5l19PIFg3qLk5+ZdVR3z8/MysjDXjMrJSc/MyulqDv5mVkZuednZmXkYa+ZlVN97FC3YNI4+ZlZR+75mVk5+YaHmZVSDye/w+hryGZ2sKm2t6utUJ3SuZI2SRqUtLxNuSRdl5U/JOnMbP8MST+S9G+SNkj666Zjjpd0l6THstd2TxR8GSc/M+usXu1uS5BUAa6n8VTBhcCFkha2hC0GFmTbUmBFtn8f8I6I+E3gDOBcSYuysuXA3RGxALg7e5/Lyc/MOpvg5AecBQxGxOaIGKXx+NQlLTFLaDyCNLInCc6SNDt7P/7o0/5si6Zjbsp+vgk4P9UQJz8z60hR7W6Tlkq6v2lb2lLlHBrPUB43lO0rFCOpIulBYAdwV0Tcl8WcEhHDANlr8sH0vuFhZp11uZhpRKwEVuaEtFvpuXW16I4x0Vhj6wxJs4DbJf16RDzcVSMzTn5m1pEm/m7vEDCv6f1cYFu3MRHxvKTvA+cCDwPbs6HxsKTZNHqGuTzsNbPO6rXutrT1wAJJ8yUNABcAq1tiVgMXZ3d9FwE7s6R2UtbjQ9JM4F3Ao03HXJL9fAlwR6oh7vmZWWcT3POLiKqkZcBaoAKsiogNki7Pym8A1gDnAYPACHBZdvhs4KbsjnEfcGtE3JmVXQXcKukDwJPA+1JtyX10pZ/eZnZ4O9Cnt40M/n5XOeCI137LT28zsx7Q5Q2Pw4mTn5l1JCc/MyslJ79ykIrc/M6PmaglgJJtyblW+1IdfdNzy+v1Yt/FPBjUdmrXL8UrpoLZweCen5mVk5OfmZWRe35mVk5OfmZWRqr6GR5mVkbu+ZlZGSnqh7oJk8bJz8w6c8/PzEqp7p7flJeaJNsISsQUmDgcHPgk5r6+GenzxOgBn4cJGLKcdOSbkzH76i/mlu/a89NkHVJ/ojz9T7VeHylwnvx6ipwnNZG9p5516+RnZmXkeX5mVk7u+ZlZKTn5mVkpOfmZWSn5mp+ZlZHc8zOzUnLyO/SmVY7JLa8XmFvVp4Hc8qOnz03WsXt0e255kP7Hctz0X0nGvL7+xtzyHX3PJOv4eT1/fl2twJ/ZSDV9nj1j+Y9InTbt+GQdtdqu/IACbU39GwGo9M3MLZ87/U3JOh7f/c+55eo7IllHkTmJU4KTn5mVkpOfmZVS1Tc8zKyM3PMzs1Kq9+6Do5z8zKwzr+dnZqXknp+ZlVIPJ78iT+k2s7KqR3dbAZLOlbRJ0qCk5W3KJem6rPwhSWdm++dJ+p6kjZI2SPpI0zFXStoq6cFsOy/VjinR8xuYdlIyplrPnwSrCfgolQILWc6bkb+4565ITwp+bt/mZMxY/8Lc8kd335Gs4+Qjfyu3/K9O/c1kHdcOP5mMOWpa/nmeqj+crGOscnRueV9isVOAWuxLxuwZHcot31z9brKOpNh74HVMERN9yU9SBbgeOAcYAtZLWh0RjzSFLQYWZNvZwIrstQp8PCIekHQ08GNJdzUde21E/H3RtrjnZ2adTXzP7yxgMCI2R2O58luAJS0xS4Cbo2EdMEvS7IgYjogHACJiF7ARmLO/H83Jz8w6q3e3SVoq6f6mbWlLjXOAp5reD/HKBJaMkXQa8Cbgvqbdy7Jh8ipJx6U+mpOfmXXWZfKLiJUR8ZambWVLje0epNPaZcyNkXQU8E3goxHxQrZ7BXA6cAYwDHwm9dGc/Myss+hySxsC5jW9nwtsKxqjxpOuvgl8JSJue6mZEdsjohYRdeBGGsPrXE5+ZtZR1NXVVsB6YIGk+ZIGgAuA1S0xq4GLs7u+i4CdETEsScDngY0RcU3zAZJmN719L5C8yzYl7vaa2RQ1wXd7I6IqaRmwFqgAqyJig6TLs/IbgDXAecAgMAJclh3+VuAi4CeSHsz2fSIi1gBXSzqDRv9zC/DBVFuc/Myss2K9ua5kyWpNy74bmn4O4Io2x91D++uBRMRF3bZjSiS/am1nMiZiLLd82rT8eWIAM/rzF9V8duSR3HKA6oz8uWRv4OxkHU/PeFU6hl/klvcVWDBzLPbkll89/Fiyjm0j9yVjUovEnjTz15N17K0+n1uuSvoKzZ7R1ktHbepJ/JOfOXBqso7UnMPfnf6uZB237fpCbnm9PjXmCka1d6+MTYnkZ2ZTVEx8z2+qcPIzs44K3sQ4LDn5mVlndQ97zayM3PMzszIKX/Mzs1LysNfMysg3PMysnJz8JlcU+A6N+mbmlo/V0ouIFplMnTKtb3pu+b17bilQS3ooMaP/xNzyWTNem6xj99j23PKdtcFkHfOOeFsyZvu+DbnlL4ylJx/PmDYrt3ysnj9hG6BSYOJ36s++yL/FfYkJ2bePfilZx1SZxJzia35mVk6+5mdmZeRrfmZWSh72mlk5edhrZmXkYa+ZlZKHvWZWSlGtHOomTBonPzPryD2/SRZRPeCYIisbT5+W/yjPvYlJwQDP792c347EqsYARwzMTsa8q/8dyZiU4RjJLf9xrMktBxiJ55IxHznl/Nzy047MX/0aYPOLM3LLN+6sJev4Tu22ZEwtMVn62P65yTqeruav+F1JTMgHaDyvu7N6ovyg8Q0PMysj3/Aws1LysNfMSsk9PzMrpQhf8zOzMnLPz8zKyNf8zKyUfM1vChD5fwmpeVMAe0a3HnA7BhKLbp434/eSdewYTc97u3Xn2bnlz33wb5J1LPpq/oKnJ1QWJOv49GvSi6aeckT+/MhF/+7eZB33/cui3PIfPHtUso73HvW+ZMy39tyZW76nlp7XmJqDVyuwaG4QyZipwNf8zKyUernn17tp3cwOWIS62oqQdK6kTZIGJS1vUy5J12XlD0k6M9s/T9L3JG2UtEHSR5qOOV7SXZIey17zv86Fk5+Z5Zjo5CepAlwPLAYWAhdKWtgSthhYkG1LgRXZ/irw8Yj4NWARcEXTscuBuyNiAXB39j6Xk5+ZdRR1dbUVcBYwGBGbo3Gh/hZgSUvMEuDmaFgHzJI0OyKGI+IBgIjYBWwE5jQdc1P2803A+amGOPmZWUf1WqWrTdJSSfc3bUtbqpwDPNX0fohfJrDCMZJOA94E3JftOiUihgGy15NTn803PMyso27n+UXESmBlTki7CltvfefGSDoK+Cbw0Yh4oasGNnHyM7OOJmGS8xAwr+n9XKD1wc4dYyT100h8X4mI5jXMto8PjSXNBnakGuJhr5l1NAnX/NYDCyTNlzQAXACsbolZDVyc3fVdBOzMkpqAzwMbI+KaNsdckv18CXBHqiFToueXmsAMUKkcnVtere137/cl0yrHJGNO7X9jbvnzY2PJOr550Q+TMT97z8bc8ku/+9ZkHT+v/yC3/D0zz0vW8aNnkiFc+9nv5ZZvvPp1yTpufTx/gddnK08k69heHU7GjFafzy0fqT2ZrCO1YO3hMoG5iInu+UVEVdIyYC1QAVZFxAZJl2flNwBrgPOAQWAEuCw7/K3ARcBPJD2Y7ftERKwBrgJulfQB4EkgOeN9SiQ/M5uaJuMbHlmyWtOy74amnwO4os1x99D+eiAR8Qzwzm7a4eRnZh3VvbCBmZVRL3+9zcnPzDryklZmVkpOfmZWSk5+ZlZKda/nN7mKzIuaiHl8qfmE9freZB39kT/H62d96Xlib/7qvGTMsXFibvmjY/mLcgJUE4tq/rD/J8k6/tsJ85Mxt370d3LLt+w+MlnHl57/Zm75a6aflaxj6577kzH1qCZj0nVMkQeKHwS+4WFmpeRhr5mVkpOfmZWSJzmbWSnV677hYWYl5J6fmZWSr/mZWSk5+ZlZKXnY2yNSk6ml9MXdTXvW5pYfP/NXk3XsHns6GbO1uj63fHr/8ck6ZvbnT5Q+vbYgWcf/3JyeFLy5b1NueTX2Jeuo1fNjtuxdl6yjT/0HHDNaH0nWUSbu+ZlZKTn5mVkpedhrZqXknp+ZlZJ7fmZWSlHgyYqHKyc/M+vIw14zKyUPe82slNzzOwykVmmG9CTnKLRCb/5E6BdHtydrqNb3JGNq9RcT5enVkfuU/9f7L/tuTtZx2pHvSMaMJT7Prn3p1a1ricnFfZXjknVU67uTMSRWcu5T/krdUK6VnN3zM7NSqvkZHmZWRh72mlkpedhrZqXUyz2/3h3Qm9kBq3e5FSHpXEmbJA1KWt6mXJKuy8ofknRmU9kqSTskPdxyzJWStkp6MNvOS7XDyc/MOopQV1uKpApwPbAYWAhcKGlhS9hiYEG2LQVWNJV9ETi3Q/XXRsQZ2bYm1RYnPzPrqB7qaivgLGAwIjZHY27ZLcCSlpglwM3RsA6YJWk2QET8EHh2Ij5b71zzU4E/+Mif55csB4L8eWKj1WcKNGMsGaPEHL0i9lV3JhpSS9axde//S8aM1Z7LLa/0HZGsQ4n5dfUCC6IWmac5MO2E3PKxav5ngfSitxFFB4BTX7ff7ZW0lEZvbdzKiFjZ9H4O8FTT+yHg7JZq2sXMAYYTp18m6WLgfuDjEZH7l+men5l11G3PLyJWRsRbmraVLVW2y6atvY4iMa1WAKcDZ9BIkp9Jfbbe6fmZ2YSrpwdD3RoC5jW9nwts24+Yl4mIl75aJelG4M5UQ9zzM7OOAnW1FbAeWCBpvhrXOi4AVrfErAYuzu76LgJ2RkTukHf8mmDmvcDDnWLHuednZh1N9CTniKhKWgasBSrAqojYIOnyrPwGYA1wHjAIjACXjR8v6WvA24ETJQ0Bn4yIzwNXSzqDxvB4C/DBVFuc/MysowL3APejzlhDI8E177uh6ecAruhw7IUd9l/UbTuc/Myso7pXcjazMurlr7c5+ZlZR17YYApILVaqIh9FEzD5NLUYZt+MZBX1As3oS0wMrtbSC3fWU4t7FpgYnprADOlJzKmFShtNSS0imp6YUGRi+L6x/MVmi/z9Tcb8j6mq5uRnZmXUO99VeSUnPzPryNf8zKyUfM3PzEqpl69uOvmZWUfu+ZlZKfmGh5mVkm94TAFSf255kQVCD/QcAErMaSvyQOtCMbX8mGmVWck6Ci3wmm7JAddQZP5dciHSAnUUeeB4kFoUNT2fMHr6StjLuednZqXknp+ZlVIvf5nFyc/MOurh3OfkZ2adeaqLmZWSb3iYWSn5hoeZlVKthy/6OfmZWUdexn6SpRYqBQjyFxEtMvH0qBm/kmhHJVnHyGj+Q+Pr9T3JOopILapZq79w4OdQ/oRtKLAgKlCv780tP2L63GQdtXr+5ONq4hwAY9VnkjGpCdfJydak/7320iRoT3Uxs1KajKe3TRVOfmbWkYe9ZlZK7vmZWSl5np+ZlZJveJhZKfVw7nPyM7POevm7vemVG82stCK624qQdK6kTZIGJS1vUy5J12XlD0k6s6lslaQdkh5uOeZ4SXdJeix7PS7VjinR8ysyKXRm/6m55TOmJT8rz49syA9QepLzta+9NLf85q0jyToerd2TjDm2P39i8K7qz5N17KvuzC2vVp9N1tHXd2QyJjUxePe+J5N1KPV7uMBKzkVUUitxF5hMXU9MuO8lE33DQ1IFuB44BxgC1ktaHRGPNIUtBhZk29nAiuwV4IvA/wBubql6OXB3RFyVJdTlwF/ktcU9PzPraBJ6fmcBgxGxORq/NW8BlrTELAFujoZ1wCxJsxvtiR8C7X5rLwFuyn6+CTg/1RAnPzPrqN7lJmmppPubtqUtVc4Bnmp6P5Tt6zam1SkRMQyQvZ6c+mxTYthrZlNTt1NdImIlsDInpN0dlNazFIk5YE5+ZtbRJCxpNQTMa3o/F9i2HzGttkuaHRHD2RB5R6ohHvaaWUeTcM1vPbBA0nxJA8AFwOqWmNXAxdld30XAzvEhbY7VwCXZz5cAd6Qa4uRnZh11e80vJSKqwDJgLbARuDUiNki6XNLlWdgaYDMwCNwI/Kfx4yV9DbgXeJ2kIUkfyIquAs6R9BiNO8lXpdriYa+ZdTQZX2+LiDU0Elzzvhuafg7gig7HXthh/zPAO7tpx6Qnv9TikQCNXwb5XjXwhtzybfv+rUBb+hPlA8k6vrw1f7HSC151VLKOL2x/czLm/cefklv++WfzPwvASF9+W54rMM9v5kDypllyIdJZA69J1pGatzgymp7XOGvmryVjdu59LLe8yL/FMvHX28yslLywgZmVktfzM7NS8np+ZlZKHvaaWSn1cO5z8jOzztzzM7NS8g0PMysl3/A4AEUmjarAs0Hrib+Gv3v1e5J1/PnjX8s/R6QXsnxgz9dzyx/dlr8IKcDIvqFkzDW1/InBe8aeTtbRl5jUXakcm6xjbuWNyZjpfTNyywfH7kvWMTKav+BpJTFhG6BW5N9aaiJ71JJ1FFl8t1fUe7jr556fmXXUu6nPyc/McviGh5mVUs3DXjMrI9/wMLNSCvf8zKyM3PMzs1Jyz2+SFZk39fTYT3PLvzacv/gnwG/MeHdu+U9r65J11GMst3z3vi3JOgamnZSM2T36VG759GknJOvYV30mt/yY6fOTdTwxdn8yZu/Y9tzyI6e/OllHRH4fo554MDrAyGjqMQ+Q6suUaQ5fEe75mVkpeZKzmZVSL/eEnfzMrCMPe82slOru+ZlZGfman5mVkq/5mVkpedhrZqXk5DcFiL7c8h/vyV+oFGBa5bjc8lfP+K1kHQt1Wm75lv78icUAj+z5djJm/hFvzy0/tn58so5tA5tzy58eeThZR6VvZoGY/IVGR6u7knWkFrSNApOciyxGe6DtgN4eCrbq5c962CQ/Mzv4qkqvbH24yu9OmVmp1YmutiIknStpk6RBScvblEvSdVn5Q5LOTB0r6UpJWyU9mG3npdrhnp+ZdRQTPM1ZUgW4HjgHGALWS1odEY80hS0GFmTb2cAK4OwCx14bEX9ftC3u+ZlZR5PQ8zsLGIyIzdG4kHsLsKQlZglwczSsA2ZJml3w2MKc/Myso7rqXW2Slkq6v2lb2lLlHKB5yaKhbF+RmNSxy7Jh8ipJ+Xc3cfIzsxz1Lv+LiJUR8ZambWVLle1up7d2GTvF5B27AjgdOAMYBj6T+my+5mdmHaWel70fhoB5Te/nAtsKxgx0OjYiXlpUUtKNwJ2phhw2yW9k3xP5Aaok66jVd+eWP7EnvZjp4/V/zi3v6zsiWce0ypHJmM0vrs0tV+KB5EX0JR42DlBN/JlBeg5ePfWgcNLzyQYqJybrGK2mH+SepPQ8P3r4+66tJvqGB7AeWCBpPrAVuAD4o5aY1TSGsLfQuOGxMyKGJT3d6VhJsyNifDXb9wLJSayHTfIzs4OvrolNfhFRlbQMWAtUgFURsUHS5Vn5DcAa4DxgEBgBLss7Nqv6akln0BgGbwE+mGqL8tbol/qnzK+45Mz7Aj0/KT/X9yndE6rXR/LrKNDzq/RNT8aMju3ILT9YPb96VJMxqZ6fCvT8Un+uRZb+n4ien5S+DJ5acn8qiRgr0JXtbMHR53eVAx7b9b8P6HwHk3t+ZtZR0Lvf8HDyM7OOJuGGx5Th5GdmHTn5mVkpedhrZqXknp+ZldIkzPObMnom+UWhKRn5MXXSi2FOqxyTW16r7UzWUa2lZw+kpsz0FZk4nPwzSU/rSE1BKaLI303KWC29SOxEOJymsRwMNcYOdRMmTc8kPzObePXwNT8zKyEPe82slHy318xKqd7D10Cd/MysIw97zayUwjc8zKyMPMnZzEqpl+c9HjbJb6o8Ob5ae+GgnCc1ubjOgU8+Ppz08v+EU5nv9ppZKfXyLx0nPzPryHd7zayUfLfXzErJw14zKyUPe82slIo8ve9w5eRnZh35mp+ZlZSHvWZWQr7hYWal5BseZlZSTn5mVkYe9ppZGfXysDf97EIzK7F6l1uapHMlbZI0KGl5m3JJui4rf0jSmaljJR0v6S5Jj2Wvx6Xa4eRnZp1FdLclSKoA1wOLgYXAhZIWtoQtBhZk21JgRYFjlwN3R8QC4O7sfS4nPzPrKLr8r4CzgMGI2BwRo8AtwJKWmCXAzdGwDpglaXbi2CXATdnPNwHnpxqSe80vYkxFPo2Z9aZuc4CkpTR6a+NWRsTKpvdzgKea3g8BZ7dU0y5mTuLYUyJiuNHmGJZ0cqqtvuFhZhMmS3Qrc0LaJdPWLmOnmCLHFuZhr5kdTEPAvKb3c4FtBWPyjt2eDY3JXnekGuLkZ2YH03pggaT5kgaAC4DVLTGrgYuzu76LgJ3ZkDbv2NXAJdnPlwB3pBriYa+ZHTQRUZW0DFgLVIBVEbFB0uVZ+Q3AGuA8YBAYAS7LOzar+irgVkkfAJ4E3pdqi6LA7Wkzs17jYa+ZlZKTn5mVkpOfmZWSk5+ZlZKTn5mVkpOfmZWSk5+ZldL/B1FFHkLwDQQ1AAAAAElFTkSuQmCC\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(5, 4))\n", + "importances = clf.feature_importances_\n", + "sns.heatmap(importances.reshape(28, 28), cmap='inferno', square=True)\n", + "ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + ")\n", + "ax.set_title(f\"{rename_clf(clf)} Importance\")\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "ObliqueForestClassifier(random_state=0)" + ] + }, + "metadata": {}, + "execution_count": 25 + } + ], + "source": [ + "clf = SPORF(n_estimators=100, random_state=0)\n", + "clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2021-04-21T03:28:42.502439\n image/svg+xml\n \n \n Matplotlib v3.4.1, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAT8AAAEYCAYAAAAqD/ElAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAmX0lEQVR4nO3df5RdZX3v8fdnhhkmIWICGBoSKhSDil6NaAHRWpY/E1xd0XZ5hbaCVBtpwWprf1Cv94q9vbfWij/o5SY3llSo3nJZVWuWRgGxaNUiQcRIjMgIQQZiIgkEQjJMMud7/zh79HA4e3/3MBNCZn9eWXvNnP08+znPPufMk2fv/T3frYjAzKxp+g50B8zMDgQPfmbWSB78zKyRPPiZWSN58DOzRvLgZ2aN5MHPzBrpoB/8JD1H0j9JGpY0Kulnkm6W9FFJCzrqnSEpOpaWpB2SrpF0Rknbb5D0ZUn3S3pU0mZJqyX9So+6x3W1H5IekvQNSWfVrN+5zEn2+wZJI5N/xZ56JP2epHcf6H5YsxxyoDswFZJOA/4NeBD4R+BO4BnAC4DfBz4PbOna7HLgBqAfeBZwPvAVSa+NiK8W7QpYDbwd+C7wd8D9wHOBtwG/I+k3I+KaHt36PPAvgICFRRv/LGlWRPxjRf1uo/krMGP8HrAI+NgB7oc1yEE9+AH/FRgHfjUiHjMLKmZO/T22uTEiPtVR77PALcCfAV8tVr+b9qC1GviDiGh11P848DXgakn/KSJ+0tX+bV3t/yMwDPwp7QG622PqN0nxH8KeA90Pa6aD/bD3BGC4e+ADiIhdEbEzayAivgtsL9pC0hDwXuAu4J2dA19R/x7as8XDaQ+YWftbgR/SnmXuVxOHwpKOl/QFSQ9L2irpv6vtyOIUwY7ikPzyYn97tXFicUpgV3Eq4TJJh/V4ztMkXVu094ikb0pa1qNeSPqUpNcXpyVGgb+QtBl4GfDMzsP+ju3+RNLXJG0rTj0MS/qApIGu9t9abPs6Se+TdE9xGuSbkl7Yoz+HSforST8s6m0r9uPlXfVeKOlzkrYX9TZIOm+y74099RzsM7+7gFdK+rWI+Pcn0oCkI4F5wI+KVS8HjgJWRsRYr20i4lpJ9wDLgXcm7R9C+5Bue0mVWZKO6lq3OyJ219yFx7UHfAW4Fvhz4I3A+4BdwJuBTcB/AX6d9uHmVtqDfa82vla08VLgD4HjgTM79u1lwPW0Twn8HbAHOA/4oqSzIuLqrnZfXGy/Evg/wD3ArcDfAkcAf9xjf/4MuAZYC+ym/f68D/jl4rm6/TUQwEeK/fhT4F8lLY6IfUW/Z9E+XfKrwNXA/wKGgNOL1+UbRb3TgeuAHwMfAh4GfgNYI+moiPi7Hs9vB4uIOGgX4BXAXtof9ltonzP6HWB+j7pnFPXeTXtwm0/7w/61Yv07i3p/VDx+Y/Lca4t6c4rHxxWPLynafwawBPi/xfqPdG0/Ub/XcnGNfb8BGOmxLoB3dawbAO4DWsAlXfXXA9tL2vjbrvWXFOuXday7CXgE+OWOdU8H7qZ9rnWgY/3Evr20x758A9hcsp+H9Vj3ftqnOxZ2rHtr0f564JCO9b9ZrD+zY937inUX9GhbEz+BjcB/dLZXlH222O+nH+i/AS9PfDngHZjyDsBLgKuABzr+wPYCl3b98Z1B74FmF+3ZwsSHfuIP41XJ836qqLeweFw2mO0DPg4Mdm0/Uf9TwKu7ll+psd9lg984MNS1/l+L51rctf5jxfp5XW0EsKCr7jHF+v9dPD66ePwPPfr2l0XZ6R3rArilZF9KB7+OOv3AXNr/sbyiaO83OsonBr8VXdvNo+M/t2LdBtoDdF/F871gYrviOTuXtxdlrzvQn38vT3w52A97iYibgbOKK7TPBl5F+/DpnbQPx/6qa5OPAF+iPRPaCWyMiM4rqw8VPw9Pnnqi/KGu9Z8GPgkM0h6Y/5L2LHBvSTubI+IryXNNxrau/YH21XCA7oszE+uPoP2fx4SHI+IxV8kj4j5JD9M+9KXj56YeffhBR51vday/q7LnPUhaCvw32q/lQFfxvB6b3N35ICIeaH80OKJj9WLgq9F1PrfLc4qflxZLL/MrtrenuIN+8JsQ7f+ufwj8UNJVtM/TnMvjB79NyWAz8Ye7BPhcRb0lwD0R8XDX+s7BbJ2ku4AraR8+/X22H9Ng/AmUqetxWZLH7npldVVSNqkru0Uo0xeBb9P+z2wEeJR2CNEn6X3Bbqr7OGGi7YuBb5bU2Zi0YU9hM2bw6xQR2yX9GHjeE9j8m8AO4Hcl/Y/ocdFD0quBY2mfKM/68k+S/gh4v6RP9hgsn4oOl7Sgc/Yn6RhgDr+YvW0ufp7UY/vndNXJlA1EbwbGaJ+C+PnAKel1NdstcwfwfEl9FbO/4eLnnmmemdtTxEEd6iLp1ZIeF8un9jcwnkvvQ7JKxR/Z3wC/AnxM0mNeI0kLaV+pfIj2Fc46/gY4EviDyfbnAHp31+P3FD+/ABARPwUmTjksmqgk6WnAO4Cf0r74UMcu2ufzuo3THhh//h4X73caYpS4GngmsKK7oDh9Au0LaLcD7+pxNR5Jz5hiH+wAO9hnfh8D5kr6PHAb7YsLJ9I+3B2kffHiibiE9ozmD4BTi8Po7fziGx6HAr8Vjw9wLvM52ofkfyLp7+OpH9j7AHC22l8PvJF2qMvvAtdGxLqOen9MOyTmRkkraX8r5TzaA8tZEVF2nrPbemBpEUD+baAVEVfR/vbLnwBflXQF7dCVNzP1/7QvoR0CtFLSRGjLIO2r/98F/mdEtIp4vmuBH0i6nPY3iI4CXkQ7zOnQKfbDDqQDfcVlKgvwOtqzsNto/8HuBe6l/XWxl3bVPYP2LOLtk2j/N2l/+LfTPtd0N/AJ4IQedY8r2v/rkrbeWpT/UZ36Nfp2A72v9o70qPvJ4rm6QzYuLtY/q7sN2v+JXEN7VraddmzenB5tn0Y7Fu5h2nF436QjrKSjXgCfKtmXw2mHBO2gfSEqOsreTPvq7J7ivf0Y7dMZAby1x+v76pLnvrhr3dOAD9I+NzwGbAO+DLysq95zaF+R31LUu7fY3z880J9/L1NbJsI7zID2NzxoD4aLsrpmB7OD+pyfmdkT5cHPzBrJg5+ZNZLP+ZlZI1WGukgD6cg4cMiRleWD/U9LO3H4wDFpnSEqExtz165r0zb6+oYqy2cN/FLaxth4dYxynf3ds/enaZ1WqzqX6bzZz0/beGD3bZXl2esBT95rMn/gxMryB1v3pW3UMdRX/a3F0Vb3txUfb2x8V2V5nc9z9jzZc9S1a/SOXt/KqW2cT09qdtTP70zp+Z5MB3ucn5ntR61W1bclH6//IDqR5sHPzEoVKRBnJA9+ZlYqYnIzv4OJBz8zK9XyzM/MmsiHvWbWSB78zKyRouXBz8yaqKkzv8MOPS5tYLC/Ovg4C7SFekG/Tx+qvu1tnb5mwbiPPLo5bSMLDN7TeiRto7/vcbe/nXSdOkGw2WtSZ3/rvDdZIHSdNkamIVC6zv5kr0n2ea5jy65vpXWyz1Gdz8jefWV3Q50+Puw1s2Zq1c1He/Dx4GdmpTzzM7Nm8gUPM2skD35m1kQar84udDDz4Gdm5TzzM7NGmsGD30GUfcvMnmyKfZNaarUpLZV0u6RhSRf1KJekS4vyDZJOLtYPSbpJ0vckbZT0gY5tLpZ0r6Rbi+XMrB+VM786wamD/fs/+Bhg19iWpB95EOx4EoCcZaWu8zx1gmR3jg6ndfIg5/w1y9R5b+rI+lInYHfO4ILK8jrB8tMRlF/nvXnm7FdUls8+5MVpGz969Ia0TuZ5h/3WlNtITTKZaUZSP3AZ8Bra94deL2ltRPygo9oyYHGxnEr7ntGn0r539isjYpekAeAbkr4UETcW2300Ij5cty8+7DWzUpr+w95TgOGIuBNA0lXAcqBz8FsOXBntGwzdKGmupAURsQWY+GrTQLE84ZsQ+bDXzMq1xie1SFoh6eaOZUVXiwuBezoejxTratWR1C/pVmAbcF1EfLuj3oXFYfIaSfOyXfPgZ2blWvsmtUTE6oh4SceyuqvFXjc46p69ldaJiPGIWAIsAk6RNHEnr5XACcASYAtwSbZrPuw1s1Ka5nN+tGdxx3Y8XgR035ovrRMRD0q6AVgK3BYRWyfKJH0C+ELWEc/8zKzcJA97a1gPLJZ0vKRB4CxgbVedtcA5xVXf04CdEbFF0jMkzQWQNAt4NfDD4nHnFbM3AukVMs/8zKzUdM/8ImKfpAuBa4B+YE1EbJR0flG+ClgHnAkMA7uB84rNFwBXFFeM+4CrI2JihvchSUtoHx5vBt6R9cWDn5mVm/7DXiJiHe0BrnPdqo7fA7igx3YbgBeVtPmWyfajcvBrtfLv9WXxVwvmnJ62cf/47Wmdo4aeXVn+0N7u0wb7RxbTVic2cjrUSWSZJcw8PHlNAbbuviWtk31Ojp/z2rSNUaqTs86b/fzKcqgXowfViVezxKwAI6PfqSyvE3OayT7vALuZeqxnZj+c83vK8MzPzMp58DOzJvLMz8yayYOfmTWR9vkeHmbWRJ75mVkTKVoHugv7jQc/MyvnmZ+ZNVKroTO/OoGlmekIkoU8iPmRRzenbWQB11t2fSttI02YWSPA9fCBY9I6o62HKsvHpiGQ9o2HLUnrfIk8Oeu2vT+qLM8CmCF/f+sEj9cJUM7USUab1amTeDULQL9/NA/8n45g6lRTBz8zazbH+ZlZM3nmZ2aN5MHPzBrJg5+ZNZLP+ZlZE8kzPzNrpKYOfkN9h6cNZPFIR88+OW2jTixgFsdX5wbOm/dWJ6Gs08asqI7P2t73s7SN2ZHHZ91HdZzfdMQK3vTwzrSNbePVMXwAz+1/WWX5vdyVtpHF8dWJ4asTC/jcWa9P62R2aEtl+XTEx9ZJzFonGe2UNXXwM7OG8+BnZo20zxc8zKyJPPMzs0ZqxYHuwX7jm5abWbloTW6pQdJSSbdLGpZ0UY9ySbq0KN8g6eRi/ZCkmyR9T9JGSR/o2OYISddJuqP4OS/rhwc/MyvXisktieKG45cBy4CTgLMlndRVbRmwuFhWACuL9Y8Cr4yIFwJLgKWSTivKLgKuj4jFwPXF40oe/Mys3DQPfsApwHBE3BkRY8BVwPKuOsuBK6PtRmCupAXF44n8aAPFEh3bXFH8fgXwhqwjHvzMrNwkBz9JKyTd3LGs6GpxIXBPx+ORYl2tOpL6Jd0KbAOui4hvF3WOjogtAMXP+dmuVV7wqJPcc+CQIyvLF48/J21jaHaeQHJktDpA+QUDR6dtXP3rx1aWX/advI3M13bliVl36+G0zjFUB7BmAdsA8wdOrCxfqDzYena8Mq1zwqzqwG/2HJ+2ccSsBZXl90Wd5J7PSutkAcpZYDjA2L7q5KwvP+S1aRs3RfXf1tOH8n2pk/B0qiZ7C4+IWA2srqiiXpvVrRMR48ASSXOBz0l6fkTk2WN78MzPzMpN/2HvCNA5C1kEdKfxTutExIPADcDSYtVWSQsAip/bso548DOzcq1JLrn1wGJJx0saBM4C1nbVWQucU1z1PQ3YGRFbJD2jmPEhaRbwauCHHducW/x+LvD5rCOO8zOzctMc4xwR+yRdCFwD9ANrImKjpPOL8lXAOuBMYBjYDZxXbL4AuKK4YtwHXB0RXyjKPghcLeltwE+AN2V98eBnZuX2Q4xzRKyjPcB1rlvV8XsAF/TYbgPwopI2twOvmkw/PPiZWalo9br2MDN48DOzcjP3q70e/MysQlNnflkMH+Q3Tr6j/4eV5VAvmWmWFPXl8x9N2/jPX8viHrembWTqxKPNVZ6ItE4s4FQ9sG8srfN93ZTWuePRJOltjZiCoezm6DXOPc3ty1/X7AbrdZLEZolGfzSwOW1j16PV8YZPFbFv5gaEeOZnZuWioTM/M2s2X/Aws2Zq+bDXzJrIMz8za6LwOT8zayQf9ppZE/mCh5k1U1MHvyyAGWBsvDoYd7Q/Tw753FmvT+ssiuqA6z+/+9q0jUydANdpSSA5mFfZNVYdBHvUUHWyU4AHW91p0h7rjNl5wszto/nzzGolyUxruLfvrsry7PUAGBpKgq3JE7yOUp2oFOA1s86pLH/t0XlE9oe3Vj9PraSq43lfp8rn/MysmXzOz8yayOf8zKyRfNhrZs3kw14zayIf9ppZI/mw18waKfb1H+gu7Dce/Mys1Eye+al9o6SSQg2k0ZpZtuc6gdJ1ZMGpsyN/ns37vjPlfmSB0Atbx6dt/JjvpXWG+qoDdus8z9Ijq9u4/IG8H2ccsiSt87y545XlO8by2cM1O6uzaM+KPJB6dhya1vmPvdW3c1009OK0jSx4/BjlgeE7VB20vWXXt9I2+vry12R8/OEpjV4PvvNZk7p/29y/Hz5oRkvP/Mys1Ey+4DFzr2Ob2ZRFaFJLHZKWSrpd0rCki3qUS9KlRfkGSScX64+V9G+SNknaKOldHdtcLOleSbcWy5lZPzzzM7NS0z3zk9QPXAa8BhgB1ktaGxE/6Ki2DFhcLKcCK4uf+4D3RMQtkp4GfEfSdR3bfjQiPly3L575mVmpiL5JLTWcAgxHxJ0RMQZcBSzvqrMcuDLabgTmSloQEVsi4pZ2v+JhYBOw8Inumwc/MyvX0qQWSSsk3dyxrOhqcSFwT8fjER4/gKV1JB0HvAj4dsfqC4vD5DWS5mW75sHPzEpN9pxfRKyOiJd0LKu7mux1HN19RbmyjqQ5wGeAd0fERO6vlcAJwBJgC3BJtm8+52dmpfbD1d4R4NiOx4uA7tih0jqSBmgPfJ+OiM/+vJ8RP4+VkvQJ4AtZR/b74JclO63rwf7q2Kq7R4fTNmYN/FJleZ2+vuCQ51eWHzGYv6Sb9uRJKH/2yJ9Wlv/76Z+tLAf4ve//rLJ88fhz0jb+6tV5bOQznnlvZXn/rEfTNm7++G9Ulr/26IG0jUu2fTOtM2dwQWX5seOL0jZm9009dnXr7lsqy7P4WYDx1iNT7kem5nm8yVgPLJZ0PHAvcBbw21111tI+hL2K9oWOnRGxRZKAy4FNEfGRzg0mzgkWD98I3JZ1xDM/Mys13TO/iNgn6ULgGqAfWBMRGyWdX5SvAtYBZwLDwG7gvGLzlwFvAb4v6dZi3XsjYh3wIUlLaB8ebwbekfXFg5+ZldofX28rBqt1XetWdfwewAU9tvsGvc8HEhFvmWw/PPiZWamZ/N1eD35mVmomf73Ng5+ZlWqNO6WVmTWQD3vNrJE8+JlZIzX2nN90BFq2WqNpGwvmnJ7WuX/09ik/T5aIlDyOlmMOrf7/4uXzd6dtXPpr+f85t7/+8sry3/5+dQJRgCHmVJbf0z+StrH+x4vTOi8cG6wsv2G4OhEtwHU3rK8s/8zbX5C2ccH4r6V1vrz9ocry7+umtI2de6oD6rNg+jqejADmOjzzM7NG2g/f8HjK8OBnZqVanvmZWRM19pyfmTWbz/mZWSN58DOzRvLgZ2aN1JrBV3srb1o+Z2hxesPiLAFonXil/r7D0jpZO8+c/Yq0jbt3f72y/OWHdudUfLzsptcvHei+F8vj7Vae3HOPquMWdytPvJrdXPuCI/O4uM/t2JbWWUQeD5rZ09pXWf6d1lfTNo47JL/heHbj+ulKvpvJPs91YgUfeXRzWidi75SmbhuXvWpSNy1/3peuP2imip75mVkpH/aaWSN58DOzRnKQs5k1Uqs1cy94ePAzs1Ke+ZlZI/mcn5k10kwe/GbuAb2ZTVkrNKmlDklLJd0uaVjSRT3KJenSonyDpJOL9cdK+jdJmyRtlPSujm2OkHSdpDuKn/OyflTO/Pbs/Wmtnany9KFnpXUe2J3eXJ3DDj2usvyuXdembWTJWesksjxq6NmV5dtjZ9rGka2np3VuGf1MZXmd13VsfFdleZbYE2DzeHVQMMB9/dVJU09Rnqw2TazaSptg054vpnXqvG6Z7O+iTtB+VqdOsHWdZMNTNd0zP0n9wGXAa4ARYL2ktRHxg45qy4DFxXIqsLL4uQ94T0TcIulpwHckXVdsexFwfUR8sBhQLwL+oqovnvmZWakITWqp4RRgOCLujIgx4Cqg+2tRy4Ero+1GYK6kBRGxJSJuafcrHgY2AQs7trmi+P0K4A1ZRzz4mVmp/XDYuxC4p+PxCL8YwGrXkXQc8CLg28WqoyNiC0Dxc37WEQ9+ZlZqsjM/SSsk3dyxrOhqstcI2f394co6kuYAnwHeHRH5uZsSvtprZqUmG+cXEauB1RVVRoBjOx4vArozcJTWkTRAe+D7dER8tqPO1olDY0kLgDQjh2d+ZlYq0KSWGtYDiyUdL2kQOAtY21VnLXBOcdX3NGBnMagJuBzYFBEf6bHNucXv5wLV6ZfwzM/MKkz31d6I2CfpQuAaoB9YExEbJZ1flK8C1gFnAsPAbuC8YvOXAW8Bvi/p1mLdeyNiHfBB4GpJbwN+Arwp64sHPzMrtT++3lYMVuu61q3q+D2AC3ps9w16nw8kIrYDr5pMPzz4mVmpmfwNj8rBr06w5pzBBZXlu8a2pG3UCdY8fOCYyvI6QaFZ9tydo8NpG6847K2V5TexIW1jcysPHJ6OYNz5AydWlu+J6mzRddqAPGP0BuVB7Jns/QcYTIKt66gT2F8ny3JmOjJGZ39708GJDcyskcZn8D08PPiZWanGHvaaWbP5sNfMGskzPzNrpBrJdA5aHvzMrJRnfmbWSI0951cnjiiL48uSf9Z1/+jtleV1+prF8dWJ3/rK2JfTOpk6MV6PPLq5sryvbyh/oqTKSI0YzMH+p+XPk3Wj7/C0ztbdt1SWHz375LSNLHlru071a1/nebK+1onRzGIS68Sc1nldp6rm93UPSp75mVmpxs78zKzZWt2Z9mYQD35mVsqHvWbWSD7sNbNGCh/2mlkTtXzYa2ZN5CBnM2skn/OrkAUXZ8HJACceekZaZ3Sw+g51dZKmZslZpyPBZJ2g4CypKuQJXqcj+LhOG3WSiGbv8da91UHBkAeYj7byOxTWSWaavcd1Pq9ZIHSdvmZarTzR7JZd35ry82TGPfiZWRM5sYGZNZLP+ZlZI/mcn5k10gwO82Pm3p3EzKasFZrUUoekpZJulzQs6aIe5ZJ0aVG+QdLJHWVrJG2THntLQEkXS7pX0q3FcmbWDw9+ZlaqNcklI6kfuAxYBpwEnC3ppK5qy4DFxbICWNlR9klgaUnzH42IJcWyrqTOz3nwM7NSEZrUUsMpwHBE3BkRY8BVwPKuOsuBK6PtRmCupAXt/sTXgR3TsW+V5/yOUZ6IdIeq4+uGhvKEiz8avSGtMx2y+Lo6yUynIxawTgxXlhCzTlxjlhC1zs3iGcirZOrsb/a61ok3rGWwurjO65rF1x126HFpG9nN0efNfn7axpORzHSyoS6SVtCerU1YHRGrOx4vBO7peDwCnNrVTK86C4HszblQ0jnAzcB7IuKBqsqe+ZlZqcnO/CJidUS8pGNZ3dVkr+lh93WVOnW6rQROAJbQHiQvyfbNV3vNrNR+SGY6Ahzb8XgRcN8TqPMYEbF14ndJnwC+kHXEMz8zKxWTXGpYDyyWdLykQeAsYG1XnbXAOcVV39OAnRFRecg7cU6w8EbgtrK6EzzzM7NS0x3kHBH7JF0IXAP0A2siYqOk84vyVcA64ExgGNgNnDexvaR/Bs4AjpI0Arw/Ii4HPiRpCe0xeDPwjqwvHvzMrNT++G5vEYayrmvdqo7fA7igZNuzS9a/ZbL98OBnZqX83V4za6TxGfz9Ng9+ZlaqsWnsN+/7TtpAFnxaJznk3n3b0zpZ4GgW0AuwYM7paZ3MQ9VX3GsF42YBrpAHsJ4w+MK0jVtan6ksP2ooD2I/IqqT1QLcT/V7XCeYOkusunV3nhA1CwyvI0vOCzA2DYlks8S6dYKtd6Uxv1Pn+/aaWSP57m1m1kiNPew1s2bzzM/MGsn38DCzRvIFDzNrpBk89nnwM7NyvoGRmTWSL3hUeGhvddBvlj0ZoK9vKK3z3P6XVZZ/75A8w3IWKFsnk/Nxh7y4snzT7i+mbUxHMO72vp+ldY6efXJl+RBz0jY27cn3J3vdstcM4L6oDpQe7M9fsxPIA7+/N/bl5HnyAOb5AydWlo+M5l8OyP4usiBoqNfXqfIFDzNrJM/8zKyRPPMzs0ZyqIuZNZJTWplZI/mcn5k1ks/5mVkjNfacXxbPBHD37q9XlteJaauTuHFhEtO0fSiPJXuwVR2TODa+K21jVlTHJD531uvTNhaRJ/f8EZsry0fJ+5rF8WXvHcAzZ78irXPaIdXv8do96yrLIY/jvLfvrrSNH7e+l9bJ4useqZFY964kcW6WeBfg8BqJZDN1Es1O1Qwe+zzzM7NyM3nm55uWm1mpiMktdUhaKul2ScOSLupRLkmXFuUbJJ3cUbZG0jZJt3Vtc4Sk6yTdUfycl/XDg5+ZlWpNcslI6gcuA5YBJwFnSzqpq9oyYHGxrABWdpR9Eljao+mLgOsjYjFwffG4kgc/MyvVisktNZwCDEfEnRExBlwFLO+qsxy4MtpuBOZKWgAQEV8HdvRodzlwRfH7FcAbso548DOzUjHJpYaFwD0dj0eKdZOt0+3oiNgCUPycn3XEg5+ZlZrszE/SCkk3dywruprslSCwe9ysU2fKfLXXzEpN9hseEbEaWF1RZQQ4tuPxInjczbDr1Om2VdKCiNhSHCJvy/rqmZ+ZlZruCx7AemCxpOMlDQJnAWu76qwFzimu+p4G7Jw4pK2wFji3+P1c4PNZRypnfnWSMmZBzHUCmOskZbwpNlSWZ0lVAQ4fOKayvE5f7x2oDrat048dST8gT7x68tBvpW0cqeqEmLNn5a/72Ucdlda5ZFt1sHSdZKab9n2zsnywRuLVnaPDaZ1MncS6WfLWsfE8se6WR7815X6MTkNS3Exrmr/cGxH7JF0IXAP0A2siYqOk84vyVcA64ExgGNgNnDexvaR/Bs4AjpI0Arw/Ii4HPghcLeltwE+AN2V98WGvmZXaHzHOEbGO9gDXuW5Vx+8BXFCy7dkl67cDr5pMPzz4mVmpmfwNDw9+ZlZqfAbntPLgZ2alnNLKzBopPPMzsybyzM/MGqmxM7868XcP7L4trZOZM7j/kzICjLYeqiw/qkaCyftHq2+uXaeNha3j0zqjQ9V9/TF54k6i+ibe91G9LwAfGLkhrbMoSSS7eW8eL5olM900Xh0HCNNzo+8sFhTyz9EjSbLT6TIdf3sZz/zMrJGmO8j5qcSDn5mVihmcyN6Dn5mV8mGvmTVSyzM/M2sin/Mzs0byOT8zayQf9ppZIzV28KuTlDG7O/2evT9N2xjqOzytk5k/cGJaZ3ZUB7juIE9mmgVkZwGwANv7fpbW2bm7OjHnM2e/Im3je6Nfriw/8dAz0jZ2DOXvzd27q5OZ1gk+3kR1EHOd9/fYvkVpna/vWVNZXifIOUt6WycRaabOa5YlVZ0OPuw1s0bap/ED3YX9xoOfmZVq7GGvmTVbzOAwZw9+ZlbKMz8za6SWPPMzswZq+bDXzJqosYNfnWSmWVxUnTi/LEEowAsHl1aW10l2mfWlTmzVeOuRyvI68XdHtp6R1pk96/XVFWqciskSq9ZJMlrnhuNDyT5ncYAAY0n5g/35zeAfVF5nwZzTK8vr3HQ+e13rfJ6zv60nKyFqZn9c8JC0FPg47ZuW/0NEfLCrXEX5mbRvWv7WiLilaltJFwO/D0wE0b63uD9wKc/8zKzUdJ/zk9QPXAa8BhgB1ktaGxE/6Ki2DFhcLKcCK4FTa2z70Yj4cN2+9E15b8xsxmpN8l8NpwDDEXFnRIwBVwHLu+osB66MthuBuZIW1Ny2Ng9+ZlYqGJ/UImmFpJs7lhVdTS4E7ul4PFKsq1Mn2/ZCSRskrZE0L9s3D35mVmqyM7+IWB0RL+lYVnc1qR5P030Gu6xO1bYrgROAJcAW4JJs33zOz8xK7YervSPAsR2PFwHdV5nK6gyWbRsRWydWSvoE8IWsI575mVmpyR721rAeWCzpeEmDwFnA2q46a4Fz1HYasDMitlRtW5wTnPBGIL2vp2d+ZlZqumd+EbFP0oXANbTDVdZExEZJ5xflq4B1tMNchmmHupxXtW3R9IckLaF9GLwZeEfWFw9+ZlZqf8T5FfF367rWrer4PYAL6m5brH/LZPtROfjVCbTMEp4+fehZaRt1kplmiTnrBGRnAchZslOAHapOZFmnjTrJTI8dr07M+R97P5+28dKB6iiAhbOWpG38y67/l9bZu297Zfl0JPccG9815TbqtFMnKH+wf8409KP672be7OenbUxHEuDMOHv3+3McKJ75mVmpVjiZqZk1kPP5mVkj1byCe1Dy4GdmpVrhmZ+ZNZAPe82skcIXPMysiRqbzNTMmi1m8Dk/tYOpe5sztDjNF1wnKDRTJ4NyFkg7cMiRaRtzBhdUlu8cHU7byPpaJ9i6jiwIto7p2N/p0GqNpnWyQOhZA7+UtjEdr9l0yD6rUO/zmsneX4Adj3y3VyaU2p4269mTun3bw3tun9LzPZk88zOzUjN55ufBz8xK+WqvmTWSr/aaWSP5sNfMGsmHvWbWSK3Yd6C7sN948DOzUo0951cvaWN1/NX8gRPTFu7e/fW0zoI5p1eWj7YeStvIElnWiUfL4vzqxD3WSfCaxawdNfTstI2H9nbfF+ax6sRXjrceSetkMXh1XpOsjTqfxTp1stjGOvGEmTqvWVanznuza6w6se708GGvmTWQL3iYWSP5goeZNZQHPzNrIh/2mlkTzeTD3r4D3QEzeyprTXLJSVoq6XZJw5Iu6lEuSZcW5RsknZxtK+kISddJuqP4OS/rhwc/MysXMbklIakfuAxYBpwEnC3ppK5qy4DFxbICWFlj24uA6yNiMXB98biSBz8zKxWT/FfDKcBwRNwZEWPAVcDyrjrLgSuj7UZgrqQFybbLgSuK368A3pB1pPKc31QTIZrZwS1i76TGAEkraM/WJqyOiNUdjxcC93Q8HgFO7WqmV52FybZHR8SWdp9ji6T5WV99wcPMpk0x0K2uqNJrMO2eMpbVqbNtbT7sNbMn0whwbMfjRUD39zDL6lRtu7U4NKb4uS3riAc/M3syrQcWSzpe0iBwFrC2q85a4Jziqu9pwM7ikLZq27XAucXv5wKfzzriw14ze9JExD5JFwLXAP3AmojYKOn8onwVsA44ExgGdgPnVW1bNP1B4GpJbwN+Arwp60vl3dvMzGYqH/aaWSN58DOzRvLgZ2aN5MHPzBrJg5+ZNZIHPzNrJA9+ZtZI/x+f+i9SBt0MLgAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(5, 4))\n", + "importances = clf.feature_importances_\n", + "sns.heatmap(importances.reshape(28, 28), cmap='inferno', square=True)\n", + "ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + ")\n", + "ax.set_title(f\"{rename_clf(clf)} Importance\")\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "# morf = MORF(random_state=0, image_height=28, image_width=28)\n", + "# morf.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# XXX: MORF fits too slowly\n", + "clfs = [\n", + " RandomForestClassifier(random_state=0),\n", + " SPORF(random_state=0),\n", + " MORF(random_state=0, image_height=28, image_width=28)\n", + "]\n", + "for clf in clfs:\n", + " clf.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(1, 3, figsize=(19, 4))\n", + "\n", + "for clf, ax in zip(clfs, ax):\n", + " importances = clf.feature_importances_\n", + " sns.heatmap(importances.reshape(28, 28), cmap='inferno', square=True, ax=ax)\n", + " ax.tick_params(\n", + " axis=\"both\",\n", + " which=\"both\",\n", + " bottom=False,\n", + " left=False,\n", + " labelbottom=False,\n", + " labelleft=False,\n", + " )\n", + " ax.set_title(f\"{rename_clf(clf)} Importance\")\n", + "fig.tight_layout();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "name": "python380jvsc74a57bd039ca1c7a169e56d6a333ccd59f8c6786beb2b8f5c3cc68b80d4610822621472b", + "display_name": "Python 3.8.0 64-bit ('ProgLearn': conda)" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/oblique_forests/morf.py b/oblique_forests/morf.py index a32a121..b2f8fdd 100644 --- a/oblique_forests/morf.py +++ b/oblique_forests/morf.py @@ -1,4 +1,8 @@ +import numpy as np +from joblib import Parallel, delayed from sklearn.ensemble._forest import ForestClassifier +from sklearn.utils.validation import check_is_fitted +from sklearn.utils.fixes import _joblib_parallel_args from .tree.morf_tree import Conv2DObliqueTreeClassifier @@ -175,11 +179,55 @@ def __init__( # self.min_impurity_split = min_impurity_split # s-rerf params - # self.discontiguous_height = discontiguous_height - # self.discontiguous_width = discontiguous_width - # self.image_height = image_height - # self.image_width = image_width - # self.patch_height_max = patch_height_max - # self.patch_height_min = patch_height_min - # self.patch_width_max = patch_width_max - # self.patch_width_min = patch_width_min + self.discontiguous_height = discontiguous_height + self.discontiguous_width = discontiguous_width + self.image_height = image_height + self.image_width = image_width + self.patch_height_max = patch_height_max + self.patch_height_min = patch_height_min + self.patch_width_max = patch_width_max + self.patch_width_min = patch_width_min + + @property + def feature_importances_(self): + """ + Computes the importance of every unique feature used to make a split + in each tree of the forest. + + Parameters + ---------- + normalize : bool, default=True + A boolean to indicate whether to normalize feature importances. + + Returns + ------- + importances : array of shape [n_features] + Array of count-based feature importances. + """ + # TODO: Parallelize this and see if there is an equivalent way to express this better + # 1. Find all unique atoms in the forest + # 2. Compute number of times each atom appears across all trees + forest_projections = [ + node.proj_vec + for tree in self.estimators_ + if tree.tree_.node_count > 0 + for node in tree.tree_.nodes + if node.proj_vec is not None + ] + unique_projections, counts = np.unique( + forest_projections, axis=0, return_counts=True + ) + + if counts.sum() == 0: + return np.zeros(self.n_features_, dtype=np.float64) + + # 3. Count how many times each feature gets nonzero weight in unique projections + importances = np.zeros(self.n_features_) + for proj_vec, count in zip(unique_projections, counts): + importances[np.nonzero(proj_vec)] += count + + # 4. Normalize by number of unique projections + if len(unique_projections) > 0: + importances /= len(unique_projections) + + return importances diff --git a/oblique_forests/sporf.py b/oblique_forests/sporf.py index 557ae6d..515ad4e 100644 --- a/oblique_forests/sporf.py +++ b/oblique_forests/sporf.py @@ -1,6 +1,10 @@ +import numpy as np +from joblib import Parallel, delayed from sklearn.ensemble._forest import ForestClassifier +from sklearn.utils.validation import check_is_fitted +from sklearn.utils.fixes import _joblib_parallel_args -from .tree.oblique_tree import ObliqueTreeClassifier +from oblique_forests.tree.oblique_tree import ObliqueTreeClassifier class ObliqueForestClassifier(ForestClassifier): @@ -95,3 +99,47 @@ def __init__( # self.max_leaf_nodes = max_leaf_nodes # self.min_impurity_decrease = min_impurity_decrease # self.min_impurity_split = min_impurity_split + + @property + def feature_importances_(self): + """ + Computes the importance of every unique feature used to make a split + in each tree of the forest. + + Parameters + ---------- + normalize : bool, default=True + A boolean to indicate whether to normalize feature importances. + + Returns + ------- + importances : array of shape [n_features] + Array of count-based feature importances. + """ + # TODO: Parallelize this and see if there is an equivalent way to express this better + # 1. Find all unique atoms in the forest + # 2. Compute number of times each atom appears across all trees + forest_projections = [ + node.proj_vec + for tree in self.estimators_ + if tree.tree_.node_count > 0 + for node in tree.tree_.nodes + if node.proj_vec is not None + ] + unique_projections, counts = np.unique( + forest_projections, axis=0, return_counts=True + ) + + if counts.sum() == 0: + return np.zeros((self.n_features_), dtype=np.float64) + + # 3. Count how many times each feature gets nonzero weight in unique projections + importances = np.zeros((self.n_features_), dtype=np.float64) + for proj_vec, count in zip(unique_projections, counts): + importances[np.nonzero(proj_vec)] += count + + # 4. Normalize by number of unique projections + if len(unique_projections) > 0: + importances /= len(unique_projections) + + return importances \ No newline at end of file diff --git a/oblique_forests/tree/morf_tree.py b/oblique_forests/tree/morf_tree.py index 936cb9a..4dc31aa 100644 --- a/oblique_forests/tree/morf_tree.py +++ b/oblique_forests/tree/morf_tree.py @@ -265,7 +265,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): splitter = self._set_splitter(X, y) # create the Oblique tree - self.tree = ObliqueTree( + self.tree_ = ObliqueTree( splitter, self.min_samples_split, self.min_samples_leaf, @@ -273,7 +273,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): self.min_impurity_split, self.min_impurity_decrease, ) - self.tree.build() + self.tree_.build() return self diff --git a/oblique_forests/tree/oblique_tree.py b/oblique_forests/tree/oblique_tree.py index 16bf622..d439cab 100644 --- a/oblique_forests/tree/oblique_tree.py +++ b/oblique_forests/tree/oblique_tree.py @@ -3,6 +3,7 @@ from joblib import Parallel, delayed from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array, check_is_fitted, check_X_y +from sklearn.utils.fixes import _joblib_parallel_args from ._split import BaseObliqueSplitter from .oblique_base import BaseManifoldSplitter, Node, SplitInfo, StackRecord @@ -528,6 +529,37 @@ def predict(self, X, check_input=True): return predictions + def compute_feature_importances(self): + """ + Computes the importance of each feature (aka variable). + + Parameters + ---------- + unique_projections : ndarray of shape (n_proj, n_features) + Array of unique sampling projection vectors. + + Returns + ------- + importances : ndarray of shape (n_features,) + Normalized importance of each feature of the data matrix. + """ + projections = [ + node.proj_vec for node in self.nodes if node.proj_vec is not None + ] + unique_projections, counts = np.unique(projections, axis=0, return_counts=True) + + if counts.sum() == 0: + return np.zeros((self.splitter.n_features,)) + + importances = np.zeros((self.splitter.n_features,)) + for proj_vec, count in zip(unique_projections, counts): + importances[np.nonzero(proj_vec)] += count + + if len(unique_projections) > 0: + importances /= len(unique_projections) + + return importances + class ObliqueTreeClassifier(BaseEstimator): """ @@ -600,6 +632,7 @@ def __init__( # Max features self.max_features = max_features + self.n_jobs = n_jobs self.n_classes = None self.n_jobs = n_jobs @@ -640,7 +673,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): tree_func = self._tree_class() # instantiate the tree and build it - self.tree = tree_func( + self.tree_ = tree_func( splitter, self.min_samples_split, self.min_samples_leaf, @@ -648,7 +681,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): self.min_impurity_split, self.min_impurity_decrease, ) - self.tree.build() + self.tree_.build() return self @@ -666,7 +699,7 @@ def apply(self, X): pred_nodes : array of shape[n_samples] The indices for each test sample's final node in the oblique tree. """ - pred_nodes = self.tree.predict(X).astype(int) + pred_nodes = self.tree_.predict(X).astype(int) return pred_nodes def predict(self, X, check_input=True): @@ -689,7 +722,7 @@ def predict(self, X, check_input=True): pred_nodes = self.apply(X) for k in range(len(pred_nodes)): id = pred_nodes[k] - preds[k] = self.tree.nodes[id].label + preds[k] = self.tree_.nodes[id].label return preds @@ -713,7 +746,7 @@ def predict_proba(self, X, check_input=True): pred_nodes = self.apply(X) for k in range(len(preds)): id = pred_nodes[k] - preds[k] = self.tree.nodes[id].proba + preds[k] = self.tree_.nodes[id].proba return preds @@ -737,3 +770,53 @@ def predict_log_proba(self, X, check_input=True): # TODO: Actually do this function def _validate_X_predict(self, X, check_input=True): return X + + @property + def feature_importances_(self): + """ + Return the feature importances. + The importance of a feature is computed as the number of times it + is used in a projection across all split nodes + + Returns + ------- + feature_importances_ : ndarray of shape (n_features,) + Array of count-based feature importances. + """ + check_is_fitted(self) + + return self.tree_.compute_feature_importances() + + def compute_projection_counts(self, unique_projections=None): + """ + Counts the number of times each unique projection in the tree appears. + + Parameters + ---------- + unique_projections : ndarray of shape (n_proj,), optional + Array of unique projections to count, by default None + + Returns + ------- + projection_counts : ndarray of shape (n_proj,) + Counts of each unique projection used in this tree. + """ + check_is_fitted(self) + + if unique_projections is None: + projections = [ + node.proj_vec + for node in self.tree_.nodes + if node.proj_vec is not None + ] + unique_projections, counts = np.unique(projections, axis=0, return_counts=True) + return counts, unique_projections + + # TODO: see if joblib will speed up at all for this for loop + n_proj = len(unique_projections) + counts = np.zeros(n_proj) + for node in self.tree_.nodes: + projection_idx = np.where((unique_projections == node.proj_vec).all(axis=1)) + counts[projection_idx] += 1 + + return counts, unique_projections \ No newline at end of file diff --git a/oblique_forests/tree/tests/test_morf_tree.py b/oblique_forests/tree/tests/test_morf_tree.py index 9cbd12c..32c1bf0 100644 --- a/oblique_forests/tree/tests/test_morf_tree.py +++ b/oblique_forests/tree/tests/test_morf_tree.py @@ -10,6 +10,10 @@ from sklearn.utils.validation import check_random_state from oblique_forests.tree.morf_split import Conv2DSplitter +from oblique_forests.sporf import ObliqueForestClassifier as SPORF +from oblique_forests.tree.oblique_tree import ObliqueTreeClassifier as OTC +from oblique_forests.morf import Conv2DObliqueForestClassifier as MORF +from oblique_forests.tree.morf_tree import Conv2DObliqueTreeClassifier # toy sample X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] @@ -43,7 +47,7 @@ def test_convolutional_splitter(): y[:25] = 0 splitter = Conv2DSplitter( - X, + X.reshape(n, -1), y, max_features=1, feature_combinations=1.5, @@ -52,4 +56,60 @@ def test_convolutional_splitter(): image_width=d, patch_height_max=2, patch_height_min=2, + patch_width_max=3, + patch_width_min=3, ) + + splitter.sample_proj_mat(splitter.indices) + + +if __name__ == "__main__": + + test_convolutional_splitter() + + # from sklearn.datasets import fetch_openml + from keras.datasets import mnist + import time + + (X_train, y_train), (X_test, y_test) = mnist.load_data() + + # Get 100 samples of 3s and 5s + num = 100 + threes = np.where(y_train == 3)[0][:num] + fives = np.where(y_train == 5)[0][:num] + train_idx = np.concatenate((threes, fives)) + + # Subset train data + Xtrain = X_train[train_idx] + ytrain = y_train[train_idx] + + # Apply random shuffling + permuted_idx = np.random.permutation(len(train_idx)) + Xtrain = Xtrain[permuted_idx] + ytrain = ytrain[permuted_idx] + + # Subset test data + test_idx = np.where(y_test == 3)[0] + Xtest = X_test[test_idx] + ytest = y_test[test_idx] + + print(f"-----{2 * num} samples") + + clf = OTC(random_state=0) + start = time.time() + clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain) + elapsed = time.time() - start + print(elapsed) + print(f"SPORF Tree: {elapsed} sec") + + clf = Conv2DObliqueTreeClassifier(image_height=28, image_width=28, random_state=0) + start = time.time() + clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain) + elapsed = time.time() - start + print(f"MORF Tree: {elapsed} sec") + + clf = SPORF(n_estimators=100, random_state=0) + start = time.time() + clf.fit(Xtrain.reshape(Xtrain.shape[0], -1), ytrain) + elapsed = time.time() - start + print(f"SPORF: {elapsed} sec") \ No newline at end of file diff --git a/oblique_forests/tree/tests/test_splitter.py b/oblique_forests/tree/tests/test_splitter.py index 4ef97e9..30b2767 100644 --- a/oblique_forests/tree/tests/test_splitter.py +++ b/oblique_forests/tree/tests/test_splitter.py @@ -34,16 +34,16 @@ def test_argmin(self): assert 4 == j def test_matmul(self): - + b = BOS() A = np.zeros((3, 3), dtype=np.float64) B = np.ones((3, 3), dtype=np.float64) - + for i in range(3): for j in range(3): - A[i, j] = 3*i + j + 1 - + A[i, j] = 3 * i + j + 1 + res = b.test_matmul(A, B) C = np.ones((3, 3), dtype=np.float64) @@ -53,7 +53,6 @@ def test_matmul(self): assert_allclose(C, res) - def test_impurity(self): """ diff --git a/oblique_forests/tree/tests/test_sporf.py b/oblique_forests/tree/tests/test_sporf.py new file mode 100644 index 0000000..921fecd --- /dev/null +++ b/oblique_forests/tree/tests/test_sporf.py @@ -0,0 +1,561 @@ +import numpy as np +from numpy.testing import ( + assert_almost_equal, + assert_allclose, + assert_array_equal, + assert_array_almost_equal, +) + +import pytest +from oblique_forests.tree.oblique_tree import ObliqueTreeClassifier as OTC +from oblique_forests.sporf import ObliqueForestClassifier as OFC + +from sklearn import datasets +from sklearn.metrics import accuracy_score + +""" +Sklearn test_tree.py stuff +""" +X_small = np.array( + [ + [ + 0, + 0, + 4, + 0, + 0, + 0, + 1, + -14, + 0, + -4, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 5, + 3, + 0, + -4, + 0, + 0, + 1, + -5, + 0.2, + 0, + 4, + 1, + ], + [ + -1, + -1, + 0, + 0, + -4.5, + 0, + 0, + 2.1, + 1, + 0, + 0, + -4.5, + 0, + 1, + ], + [ + -1, + -1, + 0, + -1.2, + 0, + 0, + 0, + 0, + 0, + 0, + 0.2, + 0, + 0, + 1, + ], + [ + -1, + -1, + 0, + 0, + 0, + 0, + 0, + 3, + 0, + 0, + 0, + 0, + 0, + 1, + ], + [ + -1, + -2, + 0, + 4, + -3, + 10, + 4, + 0, + -3.2, + 0, + 4, + 3, + -4, + 1, + ], + [ + 2.11, + 0, + -6, + -0.5, + 0, + 11, + 0, + 0, + -3.2, + 6, + 0.5, + 0, + -3, + 1, + ], + [ + 2.11, + 0, + -6, + -0.5, + 0, + 11, + 0, + 0, + -3.2, + 6, + 0, + 0, + -2, + 1, + ], + [ + 2.11, + 8, + -6, + -0.5, + 0, + 11, + 0, + 0, + -3.2, + 6, + 0, + 0, + -2, + 1, + ], + [ + 2.11, + 8, + -6, + -0.5, + 0, + 11, + 0, + 0, + -3.2, + 6, + 0.5, + 0, + -1, + 0, + ], + [ + 2, + 8, + 5, + 1, + 0.5, + -4, + 10, + 0, + 1, + -5, + 3, + 0, + 2, + 0, + ], + [ + 2, + 0, + 1, + 1, + 1, + -1, + 1, + 0, + 0, + -2, + 3, + 0, + 1, + 0, + ], + [ + 2, + 0, + 1, + 2, + 3, + -1, + 10, + 2, + 0, + -1, + 1, + 2, + 2, + 0, + ], + [ + 1, + 1, + 0, + 2, + 2, + -1, + 1, + 2, + 0, + -5, + 1, + 2, + 3, + 0, + ], + [ + 3, + 1, + 0, + 3, + 0, + -4, + 10, + 0, + 1, + -5, + 3, + 0, + 3, + 1, + ], + [ + 2.11, + 8, + -6, + -0.5, + 0, + 1, + 0, + 0, + -3.2, + 6, + 0.5, + 0, + -3, + 1, + ], + [ + 2.11, + 8, + -6, + -0.5, + 0, + 1, + 0, + 0, + -3.2, + 6, + 1.5, + 1, + -1, + -1, + ], + [ + 2.11, + 8, + -6, + -0.5, + 0, + 10, + 0, + 0, + -3.2, + 6, + 0.5, + 0, + -1, + -1, + ], + [ + 2, + 0, + 5, + 1, + 0.5, + -2, + 10, + 0, + 1, + -5, + 3, + 1, + 0, + -1, + ], + [ + 2, + 0, + 1, + 1, + 1, + -2, + 1, + 0, + 0, + -2, + 0, + 0, + 0, + 1, + ], + [ + 2, + 1, + 1, + 1, + 2, + -1, + 10, + 2, + 0, + -1, + 0, + 2, + 1, + 1, + ], + [ + 1, + 1, + 0, + 0, + 1, + -3, + 1, + 2, + 0, + -5, + 1, + 2, + 1, + 1, + ], + [ + 3, + 1, + 0, + 1, + 0, + -4, + 1, + 0, + 1, + -2, + 0, + 0, + 1, + 0, + ], + ] +) + +y_small = [1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0] +y_small_reg = [ + 1.0, + 2.1, + 1.2, + 0.05, + 10, + 2.4, + 3.1, + 1.01, + 0.01, + 2.98, + 3.1, + 1.1, + 0.0, + 1.2, + 2, + 11, + 0, + 0, + 4.5, + 0.201, + 1.06, + 0.9, + 0, +] + +# toy sample +X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] +y = [-1, -1, -1, 1, 1, 1] +T = [[-1, -1], [2, 2], [3, 2]] +true_result = [-1, 1, 1] + +# also load the iris dataset +# and randomly permute it +iris = datasets.load_iris() +rng = np.random.RandomState(1) +perm = rng.permutation(iris.target.size) +iris.data = iris.data[perm] +iris.target = iris.target[perm] + +# also load the diabetes dataset +# and randomly permute it +diabetes = datasets.load_diabetes() +perm = rng.permutation(diabetes.target.size) +diabetes.data = diabetes.data[perm] +diabetes.target = diabetes.target[perm] + +# Ignoring digits dataset cause it takes a minute + + +def test_classification_toy(): + # Check classification on a toy dataset. + clf = OTC(random_state=0) + clf.fit(X, y) + assert_array_equal(clf.predict(T), true_result) + + """ + # Ignoring because max_features implemented differently + clf = OTC(max_features=1, random_state=0) + clf.fit(X, y) + assert_array_equal(clf.predict(T), true_result) + """ + + +def test_xor(): + + # Check on a XOR problem + y = np.zeros((10, 10)) + y[:5, :5] = 1 + y[5:, 5:] = 1 + + gridx, gridy = np.indices(y.shape) + + X = np.vstack([gridx.ravel(), gridy.ravel()]).T + y = y.ravel() + + # Changing feature parameters from default 1.5 to 2 makes this test pass. + clf = OTC(random_state=0, feature_combinations=2) + clf.fit(X, y) + + assert accuracy_score(clf.predict(X), y) == 1 + + +def test_iris(): + + clf = OTC(random_state=0) + + clf.fit(iris.data, iris.target) + score = accuracy_score(clf.predict(iris.data), iris.target) + assert score > 0.9 + + +def test_diabetes(): + + """ + Diabetes should overfit with MSE = 0 for normal trees. + idk if this applies to sporf, so this is just a placeholder + to check consistency like iris. + """ + + clf = OTC(random_state=0) + + clf.fit(diabetes.data, diabetes.target) + score = accuracy_score(clf.predict(diabetes.data), diabetes.target) + assert score > 0.9 + + +def test_probability(): + + clf = OTC(random_state=0) + + clf.fit(iris.data, iris.target) + p = clf.predict_proba(iris.data) + + assert_array_almost_equal(np.sum(p, 1), np.ones(iris.data.shape[0])) + + assert_array_equal(np.argmax(p, 1), clf.predict(iris.data)) + + assert_almost_equal( + clf.predict_proba(iris.data), np.exp(clf.predict_log_proba(iris.data)) + ) + + +def test_pure_set(): + + clf = OTC(random_state=0) + + X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] + y = [1, 1, 1, 1, 1, 1] + + clf.fit(X, y) + assert_array_equal(clf.predict(X), y) + + +def test_importances(): + # Check variable importances. + X, y = datasets.make_classification( + n_samples=5000, + n_features=10, + n_informative=3, + n_redundant=0, + n_repeated=0, + shuffle=False, + random_state=0, + ) + + clf = OTC(random_state=0) + + clf.fit(X, y) + importances = clf.feature_importances_ + n_important = np.sum(importances > 0.4) + + assert importances.shape[0] == 10, "Failed with SPORF" + assert n_important == 3, "Failed with SPORF" + + # Check on iris that importances are the same for all builders + clf = OTC(random_state=0) + clf.fit(iris.data, iris.target) + clf2 = OTC(random_state=0, max_depth=len(iris.data)) + clf2.fit(iris.data, iris.target) + + assert_array_equal(clf.feature_importances_, clf2.feature_importances_) + + +def test_importances_raises(): + # Check if variable importance before fit raises ValueError. + clf = OTC(random_state=0) + with pytest.raises(ValueError): + getattr(clf, "feature_importances_") \ No newline at end of file