# Trace Generation and Modeling

To test our DeepValidate approach we generate a dataset of test traces from a chain of relatively simple arithmetical functions operating on a series of randomized inputs. Given the generated program traces, we train a LSTM classifier to predict whether the output will be valid or result in an error. 

The trace generation is performed by `output_trace.jl` which reproduces much of the functionality of `varextract.jl` with some important differences. Rather than send trace information to `stdout`, we direct the traces to a file `traces.dat`. This raw output is then processed into a CSV of traces (minus the error dumps we want to predict) and a CSV of binary (0, 1) labels indicating whether the run resulted in an error. 

(It must be noted that this is not possible within an IJulia notebook due to restrictions on [task switching in staged functions](https://github.com/JuliaLang/julia/issues/18568) which prevents the trace outputs from being written to a file recursively. However, this works just fine from the command line.)

In [None]:
function Cassette.overdub(ctx::TraceCtx,
                          f,
                          args...)
    open("traces.dat", "a") do file
        write(file, string(f))
        write(file, string(args))
    end
    
    # if we are supposed to descend, we call Cassette.recurse
    if Cassette.canrecurse(ctx, f, args...)
        subtrace = (Any[],Any[])
        push!(ctx.metadata[1], (f, args) => subtrace)
        newctx = Cassette.similarcontext(ctx, metadata = subtrace)
        retval = Cassette.recurse(newctx, f, args...)
        # push!(ctx.metadata[2], subtrace[2])
    else
        retval = Cassette.fallback(ctx, f, args...)
        push!(ctx.metadata[1], :t)
        push!(ctx.metadata[2], retval)
    end
    @info "returning"
    @show retval
    return retval
end

We then modify our `@textset` so that it creates the `traces.dat` file and then loops through a large number of randomized runs of our arithmetic tests. Error conditions happen most often when our inputs are sufficiently close to zero, so a Normal(0,2) distribution gives us a good range of values to generate a reasonable percentage of "bad" traces on which to train. Empirically the share of "bad" traces generated is about 15-17%.

In [None]:
@testset "TraceExtract" begin
    g(x) = begin
        y = add(x.*x, -x)
        z = 1
        v = y .- z
        s = sum(v)
        return s
    end
    h(x) = begin
        z = g(x)
        zed = sqrt(z)
        return zed
    end

    open("traces.dat", "w") do f
        write(f, "")
    end

    seeds = rand(Normal(0,2),30000,3)
    
    for i=1:size(seeds,1)
        ctx = TraceCtx(pass=ExtractPass, metadata = (Any[], Any[]))
        try
            result = Cassette.overdub(ctx, h, seeds[i,:])
        catch DomainError
            dump(ctx.metadata)
        finally
            open("traces.dat", "a") do f
                write(f, "\n")
            end
        end
        if i%1000 == 0
            @info string(i)
        end
    end
end


After generating our raw traces, a small amount of pre-processing is required before attempting to model around them. First, we classify our "good" and "bad" traces based on whether they have resulted in an error. 

We then need to strip out the actual error dump information from our "bad" traces, as this would too easily give away the prediction game. All traces end just before they would error, allowing the validation model to predict the that next outcome. 

In [None]:
text = split(String(read("traces.dat")), "\n");
Ys = Int.(occursin.(Ref(r"(Base[\S(?!\))]+error)"i), text));

text = split.(text, Ref(r"(Base[\S(?!\))]+error)"i));
text = [t[1] for t in text];

sum(Ys)

Finally, we save our traces and our labels our as CSV files for easy ingestion for our model. 

In [None]:
writedlm( "traces.csv",  text[1:end-1], ',')
writedlm( "y_results.csv",  Ys[1:end-1], ',')

## Validation Classifier Model
For our modeling, we use [Flux.jl](https://github.com/FluxML/Flux.jl) and train an LSTM encoder/decoder classifier on our traces.

In [None]:
using DelimitedFiles
using Flux
using Flux: onehot, throttle, crossentropy, onehotbatch, params, shuffle
using MLDataPattern: stratifiedobs
using Base.Iterators: partition

include("../../src/validation/utils.jl")


In [None]:
#
# Set up inputs for model
#

# Read lines from traces.dat text in to arrays of characters
# Convert to onehot matrices

cd(@__DIR__)

text, alphabet, N = get_data("traces.csv")
stop = onehot('\n', alphabet);


In [None]:
# Partition into subsequences to input to our model

seq_len = 50

Xs = [collect(partition(t,seq_len)) for t in text];
Ys = readdlm("y_results.csv");

dataset = [(onehotbatch(x, alphabet, '\n'), onehot(Ys[i], unique(Ys)))
           for i in 1:length(Ys) for x in Xs[i]] |> shuffle

Ys = last.(dataset)

# Pad sequences to equal lengths

Xs = [hcat(x,repeat(stop,1,seq_len-size(x)[2])) for x in first.(dataset)]


In [None]:
# There are 972,290 items in our data. We use a train:test split of 90:10, stratified to ensure we have 
# the same share of "bad" and "good" traces in our train and test sets.

(Xtrain, Ytrain), (Xtest, Ytest) = stratifiedobs((Xs, Ys), p=0.9)

train = [(Xtrain[i], Ytrain[i]) for i in 1:length(Ytrain)];
test = [(Xtest[i], Ytest[i]) for i in 1:length(Ytest)];


In [None]:
# We set up our model architecture

scanner = Chain(Dense(length(alphabet), seq_len, σ), LSTM(seq_len, seq_len))
encoder = Dense(seq_len, 2)

function model(x)
  state = scanner.([x])[end]
  Flux.reset!(scanner)
  softmax(encoder(state))
end

loss(tup) = crossentropy(mod(tup[1]), tup[2])
accuracy(tup) = mean(argmax(m(tup[1])) .== argmax(tup[2]))

opt = ADAM(0.01)
ps = params(mod)

In [None]:
# Finally, we set up our callbacks for reporting on training progress.

testacc() = mean(accuracy(t) for t in test)
testloss() = mean(loss(t) for t in test)

evalcb = () -> @show testloss(), testacc()


In [None]:
# Now, train!

Flux.train!(loss, ps, train, opt, cb = throttle(evalcb, 10))
