# Transient Model
This model aims to simulate the sulphate and pyrite concentrations and isotopic compositions in the sediment after a long period of stasis, after which a large pile of sediment is dumped on top and the evolution is monitored. Requires the `py-pde` package.

### Imports and Definitions

In [1]:
import numpy as np
import pandas as pd
import ipywidgets as wdg
from IPython.display import display
from plotly_default import go, graph_config, sel_trace
from plotly.subplots import make_subplots
import pde

In [2]:
# Conversion between isotope ratios and delta notation

VCDT = 0.0441626 # vienna canyon diablo troilite standard 34S/32S ratio

def delta_to_R(delta, std=VCDT):
    '''Takes in a d34S; outputs a 34S/32S ratio'''
    return std * ((delta/1000) + 1)

def R_to_delta(R, std=VCDT):
    '''Takes in a 34S/32S ratio; outputs the d34S'''
    return 1000*((R/std) - 1)

def wt_to_M(wt, Mr=55.845, density=1.7):
    '''Converts a weight percent of a substance with molar mass Mr (default iron) into a molar concentration
    in a sediment of specified density (default 1.7 g/cm³)'''
    # g/dm³  = 1000 * g/cm³ = 1000 * density * wt% / 100 = 10 * density * wt%
    return 10. * density * wt / Mr

def M_to_wt(M, Mr=55.845, density=1.7):
    return M * Mr / (10. * density)

def ratio_to_concs(R, C):
    '''Takes a ratio of two concentrations R = a/b
    and a total concentration C = a + b
    and returns an (a, b) tuple'''
    return (C/(1 + 1/R) , C/(1+R))

# The number of seconds in 1 thousand years
s_in_ka = 1000 * 365.24 * 24 * 60 * 60

### Modelling Pyrite Formation as a Function of Depth and Time

$$
\frac{\partial [\mathrm{SO}_4]}{\partial t} = D \nabla^2 [\mathrm{SO}_4] - k_{MSR} [\mathrm{SO}_4]
$$
$$
\frac{\partial [\mathrm{FeS}_2]}{\partial t} = k_{MSR} [\mathrm{SO}_4]
$$

In [3]:
# Create a 1D grid
n_points = 60
grid = pde.CartesianGrid([[-1.5,0]], [n_points], periodic=[False])

# Set up memory to store PDE results
state_storage = pde.MemoryStorage()

In [4]:
# Define the PDEs for the general diagenetic equation

class DiageneticEqs(pde.PDEBase):
    """General Diagenetic Equations modelling sulphate to pyrite conversion"""
    
    def __init__(self, SO4_0=0.0105, d34S_SO4_0=23, epsilon_MSR=54, D_SO4=0.2, k_MSR=10):
        # self.SO4_0 = SO4_0
        # self.d34S_SO4_0 = d34S_SO4_0
        # self.epsilon_MSR = epsilon_MSR
        self.D_SO4 = D_SO4
        self.k_MSR = k_MSR
        
        R_SO4_0 = delta_to_R(d34S_SO4_0)
        self.SO4_34_0 , self.SO4_32_0 = ratio_to_concs(R_SO4_0, SO4_0)
        # py_34_0  , py_32_0  = ratio_to_concs(R_py_0, py_0)
        
        self.a_MSR = delta_to_R(d34S_SO4_0 - epsilon_MSR)/R_SO4_0 # MSR fractionation factor

        # Define boundary conditions: [lower, upper]
        self.SO4_32_bcs = [[{'derivative': 0}, {'value': self.SO4_32_0}]]
        self.SO4_34_bcs = [[{'derivative': 0}, {'value': self.SO4_34_0}]]
        
    def evolution_rate(self, state, t=0):
        SO4_34, SO4_32, py_34, py_32 = state
        
        py_34_t = 0.5 * self.k_MSR * (SO4_34**2 + SO4_32*SO4_34)/((1/self.a_MSR)*SO4_32 + SO4_34)
        py_32_t = 0.5 * self.k_MSR * (SO4_32**2 + SO4_32*SO4_34)/(self.a_MSR*SO4_32 + SO4_34)

        SO4_34_t = self.D_SO4 * SO4_34.laplace(self.SO4_34_bcs) - 2 * py_34_t
        SO4_32_t = self.D_SO4 * SO4_32.laplace(self.SO4_32_bcs) - 2 * py_32_t
        
        return pde.FieldCollection([SO4_34_t, SO4_32_t, py_34_t, py_32_t])
        
        

In [5]:
def set_up_pdes(SO4_0, d34S_SO4_0, epsilon_MSR, D_SO4, k_MSR):

    eqs = DiageneticEqs(SO4_0, d34S_SO4_0, epsilon_MSR, D_SO4, k_MSR)
    
    # Set initial field values for sulphate
    # bcs = [{'derivative': 0}, {'value': SO4_0}]
    SO4_34_field = pde.ScalarField(grid, [eqs.SO4_34_0]*n_points)
    SO4_32_field = pde.ScalarField(grid, [eqs.SO4_32_0]*n_points)

    # For pyrite conc.
    py_34_field = pde.ScalarField(grid, [0]*n_points)
    py_32_field = pde.ScalarField(grid, [0]*n_points)
    
    fields = pde.FieldCollection([SO4_34_field, SO4_32_field, py_34_field, py_32_field])
    
    return fields, eqs

In [6]:
def solve_pdes(SO4_0=0.0105, d34S_SO4_0=23, epsilon_MSR=54, D_SO4=0.2, k_MSR=10):
    
    # print(f'Solving with: {SO4_0=}, {D_SO4=}, {k_MSR=}', end='')

    fields, eqs = set_up_pdes(SO4_0, d34S_SO4_0, epsilon_MSR, D_SO4, k_MSR)
    
    # Reset the storage
    state_storage = pde.MemoryStorage()

    eqs.solve(fields, t_range=1, dt=1e-3, tracker=state_storage.tracker(1e-2))
    
    return state_storage 

In [7]:
def create_sliders(file='./Controls/Transient_params_v1.csv', prefix='w', n_cols=2):
    w_df = pd.read_csv(file)
    w_dict = {}

    for w in w_df.iterrows():

        # Make a new variable for each widget
        vars()[prefix+'_'+w[1].var_name] = wdg.FloatSlider(value = w[1].value,
                                                     min = w[1].start,
                                                     max = w[1].stop,
                                                     step = w[1].step,
                                                     description = w[1].display_name,
                                                     layout=wdg.Layout(width='100%')
                                                    )

        # Add to a dictionary of original variables and widgets
        w_dict[w[1].var_name] = vars()[prefix+'_'+w[1].var_name]
        
    # Arrange widgets into HBoxes and VBoxes
    box_list = []
    w_list = list(w_dict.values())
    
    for i in range(0, len(w_dict), n_cols):
        box_list.append(wdg.HBox(w_list[i:i+n_cols]))
    
    return w_dict, wdg.VBox(box_list)

In [62]:
def calculate_delta(storage):
    n_times = len(storage)
    n_spaces = len(storage[0][0].data)
    
    # Initialise arrays for outputs
    SO4 = np.empty((n_times, n_spaces))
    py = np.empty((n_times, n_spaces))
    d34S_SO4 = np.empty((n_times, n_spaces))
    d34S_py = np.empty((n_times, n_spaces))
    
    # Array of nans for failed maths
    fail_array = np.empty(n_spaces)
    fail_array.fill(np.nan)
    
    for t, fieldcol in enumerate(storage):
        SO4[t] = fieldcol[0].data + fieldcol[1].data
        py[t] = fieldcol[2].data + fieldcol[3].data
        
        try:
            d34S_SO4[t] = R_to_delta(fieldcol[0].data/fieldcol[1].data)
        except FloatingPointError:
            d34S_SO4[t] = fail_array

        try:
            d34S_py[t] = R_to_delta(fieldcol[2].data/fieldcol[3].data)
        except FloatingPointError:
            d34S_py[t] = fail_array
        
    return {'SO4': SO4, 'py': py, 'd34S_SO4': d34S_SO4, 'd34S_py': d34S_py}

In [71]:
def create_depth_plot():
    zs = grid.coordinate_arrays[0]

    plot = go.FigureWidget(make_subplots(rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.01))
    plot.add_trace(go.Scatter(y=zs, name='[SO<sub>4</sub>]', marker_color='RoyalBlue', mode='lines'), row=1, col=1)
    plot.add_trace(go.Scatter(y=zs, name='[FeS<sub>2</sub>]', marker_color='MediumSeaGreen', mode='lines'), row=1, col=1)
    plot.add_trace(go.Scatter(y=zs, name='δ<sup>34</sup>S<sub>SO4</sub>', marker_color='Navy', mode='lines'), row=1, col=2)
    plot.add_trace(go.Scatter(y=zs, name='δ<sup>34</sup>S<sub>py</sub>', marker_color='DarkGreen', mode='lines'), row=1, col=2)

    plot.update_layout(
        yaxis=dict(title='Depth, <i>z</i> / m'),
        xaxis_rangemode = 'tozero',
        hovermode='y',
        margin=dict(t=60),
        width=1000, height=600
        )

    plot.update_xaxes(title_text='Concentration / mol dm<sup>−3</sup>', row=1, col=1)
    plot.update_xaxes(title_text='δ<sup>34</sup>S (‰ VCDT)', row=1, col=2)

    return plot

def pde_plot(time):
    with depth_plot.batch_update():
        
        SO4_tr = depth_plot.select_traces({'name': '[SO<sub>4</sub>]'}).__next__()
        SO4_tr.x = processed_results['SO4'][time]
        
        py_tr = depth_plot.select_traces({'name': '[FeS<sub>2</sub>]'}).__next__()
        py_tr.x = processed_results['py'][time]
        
        d_SO4_tr = depth_plot.select_traces({'name': 'δ<sup>34</sup>S<sub>SO4</sub>'}).__next__()
        d_SO4_tr.x = processed_results['d34S_SO4'][time]
        
        d_py_tr = depth_plot.select_traces({'name': 'δ<sup>34</sup>S<sub>py</sub>'}).__next__()
        d_py_tr.x = processed_results['d34S_py'][time]

In [72]:
w_dict, sliders = create_sliders()

depth_plot = create_depth_plot()

play = wdg.Play(value=0, min=0, max=100, step=1, interval=200)
time_slider = wdg.IntSlider(value=0, min=0, max=100, step=1, layout=wdg.Layout(width='100%'))
wdg.jslink((play, 'value'), (time_slider, 'value'))


run_button = wdg.Button(description='Run Model')
out = wdg.Output()

processed_results = {}

def button_click(b):
    out.clear_output()
    with out:
        print('Processing...  ', end='')
        pde_args = {par: slider.value for par, slider in w_dict.items()}
        globals()['state_storage'] = solve_pdes(**pde_args)
        
        # calculate total concentrations and d34S
        globals()['processed_results'] = calculate_delta(state_storage)
        
        # Refresh the graph
        pde_plot(time_slider.value)
        print('Done!', end='')
        
run_button.on_click(button_click)

In [73]:
wdg.interactive(pde_plot, time=time_slider)

wdg.VBox([
    sliders, 
    wdg.HBox([run_button, out]),
    wdg.HBox([play, time_slider]),
    depth_plot
])

VBox(children=(VBox(children=(HBox(children=(FloatSlider(value=0.0105, description='[SO<sub>4</sub>]<sub>0</su…