In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import rcParams
import warnings
warnings.filterwarnings('ignore')

import re
import jax.numpy as jnp
from jax import grad, jit, partial, jacrev
import ticktack
from ticktack import fitting_tinygp as fitting
from tqdm import tqdm
import scipy

rcParams['figure.figsize'] = (8.0, 4.0)

In [2]:
cbm = ticktack.load_presaved_model('Guttler14', production_rate_units = 'atoms/cm^2/s')
cf = fitting.CarbonFitter(cbm)
cf.load_data('inject_recovery_sine.csv')
cf.prepare_function(use_control_points=True, interp='gp')

In [3]:
@jit
def sine(t):
    prod =  1.87 + 0.18 * 1.87 * jnp.sin(2 * jnp.pi / 11 * t + jnp.pi/2)
    prod = prod * (t>=cf.start) + (1.87 + 0.18 * 1.87 * jnp.sin(2 * jnp.pi / 11 * cf.start + jnp.pi/2)) * (1-(t>=cf.start))
    return prod

In [4]:
params = jnp.array([2.00967368, 1.83722005, 1.64131659, 1.53621338, 1.54956525,
       1.66283525, 1.82696321, 2.00391815, 2.14344128, 2.19709936,
       2.14684926, 2.00827427, 1.82531541, 1.65642762, 1.55562115,
       1.55525348, 1.65562488, 1.82481902, 2.00895719, 2.14935574,
       2.20119142, 2.14776001, 2.00581124, 1.82028481, 1.65001027,
       1.54909134, 1.54961388, 1.65149079, 1.82251794, 2.00851632,
       2.15057973, 2.20372406, 2.15117623, 2.00966834, 1.82418043,
       1.6536497 , 1.55222482, 1.55207972, 1.65320969, 1.82348198,
       2.00878203, 2.15020536, 2.20280941, 2.14981545, 2.00801343,
       1.82240925, 1.65185187, 1.55046371, 1.5504303 , 1.65178338,
       1.82232604, 2.00791393, 2.14964303, 2.20254582, 2.14988553,
       2.00837896, 1.82298691, 1.65263095, 1.55141962, 1.55149937,
       1.65286832, 1.82335548, 2.00883447, 2.15039692, 2.20306118,
       2.15007592, 2.00821152, 1.82244066, 1.65167066, 1.55005743,
       1.54980055, 1.65094448, 1.82131178, 2.00682247, 2.14862099,
       2.20176739, 2.1494566 , 2.00843379, 1.82361733, 1.65381228,
       1.55300543, 1.55320536, 1.65424806, 1.82381264, 2.00770239,
       2.147114  , 2.19746144, 2.14260138, 2.00020468, 1.81641772,
       1.65117851, 1.55888841, 1.57012058, 1.68018942, 1.84881622,
       2.00447377, 2.07130876, 2.05570146])

In [5]:
cf.offset

DeviceArray(18.66458541, dtype=float64)

In [6]:
cf.offset = 0
cf.offset

0

In [7]:
print("total likelihood: ", cf.gp_likelihood(params))
print("gp likelihood: ", cf.gp_neg_log_likelihood(params))
print("chi2: ", cf.loss_chi2(params))

total likelihood:  58.70041500024627
gp likelihood:  58.68090153408154
chi2:  0.019513466164721688


In [8]:
%%time
jacrev(cf.gp_likelihood)(params)

CPU times: user 2min 42s, sys: 5min 24s, total: 8min 7s
Wall time: 1min 10s


DeviceArray([-0.86029809, -0.10922936, -0.22054723, -0.29326931,
             -0.27633415, -0.20715559, -0.15855908, -0.08260692,
             -0.02556043,  0.00207064, -0.00817252, -0.05209339,
             -0.11467799, -0.17374297, -0.20879373, -0.20671075,
             -0.16714601, -0.10261586, -0.03443842,  0.01417858,
              0.02572187, -0.00554549, -0.07182303, -0.15335132,
             -0.22533212, -0.26479236, -0.25872197, -0.20787171,
             -0.12648843, -0.0385261 ,  0.03055298,  0.06090601,
              0.04517063, -0.00995191, -0.08538025, -0.1560441 ,
             -0.19884434, -0.19988519, -0.15905385, -0.08970396,
             -0.01459562,  0.0413668 ,  0.05945354,  0.03265323,
             -0.03170948, -0.11413205, -0.18973704, -0.23517436,
             -0.23692243, -0.19474167, -0.12253916, -0.04336755,
              0.01764172,  0.0411054 ,  0.01997701, -0.03900043,
             -0.11676348, -0.18840537, -0.23107636, -0.23126269,
             -0.18905052,