# Transient Model
This model aims to simulate the sulphate and pyrite concentrations and isotopic compositions in the sediment after a long period of stasis, after which a large pile of sediment is dumped on top and the evolution is monitored. Requires the `py-pde` package.

### Imports and Definitions

In [1]:
import numpy as np
import pandas as pd
import ipywidgets as wdg
from IPython.display import display
from plotly_default import go, graph_config, sel_trace
from plotly.subplots import make_subplots
import pde

### Modelling Pyrite Formation as a Function of Depth and Time

$$
\frac{\partial [\mathrm{SO}_4]}{\partial t} = D \nabla^2 [\mathrm{SO}_4] - k_{MSR} [\mathrm{SO}_4]
$$
$$
\frac{\partial [\mathrm{FeS}_2]}{\partial t} = k_{MSR} [\mathrm{SO}_4]
$$

In [2]:
# Create a 1D grid
n_points = 60
grid = pde.CartesianGrid([[-1.5,0]], [n_points], periodic=[False])

# Set up memory to store PDE results
state_storage = pde.MemoryStorage()

In [3]:
# Define the PDEs for the general diagenetic equation

class DiageneticEqs(pde.PDEBase):
    """General Diagenetic Equations modelling sulphate to pyrite conversion"""
    
    def __init__(self, SO4_0=0.0105, d34S_SO4_0=23, epsilon_MSR=54, D_SO4=0.2, k_MSR=10):
        self.SO4_0 = SO4_0
        self.d34S_SO4_0 = d34S_SO4_0
        self.epsilon_MSR = epsilon_MSR
        self.D_SO4 = D_SO4
        self.k_MSR = k_MSR

        # Define boundary conditions: [lower, upper]
        self.SO4_bcs = [[{'derivative': 0}, {'value': self.SO4_0}]]
        
    def evolution_rate(self, state, t=0):
        SO4, py = state

        SO4_t = self.D_SO4 * SO4.laplace(self.SO4_bcs) - self.k_MSR * SO4
        py_t = 0.5 * self.k_MSR * SO4
        
        return pde.FieldCollection([SO4_t, py_t])
        
        
    # (BROKEN) Precompiled functions to improve performance
    # def _make_pde_rhs_numba(self, state):
    #     D_SO4 = self.D_SO4
    #     k_MSR = self.k_MSR

    #     # create custom operators
    #     laplace_SO4 = state.grid.make_operator("laplace", bc=self.SO4_bcs)
        
    #     @pde.tools.numba.jit
    #     def pde_rhs(state_data, t=0):
    #         SO4, py = state_data
            
    #         SO4_t = D_SO4 * laplace_SO4(SO4) - k_MSR * SO4
    #         py_t = 0.5 * k_MSR * SO4

    #         return pde.FieldCollection([SO4_t, py_t])
        
    #     return pde_rhs
        

In [4]:
def set_up_pdes(SO4_0, d34S_SO4_0, epsilon_MSR, D_SO4, k_MSR):

    # Set initial field values and boundary conditions for sulphate
    # bcs = [{'derivative': 0}, {'value': SO4_0}]
    SO4_field = pde.ScalarField(grid, [SO4_0]*n_points)

    # For pyrite conc.
    py_field = pde.ScalarField(grid, [0]*n_points)
    
    fields = pde.FieldCollection([SO4_field, py_field])
    eqs = DiageneticEqs(SO4_0, d34S_SO4_0, epsilon_MSR, D_SO4, k_MSR)

    # # Specify the equation
    # consts = {'D_SO4': D_SO4,
    #         'k_MSR': k_MSR}

    # eq = pde.PDE({'SO4': 'D_SO4*laplace(SO4) - k_MSR*SO4',
    #             'py': 'k_MSR*SO4'},
    #             bc=bcs, consts=consts)
    
    return fields, eqs

In [5]:
def solve_pdes(SO4_0=0.0105, d34S_SO4_0=23, epsilon_MSR=54, D_SO4=0.2, k_MSR=10):
    
    # print(f'Solving with: {SO4_0=}, {D_SO4=}, {k_MSR=}', end='')

    fields, eqs = set_up_pdes(SO4_0, d34S_SO4_0, epsilon_MSR, D_SO4, k_MSR)
    
    # Reset the storage
    state_storage = pde.MemoryStorage()

    eqs.solve(fields, t_range=1, dt=1e-3, tracker=state_storage.tracker(1e-2))
    
    return state_storage 

In [6]:
def create_sliders(file='./Controls/Transient_params_v1.csv', prefix='w', n_cols=2):
    w_df = pd.read_csv(file)
    w_dict = {}

    for w in w_df.iterrows():

        # Make a new variable for each widget
        vars()[prefix+'_'+w[1].var_name] = wdg.FloatSlider(value = w[1].value,
                                                     min = w[1].start,
                                                     max = w[1].stop,
                                                     step = w[1].step,
                                                     description = w[1].display_name,
                                                     layout=wdg.Layout(width='100%')
                                                    )

        # Add to a dictionary of original variables and widgets
        w_dict[w[1].var_name] = vars()[prefix+'_'+w[1].var_name]
        
    # Arrange widgets into HBoxes and VBoxes
    box_list = []
    w_list = list(w_dict.values())
    
    for i in range(0, len(w_dict), n_cols):
        box_list.append(wdg.HBox(w_list[i:i+n_cols]))
    
    return w_dict, wdg.VBox(box_list)

In [7]:
def create_depth_plot():
    zs = grid.coordinate_arrays[0]

    plot = go.FigureWidget(make_subplots(rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.01))
    plot.add_trace(go.Scatter(y=zs, name='[SO<sub>4</sub>]', marker_color='RoyalBlue', mode='lines'), row=1, col=1)
    plot.add_trace(go.Scatter(y=zs, name='[FeS<sub>2</sub>]', marker_color='MediumSeaGreen', mode='lines'), row=1, col=1)
    # plot.add_trace(go.Scatter(y=zs, name='δ<sup>34</sup>S<sub>SO4</sub>', marker_color='Navy', mode='lines'), row=1, col=2)
    # plot.add_trace(go.Scatter(y=zs, name='δ<sup>34</sup>S<sub>py</sub>', marker_color='DarkGreen', mode='lines'), row=1, col=2)

    plot.update_layout(
        yaxis=dict(title='Depth, <i>z</i> / m'),
        xaxis_rangemode = 'tozero',
        hovermode='y',
        margin=dict(t=60),
        width=1000, height=600
        )

    plot.update_xaxes(title_text='Concentration / mol dm<sup>−3</sup>', row=1, col=1)
    plot.update_xaxes(title_text='δ<sup>34</sup>S (‰)', row=1, col=2)

    return plot

def pde_plot(time):
    with depth_plot.batch_update():
        
        t_slice = state_storage[time]
        
        SO4_tr = depth_plot.select_traces({'name': '[SO<sub>4</sub>]'}).__next__()
        SO4_tr.x = t_slice[0].data
        
        py_tr = depth_plot.select_traces({'name': '[FeS<sub>2</sub>]'}).__next__()
        py_tr.x = t_slice[1].data

In [10]:
w_dict, sliders = create_sliders()

depth_plot = create_depth_plot()

play = wdg.Play(value=0, min=0, max=100, step=1, interval=200)
time_slider = wdg.IntSlider(value=0, min=0, max=100, step=1, layout=wdg.Layout(width='100%'))
wdg.jslink((play, 'value'), (time_slider, 'value'))


run_button = wdg.Button(description='Run Model')
out = wdg.Output()

def button_click(b):
    out.clear_output()
    with out:
        print('Processing...  ', end='')
        pde_args = {par: slider.value for par, slider in w_dict.items()}
        globals()['state_storage'] = solve_pdes(**pde_args)
        
        # Refresh the graph
        pde_plot(time_slider.value)
        print('Done!', end='')
        
run_button.on_click(button_click)

In [11]:
wdg.interactive(pde_plot, time=time_slider)

wdg.VBox([
    sliders, 
    wdg.HBox([run_button, out]),
    wdg.HBox([play, time_slider]),
    depth_plot
])

VBox(children=(VBox(children=(HBox(children=(FloatSlider(value=0.0105, description='[SO<sub>4</sub>]<sub>0</su…