In [None]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import numpy as np
from pyscf import dft, scf, gto, cc
from pyscfad import dft as dft_ad
from pyscfad import gto as gto_ad
from functools import partial
import equinox as eqx
import optax
jax.config.update("jax_enable_x64", True) #Enables 64 bit precision


from xcquinox import net
from xcquinox.net import load_xcquinox_model
from xcquinox.loss import compute_loss_mae
from xcquinox.train import Pretrainer, Optimizer
from xcquinox.utils import gen_grid_s, PBE_Fx, PBE_Fc, calculate_stats, lda_x, pw92c_unpolarized

# Get references

In [None]:
w27systems = {
    'OHmH2O4c4' : { 'atoms' : ['H', 'H', 'H', 'H', 'O', 'H', 'O', 'O', 'O', 'O', 'H', 'H', 'H', 'H'], 'coords' : [[-0.0267918, 1.5818251, 0.2834249], [1.5818251, 0.0267918, 0.2834249], [0.0267918, -1.5818251, 0.2834249], [-1.5818251, -0.0267918, 0.2834249], [0.0, 0.0, 0.9578999], [0.0, 0.0, 1.9177035], [0.0314199, 2.4185664, -0.2529905], [2.4185664, -0.0314199, -0.2529905], [-0.0314199, -2.4185664, -0.2529905], [-2.4185664, 0.0314199, -0.2529905], [0.843048, 2.2785225, -0.7493353], [2.2785225, -0.843048, -0.7493353], [-0.843048, -2.2785225, -0.7493353], [-2.2785225, 0.843048, -0.7493353]], 'charge' : -1, 'spin' : 0 },
    'H2O8d2d' : { 'atoms' : ['O', 'O', 'O', 'O', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'O', 'O', 'O', 'O', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H'], 'coords' : [[-1.4645552, -1.4645552, 1.3497067], [-1.4645552, 1.4645552, -1.3497067], [1.4645552, -1.4645552, -1.3497067], [1.4645552, 1.4645552, 1.3497067], [-1.5346654, -1.5346654, 0.3622554], [-1.5346654, 1.5346654, -0.3622554], [1.5346654, -1.5346654, -0.3622554], [1.5346654, 1.5346654, 0.3622554], [-2.0896636, -2.0896636, 1.7251112], [-2.0896636, 2.0896636, -1.7251112], [2.0896636, -2.0896636, -1.7251112], [2.0896636, 2.0896636, 1.7251112], [-1.4052886, -1.4052886, -1.3390082], [-1.4052886, 1.4052886, 1.3390082], [1.4052886, -1.4052886, 1.3390082], [1.4052886, 1.4052886, -1.3390082], [-1.5526158, -0.4620645, -1.5268989], [-1.5526158, 0.4620645, 1.5268989], [1.5526158, -0.4620645, 1.5268989], [-0.4620645, -1.5526158, -1.5268989], [0.4620645, 1.5526158, -1.5268989], [0.4620645, -1.5526158, 1.5268989], [-0.4620645, 1.5526158, 1.5268989], [1.5526158, 0.4620645, -1.5268989]], 'charge' : 0, 'spin' : 0 },
    'H2O20es' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[1.4449507, -2.3192362, 1.3852172], [0.9523409, -1.4842602, 1.5288696], [2.3694569, -2.1139495, 1.6116655], [1.4198689, -2.4626094, -1.4007113], [1.4412349, -2.4945724, -0.4194091], [0.7771617, -3.1468464, -1.6575718], [0.0068371, 0.0172196, 1.4027157], [0.3893324, 0.9065634, 1.5625403], [-0.9609503, 0.0503272, 1.5609821], [-2.7367283, -0.0717073, 1.3588421], [-2.7881947, -0.0946107, 0.3779034], [-2.9785985, -0.9717563, 1.6413949], [4.0454641, 1.4627919, -1.4612681], [4.7881071, 1.7639854, -1.9912934], [4.0333238, 0.4731882, -1.5313495], [-1.0008037, 4.2249605, 1.4982864], [-1.152189, 5.0082615, 2.0335617], [-1.8328395, 3.6856129, 1.5416327], [-3.2272404, 2.7492225, -1.5453255], [-3.8412573, 3.2117349, -2.121788], [-2.3703373, 3.2485931, -1.5888707], [0.061546, -0.0608668, -1.3504788], [0.5404095, -0.8892136, -1.5688996], [0.0525024, -0.0517312, -0.3661063], [1.3505765, 2.3624375, -1.4255501], [2.2943427, 2.2180386, -1.6165285], [0.9255618, 1.4877943, -1.5595341], [4.0797114, -1.2705898, 1.5101742], [4.8184302, -1.5307198, 2.0667554], [4.0088205, -0.2833084, 1.5792807], [-0.9118416, 4.114684, -1.3968277], [-0.8699572, 4.3289684, -0.4493559], [-0.0897473, 3.6305502, -1.5859742], [-2.6850195, -0.0424816, -1.4274401], [-2.9831436, 0.8560006, -1.6562177], [-1.7132338, -0.0438086, -1.5635337], [-3.1295491, -2.8665515, -1.3299376], [-3.1140901, -1.9191429, -1.5480976], [-3.3032413, -2.9025862, -0.3736924], [1.224434, 2.4780105, 1.3595377], [0.5812433, 3.1657978, 1.6076146], [1.272843, 2.5161003, 0.3793006], [-0.812813, -4.2331318, -1.5239325], [-0.9466607, -5.0213678, -2.0566534], [-1.663369, -3.7232425, -1.5543692], [-3.2498861, 2.7544587, 1.3479831], [-3.1924627, 1.8120206, 1.5790401], [-3.4347595, 2.7666931, 0.3928574], [4.0373837, -1.2313563, -1.3812425], [3.2335259, -1.7195098, -1.6240819], [4.1604036, -1.4012254, -0.4312943], [-3.094015, -2.885118, 1.5661544], [-2.2180185, -3.3510614, 1.5943976], [-3.6898242, -3.3867826, 2.1285951], [-0.7332536, -4.1659349, 1.3697764], [0.0709231, -3.6538397, 1.5619151], [-0.6879895, -4.3597798, 0.4176959], [3.9098072, 1.4164687, 1.4257218], [3.0660482, 1.8495739, 1.636771], [4.0554225, 1.5928403, 0.4801529]], 'charge' : 0, 'spin' : 0 },
    'H2O20fs' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[-0.4674337, -1.4156432, 2.3424832], [-0.5103881, -0.5327785, 2.7383509], [-1.154371, -1.4135419, 1.6390602], [-1.1804953, -4.411489, -2.0217879], [-1.6005951, -5.0804995, -2.5684061], [-1.6296589, -4.4443491, -1.1306922], [-0.3890822, -4.3416186, 2.3649217], [-0.4122318, -3.4254374, 2.6780297], [0.495447, -4.4516812, 1.9766411], [-2.370827, -1.4010748, 0.3317388], [-2.7687053, -0.5190604, 0.3764393], [-1.9008133, -1.4312776, -0.5314561], [-1.0412215, -1.5773429, -2.0898648], [-0.0740952, -1.5478577, -1.9182976], [-1.2210828, -2.5028893, -2.3334269], [1.6403031, -4.3519848, -1.6030968], [0.6907637, -4.4791178, -1.78646], [1.8234428, -3.4332488, -1.8548591], [-2.332464, -4.3191108, 0.3710663], [-1.6749165, -4.4161499, 1.0945673], [-2.663344, -3.4130595, 0.4587135], [1.6787401, -1.4676579, -1.5895498], [1.8042491, -1.4683808, -0.6135222], [1.9691284, -0.5921861, -1.8871045], [2.0177788, -1.5700244, 1.1457336], [1.1343262, -1.5367783, 1.5764315], [2.3187166, -2.4899023, 1.2541941], [2.1824622, -4.3930741, 1.0311736], [2.8554348, -5.0463468, 1.2384494], [2.0031126, -4.4627308, 0.0573709], [-0.4152401, 4.3842505, 2.2876138], [-0.5679834, 5.06537, 2.9479304], [-1.1127701, 4.487654, 1.6038264], [-1.1046614, 1.2760042, -2.0786808], [-1.2737935, 0.3860585, -2.4278927], [-1.5421398, 1.2973941, -1.1984656], [-0.3343417, 1.482161, 2.3157837], [0.5348035, 1.4101401, 1.8606636], [-0.3797162, 2.3986009, 2.6294913], [-2.2749913, 4.3768946, 0.233996], [-1.7988517, 4.3462894, -0.6427551], [-2.9812564, 5.0213482, 0.1371479], [-0.9933823, 4.1431644, -2.0843627], [-0.0298852, 4.2701065, -2.0022747], [-1.1087883, 3.2216193, -2.3688039], [1.6344386, 1.3954288, -1.6918288], [0.6638061, 1.3608461, -1.8412994], [1.8936812, 2.3076806, -1.9040298], [-2.3215084, 1.4785163, 0.3929266], [-2.6250091, 2.4001087, 0.3878313], [-1.6192612, 1.4525318, 1.0812056], [1.8145415, 4.2491098, -1.649936], [2.4020909, 4.8900295, -2.0585139], [1.9809374, 4.28431, -0.6721155], [2.1692949, 4.1320169, 1.0124794], [2.3810198, 3.2011114, 1.193355], [1.3372868, 4.3008339, 1.4866278], [2.0958108, 1.2695995, 1.0227267], [2.4183264, 0.3768526, 1.2284758], [1.9353623, 1.2702628, 0.0520364]], 'charge' : 0, 'spin' : 0 },
    'H3OpH2O' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'H'], 'coords' : [[-1.1997281, 0.0290201, 0.115747], [-1.6314241, -0.7333199, 0.5284065], [-1.6940017, 0.2956236, -0.6724225], [1.1997281, -0.0290201, 0.115747], [0.0, 0.0, 0.056538], [1.6940017, -0.2956236, -0.6724225], [1.6314241, 0.7333199, 0.5284065]], 'charge' : 1, 'spin' : 0 },
    'OHmH2O4cs' : { 'atoms' : ['O', 'O', 'H', 'H', 'O', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'H', 'H'], 'coords' : [[0.7941268, -0.1580495, 2.0660393], [-0.026599, -1.4686255, 0.0], [0.8535437, 1.317568, 0.7632695], [-0.0492092, 0.0556879, 2.4738562], [0.7060196, 1.9061093, 0.0], [0.8535437, 1.317568, -0.7632695], [0.7941268, -0.1580495, -2.0660393], [-0.0492092, 0.0556879, -2.4738562], [0.5253029, -0.7667811, -1.2898742], [-2.0377301, 0.3905888, 0.0], [-1.3911533, 1.1099929, 0.0], [-1.4545847, -0.4079553, 0.0], [0.5253029, -0.7667811, 1.2898742], [-0.0434809, -2.4269609, 0.0]], 'charge' : -1, 'spin' : 0 },
    'H3Op' : { 'atoms' : ['O', 'H', 'H', 'H'], 'coords' : [[0.0, 0.0, -0.2080105], [0.4697566, 0.8136422, 0.0693368], [0.4697566, -0.8136422, 0.0693368], [-0.9395131, 0.0, 0.0693368]], 'charge' : 1, 'spin' : 0 },
    'OHmH2O2' : { 'atoms' : ['H', 'O', 'H', 'H', 'O', 'H', 'O', 'H'], 'coords' : [[-1.4579587, 0.0040843, 0.0244594], [-2.4291204, 0.0733055, -0.3263151], [0.0, 0.0, 1.4813683], [-2.4564986, 0.9567436, -0.699417], [0.0, 0.0, 0.5211771], [1.4579587, -0.0040843, 0.0244594], [2.4291204, -0.0733055, -0.3263151], [2.4564986, -0.9567436, -0.699417]], 'charge' : -1, 'spin' : 0 },
    'H2O2' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[-1.3305065, 0.9990741, 0.0], [-1.0635455, 1.9210871, 0.0], [-0.5051326, 0.4919421, 0.0], [1.0052144, -0.7512682, 0.0], [0.9469851, -1.3304176, -0.7655752], [0.9469851, -1.3304176, 0.7655752]], 'charge' : 0, 'spin' : 0 },
    'H2O6c' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[-2.126243, 0.9197498, 1.8264958], [-2.1881222, 1.4620146, 2.6158476], [-1.5986358, 1.4265182, 1.1737716], [-0.6242004, -1.3533166, 1.1757103], [-1.1988054, -0.6794287, 1.5877759], [0.2729833, -0.9963075, 1.2645231], [1.8287624, 0.1754817, 0.6422639], [2.0392499, -0.2029455, -0.2401564], [2.6632512, 0.2979341, 1.1013749], [-0.8422995, -0.8438539, -1.4490439], [-1.5454762, -1.316182, -1.9016515], [-0.8613423, -1.1366364, -0.4986366], [-0.4537151, 1.8328264, -0.1915392], [-0.7281657, 1.1638105, -0.8364546], [0.4013688, 1.5054304, 0.1303423], [1.8922043, -0.9099854, -1.8706972], [2.1537551, -0.3720142, -2.6222492], [0.9154307, -0.9730955, -1.9076768]], 'charge' : 0, 'spin' : 0 },
    'H2O20' : { 'atoms' : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H'], 'coords' : [[1.5972952, 3.6342929, 0.1165628], [3.4570031, 1.5457807, 0.8715044], [-0.57424, 3.488956, 1.7956375], [0.6295601, 2.9030991, -2.4803288], [2.562887, 0.2380423, 2.9758015], [-0.0061254, 1.3030904, 3.6250687], [3.6694758, -0.1920224, -1.2507983], [-2.6588176, 2.7244506, 0.3711345], [1.9365179, 0.5910602, -3.3888087], [-1.9819526, 2.4899264, -2.3389977], [2.0356643, -2.4018164, 2.1787274], [-1.8845792, -0.6687213, 3.2997576], [2.7932149, -2.7984187, -0.3020458], [-3.7135037, 0.2090327, 1.2262852], [0.1062725, -1.293281, -3.6485737], [-2.3812958, -0.2589936, -3.0530727], [-0.7279685, -2.9024551, 2.5515924], [0.5941514, -3.461153, -1.784129], [-3.4834248, -1.6651538, -0.76021], [-1.6647964, -3.4265588, -0.0388993], [0.8366572, 3.6498, 0.7406949], [1.9869746, 4.5126507, 0.1363602], [2.8531582, 2.2552477, 0.5991012], [3.5564427, 0.9528663, 0.092497], [-0.860346, 4.2574733, 2.2971151], [-1.3676155, 3.1980612, 1.2597142], [1.0058172, 3.1354291, -1.6140348], [1.1176189, 2.1179309, -2.7977873], [3.2218852, 0.3339645, 3.6687595], [2.917311, 0.736237, 2.1825883], [0.8868802, 0.970391, 3.4227382], [-0.1583775, 2.0479994, 3.0212474], [3.3730831, -1.0735747, -0.971529], [3.1050897, 0.053996, -2.005135], [-2.4656562, 2.6610551, -0.5848248], [-3.0302259, 1.8642237, 0.6349499], [1.2448937, -0.1226649, -3.5032296], [2.3685947, 0.68444, -4.2421884], [-2.3988056, 3.1036186, -2.9499344], [-0.9990854, 2.6509609, -2.4091562], [1.089413, -2.5673533, 2.3485271], [2.2199554, -1.4937765, 2.4901053], [-2.5067142, -0.3437261, 2.6291954], [-1.2285821, 0.0523333, 3.4406625], [3.5521661, -3.3866115, -0.2562954], [2.5019631, -2.6479144, 0.6464344], [-4.6222885, 0.2638838, 1.5345724], [-3.6996256, -0.472578, 0.5153993], [0.251188, -2.0333038, -3.037317], [-0.7962964, -0.9523111, -3.4593251], [-2.2861386, 0.6788768, -2.8115617], [-2.7753977, -0.6983972, -2.2822631], [-0.9190239, -3.559496, 3.2264945], [-1.1717496, -2.0542602, 2.8528846], [0.7227644, -4.3091097, -2.2179802], [1.4210061, -3.2760638, -1.2803666], [-4.2486672, -2.1741634, -1.0420466], [-2.790722, -2.3358781, -0.4884945], [-0.8844205, -3.4528625, -0.6209771], [-1.3284624, -3.2865515, 0.8681973]], 'charge' : 0, 'spin' : 0 },
    'OHmH2O' : { 'atoms' : ['H', 'H', 'O', 'O', 'H'], 'coords' : [[-0.0913883, -0.2448128, 0.0130826], [1.5272733, 0.3383121, 0.6525543], [-1.2208378, -0.2356182, 0.0613147], [1.2462132, -0.220611, -0.0771381], [-1.4612604, 0.3627299, -0.6498135]], 'charge' : -1, 'spin' : 0 },
    'H2O6b' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[0.1144839, 1.5112333, 0.843513], [0.9588057, 1.471558, 0.3295002], [0.261291, 2.1241091, 1.5685011], [-0.0281704, -1.4370823, 0.9851182], [-0.0383127, -0.486232, 1.1718938], [-0.8642499, -1.5983938, 0.5055884], [-2.3537981, 1.3280181, -0.4618208], [-1.4871423, 1.5019404, -0.0453109], [-2.3472534, 1.794165, -1.3012567], [-2.4676711, -1.4576027, -0.4185461], [-2.5397542, -0.4850372, -0.4964169], [-3.26029, -1.745642, 0.0408062], [2.4264495, 1.14816, -0.4958637], [2.4838349, 0.1618456, -0.5091494], [2.5261922, 1.4338488, -1.4071863], [2.2762914, -1.5490049, -0.4262672], [1.4225637, -1.6349539, 0.0649769], [2.9167296, -2.0809295, 0.0519201]], 'charge' : 0, 'spin' : 0 },
    'H3OpH2O3' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'H'], 'coords' : [[0.4488552, 2.4506609, -0.112838], [0.5365645, 3.1082364, 0.5858367], [0.9587367, 2.7678608, -0.865725], [-2.3467622, -0.8366105, -0.112838], [-2.8764061, -0.55364, -0.865725], [-2.9600939, -1.0894397, 0.5858367], [1.897907, -1.6140505, -0.112838], [2.4235294, -2.0187966, 0.5858367], [1.9176694, -2.2142207, -0.865725], [0.0, 0.0, 0.4980372], [0.7206738, -0.6586452, 0.2267139], [0.2100666, 0.9534444, 0.2267139], [-0.9307403, -0.2947992, 0.2267139]], 'charge' : 1, 'spin' : 0 },
    'H2O5' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[-1.1268024, -1.5421147, 1.3115375], [-0.7661903, -2.0208888, 2.0619128], [-1.0425473, -0.5875811, 1.5375504], [0.0803781, -2.0408214, -1.0866271], [-0.5297873, -2.5676517, -1.6076619], [-0.3615744, -1.8901965, -0.2198678], [0.8653564, 2.1577269, -0.0051264], [0.9896838, 1.5118344, -0.7375001], [1.7477174, 2.4024394, 0.2841736], [-0.7980209, 1.1055424, 1.8833964], [-0.1925383, 1.5158527, 1.2242825], [-1.5800472, 1.6620636, 1.9082724], [1.1675335, 0.3103577, -1.9923084], [0.7679385, 0.5364348, -2.8354387], [0.7789005, -0.5529975, -1.7265952]], 'charge' : 0, 'spin' : 0 },
    'OHmH2O6' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H'], 'coords' : [[0.2309966, -1.7501603, -1.6137318], [0.3547704, -1.4268652, -2.5095198], [-0.4863599, -1.1511869, -1.2247679], [-0.8187277, -1.8096257, 1.5257334], [-0.3247737, -2.4623293, 1.0212858], [-1.1694889, -1.1908734, 0.8176224], [-0.9356309, 2.0631218, 1.0559169], [-1.2826168, 1.254229, 0.5973567], [-0.4402953, 2.4787494, 0.3372993], [1.0346387, 0.3062362, 2.1888664], [0.4723616, -0.4881286, 2.1220177], [0.4232868, 1.0226881, 1.9283641], [0.2656542, 1.7374371, -1.6823817], [1.0715576, 1.3171856, -1.3358043], [-0.4373017, 1.1167995, -1.3960602], [2.2703738, -0.0203258, -0.4001794], [1.9574443, 0.1070627, 0.5173217], [1.69198, -0.7185752, -0.7539803], [-1.4854286, -0.1445915, -0.4478202], [-2.3924405, -0.2408477, -0.7475386]], 'charge' : -1, 'spin' : 0 },
    'H2O3' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[1.1847029, 1.1150792, -0.0344641], [0.4939088, 0.9563767, 0.6340089], [2.0242676, 1.0811246, 0.4301417], [-1.1469443, 0.0697649, 1.1470196], [-1.2798308, -0.5232169, 1.8902833], [-1.0641398, -0.4956693, 0.356925], [-0.1633508, -1.0289346, -1.2401808], [0.4914771, -0.3248733, -1.0784838], [-0.5400907, -0.8496512, -2.1052499]], 'charge' : 0, 'spin' : 0 },
    'H2O8s4' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[1.9905911, -0.1084003, -1.4652291], [1.3302928, -0.8467868, -1.5304519], [2.7035395, -0.3103452, -2.07624], [-0.1084003, -1.9905911, 1.4652291], [-0.3103452, -2.7035395, 2.07624], [-0.8467868, -1.3302928, 1.5304519], [-1.9387532, -0.0255756, 1.4016092], [-2.1724666, 0.0263608, 0.4583695], [-1.398824, 0.7670441, 1.5644647], [0.0255756, -1.9387532, -1.4016092], [-0.0263608, -2.1724666, -0.4583695], [-0.7670441, -1.398824, -1.5644647], [-0.0255756, 1.9387532, -1.4016092], [0.7670441, 1.398824, -1.5644647], [0.0263608, 2.1724666, -0.4583695], [0.1084003, 1.9905911, 1.4652291], [0.8467868, 1.3302928, 1.5304519], [0.3103452, 2.7035395, 2.07624], [-1.9905911, 0.1084003, -1.4652291], [-1.3302928, 0.8467868, -1.5304519], [-2.7035395, 0.3103452, -2.07624], [1.9387532, 0.0255756, 1.4016092], [2.1724666, -0.0263608, 0.4583695], [1.398824, -0.7670441, 1.5644647]], 'charge' : 0, 'spin' : 0 },
    'H2O' : { 'atoms' : ['O', 'H', 'H'], 'coords' : [[0.0, 0.0, -0.3893611], [0.7629844, 0.0, 0.1946806], [-0.7629844, 0.0, 0.1946806]], 'charge' : 0, 'spin' : 0 },
    'H3OpH2O2' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'H'], 'coords' : [[0.4277117, 0.1671946, 2.0994891], [0.9102225, -0.5137364, 2.580987], [0.2026631, 0.8688028, 2.7211374], [0.4277117, 0.1671946, -2.0994891], [0.9102225, -0.5137364, -2.580987], [0.2026631, 0.8688028, -2.7211374], [-0.8926666, -0.107072, 0.0], [-0.3409152, -0.0421313, 0.8806907], [-0.3409152, -0.0421313, -0.8806907], [-1.5066975, -0.8531875, 0.0]], 'charge' : 1, 'spin' : 0 },
    'OHm' : { 'atoms' : ['O', 'H'], 'coords' : [[0.0, 0.0, -0.4822198], [0.0, 0.0, 0.4822198]], 'charge' : -1, 'spin' : 0 },
    'H3OpH2O6OHm' : { 'atoms' : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H'], 'coords' : [[-1.2924523, 1.9431423, 1.2226497], [0.0, 0.0, 2.191018], [2.3290367, 0.1477254, 1.2226497], [-1.0365844, -2.0908677, 1.2226497], [-0.1173738, -2.3666396, -1.1811466], [0.0, 0.0, -2.12815], [2.1082569, 1.0816711, -1.1811466], [-1.9908831, 1.2849685, -1.1811466], [-1.6121332, 1.7108041, 0.2902261], [-2.0329921, 2.3268088, 1.6977145], [-2.7471836, 0.6969876, -1.2446979], [-0.9985795, -2.9240272, 1.6977145], [-0.6755332, -2.2515503, 0.2902261], [-1.2036468, 0.7746315, -1.6091], [-0.5575433, 0.790631, 1.8329416], [-0.4059349, -0.8781622, 1.8329416], [1.9772008, 2.030637, -1.2446979], [-0.0690272, -1.4297045, -1.6091], [0.9634782, 0.0875312, 1.8329416], [0.7699828, -2.7276246, -1.2446979], [0.0, 0.0, -3.0886303], [1.272674, 0.6550729, -1.6091], [3.0315716, 0.5972183, 1.6977145], [2.2876664, 0.5407463, 0.2902261]], 'charge' : 0, 'spin' : 0 },
    'H3OpH2O62d' : { 'atoms' : ['O', 'H', 'O', 'H', 'O', 'O', 'H', 'O', 'H', 'H', 'H', 'H', 'H', 'H', 'O', 'O', 'H', 'H', 'H', 'H', 'H', 'H'], 'coords' : [[1.8434133, -0.9596003, -1.4867718], [1.2656736, -0.7391231, -2.2349178], [1.0793305, -0.5271854, 1.0282008], [1.4019284, -0.7017831, 0.1069993], [0.0, 0.0, -3.5118996], [-1.8434133, 0.9596003, -1.4867718], [-1.2656736, 0.7391231, -2.2349178], [-1.0793305, 0.5271854, 1.0282008], [-1.4019284, 0.7017831, 0.1069993], [0.0, 0.0, 1.033447], [-2.366676, 1.7244726, -1.7418206], [2.366676, -1.7244726, -1.7418206], [-1.7700161, 0.0335496, 1.5315311], [1.7700161, -0.0335496, 1.5315311], [-2.9914118, -0.724108, 2.4424819], [2.9914118, 0.724108, 2.4424819], [0.4334218, 0.6365549, -4.0940392], [-0.4334218, -0.6365549, -4.0940392], [-3.1797916, -1.6653136, 2.5011409], [-3.4352688, -0.3059062, 3.1864217], [3.4352688, 0.3059062, 3.1864217], [3.1797916, 1.6653136, 2.5011409]], 'charge' : 1, 'spin' : 0 },
    'H2O6c2' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[1.8630162, 1.9699095, 0.1190794], [0.9056974, 2.1787084, 0.023876], [2.1523045, 2.4204188, 0.916064], [0.7744836, -2.5983741, 0.1190794], [1.0199919, -3.0741598, 0.916064], [1.4339681, -1.8737112, 0.023876], [2.6374998, -0.6284646, -0.1190794], [2.3396656, 0.3049972, -0.023876], [3.1722964, -0.653741, -0.916064], [-1.8630162, -1.9699095, -0.1190794], [-2.1523045, -2.4204188, -0.916064], [-0.9056974, -2.1787084, -0.023876], [-0.7744836, 2.5983741, -0.1190794], [-1.0199919, 3.0741598, -0.916064], [-1.4339681, 1.8737112, -0.023876], [-2.6374998, 0.6284646, 0.1190794], [-3.1722964, 0.653741, 0.916064], [-2.3396656, -0.3049972, 0.023876]], 'charge' : 0, 'spin' : 0 },
    'H2O4' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[-1.4564924, 1.2872093, -0.0073689], [-0.492967, 1.4796927, -0.0098606], [-1.8194199, 1.7493731, 0.7521436], [-1.2872093, -1.4564924, 0.0073689], [-1.4796927, -0.492967, 0.0098606], [-1.7493731, -1.8194199, -0.7521436], [1.4564924, -1.2872093, -0.0073689], [0.492967, -1.4796927, -0.0098606], [1.8194199, -1.7493731, 0.7521436], [1.2872093, 1.4564924, 0.0073689], [1.4796927, 0.492967, 0.0098606], [1.7493731, 1.8194199, -0.7521436]], 'charge' : 0, 'spin' : 0 },
    'H3OpH2O63d' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'H'], 'coords' : [[-0.8017039, -2.1786143, 1.1619204], [-0.9501754, -2.0789803, 0.2031405], [-1.4859153, -2.7513027, 1.5209001], [1.6520557, 0.0477159, -1.5591326], [2.3313221, 0.019446, -2.2418055], [1.0863703, -0.7331799, -1.6915753], [-0.7847047, -1.4545801, -1.5591326], [-1.1488203, -2.0287072, -2.2418055], [-1.1781375, -0.5742343, -1.6915753], [2.2875872, 0.3950112, 1.1619204], [2.2755374, 0.2166141, 0.2031405], [3.1256557, 0.088811, 1.5209001], [-1.4858833, 1.7836031, 1.1619204], [-1.6397404, 2.6624917, 1.5209001], [-1.325362, 1.8623661, 0.2031405], [-0.867351, 1.4068643, -1.5591326], [-1.1825018, 2.0092611, -2.2418055], [0.0917672, 1.3074142, -1.6915753], [0.0, 0.0, 2.2211451], [-0.361831, -0.8848668, 1.8661706], [0.9472326, 0.1290786, 1.8661706], [-0.5854016, 0.7557882, 1.8661706]], 'charge' : 1, 'spin' : 0 },
    'H2O6' : { 'atoms' : ['O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H', 'O', 'H', 'H'], 'coords' : [[1.428258, -0.4003671, 1.4106584], [0.4962869, -0.7311723, 1.5266043], [1.9346614, -0.7027163, 2.168276], [0.7372918, 1.9893172, -0.0695132], [1.0637087, 1.4457976, 0.6652451], [1.0989464, 1.538598, -0.8448107], [1.6255509, -0.5260892, -1.3633321], [2.3546166, -0.9528444, -1.820267], [1.7656466, -0.6667515, -0.4068662], [-1.9174627, 1.2454847, -0.0758847], [-1.0195282, 1.6453398, -0.0861304], [-2.5435506, 1.9684754, -0.1615432], [-1.2037871, -1.3111993, -1.377411], [-0.2717808, -1.1292498, -1.5755794], [-1.6090141, -0.4363975, -1.2823355], [-1.1315777, -1.1588986, 1.4243135], [-1.1917048, -1.4941891, 0.5062038], [-1.6165614, -0.3231376, 1.3623725]], 'charge' : 0, 'spin' : 0 },
    'OHmH2O5' : { 'atoms' : ['H', 'H', 'H', 'O', 'O', 'O', 'O', 'H', 'H', 'H', 'H', 'O', 'H', 'O', 'H', 'H', 'H'], 'coords' : [[1.5012774, -0.3881987, -0.0687443], [0.069545, -2.003145, -0.0936974], [-1.2974869, -0.3758535, 0.320736], [0.0939922, 1.9749663, -1.1702309], [1.9370204, -0.2434579, -0.9535726], [-0.1555232, -2.513446, -0.9107403], [-1.9616924, -0.0600197, -0.3289319], [0.7757186, 1.2824186, -1.2650483], [1.5614874, -0.972249, -1.4627696], [-0.9487013, -2.0487143, -1.2036386], [-1.4718232, 0.6336937, -0.7997962], [0.3330052, -0.7184359, 1.1400861], [0.5549234, -1.1890187, 1.9465594], [-0.1124113, 1.8190878, 1.7342148], [0.0845176, 2.1340983, -0.2091616], [-1.0736008, 1.82259, 1.7483154], [0.1097518, 0.845684, 1.5764199]], 'charge' : -1, 'spin' : 0 },
    'OHmH2O3' : { 'atoms' : ['H', 'O', 'H', 'H', 'H', 'H', 'O', 'O', 'H', 'O', 'H'], 'coords' : [[1.8305203, -1.8968462, -0.6939628], [2.2022879, -1.1519545, -0.2166142], [1.3805632, -0.6848564, 0.1503098], [0.0, 0.0, 1.6205929], [-0.0971786, 1.538031, 0.1503098], [0.7274569, 2.5337002, -0.6939628], [-0.1035221, 2.4832145, -0.2166142], [0.0, 0.0, 0.6602086], [-1.2833846, -0.8531746, 0.1503098], [-2.0987658, -1.33126, -0.2166142], [-2.5579772, -0.636854, -0.6939628]], 'charge' : -1, 'spin' : 0 },
    'H2O20fc' : { 'atoms' : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H'], 'coords' : [[-0.0040675, 1.9370846, -5.6837488], [-2.0081019, -0.1385428, -5.7351775], [-0.1385428, 2.0081019, 5.7351775], [-1.9370846, -0.0040675, 5.6837488], [-1.9737684, 0.0227542, 2.8081295], [1.9737684, -0.0227542, 2.8081295], [-1.9687741, 0.0215608, -2.8939543], [1.9687741, -0.0215608, -2.8939543], [0.0040675, -1.9370846, -5.6837488], [2.0081019, 0.1385428, -5.7351775], [0.1385428, -2.0081019, 5.7351775], [1.9370846, 0.0040675, 5.6837488], [1.9809313, 0.0891778, -0.0443893], [0.0227542, 1.9737684, -2.8081295], [0.0215608, 1.9687741, 2.8939543], [-0.0215608, -1.9687741, 2.8939543], [-0.0891778, 1.9809313, 0.0443893], [0.0891778, -1.9809313, 0.0443893], [-0.0227542, -1.9737684, -2.8081295], [-1.9809313, -0.0891778, -0.0443893], [-0.795323, 1.3838243, -5.803811], [-0.0529878, 2.253347, -4.7672247], [-1.3302275, -0.8605673, -5.7890316], [-2.6672294, -0.3202269, -6.4099495], [-0.3202269, 2.6672294, 6.4099495], [-0.8605673, 1.3302275, 5.7890316], [-1.3838243, -0.795323, 5.803811], [-2.253347, -0.0529878, 4.7672247], [-2.3675338, -0.0049056, 1.9217175], [-1.3512621, 0.7814913, 2.7945978], [2.3675338, 0.0049056, 1.9217175], [1.3512621, -0.7814913, 2.7945978], [-2.293172, -0.0316149, -3.8093685], [-1.3496505, 0.782936, -2.8870264], [1.3496505, -0.782936, -2.8870264], [2.293172, 0.0316149, -3.8093685], [0.0529878, -2.253347, -4.7672247], [0.795323, -1.3838243, -5.803811], [1.3302275, 0.8605673, -5.7890316], [2.6672294, 0.3202269, -6.4099495], [0.8605673, -1.3302275, 5.7890316], [0.3202269, -2.6672294, 6.4099495], [1.3838243, 0.795323, 5.803811], [2.253347, 0.0529878, 4.7672247], [1.3357225, 0.8279582, -0.0422844], [2.3528107, 0.0783222, -0.9411357], [-0.0049056, 2.3675338, -1.9217175], [0.7814913, 1.3512621, -2.7945978], [0.782936, 1.3496505, 2.8870264], [-0.0316149, 2.293172, 3.8093685], [0.0316149, -2.293172, 3.8093685], [-0.782936, -1.3496505, 2.8870264], [-0.8279582, 1.3357225, 0.0422844], [-0.0783222, 2.3528107, 0.9411357], [0.0783222, -2.3528107, 0.9411357], [0.8279582, -1.3357225, 0.0422844], [0.0049056, -2.3675338, -1.9217175], [-0.7814913, -1.3512621, -2.7945978], [-2.3528107, -0.0783222, -0.9411357], [-1.3357225, -0.8279582, -0.0422844]], 'charge' : 0, 'spin' : 0 },
}

In [None]:
relevant_systems = ['H2O', 'H2O2', 'H2O3', 'H2O4', 'H2O5', 'H2O6', 'OHm', 'H3Op']
cc_ens = {k: 0 for k in relevant_systems}

## Calculate CCSD energies for the relevant systems
To have as a reference and to train on them

In [None]:
for idx, sys in enumerate(relevant_systems):
    atstr = ''
    print(sys)
    for idx, at in enumerate(w27systems[sys]['atoms']):
        atstr += '{} {} {} {}\n'.format(at, *w27systems[sys]['coords'][idx])
    mol = gto.M(atom = atstr, charge = w27systems[sys]['charge'], spin = w27systems[sys]['spin'])
    mol.build()
    mol.max_memory = 32000
    #baseline PBE calculation
    print('HF calculation...')
    mf_ccc = scf.RHF(mol)
    mf_ccc.kernel()
    print('CCSD calculation...')
    this_cc = cc.CCSD(mf_ccc)
    this_cc.kernel()
    cc_en = this_cc.e_tot
    cc_ens[sys] = cc_en    

## Get only information of 2 systems
For the training, we get $H_2O$, $H_2O_2$ and their reaction energy.

In [None]:
rxnmols = []
rxnrefs = []
for idx, sys in enumerate(['H2O', 'H2O2']):
    atstr = ''
    print(sys)
    for idx, at in enumerate(w27systems[sys]['atoms']):
        atstr += '{} {} {} {}\n'.format(at, *w27systems[sys]['coords'][idx])
    mol = gto_ad.Mole(atom = atstr, charge = w27systems[sys]['charge'], spin = w27systems[sys]['spin'])
    mol.build()
    mol.max_memory = 32000
    rxnmols.append(mol)
    rxnrefs.append(cc_ens[sys])
rxnrefs.append(0.007926693227091634)

# rxnrefs contains the reference energies for the reaction H2O + H2O -> H2O2 + H2

# Train our networks

## Loss functions

In [None]:
@eqx.filter_value_and_grad
def opt_loss_erxn(model, mols, refs):
    #in this very specific case, since len(mols) = 2 and len(refs) = 3, we can just use the last entry in refs to find the reaction loss
    #we aim for PBE total energies, but with rxn energy difference
    total_loss = 0
    energy_loss = 0
    rxn_loss = 0
    energies = {idx: 0 for idx in range(len(mols))}
    for idx, mol in enumerate(mols):
        # print(idx, mol)
        mf = dft_ad.RKS(mol)
        custom_eval_xc = partial(eval_xc_gga_j2, xcmodel=model)
        mf.grids.level = 1
        mf.define_xc_(custom_eval_xc, 'GGA')
        mf.kernel()
        energies[idx] = mf.e_tot
        energy_loss += abs(mf.e_tot - refs[idx])
    #here, we assume the ordering in mols is h2o then h2o2
    rxn_h = 2*energies[0] + (-1*energies[1])
    rxn_loss = abs(rxn_h - refs[-1])
    energy_loss /= len(mols)
    total_loss = rxn_loss+energy_loss
    return total_loss[..., jnp.newaxis][0]


@eqx.filter_value_and_grad
def opt_loss_erxn_10(model, mols, refs):
    #in this very specific case, since len(mols) = 2 and len(refs) = 3, we can just use the last entry in refs to find the reaction loss
    #we aim for PBE total energies, but with rxn energy difference
    total_loss = 0
    energy_loss = 0
    rxn_loss = 0
    energies = {idx: 0 for idx in range(len(mols))}
    for idx, mol in enumerate(mols):
        # print(idx, mol)
        mf = dft_ad.RKS(mol)
        custom_eval_xc = partial(eval_xc_gga_j2, xcmodel=model)
        mf.grids.level = 1
        mf.define_xc_(custom_eval_xc, 'GGA')
        mf.kernel()
        energies[idx] = mf.e_tot
        energy_loss += abs(mf.e_tot - refs[idx])
    #here, we assume the ordering in mols is h2o then h2o2
    rxn_h = 2*energies[0] + (-1*energies[1])
    rxn_loss = abs(rxn_h - refs[-1])
    energy_loss /= len(mols)
    total_loss = 10*rxn_loss+energy_loss
    return total_loss[..., jnp.newaxis][0]

@eqx.filter_value_and_grad
def opt_loss_rxn(model, mols, refs):
    #in this very specific case, since len(mols) = 2 and len(refs) = 3, we can just use the last entry in refs to find the reaction loss
    #we aim for PBE total energies, but with rxn energy difference
    total_loss = 0
    energies = {idx: 0 for idx in range(len(mols))}
    for idx, mol in enumerate(mols):
        # print(idx, mol)
        mf = dft_ad.RKS(mol)
        custom_eval_xc = partial(eval_xc_gga_j2, xcmodel=model)
        mf.grids.level = 1
        mf.define_xc_(custom_eval_xc, 'GGA')
        mf.kernel()
        energies[idx] = mf.e_tot
    #here, we assume the ordering in mols is h2o then h2o2
    rxn_h = 2*energies[0] + (-1*energies[1])
    rxn_loss = abs(rxn_h - refs[-1])
    total_loss += rxn_loss

    return total_loss[..., jnp.newaxis][0]

def eval_xc_gga_j2(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None,
                 xcmodel = None):
    #we only expect there to be a rho0 array, but I unpack it as (rho, deriv) here to be in line with the
    #pyscf example -- the size of the 'rho' array depends on the xc type (LDA, GGA, etc.)
    #so since LDA calculation, check for size first.
    try:
        rho0, dx, dy, dz = rho[:4]
        sigma = jnp.array(dx**2+dy**2+dz**2)
    except:
        rho0, drho = rho[:4]
        sigma = jnp.array(drho**2)
    rho0 = jnp.array(rho0)
    # sigma = jnp.array(dx**2+dy**2+dz**2)
    # print('DEBUG eval_xc_gga_j: rho0/sigma shapes: ', rho0.shape, sigma.shape)
    # rhosig = (rho0, sigma)
    rhosig = jnp.stack([rho0, sigma], axis=1)
    print(rhosig.shape)
    #calculate the "custom" energy with rho -- THIS IS e
    #cast back to np.array since that's what pyscf works with
    #pass as tuple -- (rho, sigma)
    exc = jax.vmap(xcmodel)(rhosig)
    exc = jnp.array(exc)/rho0
    # exc = jnp.array(jax.vmap(xcmodel)( rhosig ) )/rho0
    # print('exc shape = {}'.format(exc.shape))
    #first order derivatives w.r.t. rho and sigma
    vrho_f = eqx.filter_grad(xcmodel)
    vrhosigma = jnp.array(jax.vmap(vrho_f)( rhosig ))
    # print('vrhosigma shape:', vrhosigma.shape)
    vxc = (vrhosigma[:, 0], vrhosigma[:, 1], None, None)

    # v2_f = eqx.filter_hessian(derivable_custom_pbe_epsilon)
    v2_f = jax.hessian(xcmodel)
    # v2_f = jax.hessian(custom_pbe_epsilon, argnums=[0, 1])
    v2 = jnp.array(jax.vmap(v2_f)( rhosig ))
    print('v2 shape', v2.shape)
    v2rho2 = v2[:, 0, 0]
    v2rhosigma = v2[:, 0, 1]
    v2sigma2 = v2[:, 1, 1]
    v2lapl2 = None
    vtau2 = None
    v2rholapl = None
    v2rhotau = None
    v2lapltau = None
    v2sigmalapl = None
    v2sigmatau = None
    # 2nd order functional derivative
    fxc = (v2rho2, v2rhosigma, v2sigma2, v2lapl2, vtau2, v2rholapl, v2rhotau, v2lapltau, v2sigmalapl, v2sigmatau)
    #3rd order
    kxc = None
    
    return exc, vxc, fxc, kxc

class RXCModel(eqx.Module):
    xnet: eqx.Module
    cnet: eqx.Module

    def __init__(self, xnet, cnet):
        self.xnet = xnet
        self.cnet = cnet
        
    def __call__(self, inputs):
        #this generate epsilon, not exc -- divide end result by rho when needed
        rho = inputs[0]
        sigma = inputs[1]
        # print('RXCModel call - inputs {}'.format(inputs))
        return rho*(lda_x(rho)*jax.vmap(self.xnet)(inputs[..., jnp.newaxis]) + pw92c_unpolarized(rho)*jax.vmap(self.cnet)(inputs[..., jnp.newaxis])).flatten()[0]
        # return rho*(lda_x(rho)*self.xnet(inputs) + pw92c_unpolarized(rho)*self.cnet(inputs)).flatten()[0]

## Load the pretrained networks and perform the training

In [None]:
path_to_xnet = 'pretrained_sara/GGA_FxNet_G_d3_n16_s42_v_8000'
path_to_cnet = 'pretrained_sara/GGA_FcNet_G_d3_n16_s42_v_8000'
xnet = load_xcquinox_model(path_to_xnet)
cnet = load_xcquinox_model(path_to_cnet)
xc_model = RXCModel(xnet, cnet)


In [None]:
# Optimizer settings
OPT_INIT_LR = 1e-2
OPT_END_LR = 1e-5
OPTSTEPS = 100
OPTDECAYBEGIN = 30
# OPTDECAYRATE = 0.95
scheduler = optax.linear_schedule(
    init_value = OPT_INIT_LR,
    transition_steps = OPTSTEPS-OPTDECAYBEGIN,
    transition_begin = OPTDECAYBEGIN,
    end_value = OPT_END_LR,
)

In [None]:
opt_rxn_snmxc = Optimizer(model=test_model, optim=or_opt, mols = rxnmols, refs = rxnrefs, loss=opt_loss_erxn, print_every=1, steps=OPTSTEPS)
opt_rxn_snmxc_10 = Optimizer(model=test_model, optim=or10_opt, mols = rxnmols, refs = rxnrefs, loss=opt_loss_erxn_10, print_every=1, steps=OPTSTEPS)
opt_rxnonly_snmxc = Optimizer(model=test_model, optim=oro_opt, mols = rxnmols, refs = rxnrefs, loss=opt_loss_rxn, print_every=1, steps=OPTSTEPS)