Skip to content

Commit

Permalink
swap %pylab to explict imports and %matplotlib inline, add model coeffs
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmontag authored and jeffakolb committed Jun 5, 2015
1 parent 21008d4 commit fa33560
Showing 1 changed file with 35 additions and 16 deletions.
51 changes: 35 additions & 16 deletions ml-basis-expansion-101/Basis Expansion.ipynb
@@ -1,7 +1,7 @@
{
"metadata": {
"name": "",
"signature": "sha256:3549a4752602045c42ea70e7f5579fd904ef56de95d9ab983974fef1b10ea8f9"
"signature": "sha256:b8a3206d7df375d88d2ccbc5123abca03ea42d8513444b68a23f253f854f9dfb"
},
"nbformat": 3,
"nbformat_minor": 0,
Expand All @@ -12,7 +12,9 @@
"cell_type": "code",
"collapsed": false,
"input": [
"%pylab inline"
"import numpy as np\n",
"import matplotlib.pyplot as plt \n",
"%matplotlib inline"
],
"language": "python",
"metadata": {},
Expand Down Expand Up @@ -95,10 +97,9 @@
"input": [
"def f(n=100, sigma=1, mu=1):\n",
" # random normal vector\n",
" return sigma * random.randn(n) + mu\n",
" return sigma * np.random.randn(n) + mu\n",
" \n",
"plt.scatter(f(),f())\n",
"plt.show()"
"plt.scatter(f(),f())"
],
"language": "python",
"metadata": {},
Expand Down Expand Up @@ -235,8 +236,7 @@
"plt.plot(z,sigmoid(z,2,3), color=\"cyan\", label=\"compress, shift\")\n",
"plt.plot(z,sigmoid(z,-2), color=\"yellow\", label=\"shift\")\n",
"plt.plot(z,sigmoid(z,0,0.5), color=\"green\", label=\"expand\")\n",
"plt.legend(loc=2)\n",
"plt.show()"
"plt.legend(loc=2)"
],
"language": "python",
"metadata": {},
Expand Down Expand Up @@ -332,7 +332,7 @@
"\n",
"gx, gy = xx.ravel(), yy.ravel()\n",
"\n",
"Z = ext_m.decision_function(array([gx, gy, gx*gx, gy*gy, gx*gy]).transpose())\n",
"Z = ext_m.decision_function(np.array([gx, gy, gx*gx, gy*gy, gx*gy]).transpose())\n",
"Z = Z.reshape(xx.shape)\n",
"plt.imshow(Z, interpolation='nearest',\n",
" extent=(xx.min(), xx.max(), yy.min(), yy.max()), aspect='auto',\n",
Expand Down Expand Up @@ -374,8 +374,7 @@
"def b1(x, pars):\n",
" return np.exp(-(x-pars[0])*(x-pars[0])/(pars[1]*pars[1]))\n",
"\n",
"plt.plot(x, b1(x, (0,1)), color=\"green\")\n",
"plt.show()"
"plt.plot(x, b1(x, (0,1)), color=\"green\")"
],
"language": "python",
"metadata": {},
Expand All @@ -398,8 +397,7 @@
" return np.sum(c(x,pars))\n",
"\n",
"plt.plot(x,target(x) , color=\"red\")\n",
"#plt.plot(x, 10*b1(x, (1,1)), color=\"green\")\n",
"plt.show()"
"#plt.plot(x, 10*b1(x, (1,1)), color=\"green\")"
],
"language": "python",
"metadata": {},
Expand All @@ -420,8 +418,7 @@
"a = minimize(cost,[1,1,3])\n",
"print(a.x)\n",
"plt.plot(x,target(x), color=\"red\")\n",
"plt.plot(x, a.x[2]*b1(x, a.x), color=\"green\")\n",
"plt.show()"
"plt.plot(x, a.x[2]*b1(x, a.x), color=\"green\")"
],
"language": "python",
"metadata": {},
Expand Down Expand Up @@ -480,8 +477,7 @@
" 1,2,1])\n",
"print(a.x)\n",
"\n",
"plt.scatter(x, gx(x, a.x), color=\"green\")\n",
"plt.show()"
"plt.scatter(x, gx(x, a.x), color=\"green\")"
],
"language": "python",
"metadata": {},
Expand Down Expand Up @@ -628,6 +624,29 @@
"metadata": {},
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Given this particular set of basis functions, we can also reach into the estimator above and find out what the relative weights were in fitting the model to this data:"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# nb: must be same order as the X array that went into the model fit (ext_X)\n",
"feature_names = ['v11', 'v_11', 'v1_1', 'v_1_1', 'v22', 'v_22', 'v2_2', 'v_2_2']\n",
"\n",
"x = np.arange(len(feature_names))\n",
"plt.bar(x, ext_m.coef_.ravel() )\n",
"_ = plt.xticks(x + 0.5, feature_names, rotation=30)\n",
"_ = plt.ylabel(\"model coefficients\")"
],
"language": "python",
"metadata": {},
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit fa33560

Please sign in to comment.