# [`XLA.jl`](https://github.com/JuliaTPU/XLA.jl): ResNet on TPUs - Introduction

In this notebook, we will step through the fundamental infrastructure necessary to load a [ResNet50](https://arxiv.org/abs/1512.03385) model, JIT it for the [TPU](https://en.wikipedia.org/wiki/Tensor_processing_unit), and feed it with some data in order to get classifications out.  Once you are comfortable with this material, you may wish to move on to [simple model training on the TPU](2_ResNet_Training.ipynb), followed by [distributed TPU training](3_ResNet_DistributedTraining.ipynb).

## Overview of `XLA.jl` workflow

We will define a model in plain Julia using the [`Flux.jl`](https://github.com/FluxML/Flux.jl) framework, that will provide the ResNet 50 model computation.  The model definition is contained within the file [`resnet50.jl`](resnet50.jl), however note that in the near future this will instead be sourced from the Metalhead.jl repository of general computer vision models defined in `Flux.jl`/Julia.

We will define a simple set of mappings to convert a standard Julia model to be TPU-runnable.  There are a number of restrictions within the current XLA.jl compiler that must be adhered to for compilation to succeed:

* All arrays and scalars must be of type `XRTArray` and must have an element type of `Float32`, including literals.  This unfortunately means that code such as `1./x` must be transformed to `XRTArray(1f0)./x`.  We intend to make this an automatic process in the future, but for the time being we manually define the appropriate helper functions such as `softmax()` that use `XRTArray`'s properly within [`model_utils.jl`](model_utils.jl).

* All arrays are immutable, meaning that the definition of some layers such as Batch Normalization in `Flux.jl` must be adapted.  We have created two separate versions of the `BatchNorm` layer in this repository, one meant to be used on the [TPU](tpu_batch_norm.jl) (which we will be using here) and one on the [CPU](cpu_batch_norm.jl), for testing and verification.

We will now load in the necessary packages, instantiating an environment local to these notebooks, and construct the model:

In [1]:
# Load package versions that are known to work with TPUs, check that Julia version is a known compatible one
if Base.GIT_VERSION_INFO.commit != "f1dffc5c8b6b7f960b5e30835631b4caf4434b04"
    @warn("Only the very latest Julia version on the `kf/tpu3` branch is supported!")
end

import Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()

[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`
[?25l[2K[?25h

In [2]:
# Load in packages and model definition
using TensorFlow, XLA, Flux, Printf
include("resnet50.jl")
include("tpu_batch_norm.jl")

model = resnet50();
println("=> Initialized ResNet50 model with $(sum(prod(size(p)) for p in params(model))) learnable parameters")

=> Initialized ResNet50 model with 25583464 learnable parameters


At this point, we have a `Flux.jl` ResNet50 model.  The next thing we need to do is to alter the model such that it is compileable via `XLA.jl`.  We will do so by defining a set of mapping functions that take advantage of multiple dispatch to recursively walk the model structure, converting normal arrays to `XRTArray`s, coercing scalar values to `Float32`, and converting `BatchNorm` layers to `TPUBatchNorm` objects:

In [3]:
# Convert scalars to single-element XRTArrays with eltype Float32:
map_to_tpu(x::Real) = XRTArray(convert(Float32, x))

# Convert arrays to XRTArrays with eltype Float32
map_to_tpu(x::AbstractArray) = XRTArray(Float32.(x))

# Strip off the TrackedArray coating to get at the data underneath
map_to_tpu(x::TrackedArray) = map_to_tpu(Flux.data(x))

# Turn Chain objects into ImmutableChain objects which store the computation within their type signature
map_to_tpu(x::Chain) = ImmutableChain(tuple(map(map_to_tpu, x.layers)...))

# Convert BatchNorm layers into TPUBatchNorm layers, passing all children straight through,
# except for the "active" child, which is not used by the TPUBatchNorm
map_to_tpu(x::BatchNorm) = TPUBatchNorm(map(map_to_tpu, Flux.children(x))[1:end-1]...)

# For all other objects, just map the children through `map_to_tpu`.
map_to_tpu(x) = Flux.mapchildren(map_to_tpu, x)


# Convert our model to the TPU-compatible version
tpu_model = map_to_tpu(model)
println("=> Mapped model to TPU-specific construction")

=> Mapped model to TPU-specific construction


## Compiling the model

At this point, we are ready to compile the model.  In order to do so, we must first connect to a TPU or `xrt_server` binary running on a host.  We will connect here to a TPU running on a certain port, and assign the special global variable name `sess` to a `Session` object.  Once we have connected to the TPU, we can use the `@tpu_compile` macro to compile our model down to an executable handle which can then be invoked to actually run the computation upon an `x`.

Compilation can take quite a while.  On the GCE machine this notebook was run on, the first compilation took over 30 seconds.

In [4]:
tpu_ip = "10.240.7.4"
println("Connecting to TPU on $(tpu_ip)")

# NOTE: If you are connecting to an actual TPU, use `TPUSession`.  If you are
# connecting to an `xrt_server`, use `Session()`.
sess = TPUSession("$(tpu_ip):8470")
#sess = Session(Graph(); target="grpc://$(tpu_ip):8470")

# Generate random input tensor; a batch of two images with spatial dimensions 224x224 and 3 color channels.
x = randn(Float32, 224, 224, 3, 2)

# Compile the model
t_start = time()
compilation_handle = @tpu_compile tpu_model(XRTArray(x));
t_end = time()

println(@sprintf("=> Compiled model in %.1f seconds", t_end - t_start))

Connecting to TPU on 10.240.7.4
=> Compiled model in 38.9 seconds


2019-02-11 05:06:22.864611: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:349] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.


Now that it is compiled, we can run it using `TensorFlow.jl`'s `run()` method.  We must pass first the compilation handle, then a structure containing the weights of the model, then the tensor to be pushed through the model.

In [5]:
# Run the actual computation
y_hat = run(compilation_handle,
    # Transfer model weights
    XRTRemoteStruct(sess, tpu_model),
    # Transfer `x`
    XRTArray(sess, x)
)

# Convert the output (which is an XRTArray) back to a normal array:
y_hat = convert(Array, y_hat)

1000×2 Array{Float32,2}:
 -9.20049e-6   -1.2162e-5  
 -2.57845e-5   -2.47399e-5 
  2.04847e-5    1.98748e-5 
 -2.54263e-5   -1.99875e-5 
 -2.60119e-5   -1.95756e-5 
  1.82703e-5    1.93357e-5 
  3.11616e-6    3.44445e-6 
 -6.66996e-6   -9.32722e-6 
 -6.14329e-5   -5.8134e-5  
  4.69457e-5    4.64068e-5 
 -4.44638e-5   -4.12743e-5 
  2.96213e-5    2.88296e-5 
  3.7362e-6     5.74233e-6 
  ⋮                        
  4.58536e-6    1.08529e-5 
  4.7623e-5     4.91411e-5 
  8.99033e-6    1.02057e-5 
  0.000103902   0.000102969
 -2.28912e-6   -2.34087e-6 
 -8.8403e-5    -8.85922e-5 
 -6.12964e-5   -5.65298e-5 
 -1.80922e-5   -1.55818e-5 
  1.09936e-5    8.16519e-6 
 -3.59232e-5   -3.59743e-5 
 -3.13594e-5   -2.93179e-5 
 -1.48641e-5   -1.30751e-5 

error in running finalizer: ErrorException("type TPUSession has no field ptr")


And just like that, you have successfully translated, compiled, and run your first model on the TPU.  Congratulations!  You should feel very proud of yourself.  Next up, we will [learn to do some training](2_ResNet_Training.ipynb).