# `XTBTSScreener.jl`
## Screening Likely Transition States with Julia and Machine Learning
This Jupyter notebook demonstrates the use of machine learning to predict if a partially-optimized initialization of a transition state, used in the study of chemical kinetics to predict rate constants, is _like to converge"_ and produze a valid transition state or not after further simulation with expensive Density Functional Theory simulations.

In [1]:
using Lux, Random, Optimisers, Zygote

In [2]:
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

TaskLocalRNG()

In [3]:
# Construct the layer
model = Chain(
    BatchNorm(128),
    Dense(128, 256, tanh),
    BatchNorm(256),
    Chain(Dense(256, 1, tanh),
    Dense(1, 10)),
)

Chain(
    layer_1 = BatchNorm(128, affine=true, track_stats=true),  [90m# 256 parameters[39m[90m, plus 257[39m
    layer_2 = Dense(128 => 256, tanh_fast),  [90m# 33_024 parameters[39m
    layer_3 = BatchNorm(256, affine=true, track_stats=true),  [90m# 512 parameters[39m[90m, plus 513[39m
    layer_4 = Dense(256 => 1, tanh_fast),  [90m# 257 parameters[39m
    layer_5 = Dense(1 => 10),           [90m# 20 parameters[39m
) [90m        # Total: [39m34_069 parameters,
[90m          #        plus [39m770 states, [90msummarysize [39m80 bytes.

In [4]:
# Parameter and State Variables
ps, st = Lux.setup(rng, model) #.|> gpu

((layer_1 = (scale = Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_2 = (weight = Float32[-0.11034693 0.10973185 … 0.097955346 -0.009067461; -0.0111903995 0.07578978 … -0.03190492 0.08886787; … ; 0.01854451 -0.035003364 … -0.016294405 0.019076452; -0.09206565 -0.047390625 … -0.08859007 0.009517342], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (scale = Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.05381791 -0.103856824 … -0.050962884 0.020612676], bias = Float32[0.0;;]), layer_5 = (weight = Float32[-0.65478534; 0.61009777; … ; 0.41110995; 0.5493141;;], bias = Float32[0.0; 0.0; …

In [5]:
# Dummy Input
x = rand(rng, Float32, 128, 2) #|> gpu

128×2 Matrix{Float32}:
 0.188564   0.4228
 0.683095   0.953174
 0.0598976  0.62799
 0.677622   0.564635
 0.0432115  0.228648
 0.645642   0.533853
 0.709369   0.0650043
 0.634036   0.0942084
 0.639628   0.828258
 0.559584   0.347723
 ⋮          
 0.870554   0.349935
 0.669238   0.635986
 0.504906   0.741774
 0.494614   0.238266
 0.951539   0.450495
 0.0595562  0.402075
 0.746626   0.212307
 0.884608   0.239166
 0.687504   0.82052

In [6]:
# Run the model
y, st = Lux.apply(model, x, ps, st)

(Float32[0.63368976 -0.63368976; -0.59044194 0.59044194; … ; -0.39786503 0.39786503; -0.53161657 0.53161657], (layer_1 = (running_mean = Float32[0.03056821, 0.08181342, 0.03439437, 0.062112845, 0.013592961, 0.05897472, 0.038718663, 0.036412235, 0.07339431, 0.04536539  …  0.045826413, 0.061024453, 0.06526123, 0.06233399, 0.03664399, 0.0701017, 0.02308154, 0.047946673, 0.05618869, 0.07540121], running_var = Float32[0.9027433, 0.9036471, 0.9161364, 0.9006383, 0.90171933, 0.9006248, 0.9207603, 0.9145707, 0.90177906, 0.9022442  …  0.9199226, 0.9135522, 0.9000553, 0.90280527, 0.9032857, 0.9125522, 0.9058659, 0.9142748, 0.9208297, 0.9008846], training = Val{true}()), layer_2 = NamedTuple(), layer_3 = (running_mean = Float32[2.8312206f-8, 3.576279f-8, 2.9802322f-8, -1.7881394f-8, 5.3644182f-8, 0.0, -8.791685f-8, 8.940697f-9, -2.9802323f-9, -2.5331975f-8  …  2.0861625f-8, -2.3841858f-8, 8.34465f-8, 3.8556756f-8, -8.0093745f-9, 4.172325f-8, 1.3411045f-7, -1.5199184f-7, 7.897616f-8, -4.4703484f-8

In [7]:
# Gradients
## Pullback API to capture change in state
(l, st_), pb = pullback(p -> Lux.apply(model, x, p, st), ps)
gs = pb((one.(l), nothing))[1]

(layer_1 = (scale = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_2 = (weight = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (scale = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], bias = Float32[0.008653244, -0.016698875, 0.023558283, 5.5808847f-5, -0.018153034, -0.023646543, -0.0073172706, 0.017697593, 0.011687537, 0.0076471795  …  -0.016878096, 0.009687504, -0.007905553, 0.018811973, 0.00070813333, -0.019242885, 0.001841197, 0.0038661156, -0.008194192, 0.0033142595]), layer_4 = (weight = Float32[-2.6453324f-9 7.155023f-10 … -1.3974826f-9 2.0005824f-9], bias = Float32[0.16078745;;]), layer_5 = (weight = Float32[0.0; 0.0; … ; 0.0; 0.0;;]

In [8]:
# Optimization
st_opt = Optimisers.setup(Optimisers.ADAM(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)

((layer_1 = (scale = [32mLeaf(Adam{Float64}(0.0001, (0.9, 0.999), 2.22045e-16), [39m(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.81, 0.998001))[32m)[39m, bias = [32mLeaf(Adam{Float64}(0.0001, (0.9, 0.999), 2.22045e-16), [39m(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.81, 0.998001))[32m)[39m), layer_2 = (weight = [32mLeaf(Adam{Float64}(0.0001, (0.9, 0.999), 2.22045e-16), [39m(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.81, 0.998001))[32m)[39m, bias = [32mLeaf(Adam{Float64}(0.0001