# [`XLA.jl`](https://github.com/JuliaTPU/XLA.jl): Shakespeare LSTM

In this notebook, we will showcase using `XLA.jl` with LSTMs to learn the structure of Shakespearean english

In [1]:
## Load package versions that are known to work with TPUs, check that Julia version is a known compatible one
if Base.GIT_VERSION_INFO.commit != "0424938442a907a35089254d2bd14b731c2008ec"
    @warn("Only the very latest Julia version on the `kf/tpu3` branch is supported!")
end

import Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()

[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`

In [2]:
using TensorFlow, XLA, Flux, Unrolled, Zygote, Printf, Statistics
include("tpu_optimizers.jl")

# First, let's download our dataset;
if !isfile("shakespeare_input.txt")
    download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", "shakespeare_input.txt")
end

# Read text in as a giant string, convert to array of characters
text = collect(String(read("shakespeare_input.txt")))

# Generate alphabet, which we will use as an embedding (along with special "stop" character '_')
alphabet = sort([unique(text)..., '_'])
stop = UInt32(Flux.onehotidx('_', alphabet))

# Embed text through alphabet as UInt32 onehot indices
text = UInt32.(map(ch -> Flux.onehotidx(ch, alphabet), text))

println(" => Loaded $(length(text))-character dataset and encoded into $(length(alphabet))-symbol embedding")

┌ Info: Recompiling stale cache file /home/sabae/.julia/compiled/v1.1/TensorFlow/IhIhf.ji for TensorFlow [1d978283-2c37-5f34-9a8e-e9c0ece82495]
└ @ Base loading.jl:1184
└ @ TensorFlow ~/.julia/packages/TensorFlow/eu9qM/src/TensorFlow.jl:3
┌ Info: Recompiling stale cache file /home/sabae/.julia/compiled/v1.1/XLA/bZBiw.ji for XLA [1ae4bca4-de81-11e8-0eca-6d3e4e7c4181]
└ @ Base loading.jl:1184
│ - If you have XLA checked out for development and have
│   added Random as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with XLA
┌ Info: Recompiling stale cache file /home/sabae/.julia/compiled/v1.1/Unrolled/BnVLg.ji for Unrolled [9602ed7d-8fef-5bc8-8597-8f21381861e8]
└ @ Base loading.jl:1184


 => Loaded 4573338-character dataset and encoded into 68-symbol embedding


In [3]:
# We will process 64 sequences of length 50 at a time.  Reshape `text` into
# tensors of shape (seq_len, batch_size, batch_idx).  To reshape cleanly, we
# will pad our text with our `stop` character until it is easily reshapable:
batch_size = 64
seq_len = 50
num_batches = ceil(Int, (length(text) - 1)/(seq_len*batch_size))
padded_length = seq_len*batch_size*num_batches + 1
text = vcat(text, repeat([stop], padded_length - length(text)))

# Build Xs and Ys from this text, where each element of `Xs` has its next element
# predicted by the corresponding element of `Ys`.
Xs = reshape(text[1:end-1], (seq_len, batch_size, num_batches))
Ys = reshape(text[2:end-0], (seq_len, batch_size, num_batches))

println(" => Segmented into $(num_batches) batches of size $(batch_size) with $(seq_len)-element sequences")

 => Segmented into 1430 batches of size 64 with 50-element sequences


In [4]:
# Extract the LSTM state vectors from a model
get_model_state(m::Flux.LSTMCell) = Flux.hidden(m)
get_model_state(m::Flux.Recur) = get_model_state(m.cell)
function get_model_state(model)
    return tuple(
        get_model_state(model.layers[1]),
        get_model_state(model.layers[2]),
    )
end

# Update LSTM state vectors within a model
set_model_state(m::Flux.LSTMCell, state) = Flux.LSTMCell(m.Wi, m.Wh, m.b, state...)
set_model_state(m::Flux.Recur, state) = Flux.Recur(set_model_state(m.cell, state))
function set_model_state(model, state)
    return typeof(model)(
        set_model_state(model.layers[1], state[1]),
        set_model_state(model.layers[2], state[2]),
        model.layers[3],
    )
end

function initialize_state(model, x)
    # Run the given x values through the model
    h1, h2 = get_model_state(model)
    
    # Create zero-vectors of the same length (this disregards batch dimension)
    zerovec(h) = Zygote.map(sub_h -> zero(sub_h[:,1]), h)
    h1, h2 = zerovec.((h1, h2))

    # Next, run the new x through the cells to broadcast up the dimensions of h1/h2
    h1, x = model.layers[1].cell(h1, x)
    h2, x = model.layers[2].cell(h2, x)

    # Set the model state and return the model
    return set_model_state(model, (h1, h2))
end


model = Chain(
    LSTM(length(alphabet), 128),
    LSTM(128, 128),
    Dense(128, length(alphabet))
)

model = initialize_state(model, zeros(Float32, length(alphabet), batch_size))
tpu_model = map_to_tpu(model);

In [5]:
function single_lstm_run(model, state, x)
    # Unpack model into separate layers
    lstm1, lstm2, dense = model.layers

    # Unpack state for our LSTM layers
    h1, h2 = state
    
    # Push `x` through, updating our state
    h1, x = lstm1(h1, x)
    h2, x = lstm2(h2, x)
    y_hat = dense(x)

    # Return y_hat and our state
    return y_hat, (h1, h2)
end


# Helper function to convert a batch of text at a particular time point into first a OneHotMatrix,
# and then densifying that OneHotMatrix into a typical XRTArray{Float32} which we can apply
# logitcrossentropy loss upon.
function densify(::Val{alphabet_size}, x::XRTArray, t) where {alphabet_size}
    return convert(XRTArray{Float32}, Flux.OneHotMatrix(alphabet_size, x[XRTArray(t), :]))
end
function densify(::Val{alphabet_size}, x, t) where {alphabet_size}
    return Flux.OneHotMatrix(alphabet_size, x[t, :])
end

# This function runs the full lstm.  It's not very easy to return a concatenated `y`
# because all XLA.jl code is immutable, so we can't do e.g. y[i] = ...
# Luckily, for training, we don't have to, we just accumulate into `loss`.
@unroll function full_lstm(unused::Val{alphabet_size}, model, x_batch::XRTArray, y_batch::XRTArray) where {alphabet_size}
    # Get current LSTM state
    state = get_model_state(model)
    
    # Accumulate loss into here
    loss = XRTArray(0f0)

    # Iterate over time
    @unroll for time_idx = 1:size(x_batch, 1)
        # Create dense representations of the one-hot encoded text at this point in time, across an entire batch
        x = densify(Val(alphabet_size), x_batch, time_idx)
        
        # Push x through our model to get y_hat (and new recurrent state values)
        y_hat, state = single_lstm_run(model, state, x)
        
        # Accumulate loss
        loss += crossentropy(softmax(y_hat), densify(Val(alphabet_size), y_batch, time_idx))
    end
    
    model = set_model_state(model, state)
    
    # Return loss and updated model
    return loss, model
end

full_lstm_unrolled_expansion_ (generic function with 1 method)

In [6]:
function train_lstm(::Val{alphabet_size}, ::Val{num_epochs}, model, Xs, Ys, η) where {alphabet_size, num_epochs}
    # Create optimizer
    opt = TPU_ADAM(model, η, (XRTArray(0.9f0), XRTArray(0.999f0)))
    
    # We will report loss once every epoch, store it here in the meantime:
    loss_buffer = zero(XRTArray{Float32, (size(Xs, 3),), 1})

    # Iterate over epochs
    epoch_idx = XRTArray(1)
    while epoch_idx <= XRTArray(num_epochs)
        # Iterate over batches within a single epoch
        batch_idx = XRTArray(1)
        
        batch_permutation = XLA.shuffle(XRTArray(1:size(Xs, 3)))
        while batch_idx <= XRTArray(size(Xs, 3))
            # Calculate forward pass of model, and compile backward pass stored in `back()`.
            # Use `let` block to work around Julia inference limitations
            (loss, model), back = let model=model,
                             x_batch=Xs[:, :, batch_permutation[batch_idx]],
                             y_batch=Ys[:, :, batch_permutation[batch_idx]]
                Zygote._forward(
                    Zygote.Context{Nothing}(nothing),
                    model -> full_lstm(Val(alphabet_size), model, x_batch, y_batch),
                    model,
                )
            end
            
            # Invoke `back()` with sensitivity `1f0` on the `loss`
            Δ_model = Zygote.tailmemaybe(back(1f0))[1]

            # Cross-replica sum our model updates to average across all tpus
            Δ_model = XLA.unflatten_tuple(Δ_model,
               XLA.HloCrossReplicaSum{typeof(+)}((), 0, "")(
                   +,
                   XLA.flatten_tuple(Δ_model)...
               )
            )

            # Update parameters via our optimizer
            opt, model = update!(opt, model, Δ_model)

            # Store loss over an epoch into `loss_buffer`.
            loss_buffer = Base.setindex(loss_buffer, loss, batch_idx)

            # Increment batch_idx
            batch_idx += XRTArray(1)
        end
        
        # Once per epoch, output our training loss for the entire epoch
        XLA.HloOutfeed()((loss_buffer,), XLA.HloAfterAll()())
        
        # Increment epoch_idx
        epoch_idx += XRTArray(1)
    end
    
    # Return the trained model (note that this gets returned from each of our TPUs, but we
    # only pay attention to the model returned from the first node, since they are all
    # identical thanks to the cross-replica sum above in the training loop)
    return model
end

train_lstm (generic function with 1 method)

In [7]:
tpu_ip = "10.240.25.3"
println("Connecting to TPU on $(tpu_ip)")

# NOTE: If you are connecting to an actual TPU, use `TPUSession`.  If you are
# connecting to an `xrt_server`, use `Session()`.
sess = TPUSession("$(tpu_ip):8470")

num_epochs = 10
η = 0.001f0

# Compile the model
t_start = time()
all_tpus = all_tpu_devices(sess)
compilation_handle = @tpu_compile devices=all_tpus train_lstm(Val(length(alphabet)), Val(num_epochs), tpu_model, XRTArray(Xs), XRTArray(Ys), XRTArray(0.01f0));
t_end = time()

println(@sprintf("=> Compiled training loop in %.1f seconds", t_end - t_start))

t_start = time()
loop_task = XLA.run_on_devices(compilation_handle, tpu_model, Xs, Ys, η)
t_end = time()

println(@sprintf("=> Launched training loop on %d TPUs in %.1f seconds", length(all_tpus), t_end - t_start))

Connecting to TPU on 10.240.25.3


2019-02-25 15:56:25.790070: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:349] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.
└ @ Main /home/sabae/.julia/dev/XLA/src/compiler_interface.jl:117


AssertionError: AssertionError: is_header && (bb_to_outline == bbs_to_outline[end] && idx == block.stmts[end])

In [8]:
loop_task

UndefVarError: UndefVarError: loop_task not defined

In [None]:
# Make an outfeed ops
outfeed_ops = [XLA.make_outfeed_on(sess,
    # On this device
    tpu_device,
    
    # Which will output this type
    Tuple{XRTArray{Float32, (num_batches,), 1},}
) for tpu_device in all_tpu_devices(sess)]

losses = Float64[]
for epoch_idx in 1:num_epochs
    # Get loss from TPU 1
    epoch_loss = mean(run(sess, outfeed_ops))
    append!(losses, epoch_loss)

    # Print it out as we go, showing the average loss to (hopefully) watch it decrease
    println("[$epoch_idx] epoch avg. loss: $(mean(epoch_loss))")
end

In [None]:
using Plots
l_idxs = collect(1:length(losses))./num_batches
Plots.plot(l_idxs, losses; xlabel="Epochs", ylabel="Loss", legend=nothing)

In [None]:
ret = fetch(loop_task)
trained_model = convert(typeof(ret[1]).parameters[1], ret[1]);

# Convert all XRTArray values to just normal arrays:
trained_model = map_to_cpu(trained_model)

# Resize the internal state vectors to deal with a single batch at a time
#trained_model = initialize_state(trained_model, Flux.onehot('a', alphabet))

In [None]:
z = model.layers[3]
#z.Wi .- convert(Array, trained_model.layers[2].Wi)

In [None]:
[alphabet[x] for x in Xs[:, 2, 1]]

In [None]:
trained_model

In [None]:
Flux.reset!(trained_model)
x = rand(alphabet)
print(x)
for idx in 1:200
    y_hat = softmax(trained_model(Flux.onehot(x, alphabet)))
    x = alphabet[argmax(y_hat[1,:])]
    print(x)
end

In [32]:
using StatsBase

function sample(m, alphabet, len; temp = 1)
  m = cpu(m)
  Flux.reset!(m)
  buf = IOBuffer()
  c = rand(alphabet)
  for i = 1:len
    write(buf, c)
    c = wsample(alphabet, softmax(m(Flux.onehot(c, alphabet))))
  end
  return String(take!(buf))
end

sample(trained_model, alphabet, 200)

".Csvuvt-!J!ibwf!mjlf!mjlf!ijt dpmpvsu-!xfmm!nfu;!gps!epjoh!Dsfttje!xjuipvu!evuz- vomfbnofe!uif!dpvsbhf!pg!njof?!pof!dbmmt!cz!xiptf!dpvousz!ibvout Nblftu!obnf!uif!ljohepn!mjlf!ijt!mpwfmz!upohvf3  JSPT!"

In [66]:
alphabet_size = length(alphabet)
trained_model = initialize_state(trained_model, randn(Float32, alphabet_size, 64))
x_batch = Xs[:, :, 1]
y_batch = Ys[:, :, 1]

# Accumulate loss into here
loss = 0.0

Flux.reset!(trained_model)

# Iterate over time
for time_idx = 1:size(x_batch, 1)
    # Create dense representations of the one-hot encoded text at this point in time, across an entire batch
    x = Float32.(densify(Val(alphabet_size), Int64.(x_batch), time_idx))
    y = Float32.(densify(Val(alphabet_size), Int64.(y_batch), time_idx))

    # Push x through our model to get y_hat (and new recurrent state values)
    #y_hat, state = single_lstm_run(trained_model, state, x)
    y_hat = trained_model(x)

    # Accumulate loss
    loss += Flux.logitcrossentropy(y_hat, y)
end

loss

592.3437714576721

In [14]:
sort(alphabet)

68-element Array{Char,1}:
 '\n'
 ' ' 
 '!' 
 '$' 
 '&' 
 '\''
 ',' 
 '-' 
 '.' 
 '3' 
 ':' 
 ';' 
 '?' 
 ⋮   
 'o' 
 'p' 
 'q' 
 'r' 
 's' 
 't' 
 'u' 
 'v' 
 'w' 
 'x' 
 'y' 
 'z' 

In [15]:
open("ir.txt", "w") do io
    Base.IRShow.show_ir(io, XLA.code_typed_xla(Tuple{typeof(train_lstm), typeof(Val(length(alphabet))), typeof(Val(num_epochs)), typeof(tpu_model), typeof(XRTArray(Xs)), typeof(XRTArray(Ys)), typeof(XRTArray(0.01f0))})[1]; verbose_linetable=true)
end