Skip to content

Commit

Permalink
Spatial Transformer Recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
skaae committed Aug 19, 2015
1 parent 788a4cb commit 1da91a4
Showing 1 changed file with 382 additions and 0 deletions.
382 changes: 382 additions & 0 deletions examples/spatial_transformer_network.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,382 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"os.environ['THEANO_FLAGS']='device=gpu0'\n",
"\n",
"import matplotlib\n",
"import numpy as np\n",
"np.random.seed(123)\n",
"import matplotlib.pyplot as plt\n",
"import lasagne\n",
"import theano\n",
"import theano.tensor as T\n",
"try:\n",
" from lasagne.layers import dnn # fails early if not available\n",
" conv = dnn.Conv2DDNNLayer\n",
" pool = dnn.MaxPool2DDNNLayer\n",
" print \"Using DNN layers\"\n",
"except:\n",
" conv = lasagne.layers.Conv2DLayer\n",
" pool = lasagne.layers.MaxPool2DLayer\n",
" print \"DNN not available, using standard (slower) conv and pool layers\"\n",
"\n",
"NUM_EPOCHS = 500\n",
"BATCH_SIZE = 256\n",
"LEARNING_RATE = 0.001\n",
"DIM = 60\n",
"NUM_CLASSES = 10\n",
"mnist_cluttered = \"mnist_cluttered_60x60_6distortions.npz\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#Spatial Transformer Network\n",
"We use lasagne to classify cluttered MNIST digits using the spatial transformer network introduced in [1]. The spatial Transformer Network applies a learned affine transformation to its input.\n",
"\n",
"\n",
"\n",
"## Load data\n",
"We test the spatial transformer network using cluttered MNIST data.\n",
"\n",
"**Download the data (41 mb) with:**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"!wget -N https://s3.amazonaws.com/lasagne/recipes/datasets/mnist_cluttered_60x60_6distortions.npz"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def load_data():\n",
" data = np.load(mnist_cluttered)\n",
" X_train, y_train = data['x_train'], np.argmax(data['y_train'], axis=-1)\n",
" X_valid, y_valid = data['x_valid'], np.argmax(data['y_valid'], axis=-1)\n",
" X_test, y_test = data['x_test'], np.argmax(data['y_test'], axis=-1)\n",
"\n",
" # reshape for convolutions\n",
" X_train = X_train.reshape((X_train.shape[0], 1, DIM, DIM))\n",
" X_valid = X_valid.reshape((X_valid.shape[0], 1, DIM, DIM))\n",
" X_test = X_test.reshape((X_test.shape[0], 1, DIM, DIM))\n",
" \n",
" print \"Train samples:\", X_train.shape\n",
" print \"Validation samples:\", X_valid.shape\n",
" print \"Test samples:\", X_test.shape\n",
"\n",
" return dict(\n",
" X_train=lasagne.utils.floatX(X_train),\n",
" y_train=y_train.astype('int32'),\n",
" X_valid=lasagne.utils.floatX(X_valid),\n",
" y_valid=y_valid.astype('int32'),\n",
" X_test=lasagne.utils.floatX(X_test),\n",
" y_test=y_test.astype('int32'),\n",
" num_examples_train=X_train.shape[0],\n",
" num_examples_valid=X_valid.shape[0],\n",
" num_examples_test=X_test.shape[0],\n",
" input_height=X_train.shape[2],\n",
" input_width=X_train.shape[3],\n",
" output_dim=10,)\n",
"data = load_data()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plt.figure(figsize=(7,7))\n",
"plt.imshow(data['X_train'][101].reshape(DIM, DIM), cmap='gray', interpolation='none')\n",
"plt.title('Cluttered MNIST', fontsize=20)\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Building the model\n",
"We use a model where the localization network is a two layer convolution network which operates directly on the image input. The output from the localization network is a 6 dimensional vector specifying the parameters in the affine transformation. \n",
"\n",
"The localization feeds into the transformer layer which applies the transformation to the image input. In our setup the transformer layer downsamples the input by a factor 3. \n",
"\n",
"Finally a 2 layer convolution layer and 2 fully connected layers calculates the output probabilities. \n",
"\n",
"**The model**\n",
"\n",
"\n",
" Input -> localization_network -> TransformerLayer -> output_network -> predictions\n",
" | |\n",
" >--------------------------------^"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def build_model(input_width, input_height, output_dim,\n",
" batch_size=BATCH_SIZE):\n",
" ini = lasagne.init.HeUniform()\n",
" l_in = lasagne.layers.InputLayer(shape=(None, 1, input_width, input_height),)\n",
"\n",
" # Localization network\n",
" b = np.zeros((2, 3), dtype='float32')\n",
" b[0, 0] = 1\n",
" b[1, 1] = 1\n",
" b = b.flatten()\n",
" loc_l1 = pool(l_in, pool_size=(2, 2))\n",
" loc_l2 = conv(\n",
" loc_l1, num_filters=20, filter_size=(5, 5), W=ini)\n",
" loc_l3 = pool(loc_l2, pool_size=(2, 2))\n",
" loc_l4 = conv(loc_l3, num_filters=20, filter_size=(5, 5), W=ini)\n",
" loc_l5 = lasagne.layers.DenseLayer(\n",
" loc_l4, num_units=50, W=lasagne.init.HeUniform('relu'))\n",
" loc_out = lasagne.layers.DenseLayer(\n",
" loc_l5, num_units=6, b=b, W=lasagne.init.Constant(0.0), \n",
" nonlinearity=lasagne.nonlinearities.identity)\n",
" \n",
" # Transformer network\n",
" l_trans1 = lasagne.layers.TransformerLayer(l_in, loc_out, downsample_factor=3.0)\n",
" print \"Transformer network output shape: \", l_trans1.output_shape\n",
" \n",
" # Classification network\n",
" class_l1 = conv(\n",
" l_trans1,\n",
" num_filters=32,\n",
" filter_size=(3, 3),\n",
" nonlinearity=lasagne.nonlinearities.rectify,\n",
" W=ini,\n",
" )\n",
" class_l2 = pool(class_l1, pool_size=(2, 2))\n",
" class_l3 = conv(\n",
" class_l2,\n",
" num_filters=32,\n",
" filter_size=(3, 3),\n",
" nonlinearity=lasagne.nonlinearities.rectify,\n",
" W=ini,\n",
" )\n",
" class_l4 = pool(class_l3, pool_size=(2, 2))\n",
" class_l5 = lasagne.layers.DenseLayer(\n",
" class_l4,\n",
" num_units=256,\n",
" nonlinearity=lasagne.nonlinearities.rectify,\n",
" W=ini,\n",
" )\n",
"\n",
" l_out = lasagne.layers.DenseLayer(\n",
" class_l5,\n",
" num_units=output_dim,\n",
" nonlinearity=lasagne.nonlinearities.softmax,\n",
" W=ini,\n",
" )\n",
"\n",
" return l_out, l_trans1\n",
"\n",
"model, l_transform = build_model(DIM, DIM, NUM_CLASSES)\n",
"model_params = lasagne.layers.get_all_params(model, trainable=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X = T.tensor4()\n",
"y = T.ivector()\n",
"\n",
"# training output\n",
"output_train = lasagne.layers.get_output(model, X, deterministic=False)\n",
"\n",
"# evaluation output. Also includes output of transform for plotting\n",
"output_eval, transform_eval = lasagne.layers.get_output([model, l_transform], X, deterministic=True)\n",
"\n",
"sh_lr = theano.shared(lasagne.utils.floatX(LEARNING_RATE))\n",
"cost = T.mean(T.nnet.categorical_crossentropy(output_train, y))\n",
"updates = lasagne.updates.adam(cost, model_params, learning_rate=sh_lr)\n",
"\n",
"train = theano.function([X, y], [cost, output_train], updates=updates)\n",
"eval = theano.function([X], [output_eval, transform_eval])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def train_epoch(X, y):\n",
" num_samples = X.shape[0]\n",
" num_batches = int(np.ceil(num_samples / float(BATCH_SIZE)))\n",
" costs = []\n",
" correct = 0\n",
" for i in range(num_batches):\n",
" idx = range(i*BATCH_SIZE, np.minimum((i+1)*BATCH_SIZE, num_samples))\n",
" X_batch = X[idx]\n",
" y_batch = y[idx]\n",
" cost_batch, output_train = train(X_batch, y_batch)\n",
" costs += [cost_batch]\n",
" preds = np.argmax(output_train, axis=-1)\n",
" correct += np.sum(y_batch == preds)\n",
"\n",
" return np.mean(costs), correct / float(num_samples)\n",
"\n",
"\n",
"def eval_epoch(X, y):\n",
" output_eval, transform_eval = eval(X)\n",
" preds = np.argmax(output_eval, axis=-1)\n",
" acc = np.mean(preds == y)\n",
" return acc, transform_eval"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"valid_accs, train_accs, test_accs = [], [], []\n",
"try:\n",
" for n in range(NUM_EPOCHS):\n",
" train_cost, train_acc = train_epoch(data['X_train'], data['y_train'])\n",
" valid_acc, valid_trainsform = eval_epoch(data['X_valid'], data['y_valid'])\n",
" test_acc, test_transform = eval_epoch(data['X_test'], data['y_test'])\n",
" valid_accs += [valid_acc]\n",
" test_accs += [test_acc]\n",
" train_accs += [train_acc]\n",
"\n",
" if (n+1) % 20 == 0:\n",
" new_lr = sh_lr.get_value() * 0.7\n",
" print \"New LR:\", new_lr\n",
" sh_lr.set_value(lasagne.utils.floatX(new_lr))\n",
"\n",
" print \"Epoch {0}: Train cost {1}, Train acc {2}, val acc {3}, test acc {4}\".format(\n",
" n, train_cost, train_acc, valid_acc, test_acc)\n",
"except KeyboardInterrupt:\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Plot results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plt.figure(figsize=(9,9))\n",
"plt.plot(1-np.array(train_accs), label='Training Error')\n",
"plt.plot(1-np.array(valid_accs), label='Validation Error')\n",
"plt.legend(fontsize=20)\n",
"plt.xlabel('Epoch', fontsize=20)\n",
"plt.ylabel('Error', fontsize=20)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"\n",
"plt.figure(figsize=(7,14))\n",
"for i in range(3):\n",
" plt.subplot(321+i*2)\n",
" plt.imshow(data['X_test'][i].reshape(DIM, DIM), cmap='gray', interpolation='none')\n",
" if i == 0:\n",
" plt.title('Original 60x60', fontsize=20)\n",
" plt.axis('off')\n",
" plt.subplot(322+i*2)\n",
" plt.imshow(test_transform[i].reshape(DIM//3, DIM//3), cmap='gray', interpolation='none')\n",
" if i == 0:\n",
" plt.title('Transformed 20x20', fontsize=20)\n",
" plt.axis('off')\n",
"plt.tight_layout()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"# References\n",
"[1] Jaderberg, Max, et al. \"Spatial Transformer Networks.\" arXiv preprint arXiv:1506.02025 (2015).\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit 1da91a4

Please sign in to comment.