-
Notifications
You must be signed in to change notification settings - Fork 419
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,382 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%matplotlib inline\n", | ||
"import os\n", | ||
"os.environ['THEANO_FLAGS']='device=gpu0'\n", | ||
"\n", | ||
"import matplotlib\n", | ||
"import numpy as np\n", | ||
"np.random.seed(123)\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import lasagne\n", | ||
"import theano\n", | ||
"import theano.tensor as T\n", | ||
"try:\n", | ||
" from lasagne.layers import dnn # fails early if not available\n", | ||
" conv = dnn.Conv2DDNNLayer\n", | ||
" pool = dnn.MaxPool2DDNNLayer\n", | ||
" print \"Using DNN layers\"\n", | ||
"except:\n", | ||
" conv = lasagne.layers.Conv2DLayer\n", | ||
" pool = lasagne.layers.MaxPool2DLayer\n", | ||
" print \"DNN not available, using standard (slower) conv and pool layers\"\n", | ||
"\n", | ||
"NUM_EPOCHS = 500\n", | ||
"BATCH_SIZE = 256\n", | ||
"LEARNING_RATE = 0.001\n", | ||
"DIM = 60\n", | ||
"NUM_CLASSES = 10\n", | ||
"mnist_cluttered = \"mnist_cluttered_60x60_6distortions.npz\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#Spatial Transformer Network\n", | ||
"We use lasagne to classify cluttered MNIST digits using the spatial transformer network introduced in [1]. The spatial Transformer Network applies a learned affine transformation to its input.\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"## Load data\n", | ||
"We test the spatial transformer network using cluttered MNIST data.\n", | ||
"\n", | ||
"**Download the data (41 mb) with:**" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!wget -N https://s3.amazonaws.com/lasagne/recipes/datasets/mnist_cluttered_60x60_6distortions.npz" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def load_data():\n", | ||
" data = np.load(mnist_cluttered)\n", | ||
" X_train, y_train = data['x_train'], np.argmax(data['y_train'], axis=-1)\n", | ||
" X_valid, y_valid = data['x_valid'], np.argmax(data['y_valid'], axis=-1)\n", | ||
" X_test, y_test = data['x_test'], np.argmax(data['y_test'], axis=-1)\n", | ||
"\n", | ||
" # reshape for convolutions\n", | ||
" X_train = X_train.reshape((X_train.shape[0], 1, DIM, DIM))\n", | ||
" X_valid = X_valid.reshape((X_valid.shape[0], 1, DIM, DIM))\n", | ||
" X_test = X_test.reshape((X_test.shape[0], 1, DIM, DIM))\n", | ||
" \n", | ||
" print \"Train samples:\", X_train.shape\n", | ||
" print \"Validation samples:\", X_valid.shape\n", | ||
" print \"Test samples:\", X_test.shape\n", | ||
"\n", | ||
" return dict(\n", | ||
" X_train=lasagne.utils.floatX(X_train),\n", | ||
" y_train=y_train.astype('int32'),\n", | ||
" X_valid=lasagne.utils.floatX(X_valid),\n", | ||
" y_valid=y_valid.astype('int32'),\n", | ||
" X_test=lasagne.utils.floatX(X_test),\n", | ||
" y_test=y_test.astype('int32'),\n", | ||
" num_examples_train=X_train.shape[0],\n", | ||
" num_examples_valid=X_valid.shape[0],\n", | ||
" num_examples_test=X_test.shape[0],\n", | ||
" input_height=X_train.shape[2],\n", | ||
" input_width=X_train.shape[3],\n", | ||
" output_dim=10,)\n", | ||
"data = load_data()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"plt.figure(figsize=(7,7))\n", | ||
"plt.imshow(data['X_train'][101].reshape(DIM, DIM), cmap='gray', interpolation='none')\n", | ||
"plt.title('Cluttered MNIST', fontsize=20)\n", | ||
"plt.axis('off')\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Building the model\n", | ||
"We use a model where the localization network is a two layer convolution network which operates directly on the image input. The output from the localization network is a 6 dimensional vector specifying the parameters in the affine transformation. \n", | ||
"\n", | ||
"The localization feeds into the transformer layer which applies the transformation to the image input. In our setup the transformer layer downsamples the input by a factor 3. \n", | ||
"\n", | ||
"Finally a 2 layer convolution layer and 2 fully connected layers calculates the output probabilities. \n", | ||
"\n", | ||
"**The model**\n", | ||
"\n", | ||
"\n", | ||
" Input -> localization_network -> TransformerLayer -> output_network -> predictions\n", | ||
" | |\n", | ||
" >--------------------------------^" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def build_model(input_width, input_height, output_dim,\n", | ||
" batch_size=BATCH_SIZE):\n", | ||
" ini = lasagne.init.HeUniform()\n", | ||
" l_in = lasagne.layers.InputLayer(shape=(None, 1, input_width, input_height),)\n", | ||
"\n", | ||
" # Localization network\n", | ||
" b = np.zeros((2, 3), dtype='float32')\n", | ||
" b[0, 0] = 1\n", | ||
" b[1, 1] = 1\n", | ||
" b = b.flatten()\n", | ||
" loc_l1 = pool(l_in, pool_size=(2, 2))\n", | ||
" loc_l2 = conv(\n", | ||
" loc_l1, num_filters=20, filter_size=(5, 5), W=ini)\n", | ||
" loc_l3 = pool(loc_l2, pool_size=(2, 2))\n", | ||
" loc_l4 = conv(loc_l3, num_filters=20, filter_size=(5, 5), W=ini)\n", | ||
" loc_l5 = lasagne.layers.DenseLayer(\n", | ||
" loc_l4, num_units=50, W=lasagne.init.HeUniform('relu'))\n", | ||
" loc_out = lasagne.layers.DenseLayer(\n", | ||
" loc_l5, num_units=6, b=b, W=lasagne.init.Constant(0.0), \n", | ||
" nonlinearity=lasagne.nonlinearities.identity)\n", | ||
" \n", | ||
" # Transformer network\n", | ||
" l_trans1 = lasagne.layers.TransformerLayer(l_in, loc_out, downsample_factor=3.0)\n", | ||
" print \"Transformer network output shape: \", l_trans1.output_shape\n", | ||
" \n", | ||
" # Classification network\n", | ||
" class_l1 = conv(\n", | ||
" l_trans1,\n", | ||
" num_filters=32,\n", | ||
" filter_size=(3, 3),\n", | ||
" nonlinearity=lasagne.nonlinearities.rectify,\n", | ||
" W=ini,\n", | ||
" )\n", | ||
" class_l2 = pool(class_l1, pool_size=(2, 2))\n", | ||
" class_l3 = conv(\n", | ||
" class_l2,\n", | ||
" num_filters=32,\n", | ||
" filter_size=(3, 3),\n", | ||
" nonlinearity=lasagne.nonlinearities.rectify,\n", | ||
" W=ini,\n", | ||
" )\n", | ||
" class_l4 = pool(class_l3, pool_size=(2, 2))\n", | ||
" class_l5 = lasagne.layers.DenseLayer(\n", | ||
" class_l4,\n", | ||
" num_units=256,\n", | ||
" nonlinearity=lasagne.nonlinearities.rectify,\n", | ||
" W=ini,\n", | ||
" )\n", | ||
"\n", | ||
" l_out = lasagne.layers.DenseLayer(\n", | ||
" class_l5,\n", | ||
" num_units=output_dim,\n", | ||
" nonlinearity=lasagne.nonlinearities.softmax,\n", | ||
" W=ini,\n", | ||
" )\n", | ||
"\n", | ||
" return l_out, l_trans1\n", | ||
"\n", | ||
"model, l_transform = build_model(DIM, DIM, NUM_CLASSES)\n", | ||
"model_params = lasagne.layers.get_all_params(model, trainable=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"X = T.tensor4()\n", | ||
"y = T.ivector()\n", | ||
"\n", | ||
"# training output\n", | ||
"output_train = lasagne.layers.get_output(model, X, deterministic=False)\n", | ||
"\n", | ||
"# evaluation output. Also includes output of transform for plotting\n", | ||
"output_eval, transform_eval = lasagne.layers.get_output([model, l_transform], X, deterministic=True)\n", | ||
"\n", | ||
"sh_lr = theano.shared(lasagne.utils.floatX(LEARNING_RATE))\n", | ||
"cost = T.mean(T.nnet.categorical_crossentropy(output_train, y))\n", | ||
"updates = lasagne.updates.adam(cost, model_params, learning_rate=sh_lr)\n", | ||
"\n", | ||
"train = theano.function([X, y], [cost, output_train], updates=updates)\n", | ||
"eval = theano.function([X], [output_eval, transform_eval])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def train_epoch(X, y):\n", | ||
" num_samples = X.shape[0]\n", | ||
" num_batches = int(np.ceil(num_samples / float(BATCH_SIZE)))\n", | ||
" costs = []\n", | ||
" correct = 0\n", | ||
" for i in range(num_batches):\n", | ||
" idx = range(i*BATCH_SIZE, np.minimum((i+1)*BATCH_SIZE, num_samples))\n", | ||
" X_batch = X[idx]\n", | ||
" y_batch = y[idx]\n", | ||
" cost_batch, output_train = train(X_batch, y_batch)\n", | ||
" costs += [cost_batch]\n", | ||
" preds = np.argmax(output_train, axis=-1)\n", | ||
" correct += np.sum(y_batch == preds)\n", | ||
"\n", | ||
" return np.mean(costs), correct / float(num_samples)\n", | ||
"\n", | ||
"\n", | ||
"def eval_epoch(X, y):\n", | ||
" output_eval, transform_eval = eval(X)\n", | ||
" preds = np.argmax(output_eval, axis=-1)\n", | ||
" acc = np.mean(preds == y)\n", | ||
" return acc, transform_eval" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Training" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"valid_accs, train_accs, test_accs = [], [], []\n", | ||
"try:\n", | ||
" for n in range(NUM_EPOCHS):\n", | ||
" train_cost, train_acc = train_epoch(data['X_train'], data['y_train'])\n", | ||
" valid_acc, valid_trainsform = eval_epoch(data['X_valid'], data['y_valid'])\n", | ||
" test_acc, test_transform = eval_epoch(data['X_test'], data['y_test'])\n", | ||
" valid_accs += [valid_acc]\n", | ||
" test_accs += [test_acc]\n", | ||
" train_accs += [train_acc]\n", | ||
"\n", | ||
" if (n+1) % 20 == 0:\n", | ||
" new_lr = sh_lr.get_value() * 0.7\n", | ||
" print \"New LR:\", new_lr\n", | ||
" sh_lr.set_value(lasagne.utils.floatX(new_lr))\n", | ||
"\n", | ||
" print \"Epoch {0}: Train cost {1}, Train acc {2}, val acc {3}, test acc {4}\".format(\n", | ||
" n, train_cost, train_acc, valid_acc, test_acc)\n", | ||
"except KeyboardInterrupt:\n", | ||
" pass" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Plot results" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"plt.figure(figsize=(9,9))\n", | ||
"plt.plot(1-np.array(train_accs), label='Training Error')\n", | ||
"plt.plot(1-np.array(valid_accs), label='Validation Error')\n", | ||
"plt.legend(fontsize=20)\n", | ||
"plt.xlabel('Epoch', fontsize=20)\n", | ||
"plt.ylabel('Error', fontsize=20)\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"plt.figure(figsize=(7,14))\n", | ||
"for i in range(3):\n", | ||
" plt.subplot(321+i*2)\n", | ||
" plt.imshow(data['X_test'][i].reshape(DIM, DIM), cmap='gray', interpolation='none')\n", | ||
" if i == 0:\n", | ||
" plt.title('Original 60x60', fontsize=20)\n", | ||
" plt.axis('off')\n", | ||
" plt.subplot(322+i*2)\n", | ||
" plt.imshow(test_transform[i].reshape(DIM//3, DIM//3), cmap='gray', interpolation='none')\n", | ||
" if i == 0:\n", | ||
" plt.title('Transformed 20x20', fontsize=20)\n", | ||
" plt.axis('off')\n", | ||
"plt.tight_layout()\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"# References\n", | ||
"[1] Jaderberg, Max, et al. \"Spatial Transformer Networks.\" arXiv preprint arXiv:1506.02025 (2015).\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 2", | ||
"language": "python", | ||
"name": "python2" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |