# Flux JL

## Flux Model Zoo Examples

# 4. Char RNN

**FluxML contributors**

**Source:** https://github.com/FluxML/model-zoo/blob/master/text/char-rnn/char-rnn.jl

In this notebook we will do a hello world model of a sequential neural network, ie. the Char RNN

In [1]:
using Flux
using Flux: onehot, onecold, chunk, batchseq, throttle, logitcrossentropy
using StatsBase: wsample
using Statistics, Random
using Base.Iterators: partition
using Parameters: @with_kw
using ProgressMeter: @showprogress
using Logging: with_logger
using CUDA
import BSON

#### Utility Functions

In [2]:
num_params(model) = sum(length, Flux.params(model)) 
round4(x) = round(x, digits=4)
;

### Char RNN

We'll start with the model

In [3]:
function CharRNN(N; hidden_dimsize=128)
    return Chain(
            LSTM(N, hidden_dimsize),
            LSTM(hidden_dimsize, hidden_dimsize),
            Dense(hidden_dimsize, N))
end 

CharRNN (generic function with 1 method)

### Dataloader

We'll use Shakespeare text data

In [4]:
function get_data(nbatches, seqlen)
    stopchar = '_'

    # Get the data if not downloaded already
    isfile("shakespeare_input.txt") ||
        download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt","shakespeare_input.txt")

    text = collect(String(read("shakespeare_input.txt")))
    
    # Construct an alphabet of the unique characters
    alphabet = [unique(text)..., stopchar]
    
    text_data = map(ch -> onehot(ch, alphabet), text)
    stopseq = onehot(stopchar, alphabet)
    
    # Partition the data as sequence of batches, which are then collected as array of batches
    Xs = collect(partition(batchseq(chunk(text_data, nbatches), stopseq), seqlen))
    Ys = collect(partition(batchseq(chunk(text_data[2:end], nbatches), stopseq), seqlen))

    nsamples = length(text)

    evalseqs = Xs[5], Ys[5]
    dataloader = zip(Xs, Ys)

    return dataloader, evalseqs, alphabet, nsamples
end

get_data (generic function with 1 method)

### Loss Function

In [5]:
function loss(ŷs, ys)
    return sum(logitcrossentropy.(ŷs, ys))
end

loss (generic function with 1 method)

In [16]:
function evaluate(evalseqs, model, device)
    xtest, ytest = evalseqs[1] |> device, evalseqs[2] |> device
    ŷtest = model.(xtest)

    l = loss(ŷtest, ytest) |> round4
    
    acc_counter = 0
    n_counter = 0
    for (ŷ, y) in zip(ŷtest, ytest)
        acc_counter += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        n_counter += size(ŷ)[end]
    end
    acc = 100 * (acc_counter / n_counter) |> round4

    return (loss=l, accuracy=acc)
end

evaluate (generic function with 1 method)

### Training Loop

In [17]:
function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()

    if use_cuda
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end
    
    ## Data
    dataloader, evalseqs, alphabet, nsamples = get_data(args.nbatches, args.seqlen)
    nchars = length(alphabet)
    @info "Shakespeare dataset: $nsamples samples and $nchars unique chars in alphabet"

    ## Model
    model = CharRNN(nchars) |> device
    @info "Char RNN LSTM model: $(num_params(model)) trainable params" 
    
    ## Optimiser
    θ = Flux.params(model)
    optimiser = ADAM(args.η)

    ## Epoch logging
    function report(epoch)
        eval = evaluate(evalseqs, model, device)    
        println("Epoch: $epoch   Eval: $(eval)")
    end

    ## Training Loop
    @info "Training started ..."
    report(0)
    for epoch in 1:args.epochs
        @showprogress for (xs, ys) in dataloader
            xs, ys = xs |> device, ys |> device
            ∂loss = Flux.gradient(θ) do
                        ŷs = model.(xs)
                        loss(ŷs, ys)
                    end
            
            Flux.Optimise.update!(optimiser, θ, ∂loss)
        end

        ## Printing and logging
        epoch % args.infotime == 0 && report(epoch)
        if args.checktime > 0 && epoch % args.checktime == 0
            !ispath(args.savepath) && mkpath(args.savepath)
            modelpath = joinpath(args.savepath, "model.bson") 
            let model = cpu(model) #return model to cpu before serialization
                BSON.@save modelpath model epoch
            end
            @info "Model saved in \"$(modelpath)\""
        end
    end


    return model, alphabet
end

train (generic function with 1 method)

### Programme Parameters

In [8]:
@with_kw mutable struct Args
    seed::Int = 0               # set seed > 0 for reproducibility
    use_cuda::Bool = false      # if true use cuda (if available)
    η::Float64 = 1e-2	        # Learning rate
    epochs::Int = 5             # Number of epochs
    seqlen::Int = 50	        # Length of batchseqences
    nbatches::Int = 50	        # Number of batches text is divided into
    infotime::Int = 1 	        # report every `infotime` epochs
    checktime::Int = 1          # Save the model every `checktime` epochs. Set to 0 for no checkpoints
    savepath::String = "runs/char_rnn"    # results path
end

Args

In [18]:
lstm, alphabet = train()

┌ Info: Training on CPU
└ @ Main In[17]:11
┌ Info: Shakespeare dataset: 4573338 samples and 68 unique chars in alphabet
└ @ Main In[17]:17
┌ Info: Char RNN LSTM model: 241732 trainable params
└ @ Main In[17]:21
┌ Info: Training started ...
└ @ Main In[17]:34


Epoch: 0   Eval: (loss = 211.0863f0, accuracy = 1.0)


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:13:19[39m


Epoch: 1   Eval: (loss = 114.1595f0, accuracy = 35.84)


┌ Info: Model saved in "runs/char_rnn/model.bson"
└ @ Main In[17]:55
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:13:06[39m


Epoch: 2   Eval: (loss = 110.1661f0, accuracy = 38.28)


┌ Info: Model saved in "runs/char_rnn/model.bson"
└ @ Main In[17]:55
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:13:24[39m


Epoch: 3   Eval: (loss = 109.7231f0, accuracy = 37.48)


┌ Info: Model saved in "runs/char_rnn/model.bson"
└ @ Main In[17]:55


(Chain(Recur(LSTMCell(68, 128)), Recur(LSTMCell(128, 128)), Dense(128, 68)), ['F', 'i', 'r', 's', 't', ' ', 'C', 'z', 'e', 'n'  …  'K', 'Q', '&', 'Z', 'X', '3', '$', '[', ']', '_'])