# [`XLA.jl`](https://github.com/JuliaTPU/XLA.jl): Shakespeare LSTM

In this notebook, we will showcase using `XLA.jl` with LSTMs to learn the structure of Shakespearean english

In [1]:
# Load package versions that are known to work with TPUs, check that Julia version is a known compatible one
if Base.GIT_VERSION_INFO.commit != "f1dffc5c8b6b7f960b5e30835631b4caf4434b04"
    @warn("Only the very latest Julia version on the `kf/tpu3` branch is supported!")
end

import Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()

[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`
[?25l[2K[?25h

In [18]:
using TensorFlow, XLA, Flux, Unrolled, Zygote, Printf

# First, let's download our dataset;
if !isfile("shakespeare_input.txt")
    download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", "shakespeare_input.txt")
end

# Read text in as a giant string, convert to array of characters
text = collect(String(read("shakespeare_input.txt")))

# Generate alphabet, which we will use as an embedding (along with special "stop" character '_')
alphabet = [unique(text)..., '_']
stop = Flux.onehot('_', alphabet)

# Embed text through alphabet as UInt32 onehot indices
text = UInt32.(map(ch -> Flux.onehotidx(ch, alphabet), text))

# We will process 64 sequences of length 50 at a time.  Reshape `text` into
# tensors of shape ()
batch_size = 64
seq_len = 50
num_batches = cld(length(text), seq_len*batch_size)
padded_length = seq_len*batch_size*nbatches + 1

text = map(x->Flux.onehotidx(x, alphabet), collect(String(read("shakespeare_input.txt"))))
text_padded = UInt32.(rpad(text, padded_length, 68))



println(" => Loaded $(length(text))-character dataset and encoded into $(length(alphabet))-symbol embedding")

 => Loaded 4573338-character dataset and encoded into 68-symbol embedding


In [23]:
using Flux: chunk, batchseq
using Base.Iterators: partition

# Batch segments of the text together into 50-character long chunks, with a batch size of 64
N = length(alphabet)
seqlen = 50
nbatches = 64

# We will be mapping from X -> Y, where Y is X but shifted over so that it predicts
# the next character of the text
Xs = collect(partition(batchseq(chunk(text[1:end-1], nbatches), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatches), stop), seqlen));

println(" => Segmented $(length(Xs)) batches containing $(length(Xs[1]))-length sequences with a batch size of $(size(Xs[1][1], 2))")

 => Segmented 1430 batches containing 50-length sequences with a batch size of 64


In [52]:
[Flux.onecold(Xs[1][idx]) for idx in 1:50]
#Flux.onecold(collect(Xs[1][1]))

50-element Array{Array{Int64,1},1}:
 [1, 10, 20, 35, 20, 25, 9, 20, 9, 4  …  13, 20, 9, 20, 9, 29, 10, 6, 4, 9]    
 [2, 5, 17, 20, 22, 42, 10, 21, 20, 42  …  22, 10, 27, 5, 20, 2, 19, 4, 9, 4]  
 [3, 4, 17, 19, 41, 12, 41, 6, 32, 12  …  5, 19, 12, 23, 3, 4, 9, 2, 10, 6]    
 [4, 6, 21, 20, 23, 52, 9, 5, 9, 49  …  6, 6, 12, 6, 5, 23, 3, 3, 5, 16]       
 [5, 10, 6, 25, 5, 10, 3, 15, 10, 23  …  14, 17, 1, 25, 44, 6, 37, 27, 6, 2]   
 [6, 15, 32, 24, 9, 29, 27, 6, 4, 20  …  15, 9, 38, 2, 12, 16, 4, 12, 16, 5]   
 [7, 5, 2, 6, 3, 9, 12, 25, 44, 5  …  3, 5, 28, 4, 28, 9, 6, 12, 2, 23]        
 [2, 6, 18, 23, 24, 4, 12, 9, 6, 6  …  6, 5, 47, 4, 29, 6, 5, 30, 5, 6]        
 [5, 4, 5, 20, 12, 4, 50, 6, 45, 29  …  5, 21, 39, 37, 20, 5, 15, 39, 23, 18]  
 [2, 17, 15, 32, 28, 6, 51, 5, 6, 15  …  23, 6, 52, 19, 18, 23, 10, 34, 6, 23] 
 [8, 2, 3, 9, 10, 21, 40, 23, 16, 32  …  9, 4, 30, 6, 26, 2, 41, 6, 21, 2]     
 [9, 3, 21, 6, 19, 15, 34, 20, 2, 9  …  4, 17, 11, 5, 24, 4, 22, 28, 15, 29]   
 [10

In [7]:
# This function runs the full 
@unroll function full_lstm(a::Val{alphabet_size}, model, batch_idx, h1, h2, Xss, Yss) where {alphabet_size}
    # Unpack model into 
    (lstm1, lstm2, dense) = model
    loss = XRTArray(0f0)
    @unroll for i = 1:size(Xss, 1)
        idx = XRTArray(i)
        x = convert(XRTArray{Float32}, Flux.OneHotMatrix(alphabet_size, Xss[idx, batch_idx, :]))
        y = convert(XRTArray{Float32}, Flux.OneHotMatrix(alphabet_size, Yss[idx, batch_idx, :]))
        (h1, x) = lstm1(h1, x)
        (h2, x) = lstm2(h2, x)
        loss += logitcrossentropy(dense(x), y)
    end
    return loss, (h1, h2)
end

full_lstm_unrolled_expansion_ (generic function with 1 method)

In [8]:
include("../resnet/ADAM_tpu.jl")

unflatten_opt_state (generic function with 1 method)

In [9]:
function train_lstm(::Val{batch_size}, ::Val{seq_len}, ::Val{alphabet_size}, ::Val{nepochs}, ::Val{nbatches}, xrtic, text) where {mb_size, seq_len, alphabet_size, nepochs, nbatches}
    (lstm1, lstm2, dense) = xrtic
    h1 = Flux.hidden(lstm1)
    h2 = Flux.hidden(lstm2)

    η = XRTArray(0.01f0)
    β = (XRTArray(0.9f0), XRTArray(0.999f0))

    flat_model = XLA.flatten_tuple(xrtic)
    opt_state = Zygote.map(x -> (zero(x), zero(x), β), flat_model)
    opt_state = unflatten_opt_state(xrtic, opt_state)

    Xss = reshape(text[1:end-1], (seq_len, nbatches, batch_size))
    Yss = reshape(text[2:end],   (seq_len, nbatches, batch_size))

    i = XRTArray(0)

    # Get the batch-adjusted proper shape for h1 and h2 by running the computation forwards once
    (_, (a, b)) = full_lstm(Val(alphabet_size), xrtic, i, h1, h2, Xss, Yss)
    h1 = map(z->zeros(typeof(z)), a)
    h2 = map(z->zeros(typeof(z)), b)

    j = XRTArray(0)
    while j < XRTArray(nepochs)
        i = XRTArray(0)
        while i < XRTArray(nbatches)
            ((loss, (h1, h2)), back) = let h1=h1, h2=h2, Xss=Xss, Yss=Yss, xrtic=xrtic, i=i
                loss, back = Zygote._forward(
                    Zygote.Context{Nothing}(nothing),
                    xrtic -> full_lstm(Val(alphabet_size), xrtic, i, h1, h2, Xss, Yss),
                    xrtic,
                )
            end
            updates = Zygote.tailmemaybe(back((1f0, nothing)))[1]

            # Cross-replica sum our model updates to mix-n-match across all tpus
            updates = XLA.unflatten_tuple(updates,
                XLA.HloCrossReplicaSum{typeof(+)}((), 0, "")(
                    +,
                    XLA.flatten_tuple(updates)...
                )
            )

            # Update parameters via our optimizer
            (xrtic, opt_state) = update_params(xrtic, updates, opt_state, η, β)

            # Outfeed the loss
            loss = reshape(loss, (1,))
            XLA.HloOutfeed()((loss,), XLA.HloAfterAll()())

            # Count up
            i += XRTArray(1)
        end
        j += XRTArray(1)
    end
    return xrtic
end

train_lstm (generic function with 1 method)

In [41]:


#nepochs = 6
#compld = @tpu_compile train_lstm(Val(mb_size), Val(seq_len), Val(N), Val(nepochs), Val(nbatches), map_to_tpu(model), XRTArray(text_padded))

4573338-element Array{Int64,1}:
  1
  2
  3
  4
  5
  6
  7
  2
  5
  2
  8
  9
 10
  ⋮
  5
  6
 43
 22
  5
  6
  5
  3
 22
  9
 27
 12

In [50]:
Xss = reshape(text_padded[1:end-1], (seqlen, nbatches, mb_size))
Xss[:,2,1]

50-element Array{UInt32,1}:
 0x00000006
 0x00000019
 0x00000009
 0x00000006
 0x00000004
 0x00000011
 0x00000009
 0x00000014
 0x0000001a
 0x0000001b
 0x0000000c
 0x0000000c
 0x0000001c
          ⋮
 0x00000007
 0x00000002
 0x00000005
 0x00000002
 0x00000008
 0x00000009
 0x0000000a
 0x0000000b
 0x0000000c
 0x0000001f
 0x0000000f
 0x00000016