# `XTBTSScreener.jl` - Screening Likely Transition States with Julia and Machine Learning
This Jupyter notebook demonstrates the use of machine learning to predict if a partially-optimized initialization of a transition state, used in the study of chemical kinetics to predict rate constants, is _like to converge"_ and produze a valid transition state or not after further simulation with expensive Density Functional Theory simulations.

## Load the Data
The input data is saved in a CSV file, load it using `CSV.jl` and then partition the data into training and testing sets using `MLUtils.jl`.

In [1]:
using MLUtils, CSV

In [2]:
function get_dataloaders()
    csv_reader = CSV.File("data/roo_co2_full_data_augmented.csv")
    n_samples = 16517
    x_data = Array{Float32}(undef, 60, 6, n_samples)
    labels = Float32[]
    iter = 1
    println("Progress:")
    for row in csv_reader[1:n_samples]
        # print some updates as we go
        if mod(iter, div(n_samples, 60)) == 0
            println(" - row $iter of $n_samples")
            flush(stdout)
        end
        
        # get if it converged or not
        if parse(Bool, "$(row.converged)")
            push!(labels, 1.0f0)
        else
            push!(labels, 0.0f0)
        end
        
        # array for descriptors for this transition state
        m = Array{Float32}(undef, 60, 6)
        
        # pull out the augmented descriptors
        split_gibbs = split("$(row.gibbs)")
        split_steps = split("$(row.steps)")
        split_e0_zpe = split("$(row.e0_zpe)")
        split_descriptors = [split_gibbs,split_steps,split_e0_zpe]
        for i in 1:3
            for j in 1:3
                temp = String(split_descriptors[i][j])
                temp = replace(temp,"]"=>"")
                temp = replace(temp,"["=>"")
                temp = replace(temp,","=>"")
                m[i,j] = parse(Float32, temp)
            end
            m[i,4] = Float32(0.0)
            m[i,5] = Float32(0.0)
            m[i,6] = Float32(0.0)
        end

        # get the final coordinates of the atoms
        split_array = split("$(row.std_xyz)")
        n_atoms = Int(length(split_array)/6)
        row_counter = 4
        column_counter = 1
        for value in split_array
            temp = String(value)
            temp = replace(temp,"]"=>"")
            temp = replace(temp,"["=>"")
            temp = replace(temp,","=>"")
            m[row_counter, column_counter] = parse(Float32, temp)
            column_counter += 1
            if column_counter > 6
                column_counter = 1
                row_counter += 1
            end
        end

        # zero-padding
        for i in n_atoms+1:60
            m[i, 1:6] = [0,0,0,0,0,0]
        end
        x_data[1:60, 1:6, iter] = m
        iter += 1
    end
    println("loading done, partitioning data.")
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    return (DataLoader(collect.((x_train, y_train)); batchsize=2^4, shuffle=true),
            DataLoader(collect.((x_val, y_val)); batchsize=2^4, shuffle=false))
end

get_dataloaders (generic function with 1 method)

For ease of debugging and as a reference, the original tutorial dataloading function is included below.

In [3]:
function get_tutorial_dataloaders()
    dataset_size=1000
    sequence_length=50
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(d[1][:, (sequence_length + 1):end], :, sequence_length,
                                     1) for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    return (DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
            DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end

get_tutorial_dataloaders (generic function with 1 method)

## Configure the Neural Network
Following from the tutorial in the [Lux documentation](https://lux.csail.mit.edu/stable/examples/generated/beginner/SimpleRNN/main/) we write a series of functions that will create our NN.

In [4]:
using Lux, Random, Optimisers, Zygote, NNlib, Statistics

In [5]:
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 42)

TaskLocalRNG()

In [6]:
struct StateClassifier{L, C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

In [7]:
function StateClassifier(in_dims, hidden_dims, out_dims)
    return StateClassifier(LSTMCell(in_dims => hidden_dims),
                            Dense(hidden_dims => out_dims, sigmoid))
end

StateClassifier

In [8]:
function (s::StateClassifier)(x::AbstractArray{T, 3}, ps::NamedTuple,
                               st::NamedTuple) where {T}
    x_init, x_rest = Iterators.peel(eachslice(x; dims=2))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

In [9]:
function xlogy(x, y)
    result = x * log(y)
    return ifelse(iszero(x), zero(result), result)
end

function binarycrossentropy(y_pred, y_true)
    y_pred = y_pred .+ eps(eltype(y_pred))
    return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end

function compute_loss(x, y, model, ps, st)
    y_pred, st = model(x, ps, st)
    return binarycrossentropy(y_pred, y), y_pred, st
end

matches(y_pred, y_true) = sum((y_pred .> 0.5) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

accuracy (generic function with 1 method)

In [10]:
function create_optimiser(ps)
    opt = Optimisers.ADAM(0.0001f0)
    return Optimisers.setup(opt, ps)
end

create_optimiser (generic function with 1 method)

## Train the NN
Actual training and evaluation steps.

Load the data from the file and parition it:

In [11]:
(train_loader, val_loader) = get_dataloaders()

Progress:
 - row 275 of 16517
 - row 550 of 16517
 - row 825 of 16517
 - row 1100 of 16517
 - row 1375 of 16517
 - row 1650 of 16517
 - row 1925 of 16517
 - row 2200 of 16517
 - row 2475 of 16517
 - row 2750 of 16517
 - row 3025 of 16517
 - row 3300 of 16517
 - row 3575 of 16517
 - row 3850 of 16517
 - row 4125 of 16517
 - row 4400 of 16517
 - row 4675 of 16517
 - row 4950 of 16517
 - row 5225 of 16517
 - row 5500 of 16517
 - row 5775 of 16517
 - row 6050 of 16517
 - row 6325 of 16517
 - row 6600 of 16517
 - row 6875 of 16517
 - row 7150 of 16517
 - row 7425 of 16517
 - row 7700 of 16517
 - row 7975 of 16517
 - row 8250 of 16517
 - row 8525 of 16517
 - row 8800 of 16517
 - row 9075 of 16517
 - row 9350 of 16517
 - row 9625 of 16517
 - row 9900 of 16517
 - row 10175 of 16517
 - row 10450 of 16517
 - row 10725 of 16517
 - row 11000 of 16517
 - row 11275 of 16517
 - row 11550 of 16517
 - row 11825 of 16517
 - row 12100 of 16517
 - row 12375 of 16517
 - row 12650 of 16517
 - row 12925 of 1

(DataLoader(::Tuple{Array{Float32, 3}, Vector{Float32}}, shuffle=true, batchsize=16), DataLoader(::Tuple{Array{Float32, 3}, Vector{Float32}}, batchsize=16))

Create the model and optimizer:

In [15]:
model = StateClassifier(60, 6, 1)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
opt_state = create_optimiser(ps)

(lstm_cell = (weight_i = [32mLeaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, weight_h = [32mLeaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, bias = [32mLeaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))[32m)[39m), classifier = (weight = [32mLeaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))[32m)[39m, bias = [32mLeaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), [39m(Float32[0.0;;],

Actual model training and validation:

In [16]:
loss_vector = Float64[]
accuracy_vector = Float64[]
for epoch in 1:500
    # Train the model
    epoch_loss = Float64[]
    for (x, y) in train_loader
        (loss, y_pred, st), back = pullback(p -> compute_loss(x, y, model, p, st), ps)
        gs = back((one(loss), nothing, nothing))[1]
        opt_state, ps = Optimisers.update(opt_state, ps, gs)
        push!(epoch_loss, loss)
    end
    avg_loss = mean(epoch_loss)
    println("Epoch # $epoch:\n - loss of $avg_loss")
    push!(loss_vector, avg_loss)

    # Validate the model
    epoch_accuracy = Float64[]
    st_ = Lux.testmode(st)
    for (x, y) in val_loader
        (loss, y_pred, st_) = compute_loss(x, y, model, ps, st_)
        acc = accuracy(y_pred, y)
        push!(epoch_accuracy, acc)
    end
    avg_accuracy = mean(epoch_accuracy)
    println(" - accuracy of $avg_accuracy")
    push!(accuracy_vector, avg_accuracy)
end

Epoch # 1:
 - loss of 0.5355283794405963
 - accuracy of 0.8051673567977916
Epoch # 2:
 - loss of 0.49769447758324786
 - accuracy of 0.8063750862663907
Epoch # 3:
 - loss of 0.4852402705978828
 - accuracy of 0.8063750862663907
Epoch # 4:
 - loss of 0.4809976236186651
 - accuracy of 0.8063750862663907
Epoch # 5:
 - loss of 0.4787262171759444
 - accuracy of 0.8063750862663907
Epoch # 6:
 - loss of 0.4771813260116242
 - accuracy of 0.8063750862663907
Epoch # 7:
 - loss of 0.47593226783500747
 - accuracy of 0.8063750862663907
Epoch # 8:
 - loss of 0.47485157035597875
 - accuracy of 0.8063750862663907
Epoch # 9:
 - loss of 0.4740213891243242
 - accuracy of 0.8063750862663907
Epoch # 10:
 - loss of 0.47327926251631386
 - accuracy of 0.8063750862663907
Epoch # 11:
 - loss of 0.47255177218602296
 - accuracy of 0.8063750862663907
Epoch # 12:
 - loss of 0.47185062431250013
 - accuracy of 0.8063750862663907
Epoch # 13:
 - loss of 0.471256309267828
 - accuracy of 0.8063750862663907
Epoch # 14:
 - l

Epoch # 109:
 - loss of 0.43890100551575206
 - accuracy of 0.8012422360248448
Epoch # 110:
 - loss of 0.4387239759864588
 - accuracy of 0.8018461007591443
Epoch # 111:
 - loss of 0.4384171469520426
 - accuracy of 0.8015441683919945
Epoch # 112:
 - loss of 0.43826299876912744
 - accuracy of 0.8015441683919945
Epoch # 113:
 - loss of 0.4380062814017185
 - accuracy of 0.802148033126294
Epoch # 114:
 - loss of 0.4377482482942484
 - accuracy of 0.8015441683919945
Epoch # 115:
 - loss of 0.43752987230011686
 - accuracy of 0.8018461007591443
Epoch # 116:
 - loss of 0.4372815247670213
 - accuracy of 0.8015441683919945
Epoch # 117:
 - loss of 0.4371494181075339
 - accuracy of 0.800940303657695
Epoch # 118:
 - loss of 0.4369227043364296
 - accuracy of 0.8012422360248448
Epoch # 119:
 - loss of 0.43664717344120685
 - accuracy of 0.8012422360248448
Epoch # 120:
 - loss of 0.4364677699729259
 - accuracy of 0.800940303657695
Epoch # 121:
 - loss of 0.4362483411377914
 - accuracy of 0.800336438923395

Epoch # 215:
 - loss of 0.42200955520350186
 - accuracy of 0.7936939268461007
Epoch # 216:
 - loss of 0.4219149679901669
 - accuracy of 0.793391994478951
Epoch # 217:
 - loss of 0.421816881538592
 - accuracy of 0.7936939268461007
Epoch # 218:
 - loss of 0.4216969364906772
 - accuracy of 0.7942977915804003
Epoch # 219:
 - loss of 0.42151981192589094
 - accuracy of 0.7939958592132506
Epoch # 220:
 - loss of 0.42146280635211425
 - accuracy of 0.7936939268461007
Epoch # 221:
 - loss of 0.42128632823794576
 - accuracy of 0.793391994478951
Epoch # 222:
 - loss of 0.42116200653317476
 - accuracy of 0.7930900621118012
Epoch # 223:
 - loss of 0.4211178751387261
 - accuracy of 0.7927881297446515
Epoch # 224:
 - loss of 0.4208431341313277
 - accuracy of 0.7952035886818496
Epoch # 225:
 - loss of 0.4208107980162122
 - accuracy of 0.7936939268461007
Epoch # 226:
 - loss of 0.4206660367710827
 - accuracy of 0.7945997239475501
Epoch # 227:
 - loss of 0.42048448722890736
 - accuracy of 0.7930900621118

Epoch # 322:
 - loss of 0.4101717824580883
 - accuracy of 0.7820048309178744
Epoch # 323:
 - loss of 0.4100589167245508
 - accuracy of 0.7817028985507246
Epoch # 324:
 - loss of 0.4099903860136181
 - accuracy of 0.7844202898550725
Epoch # 325:
 - loss of 0.4099628142741921
 - accuracy of 0.7831262939958592
Epoch # 326:
 - loss of 0.40983789353732913
 - accuracy of 0.7832125603864735
Epoch # 327:
 - loss of 0.4097617305183815
 - accuracy of 0.7832125603864735
Epoch # 328:
 - loss of 0.40975075546137935
 - accuracy of 0.7829106280193237
Epoch # 329:
 - loss of 0.4093991652419723
 - accuracy of 0.7846359558316081
Epoch # 330:
 - loss of 0.40950446021145537
 - accuracy of 0.7835144927536232
Epoch # 331:
 - loss of 0.4092808796079338
 - accuracy of 0.783816425120773
Epoch # 332:
 - loss of 0.4091763100237304
 - accuracy of 0.7852398205659076
Epoch # 333:
 - loss of 0.4091917511445558
 - accuracy of 0.7844202898550725
Epoch # 334:
 - loss of 0.4090565484080274
 - accuracy of 0.78381642512077

Epoch # 429:
 - loss of 0.40098126804496703
 - accuracy of 0.7756642512077294
Epoch # 430:
 - loss of 0.40081194072925724
 - accuracy of 0.779891304347826
Epoch # 431:
 - loss of 0.40070543972184525
 - accuracy of 0.7756642512077294
Epoch # 432:
 - loss of 0.40061917698087474
 - accuracy of 0.777475845410628
Epoch # 433:
 - loss of 0.4005830048721004
 - accuracy of 0.777475845410628
Epoch # 434:
 - loss of 0.40054120255079456
 - accuracy of 0.7747584541062802
Epoch # 435:
 - loss of 0.4003960821263438
 - accuracy of 0.7771739130434783
Epoch # 436:
 - loss of 0.40034654958412663
 - accuracy of 0.7768719806763285
Epoch # 437:
 - loss of 0.4002171857846274
 - accuracy of 0.7762681159420289
Epoch # 438:
 - loss of 0.40016849570824214
 - accuracy of 0.7753623188405797
Epoch # 439:
 - loss of 0.4001307138410954
 - accuracy of 0.7762681159420289
Epoch # 440:
 - loss of 0.40010201387089334
 - accuracy of 0.7768719806763285
Epoch # 441:
 - loss of 0.39989795368931486
 - accuracy of 0.7723429951

In [14]:
using Plots
plot(loss_vector, label="loss", legend=:bottom, color=:red, rightmargin = 1.5Plots.cm, bottommargin = 0.5Plots.cm, box = :on, fmt = :png)
plot!(twinx(), accuracy_vector, label="accuracy", legend=:bottomleft, xlabel="epoch", rightmargin = 1.5Plots.cm, bottommargin = 0.5Plots.cm, box = :on)
using Dates
timestamp = now()
savefig("result-$timestamp.png")