# Confirm TensorFlow can see the GPU

Simply select "GPU" in the Accelerator drop-down in Notebook Settings (either through the Edit menu or the command palette at cmd/ctrl-shift-P).

In [0]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


# Observe TensorFlow speedup on GPU relative to CPU

This example constructs a typical convolutional neural network layer over a
random image and manually places the resulting ops on either the CPU or the GPU
to compare execution speed.

In [0]:
import tensorflow as tf
import timeit

# See https://www.tensorflow.org/tutorials/using_gpu#allowing_gpu_memory_growth
config = tf.ConfigProto(log_device_placement=True)
config.gpu_options.allow_growth = True

with tf.device('/cpu:0'):
  random_image_cpu = tf.random_normal((100, 100, 100, 3))
  net_cpu = tf.layers.conv2d(random_image_cpu, 32, 7)
  net_cpu = tf.reduce_sum(net_cpu)

with tf.device('/gpu:0'):
  random_image_gpu = tf.random_normal((100, 100, 100, 3))
  net_gpu = tf.layers.conv2d(random_image_gpu, 32, 7)
  net_gpu = tf.reduce_sum(net_gpu)

sess = tf.Session(config=config)

# Test execution once to detect errors early.
try:
  sess.run(tf.global_variables_initializer())
except tf.errors.InvalidArgumentError:
  print(
      '\n\nThis error most likely means that this notebook is not '
      'configured to use a GPU.  Change this in Notebook Settings via the '
      'command palette (cmd/ctrl-shift-P) or the Edit menu.\n\n')
  raise

def cpu():
  sess.run(net_cpu)
  
def gpu():
  sess.run(net_gpu)
  
# Runs the op several times.
print('Time (s) to convolve 32x7x7x3 filter over random 100x100x100x3 images '
      '(batch x height x width x channel). Sum of ten runs.')
print('CPU (s):')
cpu_time = timeit.timeit('cpu()', number=10, setup="from __main__ import cpu")
print(cpu_time)
print('GPU (s):')
gpu_time = timeit.timeit('gpu()', number=10, setup="from __main__ import gpu")
print(gpu_time)
print('GPU speedup over CPU: {}x'.format(int(cpu_time/gpu_time)))

sess.close()

Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.
Instructions for updating:
Please use `layer.__call__` method instead.
Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0

Time (s) to convolve 32x7x7x3 filter over random 100x100x100x3 images (batch x height x width x channel). Sum of ten runs.
CPU (s):
1.903530227000033
GPU (s):
1.4565913440000031
GPU speedup over CPU: 1x


# Benchmark TensorFlow GPU on real ANN

## Download models

In [0]:
models = {"mobilenet_v2_1.0_96": "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz",
          "resnet_v2_101_2017_04_14": "http://download.tensorflow.org/models/resnet_v2_101_2017_04_14.tar.gz"}

In [0]:
!wget https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz && \
    wget http://download.tensorflow.org/models/resnet_v2_101_2017_04_14.tar.gz

!ls -sh1

--2019-12-09 22:28:01--  https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.217.128, 2607:f8b0:400c:c03::80
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.217.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 78327491 (75M) [application/x-tar]
Saving to: ‘mobilenet_v2_1.0_96.tgz’


2019-12-09 22:28:01 (187 MB/s) - ‘mobilenet_v2_1.0_96.tgz’ saved [78327491/78327491]

--2019-12-09 22:28:01--  http://download.tensorflow.org/models/resnet_v2_101_2017_04_14.tar.gz
Resolving download.tensorflow.org (download.tensorflow.org)... 172.217.204.128, 2607:f8b0:400c:c0b::80
Connecting to download.tensorflow.org (download.tensorflow.org)|172.217.204.128|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 499948384 (477M) [application/x-tar]
Saving to: ‘resnet_v2_101_2017_04_14.tar.gz’


2019-12-09 22:28:04 (218 MB/s) - ‘resnet_v2_101

In [0]:
!mkdir mobilenet_v2_1.0_96 && cd mobilenet_v2_1.0_96 && tar fvxz ../mobilenet_v2_1.0_96.tgz && cd .. && rm mobilenet_v2_1.0_96.tgz && \
mkdir resnet_v2_101_2017_04_14 && cd resnet_v2_101_2017_04_14 && tar fvxz ../resnet_v2_101_2017_04_14.tar.gz && cd .. && rm resnet_v2_101_2017_04_14.tar.gz

mkdir: cannot create directory ‘mobilenet_v2_1.0_96’: File exists


In [0]:
!ls -R

.:
mobilenet_v2_1.0_96	 resnet_v2_101_2017_04_14	  sample_data
mobilenet_v2_1.0_96.tgz  resnet_v2_101_2017_04_14.tar.gz

./mobilenet_v2_1.0_96:
mobilenet_v2_1.0_96.ckpt.data-00000-of-00001  mobilenet_v2_1.0_96_frozen.pb
mobilenet_v2_1.0_96.ckpt.index		      mobilenet_v2_1.0_96_info.txt
mobilenet_v2_1.0_96.ckpt.meta		      mobilenet_v2_1.0_96.tflite
mobilenet_v2_1.0_96_eval.pbtxt

./resnet_v2_101_2017_04_14:
eval.graph  resnet_v2_101.ckpt	train.graph

./sample_data:
anscombe.json		      mnist_test.csv
california_housing_test.csv   mnist_train_small.csv
california_housing_train.csv  README.md


## Sample data

In [0]:
!ls sample_data

anscombe.json		      mnist_test.csv
california_housing_test.csv   mnist_train_small.csv
california_housing_train.csv  README.md


### MNIST

In [0]:
from pandas import read_csv

mnist_train_small = read_csv("sample_data/mnist_train_small.csv", index_col=False)
mnist_train_small.shape

(19999, 785)

In [0]:
mnist_train_small.info()
mnist_train_small.describe()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 19999 entries, 0 to 19998
Columns: 785 entries, 6 to 0.590
dtypes: int64(785)
memory usage: 119.8 MB


Unnamed: 0,6,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.10,0.11,0.12,0.13,0.14,0.15,0.16,0.17,0.18,0.19,0.20,0.21,0.22,0.23,0.24,0.25,0.26,0.27,0.28,0.29,0.30,0.31,0.32,0.33,0.34,0.35,0.36,0.37,0.38,...,0.551,0.552,0.553,0.554,0.555,0.556,0.557,0.558,0.559,0.560,0.561,0.562,0.563,0.564,0.565,0.566,0.567,0.568,0.569,0.570,0.571,0.572,0.573,0.574,0.575,0.576,0.577,0.578,0.579,0.580,0.581,0.582,0.583,0.584,0.585,0.586,0.587,0.588,0.589,0.590
count,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,...,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0
mean,4.470124,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0005,0.010801,0.010801,0.00045,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.012401,0.028451,0.058303,0.065503,0.127556,...,3.714036,2.627231,1.718486,1.020101,0.553228,0.247412,0.097755,0.020751,0.0014,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00035,0.010651,0.036052,0.088304,0.112306,0.158508,0.276914,0.40607,0.546827,0.572079,0.696235,0.671684,0.545927,0.366318,0.215011,0.087704,0.036502,0.013651,0.032602,0.006,0.0,0.0,0.0,0.0
std,2.892807,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.070712,1.527389,1.527389,0.063641,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.233369,1.986589,3.473328,3.100786,5.003077,...,26.815104,22.334578,18.262801,14.000786,10.463422,6.750766,4.079112,1.322117,0.197995,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.049499,1.095106,2.22082,3.900144,4.749952,5.406774,7.0053,8.719149,10.379141,10.254843,11.457391,11.297264,10.05733,8.255546,6.314821,3.921664,2.712527,0.950818,2.718102,0.600333,0.0,0.0,0.0,0.0
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,9.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.0,216.0,216.0,9.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,132.0,212.0,253.0,230.0,255.0,...,255.0,255.0,255.0,255.0,255.0,254.0,255.0,135.0,28.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0,132.0,231.0,253.0,253.0,253.0,255.0,255.0,255.0,255.0,255.0,255.0,255.0,255.0,253.0,254.0,253.0,79.0,254.0,62.0,0.0,0.0,0.0,0.0


In [0]:
from pandas import Series

y_mnist_train_small = Series([sample[0] for sample in mnist_train_small.values])
y_mnist_train_small.describe()

count    19999.000000
mean         4.470124
std          2.892807
min          0.000000
25%          2.000000
50%          4.000000
75%          7.000000
max          9.000000
dtype: float64

In [0]:
n_samples=int(1e3)
last_y_mnist_train_small = y_mnist_train_small[:n_samples].copy()
last_y_mnist_train_small.describe()

count    1000.000000
mean        4.631000
std         2.934186
min         0.000000
25%         2.000000
50%         5.000000
75%         7.000000
max         9.000000
dtype: float64

In [0]:
n_classes = len(last_y_mnist_train_small.unique())
n_classes

10

In [0]:
from pandas import DataFrame

x_mnist_train_small = DataFrame([sample for sample in mnist_train_small.values[:,1:]])
x_mnist_train_small.info()
x_mnist_train_small.describe()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 19999 entries, 0 to 19998
Columns: 784 entries, 0 to 783
dtypes: int64(784)
memory usage: 119.6 MB


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767,768,769,770,771,772,773,774,775,776,777,778,779,780,781,782,783
count,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,...,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0,19999.0
mean,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0005,0.010801,0.010801,0.00045,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.012401,0.028451,0.058303,0.065503,0.127556,0.19236,...,3.714036,2.627231,1.718486,1.020101,0.553228,0.247412,0.097755,0.020751,0.0014,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00035,0.010651,0.036052,0.088304,0.112306,0.158508,0.276914,0.40607,0.546827,0.572079,0.696235,0.671684,0.545927,0.366318,0.215011,0.087704,0.036502,0.013651,0.032602,0.006,0.0,0.0,0.0,0.0
std,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.070712,1.527389,1.527389,0.063641,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.233369,1.986589,3.473328,3.100786,5.003077,5.989394,...,26.815104,22.334578,18.262801,14.000786,10.463422,6.750766,4.079112,1.322117,0.197995,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.049499,1.095106,2.22082,3.900144,4.749952,5.406774,7.0053,8.719149,10.379141,10.254843,11.457391,11.297264,10.05733,8.255546,6.314821,3.921664,2.712527,0.950818,2.718102,0.600333,0.0,0.0,0.0,0.0
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.0,216.0,216.0,9.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,132.0,212.0,253.0,230.0,255.0,255.0,...,255.0,255.0,255.0,255.0,255.0,254.0,255.0,135.0,28.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0,132.0,231.0,253.0,253.0,253.0,255.0,255.0,255.0,255.0,255.0,255.0,255.0,255.0,253.0,254.0,253.0,79.0,254.0,62.0,0.0,0.0,0.0,0.0


In [0]:
HEIGHT = 28
WIDTH = 28
CHANNELS = 1 # Grayscale

In [0]:
from numpy import reshape

x_mnist_train_small_matrices = reshape(x_mnist_train_small.values, (-1, HEIGHT, WIDTH, CHANNELS))
x_mnist_train_small_matrices.shape

(19999, 28, 28, 1)

Note: [ImageNet images](http://www.image-net.org/challenges/LSVRC/2012/nonpub-downloads) are commonly resized to 224x224

In [0]:
HEIGHT = 224
WIDTH = 224

from numpy import array, float32

from PIL import Image

x_mnist_train_small_scaled = array([array(Image.fromarray(arr[:,:,0]).resize((HEIGHT, WIDTH))).reshape((HEIGHT, WIDTH, CHANNELS)) for arr in x_mnist_train_small_matrices.astype(float32)])
x_mnist_train_small_scaled.shape

(19999, 224, 224, 1)

In [0]:
inputs = x_mnist_train_small_scaled
last_inputs = x_mnist_train_small_scaled[:n_samples].copy()
inputs.shape, last_inputs.shape

((19999, 224, 224, 1), (1000, 224, 224, 1))

In [0]:
GIGA = 1<<30

from sys import getsizeof

getsizeof(inputs) / GIGA, getsizeof(last_inputs) / GIGA

(3.738216534256935, 0.18692030012607574)

## ResNet v2

In [0]:
!wget https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v2.py

--2019-12-09 22:45:32--  https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v2.py
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘resnet_v2.py’

resnet_v2.py            [<=>                 ]       0  --.-KB/s               resnet_v2.py            [ <=>                ] 167.52K  --.-KB/s    in 0.04s   

2019-12-09 22:45:32 (4.20 MB/s) - ‘resnet_v2.py’ saved [171539]



In [0]:
!ls

mobilenet_v2_1.0_96	 resnet_v2_101_2017_04_14	  resnet_v2.py
mobilenet_v2_1.0_96.tgz  resnet_v2_101_2017_04_14.tar.gz  sample_data


In [0]:
from tensorflow.contrib import slim

https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v2.py#L29

Typical use:

In [0]:
from tensorflow.contrib.slim.nets import resnet_v2

### [ResNet-101](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v2.py#L274) for image classification into `n_classes` \[1000 for ImageNet\] classes:

In [0]:
from tensorflow import AUTO_REUSE

In [0]:
# inputs has shape [batch, 224, 224, 3]
with slim.arg_scope(resnet_v2.resnet_arg_scope()):
  net, end_points = resnet_v2.resnet_v2_101(last_inputs, num_classes=n_classes, is_training=False, reuse=AUTO_REUSE)
    
net

<tf.Tensor 'resnet_v2_101_1/logits/BiasAdd:0' shape=(1000, 1, 1, 10) dtype=float32>

In [0]:
logits = Series([net[sample,:,:,:].value_index for sample in range(n_samples)])
logits.describe()

count    1000.0
mean        0.0
std         0.0
min         0.0
25%         0.0
50%         0.0
75%         0.0
max         0.0
dtype: float64

In [0]:
from tensorflow import Session

init = tf.global_variables_initializer()

with Session() as sess:    
    sess.run(init)
    net.eval()

In [0]:
from tensorflow.losses import log_loss

log_loss(last_y_mnist_train_small, logits, scope='resnet_v2_101')

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


<tf.Tensor 'resnet_v2_101_3/value:0' shape=() dtype=float32>

### ResNet-101 for semantic segmentation into 21 classes:

In [0]:
# inputs has shape [batch, 513, 513, 3]
   with slim.arg_scope(resnet_v2.resnet_arg_scope()):
      net, end_points = resnet_v2.resnet_v2_101(last_inputs,
                                                21,
                                                is_training=False,
                                                global_pool=False,
                                                output_stride=16, reuse=AUTO_REUSE)