Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
1234 lines (1233 sloc) 269 KB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Backpropagation\n",
"## Instructions\n",
"\n",
"In this assignment, you will train a neural network to draw a curve.\n",
"The curve takes one input variable, the amount travelled along the curve from 0 to 1, and returns 2 outputs, the 2D coordinates of the position of points on the curve.\n",
"\n",
"To help capture the complexity of the curve, we shall use two hidden layers in our network with 6 and 7 neurons respectively.\n",
"\n",
"![Neural network with 2 hidden layers. There is 1 nodes in the zeroth layer, 6 in the first, 7 in the second, and 2 in the third.](readonly/bigNet.png \"The structure of the network we will consider in this assignment.\")\n",
"\n",
"You will be asked to complete functions that calculate the Jacobian of the cost function, with respect to the weights and biases of the network. Your code will form part of a stochastic steepest descent algorithm that will train your network.\n",
"\n",
"### Matrices in Python\n",
"Recall from assignments in the previous course in this specialisation that matrices can be multiplied together in two ways.\n",
"\n",
"Element wise: when two matrices have the same dimensions, matrix elements in the same position in each matrix are multiplied together\n",
"In python this uses the '$*$' operator.\n",
"```python\n",
"A = B * C\n",
"```\n",
"\n",
"Matrix multiplication: when the number of columns in the first matrix is the same as the number of rows in the second.\n",
"In python this uses the '$@$' operator\n",
"```python\n",
"A = B @ C\n",
"```\n",
"\n",
"This assignment will not test which ones to use where, but it will use both in the starter code presented to you.\n",
"There is no need to change these or worry about their specifics.\n",
"\n",
"### How to submit\n",
"To complete the assignment, edit the code in the cells below where you are told to do so.\n",
"Once you are finished and happy with it, press the **Submit Assignment** button at the top of this worksheet.\n",
"Test your code using the cells at the bottom of the notebook before you submit.\n",
"\n",
"Please don't change any of the function names, as these will be checked by the grading script."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feed forward\n",
"\n",
"In the following cell, we will define functions to set up our neural network.\n",
"Namely an activation function, $\\sigma(z)$, it's derivative, $\\sigma'(z)$, a function to initialise weights and biases, and a function that calculates each activation of the network using feed-forward.\n",
"\n",
"Recall the feed-forward equations,\n",
"$$ \\mathbf{a}^{(n)} = \\sigma(\\mathbf{z}^{(n)}) $$\n",
"$$ \\mathbf{z}^{(n)} = \\mathbf{W}^{(n)}\\mathbf{a}^{(n-1)} + \\mathbf{b}^{(n)} $$\n",
"\n",
"In this worksheet we will use the *logistic function* as our activation function, rather than the more familiar $\\tanh$.\n",
"$$ \\sigma(\\mathbf{z}) = \\frac{1}{1 + \\exp(-\\mathbf{z})} $$\n",
"\n",
"There is no need to edit the following cells.\n",
"They do not form part of the assessment.\n",
"You may wish to study how it works though.\n",
"\n",
"**Run the following cells before continuing.**"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%run \"readonly/BackpropModule.ipynb\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# PACKAGE\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# PACKAGE\n",
"# First load the worksheet dependencies.\n",
"# Here is the activation function and its derivative.\n",
"sigma = lambda z : 1 / (1 + np.exp(-z))\n",
"d_sigma = lambda z : np.cosh(z/2)**(-2) / 4\n",
"\n",
"# This function initialises the network with it's structure, it also resets any training already done.\n",
"def reset_network (n1 = 6, n2 = 7, random=np.random) :\n",
" global W1, W2, W3, b1, b2, b3\n",
" W1 = random.randn(n1, 1) / 2\n",
" W2 = random.randn(n2, n1) / 2\n",
" W3 = random.randn(2, n2) / 2\n",
" b1 = random.randn(n1, 1) / 2\n",
" b2 = random.randn(n2, 1) / 2\n",
" b3 = random.randn(2, 1) / 2\n",
"\n",
"# This function feeds forward each activation to the next layer. It returns all weighted sums and activations.\n",
"def network_function(a0) :\n",
" z1 = W1 @ a0 + b1\n",
" a1 = sigma(z1)\n",
" z2 = W2 @ a1 + b2\n",
" a2 = sigma(z2)\n",
" z3 = W3 @ a2 + b3\n",
" a3 = sigma(z3)\n",
" return a0, z1, a1, z2, a2, z3, a3\n",
"\n",
"# This is the cost function of a neural network with respect to a training set.\n",
"def cost(x, y) :\n",
" return np.linalg.norm(network_function(x)[-1] - y)**2 / x.size"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Backpropagation\n",
"\n",
"In the next cells, you will be asked to complete functions for the Jacobian of the cost function with respect to the weights and biases.\n",
"We will start with layer 3, which is the easiest, and work backwards through the layers.\n",
"\n",
"We'll define our Jacobians as,\n",
"$$ \\mathbf{J}_{\\mathbf{W}^{(3)}} = \\frac{\\partial C}{\\partial \\mathbf{W}^{(3)}} $$\n",
"$$ \\mathbf{J}_{\\mathbf{b}^{(3)}} = \\frac{\\partial C}{\\partial \\mathbf{b}^{(3)}} $$\n",
"etc., where $C$ is the average cost function over the training set. i.e.,\n",
"$$ C = \\frac{1}{N}\\sum_k C_k $$\n",
"You calculated the following in the practice quizzes,\n",
"$$ \\frac{\\partial C}{\\partial \\mathbf{W}^{(3)}} =\n",
" \\frac{\\partial C}{\\partial \\mathbf{a}^{(3)}}\n",
" \\frac{\\partial \\mathbf{a}^{(3)}}{\\partial \\mathbf{z}^{(3)}}\n",
" \\frac{\\partial \\mathbf{z}^{(3)}}{\\partial \\mathbf{W}^{(3)}}\n",
" ,$$\n",
"for the weight, and similarly for the bias,\n",
"$$ \\frac{\\partial C}{\\partial \\mathbf{b}^{(3)}} =\n",
" \\frac{\\partial C}{\\partial \\mathbf{a}^{(3)}}\n",
" \\frac{\\partial \\mathbf{a}^{(3)}}{\\partial \\mathbf{z}^{(3)}}\n",
" \\frac{\\partial \\mathbf{z}^{(3)}}{\\partial \\mathbf{b}^{(3)}}\n",
" .$$\n",
"With the partial derivatives taking the form,\n",
"$$ \\frac{\\partial C}{\\partial \\mathbf{a}^{(3)}} = 2(\\mathbf{a}^{(3)} - \\mathbf{y}) $$\n",
"$$ \\frac{\\partial \\mathbf{a}^{(3)}}{\\partial \\mathbf{z}^{(3)}} = \\sigma'({z}^{(3)})$$\n",
"$$ \\frac{\\partial \\mathbf{z}^{(3)}}{\\partial \\mathbf{W}^{(3)}} = \\mathbf{a}^{(2)}$$\n",
"$$ \\frac{\\partial \\mathbf{z}^{(3)}}{\\partial \\mathbf{b}^{(3)}} = 1$$\n",
"\n",
"We'll do the J_W3 ($\\mathbf{J}_{\\mathbf{W}^{(3)}}$) function for you, so you can see how it works.\n",
"You should then be able to adapt the J_b3 function, with help, yourself."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# GRADED FUNCTION\n",
"\n",
"# Jacobian for the third layer weights. There is no need to edit this function.\n",
"def J_W3 (x, y) :\n",
" # First get all the activations and weighted sums at each layer of the network.\n",
" a0, z1, a1, z2, a2, z3, a3 = network_function(x)\n",
" # We'll use the variable J to store parts of our result as we go along, updating it in each line.\n",
" # Firstly, we calculate dC/da3, using the expressions above.\n",
" J = 2 * (a3 - y)\n",
" # Next multiply the result we've calculated by the derivative of sigma, evaluated at z3.\n",
" J = J * d_sigma(z3)\n",
" # Then we take the dot product (along the axis that holds the training examples) with the final partial derivative,\n",
" # i.e. dz3/dW3 = a2\n",
" # and divide by the number of training examples, for the average over all training examples.\n",
" J = J @ a2.T / x.size\n",
" # Finally return the result out of the function.\n",
" return J\n",
"\n",
"# In this function, you will implement the jacobian for the bias.\n",
"# As you will see from the partial derivatives, only the last partial derivative is different.\n",
"# The first two partial derivatives are the same as previously.\n",
"# ===YOU SHOULD EDIT THIS FUNCTION===\n",
"def J_b3 (x, y) :\n",
" # As last time, we'll first set up the activations.\n",
" a0, z1, a1, z2, a2, z3, a3 = network_function(x)\n",
" # Next you should implement the first two partial derivatives of the Jacobian.\n",
" # ===COPY TWO LINES FROM THE PREVIOUS FUNCTION TO SET UP THE FIRST TWO JACOBIAN TERMS===\n",
" J = 2 * (a3 - y)\n",
" J = J * d_sigma(z3)\n",
" # For the final line, we don't need to multiply by dz3/db3, because that is multiplying by 1.\n",
" # We still need to sum over all training examples however.\n",
" # There is no need to edit this line.\n",
" J = np.sum(J, axis=1, keepdims=True) / x.size\n",
" return J"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll next do the Jacobian for the Layer 2. The partial derivatives for this are,\n",
"$$ \\frac{\\partial C}{\\partial \\mathbf{W}^{(2)}} =\n",
" \\frac{\\partial C}{\\partial \\mathbf{a}^{(3)}}\n",
" \\left(\n",
" \\frac{\\partial \\mathbf{a}^{(3)}}{\\partial \\mathbf{a}^{(2)}}\n",
" \\right)\n",
" \\frac{\\partial \\mathbf{a}^{(2)}}{\\partial \\mathbf{z}^{(2)}}\n",
" \\frac{\\partial \\mathbf{z}^{(2)}}{\\partial \\mathbf{W}^{(2)}}\n",
" ,$$\n",
"$$ \\frac{\\partial C}{\\partial \\mathbf{b}^{(2)}} =\n",
" \\frac{\\partial C}{\\partial \\mathbf{a}^{(3)}}\n",
" \\left(\n",
" \\frac{\\partial \\mathbf{a}^{(3)}}{\\partial \\mathbf{a}^{(2)}}\n",
" \\right)\n",
" \\frac{\\partial \\mathbf{a}^{(2)}}{\\partial \\mathbf{z}^{(2)}}\n",
" \\frac{\\partial \\mathbf{z}^{(2)}}{\\partial \\mathbf{b}^{(2)}}\n",
" .$$\n",
"This is very similar to the previous layer, with two exceptions:\n",
"* There is a new partial derivative, in parentheses, $\\frac{\\partial \\mathbf{a}^{(3)}}{\\partial \\mathbf{a}^{(2)}}$\n",
"* The terms after the parentheses are now one layer lower.\n",
"\n",
"Recall the new partial derivative takes the following form,\n",
"$$ \\frac{\\partial \\mathbf{a}^{(3)}}{\\partial \\mathbf{a}^{(2)}} =\n",
" \\frac{\\partial \\mathbf{a}^{(3)}}{\\partial \\mathbf{z}^{(3)}}\n",
" \\frac{\\partial \\mathbf{z}^{(3)}}{\\partial \\mathbf{a}^{(2)}} =\n",
" \\sigma'(\\mathbf{z}^{(3)})\n",
" \\mathbf{W}^{(3)}\n",
"$$\n",
"\n",
"To show how this changes things, we will implement the Jacobian for the weight again and ask you to implement it for the bias."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# GRADED FUNCTION\n",
"\n",
"# Compare this function to J_W3 to see how it changes.\n",
"# There is no need to edit this function.\n",
"def J_W2 (x, y) :\n",
" #The first two lines are identical to in J_W3.\n",
" a0, z1, a1, z2, a2, z3, a3 = network_function(x) \n",
" J = 2 * (a3 - y)\n",
" # the next two lines implement da3/da2, first σ' and then W3.\n",
" J = J * d_sigma(z3)\n",
" J = (J.T @ W3).T\n",
" # then the final lines are the same as in J_W3 but with the layer number bumped down.\n",
" J = J * d_sigma(z2)\n",
" J = J @ a1.T / x.size\n",
" return J\n",
"\n",
"# As previously, fill in all the incomplete lines.\n",
"# ===YOU SHOULD EDIT THIS FUNCTION===\n",
"def J_b2 (x, y) :\n",
" a0, z1, a1, z2, a2, z3, a3 = network_function(x)\n",
" J = 2 * (a3 - y)\n",
" J = J * d_sigma(z3)\n",
" J = (J.T @ W3).T\n",
" J = J * d_sigma(z2)\n",
" J = np.sum(J, axis=1, keepdims=True) / x.size\n",
" return J"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Layer 1 is very similar to Layer 2, but with an addition partial derivative term.\n",
"$$ \\frac{\\partial C}{\\partial \\mathbf{W}^{(1)}} =\n",
" \\frac{\\partial C}{\\partial \\mathbf{a}^{(3)}}\n",
" \\left(\n",
" \\frac{\\partial \\mathbf{a}^{(3)}}{\\partial \\mathbf{a}^{(2)}}\n",
" \\frac{\\partial \\mathbf{a}^{(2)}}{\\partial \\mathbf{a}^{(1)}}\n",
" \\right)\n",
" \\frac{\\partial \\mathbf{a}^{(1)}}{\\partial \\mathbf{z}^{(1)}}\n",
" \\frac{\\partial \\mathbf{z}^{(1)}}{\\partial \\mathbf{W}^{(1)}}\n",
" ,$$\n",
"$$ \\frac{\\partial C}{\\partial \\mathbf{b}^{(1)}} =\n",
" \\frac{\\partial C}{\\partial \\mathbf{a}^{(3)}}\n",
" \\left(\n",
" \\frac{\\partial \\mathbf{a}^{(3)}}{\\partial \\mathbf{a}^{(2)}}\n",
" \\frac{\\partial \\mathbf{a}^{(2)}}{\\partial \\mathbf{a}^{(1)}}\n",
" \\right)\n",
" \\frac{\\partial \\mathbf{a}^{(1)}}{\\partial \\mathbf{z}^{(1)}}\n",
" \\frac{\\partial \\mathbf{z}^{(1)}}{\\partial \\mathbf{b}^{(1)}}\n",
" .$$\n",
"You should be able to adapt lines from the previous cells to complete **both** the weight and bias Jacobian."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# GRADED FUNCTION\n",
"\n",
"# Fill in all incomplete lines.\n",
"# ===YOU SHOULD EDIT THIS FUNCTION===\n",
"def J_W1 (x, y) :\n",
" a0, z1, a1, z2, a2, z3, a3 = network_function(x)\n",
" J = 2 * (a3 - y)\n",
" J = J * d_sigma(z3)\n",
" J = (J.T @ W3).T\n",
" J = J * d_sigma(z2)\n",
" J = (J.T @ W2).T\n",
" J = J * d_sigma(z1)\n",
" J = J @ a0.T / x.size\n",
" return J\n",
"\n",
"# Fill in all incomplete lines.\n",
"# ===YOU SHOULD EDIT THIS FUNCTION===\n",
"def J_b1 (x, y) :\n",
" a0, z1, a1, z2, a2, z3, a3 = network_function(x)\n",
" J = 2 * (a3 - y)\n",
" J = J * d_sigma(z3)\n",
" J = (J.T @ W3).T\n",
" J = J * d_sigma(z2)\n",
" J = (J.T @ W2).T\n",
" J = J * d_sigma(z1)\n",
" J = np.sum(J, axis=1, keepdims=True) / x.size\n",
" return J"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test your code before submission\n",
"To test the code you've written above, run all previous cells (select each cell, then press the play button [ ▶| ] or press shift-enter).\n",
"You can then use the code below to test out your function.\n",
"You don't need to submit these cells; you can edit and run them as much as you like."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we generate training data, and generate a network with randomly assigned weights and biases."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"x, y = training_data()\n",
"reset_network()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, if you've implemented the assignment correctly, the following code will iterate through a steepest descent algorithm using the Jacobians you have calculated.\n",
"The function will plot the training data (in green), and your neural network solutions in pink for each iteration, and orange for the last output.\n",
"\n",
"It takes about 50,000 iterations to train this network.\n",
"We can split this up though - **10,000 iterations should take about a minute to run**.\n",
"Run the line below as many times as you like."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" this.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overriden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" // select the cell after this one\n",
" var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
" IPython.notebook.select(index + 1);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"\" width=\"640\">"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_training(x, y, iterations=10000, aggression=7, noise=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you wish, you can change parameters of the steepest descent algorithm (We'll go into more details in future exercises), but you can change how many iterations are plotted, how agressive the step down the Jacobian is, and how much noise to add.\n",
"\n",
"You can also edit the parameters of the neural network, i.e. to give it different amounts of neurons in the hidden layers by calling,\n",
"```python\n",
"reset_network(n1, n2)\n",
"```\n",
"\n",
"Play around with the parameters, and save your favourite result for the discussion prompt - *I ❤️ backpropagation*."
]
}
],
"metadata": {
"coursera": {
"course_slug": "maths4ml-calculus",
"graded_item_id": "oBlKd",
"launcher_item_id": "37rcX"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
You can’t perform that action at this time.