# `XTBTSScreener.jl` - Screening Likely Transition States with Julia and Machine Learning
This Jupyter notebook demonstrates the use of machine learning to predict if a partially-optimized initialization of a transition state, used in the study of chemical kinetics to predict rate constants, is _like to converge"_ and produze a valid transition state or not after further simulation with expensive Density Functional Theory simulations.

## Load the Data
The input data is saved in a CSV file, load it using `CSV.jl` and then partition the data into training and testing sets using `MLUtils.jl`.

In [1]:
using MLUtils, CSV

In [2]:
function get_dataloaders()
    csv_reader = CSV.File("data/roo_co2_full_data.csv")
    n_samples = 16517
    x_data = Array{Float32}(undef, 55, 6, n_samples)
    labels = Float32[]
    iter = 1
    println("Progress:")
    for row in csv_reader[1:n_samples]
        # print some updates as we go
        if mod(iter, div(n_samples, 25)) == 0
            println(" - row $iter of $n_samples")
            flush(stdout)
        end
        
        # get if it converged or not
        if parse(Bool, "$(row.converged)")
            push!(labels, 1.0f0)
        else
            push!(labels, 0.0f0)
        end

        # get the final coordinates of the atoms
        split_array = split("$(row.std_xyz)")
        n_atoms = Int(length(split_array)/6)
        m = Array{Float32}(undef, 55, 6)
        row_counter = 1
        column_counter = 1
        for value in split_array
            temp = String(value)
            temp = replace(temp,"]"=>"")
            temp = replace(temp,"["=>"")
            temp = replace(temp,","=>"")
            m[row_counter, column_counter] = parse(Float32, temp)
            column_counter += 1
            if column_counter > 6
                column_counter = 1
                row_counter += 1
            end
        end

        # zero-padding
        for i in n_atoms+1:55
            m[i, 1:6] = [0,0,0,0,0,0]
        end
        x_data[1:55, 1:6, iter] = m
        iter += 1
    end
    println("loading done, partitioning data.")
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    return (DataLoader(collect.((x_train, y_train)); batchsize=2^6, shuffle=true),
            DataLoader(collect.((x_val, y_val)); batchsize=2^6, shuffle=false))
end

get_dataloaders (generic function with 1 method)

For ease of debuggin and as a reference, the original tutorial dataloading function is included below.

In [3]:
function get_tutorial_dataloaders()
    dataset_size=1000
    sequence_length=50
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(d[1][:, (sequence_length + 1):end], :, sequence_length,
                                     1) for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    return (DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
            DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end

get_tutorial_dataloaders (generic function with 1 method)

## Configure the Neural Network
Following from the tutorial in the [Lux documentation](https://lux.csail.mit.edu/stable/examples/generated/beginner/SimpleRNN/main/) we write a series of functions that will create our NN.

In [4]:
using Lux, Random, Optimisers, Zygote, NNlib, Statistics

In [5]:
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 42)

TaskLocalRNG()

In [6]:
struct StateClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

In [7]:
function StateClassifier(in_dims, hidden_dims, out_dims)
    return StateClassifier(LSTMCell(in_dims => hidden_dims),
                            Dense(hidden_dims => out_dims, sigmoid))
end

StateClassifier

In [8]:
function (s::StateClassifier)(x::AbstractArray{T, 3}, ps::NamedTuple,
                               st::NamedTuple) where {T}
    x_init, x_rest = Iterators.peel(eachslice(x; dims=2))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

In [9]:
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(x, y, model, ps, st)
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), y_pred, st
end

matches(y_pred, y_true) = sum((y_pred .> 0.5) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

accuracy (generic function with 1 method)

In [10]:
function create_optimiser(ps)
    opt = Optimisers.ADAM(0.01f0)
    return Optimisers.setup(opt, ps)
end

create_optimiser (generic function with 1 method)

## Train the NN
Actual training and evaluation steps.

Load the data from the file and parition it:

In [11]:
(train_loader, val_loader) = get_dataloaders()

Progress:
 - row 660 of 16517
 - row 1320 of 16517
 - row 1980 of 16517
 - row 2640 of 16517
 - row 3300 of 16517
 - row 3960 of 16517
 - row 4620 of 16517
 - row 5280 of 16517
 - row 5940 of 16517
 - row 6600 of 16517
 - row 7260 of 16517
 - row 7920 of 16517
 - row 8580 of 16517
 - row 9240 of 16517
 - row 9900 of 16517
 - row 10560 of 16517
 - row 11220 of 16517
 - row 11880 of 16517
 - row 12540 of 16517
 - row 13200 of 16517
 - row 13860 of 16517
 - row 14520 of 16517
 - row 15180 of 16517
 - row 15840 of 16517
 - row 16500 of 16517
loading done, partitioning data.


(DataLoader(::Tuple{Array{Float32, 3}, Vector{Float32}}, shuffle=true, batchsize=64), DataLoader(::Tuple{Array{Float32, 3}, Vector{Float32}}, batchsize=64))

Create the model and optimizer:

In [12]:
model = StateClassifier(55, 6, 1)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
opt_state = create_optimiser(ps)

(lstm_cell = (weight_i = [32mLeaf(Adam{Float32}(0.01, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, weight_h = [32mLeaf(Adam{Float32}(0.01, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, bias = [32mLeaf(Adam{Float32}(0.01, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))[32m)[39m), classifier = (weight = [32mLeaf(Adam{Float32}(0.01, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, bias = [32mLeaf(Adam{Float32}(0.01, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0;;], Float32[0

Actual model training and validation:

In [13]:
loss_vector = Float64[]
accuracy_vector = Float64[]
for epoch in 1:500
    # Train the model
    epoch_loss = Float64[]
    for (x, y) in train_loader
        (loss, y_pred, st), back = pullback(p -> compute_loss(x, y, model, p, st), ps)
        gs = back((one(loss), nothing, nothing))[1]
        opt_state, ps = Optimisers.update(opt_state, ps, gs)
        push!(epoch_loss, loss)
    end
    avg_loss = mean(epoch_loss)
    println("Epoch # $epoch:\n - loss of $avg_loss")
    push!(loss_vector, avg_loss)

    # Validate the model
    epoch_accuracy = Float64[]
    st_ = Lux.testmode(st)
    for (x, y) in val_loader
        (loss, y_pred, st_) = compute_loss(x, y, model, ps, st_)
        acc = accuracy(y_pred, y)
        push!(epoch_accuracy, acc)
    end
    avg_accuracy = mean(epoch_accuracy)
    println(" - accuracy of $avg_accuracy")
    push!(accuracy_vector, avg_accuracy)
end

Epoch # 1:
 - loss of 0.48685051076078184
 - accuracy of 0.8065366124260355
Epoch # 2:
 - loss of 0.47336750502747615
 - accuracy of 0.8062361316568047
Epoch # 3:
 - loss of 0.4666447519968097
 - accuracy of 0.8062361316568047
Epoch # 4:
 - loss of 0.4629714838836504
 - accuracy of 0.8056351701183432
Epoch # 5:
 - loss of 0.4570619589176731
 - accuracy of 0.8047337278106509
Epoch # 6:
 - loss of 0.4524433393985177
 - accuracy of 0.8050342085798816
Epoch # 7:
 - loss of 0.44906442968741705
 - accuracy of 0.8047337278106509
Epoch # 8:
 - loss of 0.4447438512447376
 - accuracy of 0.8008274778106509
Epoch # 9:
 - loss of 0.44139255691265716
 - accuracy of 0.7978226701183432
Epoch # 10:
 - loss of 0.43807190483894903
 - accuracy of 0.7969212278106509
Epoch # 11:
 - loss of 0.43688222996278664
 - accuracy of 0.7961276503944773
Epoch # 12:
 - loss of 0.4335442839037393
 - accuracy of 0.7982310157790926
Epoch # 13:
 - loss of 0.43123862579248956
 - accuracy of 0.7963202662721893
Epoch # 14:
 -

Epoch # 109:
 - loss of 0.3725320664436921
 - accuracy of 0.7571730152859961
Epoch # 110:
 - loss of 0.3741567327780424
 - accuracy of 0.7599852071005917
Epoch # 111:
 - loss of 0.3719594785268756
 - accuracy of 0.7585906681459567
Epoch # 112:
 - loss of 0.37043100271535956
 - accuracy of 0.7565720537475346
Epoch # 113:
 - loss of 0.3695730718149655
 - accuracy of 0.759877342209073
Epoch # 114:
 - loss of 0.36716788200940487
 - accuracy of 0.7616802268244576
Epoch # 115:
 - loss of 0.36516805986563367
 - accuracy of 0.7545765532544378
Epoch # 116:
 - loss of 0.3661393193519058
 - accuracy of 0.7532667652859961
Epoch # 117:
 - loss of 0.364970616168446
 - accuracy of 0.7495762450690335
Epoch # 118:
 - loss of 0.3643284175850919
 - accuracy of 0.7503929363905326
Epoch # 119:
 - loss of 0.3641648907304386
 - accuracy of 0.7564873027613412
Epoch # 120:
 - loss of 0.36703292867123793
 - accuracy of 0.7564873027613412
Epoch # 121:
 - loss of 0.36584849461265234
 - accuracy of 0.7560789571005

Epoch # 215:
 - loss of 0.35881887412301583
 - accuracy of 0.7547691691321499
Epoch # 216:
 - loss of 0.35899857593619305
 - accuracy of 0.7543608234714003
Epoch # 217:
 - loss of 0.363528528507205
 - accuracy of 0.7602856878698224
Epoch # 218:
 - loss of 0.3602865310250849
 - accuracy of 0.757365631163708
Epoch # 219:
 - loss of 0.355107432042343
 - accuracy of 0.7572577662721893
Epoch # 220:
 - loss of 0.3578649893906957
 - accuracy of 0.7514638806706114
Epoch # 221:
 - loss of 0.36102575366047845
 - accuracy of 0.7517643614398423
Epoch # 222:
 - loss of 0.35945990883209855
 - accuracy of 0.7548770340236686
Epoch # 223:
 - loss of 0.3549792971012097
 - accuracy of 0.7450459196252465
Epoch # 224:
 - loss of 0.353390345826817
 - accuracy of 0.7596847263313609
Epoch # 225:
 - loss of 0.3530448145336575
 - accuracy of 0.7531589003944773
Epoch # 226:
 - loss of 0.357308744304422
 - accuracy of 0.7579665927021696
Epoch # 227:
 - loss of 0.35656438335992285
 - accuracy of 0.7497688609467456

Epoch # 322:
 - loss of 0.34881644008528206
 - accuracy of 0.734744822485207
Epoch # 323:
 - loss of 0.350515260644581
 - accuracy of 0.7367634368836291
Epoch # 324:
 - loss of 0.3542271384582427
 - accuracy of 0.7480738412228797
Epoch # 325:
 - loss of 0.35274624162250096
 - accuracy of 0.7519800912228797
Epoch # 326:
 - loss of 0.3534018762301708
 - accuracy of 0.7562715729783037
Epoch # 327:
 - loss of 0.3505738818127176
 - accuracy of 0.7479659763313609
Epoch # 328:
 - loss of 0.3529860390269238
 - accuracy of 0.741162783530572
Epoch # 329:
 - loss of 0.3525622739020177
 - accuracy of 0.7550696499013807
Epoch # 330:
 - loss of 0.35265788327956543
 - accuracy of 0.7534824950690335
Epoch # 331:
 - loss of 0.3507796442451108
 - accuracy of 0.7485669378698224
Epoch # 332:
 - loss of 0.34783248279405676
 - accuracy of 0.7498767258382644
Epoch # 333:
 - loss of 0.3470951910468115
 - accuracy of 0.7496609960552268
Epoch # 334:
 - loss of 0.34312817101605275
 - accuracy of 0.74395186143984

Epoch # 429:
 - loss of 0.34733988222292655
 - accuracy of 0.7462709566074951
Epoch # 430:
 - loss of 0.34711334771580166
 - accuracy of 0.7467640532544378
Epoch # 431:
 - loss of 0.3474951664318785
 - accuracy of 0.7427499383629191
Epoch # 432:
 - loss of 0.3497470035645121
 - accuracy of 0.7463788214990138
Epoch # 433:
 - loss of 0.34929013266655556
 - accuracy of 0.7474728796844181
Epoch # 434:
 - loss of 0.34822204781039323
 - accuracy of 0.7414632642998028
Epoch # 435:
 - loss of 0.3484659682293445
 - accuracy of 0.7401534763313609
Epoch # 436:
 - loss of 0.3446296434039655
 - accuracy of 0.746463572485207
Epoch # 437:
 - loss of 0.3427773933623724
 - accuracy of 0.7495762450690335
Epoch # 438:
 - loss of 0.3446231961682223
 - accuracy of 0.7466561883629191
Epoch # 439:
 - loss of 0.34550452707470325
 - accuracy of 0.7455621301775147
Epoch # 440:
 - loss of 0.34176704532282365
 - accuracy of 0.746078340729783
Epoch # 441:
 - loss of 0.3447835071795229
 - accuracy of 0.741655880177

In [15]:
using Plots
plot(loss_vector, label="loss", legend=:bottom, color=:red, rightmargin = 1.5Plots.cm, bottommargin = 0.5Plots.cm, box = :on, fmt = :png)
plot!(twinx(), accuracy_vector, label="accuracy", legend=:outertopright, xlabel="epoch", rightmargin = 1.5Plots.cm, bottommargin = 0.5Plots.cm, box = :on)
using Dates
timestamp = now()
savefig("result-$timestamp.png")