# Transient Model
This model aims to simulate the sulphate and pyrite concentrations and isotopic compositions in the sediment after a long period of stasis, after which a large pile of sediment is dumped on top and the evolution is monitored. Requires the `py-pde` package.

### Imports and Definitions

In [1]:
import numpy as np
import pandas as pd
import ipywidgets as wdg
from IPython.display import display
from plotly_default import go, graph_config, sel_trace
from plotly.subplots import make_subplots
import pde
import random

In [2]:
# Conversion between isotope ratios and delta notation

VCDT = 0.0441626 # vienna canyon diablo troilite standard 34S/32S ratio

def delta_to_R(delta, std=VCDT):
    '''Takes in a d34S; outputs a 34S/32S ratio'''
    return std * ((delta/1000) + 1)

def R_to_delta(R, std=VCDT):
    '''Takes in a 34S/32S ratio; outputs the d34S'''
    return 1000*((R/std) - 1)

def wt_to_M(wt, Mr=55.845, density=1.7):
    '''Converts a weight percent of a substance with molar mass Mr (default iron) into a molar concentration
    in a sediment of specified density (default 1.7 g/cm³)'''
    # g/dm³  = 1000 * g/cm³ = 1000 * density * wt% / 100 = 10 * density * wt%
    return 10. * density * wt / Mr

def M_to_wt(M, Mr=55.845, density=1.7):
    return M * Mr / (10. * density)

def ratio_to_concs(R, C):
    '''Takes a ratio of two concentrations R = a/b
    and a total concentration C = a + b
    and returns an (a, b) tuple'''
    return (C/(1 + 1/R) , C/(1+R))

# The number of seconds in 1 thousand years
s_in_ka = 1000 * 365.24 * 24 * 60 * 60

### Modelling Pyrite Formation as a Function of Depth and Time

$$
\frac{\partial [^{34}\mathrm{SO}_4]}{\partial t} = D \nabla^2 [^{34}\mathrm{SO}_4] - 2 \frac{\partial [^{34}\mathrm{FeS}_2]}{\partial t}
$$
$$
\frac{\partial [^{32}\mathrm{SO}_4]}{\partial t} = D \nabla^2 [^{32}\mathrm{SO}_4] - 2 \frac{\partial [^{32}\mathrm{FeS}_2]}{\partial t}
$$
$$
\frac{\partial [^{34}\mathrm{FeS}_2]}{\partial t} = \frac{1}{2} k_{MSR} \frac{[^{34}\mathrm{SO}_4]^2 + [^{32}\mathrm{SO}_4][^{34}\mathrm{SO}_4]}{[^{32}\mathrm{SO}_4]/\alpha_{MSR} + [^{34}\mathrm{SO}_4]}
$$
$$
\frac{\partial [^{32}\mathrm{FeS}_2]}{\partial t} = \frac{1}{2} k_{MSR} \frac{[^{32}\mathrm{SO}_4]^2 + [^{32}\mathrm{SO}_4][^{34}\mathrm{SO}_4]}{\alpha_{MSR}[^{34}\mathrm{SO}_4] + [^{32}\mathrm{SO}_4]}
$$

In [3]:
# Define globally-used variables

# Create a 1D grid
# n_points = 60
# max_depth = 1
grid = pde.CartesianGrid([[-1,0]], [20], periodic=[False])
zs = grid.coordinate_arrays[0]

# Time steps
# max_time = 10 # ka
# delta_time = 1e-3 # ka: timestep for calculations
time_slider_resolution = 1e-2 # ka

# Set up memory to store PDE results
state_storage = pde.MemoryStorage()

# Dictionary for processed results
processed_results = {}

In [4]:
# Define the PDEs for the general diagenetic equation

class DiageneticEqs(pde.PDEBase):
    """General Diagenetic Equations modelling sulphate to pyrite conversion"""
    
    def __init__(self, SO4_0, d34S_SO4_0, epsilon_MSR, D_prime, k_prime, v_prime):
        # self.SO4_0 = SO4_0
        # self.d34S_SO4_0 = d34S_SO4_0
        # self.epsilon_MSR = epsilon_MSR
        self.D_prime = D_prime
        self.k_prime = k_prime
        self.v_prime = v_prime
        
        R_SO4_0 = delta_to_R(d34S_SO4_0)
        self.SO4_34_0 , self.SO4_32_0 = ratio_to_concs(R_SO4_0, SO4_0)
        
        self.a_MSR = delta_to_R(d34S_SO4_0 - epsilon_MSR)/R_SO4_0 # MSR fractionation factor
        

        # Define boundary conditions: [lower, upper]
        self.SO4_32_lapbcs = [[{'derivative': 0}, {'value': self.SO4_32_0}]]
        self.SO4_34_lapbcs = [[{'derivative': 0}, {'value': self.SO4_34_0}]]
        
        self.SO4_32_gradbcs = [[{'derivative': 0}, {'value': self.SO4_32_0}]]
        self.SO4_34_gradbcs = [[{'derivative': 0}, {'value': self.SO4_34_0}]]
        self.py_gradbcs = [[{'derivative': 0}, {'value': 0}]]
        
    def evolution_rate(self, state, t=0):
        SO4_34, SO4_32, py_34, py_32 = state
        
        # Attempt to smooth out discontinuities when adding sediment with non-zero background sedimentation
        if t == 0:
            v_sed = 0
        else:
            v_sed = self.v_prime
        
        # Calculate pyrite production rate for each isotope
        py_34_prod = 0.5 * self.k_prime * (SO4_34**2 + SO4_32*SO4_34)/((1/self.a_MSR)*SO4_32 + SO4_34)
        py_32_prod = 0.5 * self.k_prime * (SO4_32**2 + SO4_32*SO4_34)/(self.a_MSR*SO4_34 + SO4_32)
        
        py_34_t = + v_sed*py_34.gradient(self.py_gradbcs) + py_34_prod
        py_32_t = + v_sed*py_32.gradient(self.py_gradbcs) + py_32_prod

        SO4_34_t = self.D_prime * SO4_34.laplace(self.SO4_34_lapbcs) + v_sed*SO4_34.gradient(self.SO4_34_gradbcs) - 2 * py_34_prod
        SO4_32_t = self.D_prime * SO4_32.laplace(self.SO4_32_lapbcs) + v_sed*SO4_32.gradient(self.SO4_32_gradbcs) - 2 * py_32_prod
        
        return pde.FieldCollection([SO4_34_t, SO4_32_t, py_34_t, py_32_t])
        
        

In [5]:
def set_up_pdes(SO4_0, d34S_SO4_0, epsilon_MSR, D_SO4, k_MSR, v, max_t, dt, max_z, dz):
    
    n_points = int(max_z/dz)
    
    globals()['grid'] = pde.CartesianGrid([[-1,0]], [n_points], periodic=[False])
    
    # Calcuate scaled parameters
    D_prime = D_SO4*max_t/max_z**2
    v_prime = v*max_t/max_z
    k_prime = k_MSR*max_t

    eqs = DiageneticEqs(SO4_0, d34S_SO4_0, epsilon_MSR, D_prime, k_prime, v_prime)
    
    # Set initial field values for sulphate
    # bcs = [{'derivative': 0}, {'value': SO4_0}]
    SO4_34_field = pde.ScalarField(grid, [eqs.SO4_34_0]*n_points)
    SO4_32_field = pde.ScalarField(grid, [eqs.SO4_32_0]*n_points)

    # For pyrite conc.
    py_34_field = pde.ScalarField(grid, [0]*n_points)
    py_32_field = pde.ScalarField(grid, [0]*n_points)
    
    fields = pde.FieldCollection([SO4_34_field, SO4_32_field, py_34_field, py_32_field])
    
    return fields, eqs

In [6]:
def create_sliders(file='./Controls/Transient_params_v1.csv', prefix='w', n_cols=2):
    w_df = pd.read_csv(file)
    w_dict = {}

    for w in w_df.iterrows():

        # Make a new variable for each widget
        vars()[prefix+'_'+w[1].var_name] = wdg.FloatSlider(value = w[1].value,
                                                     min = w[1].start,
                                                     max = w[1].stop,
                                                     step = w[1].step,
                                                     description = w[1].display_name,
                                                     layout=wdg.Layout(width='100%')
                                                    )

        # Add to a dictionary of original variables and widgets
        w_dict[w[1].var_name] = vars()[prefix+'_'+w[1].var_name]
        
    # Arrange widgets into HBoxes and VBoxes
    box_list = []
    w_list = list(w_dict.values())
    
    for i in range(0, len(w_dict), n_cols):
        box_list.append(wdg.HBox(w_list[i:i+n_cols]))
        
    # Create boxes for max time, depth and resolution
    w_max_t = wdg.FloatText(value=4., description='Max time (ka)', style={'description_width': 'initial'})
    w_dt = wdg.FloatText(value=4e-3, description='Timestep (ka)', style={'description_width': 'initial'})
    w_max_z = wdg.FloatText(value=1., description='Max depth (m)', style={'description_width': 'initial'})
    w_dz = wdg.FloatText(value=0.05, description='Resolution (m)', style={'description_width': 'initial'})
    
    box_list.append(wdg.HBox([w_max_t, w_dt, w_max_z, w_dz]))
    
    w_dict['max_t'] = w_max_t
    w_dict['dt'] = w_dt
    w_dict['max_z'] = w_max_z
    w_dict['dz'] = w_dz
    
    return w_dict, wdg.VBox(box_list)

In [7]:
def calculate_delta(storage):
    n_times = len(storage)
    n_spaces = len(storage[0][0].data)
    
    # Initialise arrays for outputs
    SO4 = np.empty((n_times, n_spaces))
    py = np.empty((n_times, n_spaces))
    d34S_SO4 = np.empty((n_times, n_spaces))
    d34S_py = np.empty((n_times, n_spaces))
    
    # Array of nans for failed maths
    fail_array = np.full(n_spaces, np.nan)
    
    for t, fieldcol in enumerate(storage):
        SO4[t] = fieldcol[0].data + fieldcol[1].data
        py[t] = fieldcol[2].data + fieldcol[3].data
        
        with np.errstate(all='ignore'):
            d34S_SO4[t] = R_to_delta(fieldcol[0].data/fieldcol[1].data)
        # except FloatingPointError:
        #     d34S_SO4[t] = fail_array

        with np.errstate(all='ignore'):
            d34S_py[t] = R_to_delta(fieldcol[2].data/fieldcol[3].data)
        # except FloatingPointError:
        #     d34S_py[t] = fail_array
        
    return {'SO4': SO4, 'py': py, 'd34S_SO4': d34S_SO4, 'd34S_py': d34S_py}

In [8]:
def create_depth_plot():

    plot = go.FigureWidget(make_subplots(rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.01))
    plot.add_trace(go.Scatter(name='[SO<sub>4</sub>]', marker_color='RoyalBlue', mode='lines'), row=1, col=1)
    plot.add_trace(go.Scatter(name='[FeS<sub>2</sub>]', marker_color='MediumSeaGreen', mode='lines'), row=1, col=1)
    plot.add_trace(go.Scatter(name='δ<sup>34</sup>S<sub>SO4</sub>', marker_color='Navy', mode='lines'), row=1, col=2)
    plot.add_trace(go.Scatter(name='δ<sup>34</sup>S<sub>py</sub>', marker_color='DarkGreen', mode='lines'), row=1, col=2)

    plot.update_layout(
        yaxis_title='Depth, <i>z</i> (m)',
        xaxis_rangemode = 'tozero',
        hovermode='y',
        margin=dict(t=60),
        width=1000, height=600
        )

    plot.update_xaxes(title_text='Concentration (mol dm<sup>−3</sup>)', row=1, col=1)
    plot.update_xaxes(title_text='δ<sup>34</sup>S (‰ VCDT)', row=1, col=2)

    return plot

def create_cross_plot():    
    cross_plot = go.FigureWidget()
    cross_plot.add_trace(go.Scatter(mode='markers',
                                    marker_line={'color': 'Black', 'width': 1}, marker_coloraxis='coloraxis'))

    cross_plot.update_layout(xaxis_title='[Fe<sub>py</sub>] (wt%)',
                            #  xaxis_rangemode='tozero',
                             yaxis_title='δ<sup>34</sup>S<sub>py</sub> (‰ VCDT)',
                             coloraxis=dict(colorscale='viridis', colorbar_ticks='outside',
                                            colorbar_thickness=15, colorbar_title='<i>z</i> (m)'),
                             margin_t=60,
                             height=400, width=500)
    
    return cross_plot

def pde_plot(time):
    
    global zs

    kas = time*time_slider_resolution
    time_out.value=f'Current Stasis Time: <b>{kas:.2f}</b> ka'
    
    with depth_plot.batch_update():
        
        SO4_tr = depth_plot.select_traces({'name': '[SO<sub>4</sub>]'}).__next__()
        SO4_tr.x = processed_results['SO4'][time]
        SO4_tr.y = zs
        
        py_tr = depth_plot.select_traces({'name': '[FeS<sub>2</sub>]'}).__next__()
        py_tr.x = processed_results['py'][time]
        py_tr.y = zs
        
        d_SO4_tr = depth_plot.select_traces({'name': 'δ<sup>34</sup>S<sub>SO4</sub>'}).__next__()
        d_SO4_tr.x = processed_results['d34S_SO4'][time]
        d_SO4_tr.y = zs
        
        d_py_tr = depth_plot.select_traces({'name': 'δ<sup>34</sup>S<sub>py</sub>'}).__next__()
        d_py_tr.x = processed_results['d34S_py'][time]
        d_py_tr.y = zs
        
        
    with cross_plot.batch_update():
        tr = cross_plot.select_traces().__next__()
        tr.x = M_to_wt(processed_results['py'][time])
        tr.y = processed_results['d34S_py'][time]
        tr.text = [f'z={z:.3f}' for z in zs]
        tr.marker.color = zs
        
        
def run_model(button=None):
    model_out.clear_output()
    with model_out:
        print('Processing...  ', end='')
        pde_args = {par: slider.value for par, slider in w_dict.items()}
        
        fields, eqs = set_up_pdes(**pde_args)
    
        # Reset the storage
        globals()['state_storage'] = pde.MemoryStorage()
        
        # Calculate timesteps
        delta_time = pde_args['dt']/pde_args['max_t']
        globals()['time_slider_resolution'] = 10**round(np.log10(pde_args['max_t']) - 3) # So one step on slider will be a power of 10 years
        time_slider.max = play.max = time_return.max = pde_args['max_t']/time_slider_resolution # Set upper limit of time slider
        tracker_time = time_slider_resolution/pde_args['max_t']
        
        # Solve the equations
        eqs.solve(fields, t_range=1, dt=delta_time, tracker=state_storage.tracker(tracker_time))
        
        # calculate total concentrations and d34S
        globals()['processed_results'] = calculate_delta(state_storage)
        
        # calculate z values
        globals()['zs'] = grid.coordinate_arrays[0] * pde_args['max_z']
        
        # Refresh the graph
        depth_plot.update_layout(yaxis_range=[-pde_args['max_z'], 0])
        cross_plot.update_layout(coloraxis_cmin=-pde_args['max_z'], coloraxis_cmax=0)
        pde_plot(time_slider.value)
        print('Done!', end='')
        

def add_sediment(button=None):
    global state_storage
    sed_out.clear_output()
    with sed_out:
        print('Processing... ', end='')
        SO4_34, SO4_32, py_34, py_32 = state_storage[time_slider.value].data
        
        N = grid.shape[0]
        dz = w_dict['dz'].value
        reworked = rew_slider.value
        
        thickness = int(sed_slider.value/dz) # number of grid points of new sediment to add
        del_from_top = int(thickness*reworked) # number of points to delete from top of old distribution
        del_from_bottom = thickness - del_from_top # number of points to delete from bottom of old distribution
        
        # Average the composition of reworked pyrite
        py_34_0 = py_34[-del_from_top :].mean() * reworked
        py_32_0 = py_32[-del_from_top :].mean() * reworked
        
        # Set up a new set of pdes for the next run
        pde_args = {par: slider.value for par, slider in w_dict.items()}
        _ , eqs = set_up_pdes(**pde_args)
        
        # Add new arrays with values of homogenised sediment
        new_SO4_34 = np.full(N, eqs.SO4_34_0)
        new_SO4_32 = np.full(N, eqs.SO4_32_0)
        new_py_34 = np.full(N, py_34_0)
        new_py_32 = np.full(N, py_32_0)
        
        # Add in the old sediment in the appropriate place
        new_SO4_34[0 : N - thickness] = SO4_34[del_from_bottom : N - del_from_top]
        new_SO4_32[0 : N - thickness] = SO4_32[del_from_bottom : N - del_from_top]
        new_py_34[0 : N - thickness] = py_34[del_from_bottom : N - del_from_top]
        new_py_32[0 : N - thickness] = py_32[del_from_bottom : N - del_from_top]
        
        new_fields = pde.FieldCollection([
            pde.ScalarField(grid, new_SO4_34),
            pde.ScalarField(grid, new_SO4_32),
            pde.ScalarField(grid, new_py_34),
            pde.ScalarField(grid, new_py_32),
        ])
        
        # Solve the pdes again
        # Reset the storage
        state_storage = pde.MemoryStorage()

        # Calculate timesteps
        delta_time = pde_args['dt']/pde_args['max_t']
        globals()['time_slider_resolution'] = 10**round(np.log10(pde_args['max_t']) - 3) # So one step on slider will be a power of 10 years
        time_slider.max = play.max = time_return.max = pde_args['max_t']/time_slider_resolution # Set upper limit of time slider
        tracker_time = time_slider_resolution/pde_args['max_t']
        
        # Solve the equations
        eqs.solve(new_fields, t_range=1, dt=delta_time, tracker=state_storage.tracker(tracker_time))
        
        # calculate total concentrations and d34S
        globals()['processed_results'] = calculate_delta(state_storage)
        
        
        # Refresh the graph
        time_slider.value = time_return.value
        pde_plot(time_slider.value)
        
        print('Done!', end='')

In [9]:
w_dict, sliders = create_sliders()

max_time = w_dict['max_t'].value
max_depth = w_dict['max_z'].value

depth_plot = create_depth_plot()
cross_plot = create_cross_plot()


play = wdg.Play(value=0, min=0, max=int(max_time/time_slider_resolution), step=1, interval=400)
time_slider = wdg.IntSlider(value=0, min=0, max=int(max_time/time_slider_resolution), step=1, layout=wdg.Layout(width='100%'))
wdg.jslink((play, 'value'), (time_slider, 'value'))

time_out = wdg.HTML(value='Please run the model first!', layout=wdg.Layout(width='30%'))

run_button = wdg.Button(description='Run Model')
model_out = wdg.Output()

sed_slider = wdg.FloatSlider(value=0.5, min=0, max=max_depth, step=0.01, description='Thickness (m)', layout=wdg.Layout(width='100%'), style={'description_width': 'initial'})
rew_slider = wdg.FloatSlider(value=0.5, min=0, max=1, step=0.01, description='Proportion Reworked', layout=wdg.Layout(width='100%'), style={'description_width': 'initial'})
time_return = wdg.IntSlider(value=0, min=0, max=int(max_time/time_slider_resolution), step=1, description='Return time to', layout=wdg.Layout(width='100%'), style={'description_width': 'initial'})

sed_button = wdg.Button(description='Add Sediment')
sed_out = wdg.Output()


run_button.on_click(run_model)
sed_button.on_click(add_sediment)

wdg.interactive(pde_plot, time=time_slider)


wdg.VBox([
    sliders, 
    wdg.HBox([run_button, model_out]),
    wdg.HBox([play, time_slider, time_out]),
    wdg.HBox([depth_plot, cross_plot]),
    wdg.HBox([sed_slider, rew_slider, time_return]),
    wdg.HBox([sed_button, sed_out])
])

VBox(children=(VBox(children=(HBox(children=(FloatSlider(value=0.0105, description='[SO<sub>4</sub>]<sub>0</su…

In [10]:
# for _ in range(10):
#     sed_slider.value = random.uniform(0, 0.5)
#     rew_slider.value = random.uniform(0.25, 0.75)
#     time_return.value = random.uniform(0,400)
    # add_sediment()

In [None]:
# depth_plot.write_image('./images/transient_depth_2.pdf')

In [None]:
# cross_plot.write_image('./images/transient_cross_2.pdf')