# [`XLA.jl`](https://github.com/JuliaTPU/XLA.jl): ResNet on TPUs - Training

In this notebook, we will build on [the previous notebook](1_ResNet_Intro.ipynb) by introducing the training loop on TPUs for the ResNet 50 computer vision model.  We will use the same model, this time embedding its forward pass within a loop that streams batches of data in through an XLA InFeed and streams training loss out through an XLA OutFeed.  We essentially build an autonomous program which accepts tensors fed in over the network, calculates forward pass, backward pass, updates the model, and then finally returns the fully trained model.  This means that once the TPU code is running, it runs until the model is fully trained, the only interaction required with the TPU is feeding in data.

For simplicity and speed, we make use of a preprocessed ImageNet dataset, where each image has been transformed into a `224x224x3` array of UInt8's and each label has been transformed into a `UInt16`, which are stored on disk.  The image preprocessing steps can be found in [this notebook](PreprocessedImagenet.ipynb), for the terminally curious.  These are transferred to the TPU, where they are then unpacked and normalized into `Float32` tensors, ready to be run through the Flux model.

In [1]:
# Load package versions that are known to work with TPUs, check that Julia version is a known compatible one
if Base.GIT_VERSION_INFO.commit != "f1dffc5c8b6b7f960b5e30835631b4caf4434b04"
    @warn("Only the very latest Julia version on the `kf/tpu3` branch is supported!")
end

import Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()

[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`
[?25l[2K[?25h

In [3]:
# Load in packages and model definition
using TensorFlow, XLA, Flux, Printf
include("resnet50.jl")
include("tpu_batch_norm.jl")
include("preprocessing_utils.jl")
include("model_utils.jl")

model = resnet50();
println("=> Initialized ResNet50 model with $(sum(prod(size(p)) for p in params(model))) learnable parameters")

# Convert our model to the TPU-compatible version
tpu_model = map_to_tpu(model)
println("=> Mapped model to TPU-specific construction")

=> Initialized ResNet50 model with 25583464 learnable parameters
=> Mapped model to TPU-specific construction


## The Training Loop

We will define here a training loop that will run as a program on the TPU, taking batches in through an infeed, calculating forward and backward passes, updating weights, and outputting training loss through an outfeed.  We make use of [`Zygote.jl`](https://github.com/FluxML/Zygote.jl) to automatically differentiate the model, generating a function that calculates the backward pass and applying the updates using a custom-built `ADAM` implementation.

First, getting data onto the device.  TPUs do not support UInt8 arrays at the time of writing, so we pack a single pixel's (R, G, B) values into a UInt32, and transfer tensors of UInt32's across the wire to the TPU.  See [`preprocessing_utils.jl`](preprocessing_utils.jl) for more on that.  We define a method called `getminibatch_data()` that will read from an infeed, convert the pixel-packed values to `Float32` tensors ready for pushing through the model, and expanding the provided labels into onehot matrices.  Note that the `Val{batch_size}` is because we need this method to be completely statically inferrable (including the size of all tensors).  We pass in `batch_size` as a value type parameter to support compiling models for different batch sizes easily, whereas spatial resolution (`224x224` in this case) is hardcoded as that is much less likely to change, however the same treatment could be given to those values to create a more general infeed function.

In [4]:
function get_minibatch_data(::Val{batch_size}) where {batch_size}
    # Construct HloInfeed object that will expect to receive a tuple
    # of two arrays, one for `x` and one for `y`.  Note that incorrect sizes
    # here will cause...unexpected results, so do your best not to do that.
    infeed = XLA.HloInfeed(Tuple{
        XRTArray{UInt32, (224*224*batch_size,), 1},
        XRTArray{UInt32, (batch_size,), 1},
    })
    # Read in from the infeed
    (x, y), _ = infeed(XLA.HloAfterAll()())
    x = reshape(x, (224, 224, batch_size))
    
    # Do pixel unpacking/channel normalization.
    # We feed one-dimensional vectors, so we have to reshape as well.
    x = unpack_pixels(x)

    # Convert labels to onehot represnetation
    y = make_onehot(y)
    
    # Return our data!
    return x, y
end

get_minibatch_data (generic function with 1 method)

Next up, optimizer definition.  We hand-craft a simple SGD implementation here; for a more advanced optimizer see the [`ADAM_tpu.jl`](ADAM_tpu.jl) file, used in the next tutorial on distributed TPU training.  `ADAM` is slightly more complex as it must track gradient statistics for each weight in the model, complicating the update step.  In this example, we simply define an `SGD` type for dispatch purposes, then define a recursive update rule that will walk the model weights and gradients, updating as it goes and returning a new model:

In [5]:
struct SGD
    # Learning rate; the only parameter this optimizer needs to keep track of
    η
end

# Simplest update step in existence.
update!(model::AbstractArray, Δ::AbstractArray, η) = model .- Δ .* η

# If this leaf node had no updates calculated for it, then skip out early.
update!(model, Δ::Nothing, η) = model

function update!(model, Δ, η)
    # Base condition; if we have reached a leaf node return the inputs unchanged.
    # Note that if `model` is an XRTArray, we will hit the override above that actually
    # updates the model rather than this generic update!(), same for if Δ is `nothing`.
    if nfields(model) == 0
        return model
    end
    
    # Recursively pass the fields of this model through the update machinery.  We use
    # this strange ntuple() do-block because we cannot perform any kind of mutation
    # (such as push!()'ing onto a list) and so we adopt this more functional-style of
    # programming.
    new_fields = ntuple(Val(nfields(model))) do i
        update!(getfield(model, i), getfield(Δ, i), η)
    end
    
    # Return something of the same type as `model`, but with the new fields
    if isa(model, Tuple)
        return tuple(new_fields...)
    else
        return typeof(model)(new_fields...)
    end
end

# Main entry point for this optimizer's update steps
update!(opt::SGD, model::ImmutableChain, Δ) = update!(model.layers, Δ.layers, opt.η)

update! (generic function with 4 methods)

Finally, the full training loop.  Now that we have the above pieces, this is conceptually very simple.  We will first initialize our optimizer object (not much to do there for SGD, but once we start using ADAM, this will become a little more involved), then we enter the minibatch-pushing loop.  This loop will infeed a new batch of data, push it through the model calculating loss, then backpropagate minimizing that loss in order to calculate a set of updates that should be applied to the model.  We then apply those updates to the model, finally outputting the training loss for this minibatch back to the controlling host.  Finally, once we have exceeded `nbatches` of training data, we return the trained model.

In [6]:
# Define our training loop
function train_loop(::Val{batch_size}, model, nbatches, η) where {batch_size}
    # Initialize optimizer, will allocate space for all necessary statistics within itself
    opt = SGD(η)

    # Run until nbatches is zero
    while nbatches > XRTArray(0)
        # Get next minibatch of data
        mb_data = get_minibatch_data(Val(batch_size))

        # Let block to fend off the inference demons
        loss, back = let x = mb_data[1], y = mb_data[2]
            # Calculate forward pass to get loss, and compile backwards pass
            # to get the updates to our model weights.
            Zygote._forward(
                Zygote.Context{Nothing}(nothing),
                model -> logitcrossentropy(model(x), y),
                model,
            )
        end

        # Evaluate the backwards pass.  Zygote automatically calculates
        # sensitivities upon `x` and `y`; we discard those via the tail()
        Δ_model = Zygote.tailmemaybe(back(1f0))[1]

        # Update parameters via our optimizer
        model = update!(opt, model, Δ_model)

        # Outfeed the loss
        #loss = reshape(loss, (1,))
        XLA.HloOutfeed()((loss,), XLA.HloAfterAll()())

        # Count down the batches
        nbatches -= XRTArray(1)
    end
    
    # At the end of all things, return the trained model
    return model
end

train_loop (generic function with 1 method)

In [7]:
# This works
function debug2(::Val{batch_size}, model, nbatches, η) where {batch_size}
    opt = SGD(η)

    mb_data = get_minibatch_data(Val(batch_size))
    
    loss, back = let x = mb_data[1], y = mb_data[2]
        Zygote._forward(
            Zygote.Context{Nothing}(nothing),
            model -> logitcrossentropy(model(x), y),
            model,
        )
    end

    Δ_model = Zygote.tailmemaybe(back(1f0))[1]
    model = update!(opt, model, Δ_model)
    
    XLA.HloOutfeed()((loss,), XLA.HloAfterAll()())
    return model
end

debug2 (generic function with 1 method)

## Running the training loop

Now that we've got all that code written up, let's actually run the training loop.  First, we compile it.  Again, this can take a _very_ long time (on the GCE instance this notebook was run on, this took over 60 seconds), so be patient.

In [8]:
tpu_ip = "10.240.7.4"
println("Connecting to TPU on $(tpu_ip)")

# NOTE: If you are connecting to an actual TPU, use `TPUSession`.  If you are
# connecting to an `xrt_server`, use `Session()`.
sess = TPUSession("$(tpu_ip):8470")
#sess = Session(Graph(); target="grpc://$(tpu_ip):8470")


batch_size = 128
num_batches = 1000
η = 0.001

x = randn(Float32, 224, 224, 3, 1)
y = rand(Float32, 1000, 1)

# Compile the model
t_start = time()
compilation_handle = @tpu_compile debug2(Val(batch_size), tpu_model, XRTArray(num_batches), XRTArray(η))
#compilation_handle = @tpu_compile train_loop(Val(batch_size), tpu_model, XRTArray(num_batches), XRTArray(η));
t_end = time()

println(@sprintf("=> Compiled training loop in %.1f seconds", t_end - t_start))

Connecting to TPU on 10.240.7.4
stmt = :((Core.apply_type)(%8972, Tuple{Conv{typeof(identity),XRTArray{Float32,(7, 7, 3, 64),4},XRTArray{Float32,(64,),1},(2, 2),(3, 3),(1, 1)},getfield(Main, Symbol("##65#68")),ResidualBlock{Tuple{ConvNorm{Conv{typeof(identity),XRTArray{Float32,(1, 1, 64, 64),4},XRTArray{Float32,(64,),1},(1, 1),(0, 0),(1, 1)},TPUBatchNorm{typeof(identity),XRTArray{Float32,(64,),1},XRTArray{Float32,(64,),1}}},ConvNorm{Conv{typeof(identity),XRTArray{Float32,(3, 3, 64, 64),4},XRTArray{Float32,(64,),1},(1, 1),(1, 1),(1, 1)},TPUBatchNorm{typeof(identity),XRTArray{Float32,(64,),1},XRTArray{Float32,(64,),1}}},ConvNorm{Conv{typeof(identity),XRTArray{Float32,(1, 1, 64, 256),4},XRTArray{Float32,(256,),1},(1, 1),(0, 0),(1, 1)},TPUBatchNorm{typeof(identity),XRTArray{Float32,(256,),1},XRTArray{Float32,(256,),1}}}},ImmutableChain{Tuple{Conv{typeof(identity),XRTArray{Float32,(1, 1, 64, 256),4},XRTArray{Float32,(256,),1},(1, 1),(0, 0),(1, 1)},TPUBatchNorm{typeof(identity),XRTArray{Floa

[90m[-16G│╻                                        get_minibatch_data[1G[39m[90m5  [39m1 ─ %1    = invoke $(QuoteNode(XLA.HloAfterAll()))()[36m::XLA.HloToken[39m
[90m[-16G││                                       [1G[39m[90m   [39m│   %2    = invoke XLA.HloInfeed{Tuple{XRTArray{UInt32,(6422528,),1},XRTArray{UInt32,(128,),1}}}(Tuple{XRTArray{UInt32,(6422528,),1},XRTArray{UInt32,(128,),1}})(%1::XLA.HloToken)[36m::Tuple{Tuple{XRTArray{UInt32,(6422528,),1},XRTArray{UInt32,(128,),1}},XLA.HloToken}[39m
[90m[-16G││╻                                        indexed_iterate[1G[39m[90m   [39m│   %3    = (Base.getfield)(%2, 1)[36m::Tuple{XRTArray{UInt32,(6422528,),1},XRTArray{UInt32,(128,),1}}[39m
[90m[-16G│││╻                                        indexed_iterate[1G[39m[90m   [39m│   %4    = (Base.getfield)(%3, 1)[36m::XRTArray{UInt32,(6422528,),1}[39m
[90m[-16G││╻                                        indexed_iterate[1G[39m[90m   [39m│   %5    = (Base.getfie




2019-02-11 04:53:22.898273: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:349] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.


ErrorException: Unrecognized expr

In [None]:
ret = run(compilation_handle, XRTRemoteStruct(sess, tpu_model))
ret = convert(typeof(ret).parameters[1], ret);

In [None]:
typeof(ret)

In [None]:
Base.IRShow.show_ir(stdout, XLA.code_typed_xla(Tuple{typeof(debug2), typeof(Val(batch_size)), typeof(tpu_model), typeof(XRTArray(num_batches)), typeof(XRTArray(η))})[1]; verbose_linetable=true)


In [11]:
XLA.explain_suboptimal_inference(Tuple{typeof(debug2), typeof(Val(batch_size)), typeof(tpu_model), typeof(XRTArray(num_batches)), typeof(XRTArray(η))})