# `XTBTSScreener.jl` - Screening Likely Transition States with Julia and Machine Learning
This Jupyter notebook demonstrates the use of machine learning to predict if a partially-optimized initialization of a transition state, used in the study of chemical kinetics to predict rate constants, is _like to converge"_ and produze a valid transition state or not after further simulation with expensive Density Functional Theory simulations.

## Load the Data
The input data is saved in a CSV file, load it using `CSV.jl` and then partition the data into training and testing sets using `MLUtils.jl`.

In [1]:
using MLUtils, CSV

In [2]:
function get_dataloaders()
    csv_reader = CSV.File("data/roo_co2_full_data.csv")
    n_samples = 16517
    x_data = Array{Float32}(undef, 55, 6, n_samples)
    labels = Float32[]
    iter = 1
    println("Progress:")
    for row in csv_reader[1:n_samples]
        # print some updates as we go
        if mod(iter, div(n_samples, 25)) == 0
            println(" - row $iter of $n_samples")
            flush(stdout)
        end
        
        # get if it converged or not
        if parse(Bool, "$(row.converged)")
            push!(labels, 1.0f0)
        else
            push!(labels, 0.0f0)
        end

        # get the final coordinates of the atoms
        split_array = split("$(row.std_xyz)")
        n_atoms = Int(length(split_array)/6)
        m = Array{Float32}(undef, 55, 6)
        row_counter = 1
        column_counter = 1
        for value in split_array
            temp = String(value)
            temp = replace(temp,"]"=>"")
            temp = replace(temp,"["=>"")
            temp = replace(temp,","=>"")
            m[row_counter, column_counter] = parse(Float32, temp)
            column_counter += 1
            if column_counter > 6
                column_counter = 1
                row_counter += 1
            end
        end

        # zero-padding
        for i in n_atoms+1:55
            m[i, 1:6] = [0,0,0,0,0,0]
        end
        x_data[1:55, 1:6, iter] = m
        iter += 1
    end
    println("loading done, partitioning data.")
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    return (DataLoader(collect.((x_train, y_train)); batchsize=2^4, shuffle=true),
            DataLoader(collect.((x_val, y_val)); batchsize=2^4, shuffle=false))
end

get_dataloaders (generic function with 1 method)

For ease of debuggin and as a reference, the original tutorial dataloading function is included below.

In [3]:
function get_tutorial_dataloaders()
    dataset_size=1000
    sequence_length=50
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(d[1][:, (sequence_length + 1):end], :, sequence_length,
                                     1) for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    return (DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
            DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end

get_tutorial_dataloaders (generic function with 1 method)

## Configure the Neural Network
Following from the tutorial in the [Lux documentation](https://lux.csail.mit.edu/stable/examples/generated/beginner/SimpleRNN/main/) we write a series of functions that will create our NN.

In [4]:
using Lux, Random, Optimisers, Zygote, NNlib, Statistics

In [5]:
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 42)

TaskLocalRNG()

In [6]:
struct StateClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

In [7]:
function StateClassifier(in_dims, hidden_dims, out_dims)
    return StateClassifier(LSTMCell(in_dims => hidden_dims),
                            Dense(hidden_dims => out_dims, sigmoid))
end

StateClassifier

In [8]:
function (s::StateClassifier)(x::AbstractArray{T, 3}, ps::NamedTuple,
                               st::NamedTuple) where {T}
    x_init, x_rest = Iterators.peel(eachslice(x; dims=2))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

In [9]:
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(x, y, model, ps, st)
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), y_pred, st
end

matches(y_pred, y_true) = sum((y_pred .> 0.5) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

accuracy (generic function with 1 method)

In [10]:
function create_optimiser(ps)
    opt = Optimisers.ADAM(0.0001f0)
    return Optimisers.setup(opt, ps)
end

create_optimiser (generic function with 1 method)

## Train the NN
Actual training and evaluation steps.

Load the data from the file and parition it:

In [11]:
(train_loader, val_loader) = get_dataloaders()

Progress:
 - row 660 of 16517
 - row 1320 of 16517
 - row 1980 of 16517
 - row 2640 of 16517
 - row 3300 of 16517
 - row 3960 of 16517
 - row 4620 of 16517
 - row 5280 of 16517
 - row 5940 of 16517
 - row 6600 of 16517
 - row 7260 of 16517
 - row 7920 of 16517
 - row 8580 of 16517
 - row 9240 of 16517
 - row 9900 of 16517
 - row 10560 of 16517
 - row 11220 of 16517
 - row 11880 of 16517
 - row 12540 of 16517
 - row 13200 of 16517
 - row 13860 of 16517
 - row 14520 of 16517
 - row 15180 of 16517
 - row 15840 of 16517
 - row 16500 of 16517
loading done, partitioning data.


(DataLoader(::Tuple{Array{Float32, 3}, Vector{Float32}}, shuffle=true, batchsize=16), DataLoader(::Tuple{Array{Float32, 3}, Vector{Float32}}, batchsize=16))

Create the model and optimizer:

In [12]:
model = StateClassifier(55, 6, 1)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
opt_state = create_optimiser(ps)

(lstm_cell = (weight_i = [32mLeaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, weight_h = [32mLeaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, bias = [32mLeaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))[32m)[39m), classifier = (weight = [32mLeaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, bias = [32mLeaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0;;],

Actual model training and validation:

In [15]:
loss_vector = Float64[]
accuracy_vector = Float64[]
for epoch in 1:500
    # Train the model
    epoch_loss = Float64[]
    for (x, y) in train_loader
        (loss, y_pred, st), back = pullback(p -> compute_loss(x, y, model, p, st), ps)
        gs = back((one(loss), nothing, nothing))[1]
        opt_state, ps = Optimisers.update(opt_state, ps, gs)
        push!(epoch_loss, loss)
    end
    avg_loss = mean(epoch_loss)
    println("Epoch # $epoch:\n - loss of $avg_loss")
    push!(loss_vector, avg_loss)

    # Validate the model
    epoch_accuracy = Float64[]
    st_ = Lux.testmode(st)
    for (x, y) in val_loader
        (loss, y_pred, st_) = compute_loss(x, y, model, ps, st_)
        acc = accuracy(y_pred, y)
        push!(epoch_accuracy, acc)
    end
    avg_accuracy = mean(epoch_accuracy)
    println(" - accuracy of $avg_accuracy")
    push!(accuracy_vector, avg_accuracy)
end

Epoch # 1:
 - loss of 0.5211532165005478
 - accuracy of 0.8057712215320911
Epoch # 2:
 - loss of 0.49489579337128137
 - accuracy of 0.8066770186335404
Epoch # 3:
 - loss of 0.4861325481252578
 - accuracy of 0.8066770186335404
Epoch # 4:
 - loss of 0.48275915587424656
 - accuracy of 0.8066770186335404
Epoch # 5:
 - loss of 0.48101673636035247
 - accuracy of 0.8066770186335404
Epoch # 6:
 - loss of 0.47967421896665496
 - accuracy of 0.8066770186335404
Epoch # 7:
 - loss of 0.4784496566573875
 - accuracy of 0.8066770186335404
Epoch # 8:
 - loss of 0.4773907147474208
 - accuracy of 0.8066770186335404
Epoch # 9:
 - loss of 0.4764636383543003
 - accuracy of 0.8066770186335404
Epoch # 10:
 - loss of 0.4755229091420589
 - accuracy of 0.8066770186335404
Epoch # 11:
 - loss of 0.4746417457521972
 - accuracy of 0.8066770186335404
Epoch # 12:
 - loss of 0.4738314503185974
 - accuracy of 0.8066770186335404
Epoch # 13:
 - loss of 0.4730674123828983
 - accuracy of 0.8066770186335404
Epoch # 14:
 - lo

Epoch # 109:
 - loss of 0.4373639938004369
 - accuracy of 0.8006383712905453
Epoch # 110:
 - loss of 0.4370436503869858
 - accuracy of 0.8006383712905453
Epoch # 111:
 - loss of 0.4367962864748502
 - accuracy of 0.8006383712905453
Epoch # 112:
 - loss of 0.43652178890119164
 - accuracy of 0.8003364389233955
Epoch # 113:
 - loss of 0.4362448503152799
 - accuracy of 0.8000345065562458
Epoch # 114:
 - loss of 0.43593391263744735
 - accuracy of 0.8003364389233955
Epoch # 115:
 - loss of 0.4356836666855916
 - accuracy of 0.8006383712905453
Epoch # 116:
 - loss of 0.43544008770447956
 - accuracy of 0.8006383712905453
Epoch # 117:
 - loss of 0.43521992562758144
 - accuracy of 0.8006383712905453
Epoch # 118:
 - loss of 0.43491322370403906
 - accuracy of 0.8006383712905453
Epoch # 119:
 - loss of 0.43472839754805437
 - accuracy of 0.8003364389233955
Epoch # 120:
 - loss of 0.4344160493201696
 - accuracy of 0.800940303657695
Epoch # 121:
 - loss of 0.43425438630667496
 - accuracy of 0.8006383712

Epoch # 215:
 - loss of 0.4166417224806244
 - accuracy of 0.7891649413388544
Epoch # 216:
 - loss of 0.4164211796934899
 - accuracy of 0.7879572118702554
Epoch # 217:
 - loss of 0.41637068799275173
 - accuracy of 0.7876552795031057
Epoch # 218:
 - loss of 0.4160949053292413
 - accuracy of 0.7897688060731539
Epoch # 219:
 - loss of 0.41605679141515395
 - accuracy of 0.7885610766045549
Epoch # 220:
 - loss of 0.41582299029278696
 - accuracy of 0.7873533471359558
Epoch # 221:
 - loss of 0.4157549711823752
 - accuracy of 0.7888630089717047
Epoch # 222:
 - loss of 0.41566831549092875
 - accuracy of 0.7879572118702554
Epoch # 223:
 - loss of 0.4154517865281994
 - accuracy of 0.7906746031746031
Epoch # 224:
 - loss of 0.41527546446130004
 - accuracy of 0.7867494824016563
Epoch # 225:
 - loss of 0.41521252229918004
 - accuracy of 0.7894668737060042
Epoch # 226:
 - loss of 0.4150763784172171
 - accuracy of 0.7867494824016563
Epoch # 227:
 - loss of 0.41487452054067037
 - accuracy of 0.787353347

Epoch # 322:
 - loss of 0.40387605199225013
 - accuracy of 0.7752760524499656
Epoch # 323:
 - loss of 0.40376224222685464
 - accuracy of 0.7795031055900621
Epoch # 324:
 - loss of 0.4038120326168889
 - accuracy of 0.774672187715666
Epoch # 325:
 - loss of 0.40355072330151287
 - accuracy of 0.7743702553485162
Epoch # 326:
 - loss of 0.4034786410900352
 - accuracy of 0.774672187715666
Epoch # 327:
 - loss of 0.4034137670627229
 - accuracy of 0.7776915113871635
Epoch # 328:
 - loss of 0.4031844324732231
 - accuracy of 0.7776915113871635
Epoch # 329:
 - loss of 0.403312124564486
 - accuracy of 0.7755779848171153
Epoch # 330:
 - loss of 0.4030350019690777
 - accuracy of 0.7764837819185646
Epoch # 331:
 - loss of 0.40291855901876605
 - accuracy of 0.7795031055900621
Epoch # 332:
 - loss of 0.4029647286340919
 - accuracy of 0.7767857142857143
Epoch # 333:
 - loss of 0.40271264200998563
 - accuracy of 0.7776915113871635
Epoch # 334:
 - loss of 0.4026538949757454
 - accuracy of 0.77708764665286

Epoch # 429:
 - loss of 0.3950811293851088
 - accuracy of 0.7713509316770186
Epoch # 430:
 - loss of 0.3950887733380385
 - accuracy of 0.7705314009661836
Epoch # 431:
 - loss of 0.39489469835061136
 - accuracy of 0.7728605935127675
Epoch # 432:
 - loss of 0.39487209282905655
 - accuracy of 0.7728605935127675
Epoch # 433:
 - loss of 0.39470819007722574
 - accuracy of 0.7708333333333334
Epoch # 434:
 - loss of 0.3947570572447113
 - accuracy of 0.7713509316770186
Epoch # 435:
 - loss of 0.3946276632640466
 - accuracy of 0.7708333333333334
Epoch # 436:
 - loss of 0.39459511007184556
 - accuracy of 0.773464458247067
Epoch # 437:
 - loss of 0.3946144129837396
 - accuracy of 0.7716528640441684
Epoch # 438:
 - loss of 0.3943759357831813
 - accuracy of 0.7716528640441684
Epoch # 439:
 - loss of 0.39447606772581256
 - accuracy of 0.7725586611456177
Epoch # 440:
 - loss of 0.39439161571946907
 - accuracy of 0.7698412698412699
Epoch # 441:
 - loss of 0.39421909417930007
 - accuracy of 0.7734644582

In [16]:
using Plots
plot(loss_vector, label="loss", legend=:bottom, color=:red, rightmargin = 1.5Plots.cm, bottommargin = 0.5Plots.cm, box = :on, fmt = :png)
plot!(twinx(), accuracy_vector, label="accuracy", legend=:bottomleft, xlabel="epoch", rightmargin = 1.5Plots.cm, bottommargin = 0.5Plots.cm, box = :on)
using Dates
timestamp = now()
savefig("result-$timestamp.png")