# Geochemistry

In this notebook we will go interrogate tables of geochemical data using the Python packages 'pandas' and 'pyrolite'. Pandas makes handling tabular data in Python incredibly easy and, as it runs numpy in the background, the calculations are fast. Most packages that rely on tabular data are built to work with pandas, so mastering is fundamental for performing data analysis in Python. `Pyrolite`, developed by Morgan Williams (CSIRO), provides built-in functions that can simplify geochemical data manipulations. While Pyrolite provides some functionality specific to geochemical data (e.g. geochemical plot templates), it is not the only library useful for geological plotting and compositional data. As we go, this notebook will demonstrate several packages for achieve plotting tasks.

- You can read the pandas docs here: https://pandas.pydata.org/

- You can read the pyrolite docs here: https://pyrolite.readthedocs.io/en/main/


---



In [None]:
# This cell is only necessary if we're running in Google colab
!pip install pandas numpy pyrolite seaborn matplotlib

First, we must import the packages we are going to use.

In [None]:
# For data processing
import pandas as pd
import numpy as np
from pyrolite.geochem.norm import all_reference_compositions,get_reference_composition

# For visualization
import seaborn as sns
import pyrolite.plot
import matplotlib.pyplot as plt



## Example 1: Geochem QA/QC using standards
In this example, we will:
1. Load assay data measured from the standard SY-4
2. Load the reference data for standard SY-4
3. Create an XY plot of measured versus reference values for SY-4
4. Optional: Try it yourself using the data for SY-2 

Let's start my loading the data into pandas.

In [None]:
#Load the measured trace element values for the Y4 standard, from the drillhole trace element csv file
measured_SY4 = pd.read_csv('Data_for_exercises/DH 1&2 trace elements.csv',header = 4,skiprows=range(5, 35), nrows=1)
measured_SY4

In [None]:
#Load the reference values for the standard, from the SY-4.csv file.
standard_SY4 = pd.read_csv('Data_for_exercises/SY-4.csv')
standard_SY4

The standard-SY4 dataframe (SY4 reference values) contains more columns (measured chemical species) than the measured-SY4 dataframe. Let's select only those columns that are common to both dataframes. We'll use pandas' intersection method.

In [None]:
common_cols = standard_SY4.columns.intersection(measured_SY4.columns)
print(common_cols)


In [None]:
standard_SY4_common = standard_SY4[common_cols]
standard_SY4_common

In [None]:
measured_SY4_common = measured_SY4[common_cols]
measured_SY4_common

Let's transpose our data so that our rows become columns. We'll create a function so that we can easily reuse the code to repeat the procedure for future dataframes.

In [None]:
def transpose_data(dataframe, data_column_labels):
    '''
    dataframe (df): dataframe to transpose rows to columns
    data_column_labels (list): User-defined column labels, comprises two column lables 1)chemical species 2) value
    '''
    dataframe_transposed = dataframe.T
    dataframe_transposed.reset_index(inplace=True)
    dataframe_transposed.columns = data_column_labels
    
    return dataframe_transposed


In [None]:
standard_SY4_data_labels = ['chemical_species','reference_values']
standard_SY4_transposed = transpose_data(standard_SY4_common,standard_SY4_data_labels)

measured_SY4_data_labels = ['chemical_species','measured_values']
measured_SY4_transposed = transpose_data(measured_SY4_common, measured_SY4_data_labels)

#TIP:when typing a function call, place cursor inside the brackets and type shift+tab to see the docstring notes

standard_SY4_transposed.head()

In [None]:
#Join the two dataframes together. This will make it very easy to create a plot because we can simply reference columns from the merged dataframe.
# An inner join will join those rows that have matching keys. In this case, we'll use 'chemical_species' as the key.
df_merged = pd.merge(standard_SY4_transposed, measured_SY4_transposed, on='chemical_species', how='inner')
df_merged.sort_values(by='reference_values', inplace=True)
df_merged.head()

Let's make a function that will plot our measured versus reference values. 

In [None]:
def make_scatter_plot(dataframe, x_column_name, y_column_name, data_column_name):
    ax = sns.scatterplot(data = dataframe,x = x_column_name, y = y_column_name) #plot the data
    plt.xscale('log') #let's use a log scale so that we can see all our data points clearly
    plt.yscale('log')
    plt.xlabel(x_column_name) #'Reference values'
    plt.ylabel(y_column_name) #'Measured values'
    plt.title(f'{y_column_name} vs {x_column_name}')
    
    for i in range(len(dataframe[data_column_name])):
        ax.annotate(dataframe[data_column_name][i], (dataframe[x_column_name][i], dataframe[y_column_name][i]),
                    textcoords="offset points", xytext=(0,20), ha='left')
    plt.show()

    

In [None]:
%matplotlib inline
make_scatter_plot(
    df_merged, 
    'reference_values',
    'measured_values',
    'chemical_species'
)


<font color="green"><b>TRY IT YOURSELF:</b> 
<ol style="color: green";">
<li>Read in the data for standard SY-2 from SY-2.csv in the Data_for_exercises folder.</li>
<li>Read in the data for measured SY-2 from the trace element csv the Data_for_exercises folder.</li>
 <li>Find the columns that are common to SY-2.csv and the measured trace element data. Select only the common columns from each dataframe.</li>
 <li>Transpose both dataframes using the transpose_data function. </li>
<li>Merge the transposed dataframes</li>
<li>Plot the measured versus reference value using the make_scatter_plot function </li>
</ol>  

In [None]:
# write your code here

In [None]:
# Expand to see how it's done

#import the data into pandas
measured_SY2 = pd.read_csv('Data_for_exercises/DH 1&2 trace elements.csv',header = 4,skiprows=range(5, 34), nrows=1)
standard_SY2 = pd.read_csv('Data_for_exercises/SY-2.csv')

#find the columns in common and filter each dataframe for the common columns
common_cols = standard_SY2.columns.intersection(measured_SY2.columns)
standard_SY2_common = standard_SY4[common_cols]
measured_SY2_common = measured_SY4[common_cols]

#transpose the two dataframes using the transpose_data function
standard_SY2_transposed = transpose_data(standard_SY2_common,['chemical_species','reference_values'])
measured_SY2_transposed = transpose_data(measured_SY2_common,['chemical_species', 'measured_values'])

#merge the two dataframes
df_merged_SY2 = pd.merge(standard_SY2_transposed, measured_SY2_transposed, on='chemical_species', how='inner')
df_merged_SY2.sort_values(by='reference_values', inplace=True)

#Create a scatter plot using the make_scatter_plot function
make_scatter_plot(df_merged_SY2,
                  'reference_values',
                  'measured_values',
                  'chemical_species'
                 )

## Example 2: Plot chemistry versus depth
In this example we will:
1.  Load the major element data
2.  Create a new pandas column, to hold interval mid-point.
3.  Plot chemistry versus depth (sampling interval midpoint)
4.  See how matplotlib axes may be inverted
5.  See how to share a common axis between two plots

Our depth plot will look like this:


![depth_plots.png](attachment:b376e5e2-a6bd-4173-ae29-7fbb9d2fdbb8.png)

In [None]:
#Load the major element data

major_elements = pd.read_csv('Data_for_exercises/DH1&2 major elements.csv', header = 4)
major_elements.head()


In [None]:
#Create a new pandas column that will hold the calculated sampling intersection midpoint

def calculate_midpoint(row):
    from_depth = float(row['Interval'].split('-')[0])
    to_depth = float(row['Interval'].split('-')[1])
    midpoint = from_depth + (to_depth - from_depth)/2
    return midpoint

major_elements['depth'] = major_elements.apply(calculate_midpoint, axis=1)
major_elements.head()

In [None]:
#Create a matplotlib plot that contains two subplots, placed side-by-side

def plot_major_element_versus_depth(dataframe, element1, element2):
    fig, ax = plt.subplots(1, 2, figsize=(8,6), sharey=True)

    for drillhole in major_elements['Location'].unique():
        drillhole_data = major_elements[major_elements['Location']==drillhole]
        
        ax[0].plot(drillhole_data[element1], drillhole_data['depth'], label = drillhole)
        ax[0].set_xlabel(f'{element1} (wt%)')
        ax[0].set_ylabel('Depth (m)')
        

        ax[1].plot(drillhole_data[element2], drillhole_data['depth'], label = drillhole)
        ax[1].set_xlabel(f'{element2} (wt%)')
        ax[1].legend(title='Drillhole')
    ax[0].invert_yaxis()

    plt.tight_layout()
    plt.show()

        
plot_major_element_versus_depth(major_elements, 'Fe2O3', 'MgO')

 <font color="green"><b>TRY IT YOURSELF:</b> Call the plot_major_element_versus_depth function to plot Al2O3 and K2O with depth. </font>  

In [None]:
# write your code here

In [None]:
# expand to see how it's done

plot_major_element_versus_depth(major_elements, 'Al2O3', 'K2O')

Let's change the asthetics of this plot

In [None]:
def plot_major_element_versus_depth(dataframe, element1, element2):
    fig, ax = plt.subplots(1, 2, figsize=(12, 7), sharey=True)

    # Define colors and markers cycles
    colours = plt.cm.tab20.colors  # Matplotlib has several colour maps. This one has 20 colours. You can read more here: https://matplotlib.org/stable/users/explain/colors/colormaps.html
    markers = ['o', 's'] # matplotlib has different marker styles. You can read more here: https://matplotlib.org/stable/api/markers_api.html

    drillholes = dataframe['Location'].unique()
    num_colors = len(colours)
    num_markers = len(markers)

    for idx, drillhole in enumerate(drillholes):
        drillhole_data = dataframe[dataframe['Location'] == drillhole]

        colour = colours[idx]
        marker = markers[idx]

        ax[0].plot(drillhole_data[element1], drillhole_data['depth'],label=drillhole, color=colour, marker=marker, linestyle='-', markersize=5)
        ax[1].plot(drillhole_data[element2], drillhole_data['depth'],label=drillhole, color=colour, marker=marker, linestyle='-', markersize=5)

    # Labels and titles
    ax[0].set_xlabel(f'{element1} (wt%)', fontsize=12)
    ax[1].set_xlabel(f'{element2} (wt%)', fontsize=12)
    ax[0].set_ylabel('Depth (m)', fontsize=12)

    ax[0].set_title(f'{element1} vs Depth', fontsize=14, fontweight='bold')
    ax[1].set_title(f'{element2} vs Depth', fontsize=14, fontweight='bold')

    for axis in ax:
        axis.grid(True, linestyle='--', alpha=0.6)
        axis.minorticks_on()
        axis.tick_params(axis='both', which='major', labelsize=10)
        axis.tick_params(axis='both', which='minor', length=4)

    ax[1].legend(title='Drillhole', fontsize=10, title_fontsize=12, loc='upper right',frameon=True, facecolor='white')
    ax[0].invert_yaxis()
    plt.tight_layout(rect=[0, 0, 0.85, 1])  # leave space on right for legend
    plt.show()

plot_major_element_versus_depth(major_elements, 'Fe2O3', 'MgO')

## Example 3: Creating geochemical plots - oxide-oxide plot, TAS diagram, ternary diagram, spidergram

We'll use a dataset of mafic extrusive igneous rocks in Western Australia, downloaded from the Ausgeochem online portal.

![ausgeochemsamples.png](attachment:4f3cd647-9de1-43a1-ba52-61399f60722c.png)

In [None]:
df = pd.read_csv('Data_for_exercises/Rock_sample_geochem.csv')
df.head()

We can get rid of unwanted columnns using the 'drop' method. 

In [None]:
df.drop(columns = ['location_type', 'location_comment'], inplace=True)
df.head()

### Example 3.1: Data preparation
- Filter the data to extract major oxide columns.
- Get an overview  of the data: how many rows, missing values etc.
- Drop rows with missing values
- Replace zeros
- Execute closure of oxides so that compositional data sums to 100

We can filter data by row and column index using the .iloc method. (iloc standads for interger-location)

In [None]:
df.iloc[0:3, [42, 47, 51]]

<font color="green"><b>TRY IT YOURSELF:</b> Instead of counting the column headers to get the index of the column of interest, we could get the indices of interest programmatically. Below, we extract the column indexes programmatically. Use these variables to filter out the SiO2, K2O and Na2O data.</font>  


In [None]:
index_of_SiO2 = df.columns.get_loc('SiO2')
index_of_Na2O = df.columns.get_loc('Na2O')
index_of_K2O = df.columns.get_loc('K2O')

print(f'index of SiO2: {index_of_SiO2}')
print(f'index of Na2O: {index_of_Na2O}')
print(f'index of K2O: {index_of_K2O}')

#Type your code here. Hint: use the .iloc method



In [None]:
# Expand to see how it's done

df.iloc[0:3, [index_of_SiO2, index_of_Na2O, index_of_K2O]]

We can also filter by column and row name using the .loc method. (loc stands for location)

In [None]:
element_oxides = ['SiO2', 'Al2O3', 'Fe2O3', 'CaO', 'MgO', 'Na2O', 'K2O', 'MnO', 'TiO2', 'P2O5', 'Cr2O3']
oxides = df.loc[:,element_oxides]
oxides.head()

Pyrolite has some built-in methods for quick data filtering of geochem data. For example, you can easily filter out the oxide and REE columns.

In [None]:
oxides_pyrochem = df.pyrochem.oxides
oxides_pyrochem.head()

 <font color="green"><b>TRY IT YOURSELF:</b> Select rare earth elements using df.pyrochem.REE. Print out the results using the .head() method</font>  

In [None]:
# Write your code here

In [None]:
# Expand to see how it's done

REE_pyrochem = df.pyrochem.REE
REE_pyrochem.head()

You can quickly assess the numerical data to understand data max, min, average etc., and missing data using the .describe() and .info() methods. Let's assess the oxide data.

In [None]:
oxides.describe()

In [None]:
oxides.info()

### What do we do with zero values? When values are below detection limit...

Some columns contain zero values. Zero values can be problematic in calculations and statistical analysis! There are many ways to handle zeros. 
- Delete all rows or columns with zero values. This can cause us to lose a lot of data, though!
- We can replace zeros with a small value, or some fraction of detection limit.
- We can impute the missing values.

In [None]:
#Let's replace zero values with very small values. Similarly, we could replace the values with e.g. half detection limit
oxides_replaced = oxides.replace(0, 1e-10)
oxides_replaced.describe()

In [None]:
#If we're replacing the zero values, we'll need to ensure that the compositional data still sum to 100. This means we need to renormalise the data.
oxides_renormalised = pyrolite.comp.codata.renormalise(oxides_replaced, scale=100)

#We'll create a 'sum' column, just to demonstrate that the data sum to 100
oxides_renormalised['sum'] = oxides_renormalised.sum(axis=1, skipna=False)

#Visualise the data
oxides_renormalised.head()


In [None]:
#remove the now-unwanted sum column
oxides_renormalised.drop(columns = 'sum', inplace=True)
oxides_renormalised.head()

## Example 3.2: Creating a scatterplot of whole-rock oxides

Let's create a scatterplot of CaO versus MgO

In [None]:
fig, axis = plt.subplots(1,1, figsize = (10,4))

#Plot oxides
oxide_scatter_plot = axis.scatter(
    oxides_renormalised["MgO"], # x-variable
    oxides_renormalised["CaO"], # y-variable
    marker="o",
    c=oxides_renormalised["SiO2"],
    s=60,
    cmap='rainbow',
    edgecolors='black',
    alpha = 0.6
)

colour_bar = plt.colorbar(oxide_scatter_plot, ax=axis)
colour_bar.set_label('SiO2')

axis.set_xlabel("$MgO$ (wt%)")
axis.set_ylabel("$Al_{2}O_{3}$ (wt%)")

 <font color="green"><b>TRY IT YOURSELF:</b> Create a scatterplot of TiO2 (y) versus Fe2O3 (x). Colour by MgO using cmap='gnuplot2'</font>  

In [None]:
# Write your code here


In [None]:
# Expand to see how it's done

fig, axis = plt.subplots(1,1, figsize = (10,4))

oxide_scatter_plot = axis.scatter(
    oxides_renormalised["Fe2O3"], # x-variable
    oxides_renormalised["TiO2"], # y-variable
    marker="o",
    c=oxides_renormalised["MgO"],
    s=60,
    cmap='gnuplot2',
    edgecolors='black',
    alpha = 0.6
)

colour_bar = plt.colorbar(oxide_scatter_plot, ax=axis)
colour_bar.set_label('MgO')

axis.set_xlabel("$Fe_{2}O_{3}$ (wt%)")
axis.set_ylabel("$TiO_{2}$ (wt%)");

## Example 3.3: Classifying data using a TAS diagram


Let's pause here for a moment and discuss what we're about to do:
- We'll going to to use the renormalised oxide data to plot our data points on a TAS diagram.
- Using the TAS diagram, we'll be able to **classify** our datapoints into volcanic rock types
- We will filter out the volcanic rock types, based on this classification. The filtered data will be plotted on a spidergram
- Spidergrams plot the REE content. This means we will need both the **oxide and the REE data**! However, in the previous exercise, we extracted the oxide data only...
- Let's fix this by creating a dataset that includes both the oxide and REE data, with zeros replaced, and the oxides renormalised.

This is how we'll do it:

In [None]:
#Firstly, let's take our original dataframe, and replace all the zeros with small values, just like we did before. 
#The difference is that we will do this to the full dataset, not only the oxide subset

df_no_zeros = df.replace(0, 1e-10)
df_no_zeros.describe() #No columns have a min of zero anymore!

In [None]:
#Next, we want to make sure our oxides are normalised to 100, so that we can plot the data on a TAS diagram.
#Pyrolite actually allows us to normalise a subset of our columns, if we provide the column names

df_normalised_oxides = pyrolite.comp.codata.renormalise(
    df = df_no_zeros, 
    components = df.pyrochem.list_oxides, # this is another pyrolite function, to extract oxide columns from the dataframe
    scale=100
)

# So do we know that worked? Let's take a look:
oxides = df_normalised_oxides.pyrochem.oxides #let's isolate the oxide columns in our df_normalised_oxides
oxides['sum'] = oxides.sum(axis=1) #manually create a sum column, where we total the oxide values
oxides.head() # Visualise the result. Oxides sum to 100!


In [None]:
#TAS diagrams plot alkalis versus silica. Let's create a new column to contain total Na2O + K2O
df_normalised_oxides["Na2O + K2O"] = df_normalised_oxides["Na2O"] + df_normalised_oxides["K2O"]

#Other ways to execute the same calculation:
df_normalised_oxides["Na2O + K2O"] = df_normalised_oxides.eval("Na2O + K2O")
df_normalised_oxides["Na2O + K2O"] = df_normalised_oxides.apply(lambda x: x["K2O"]+x["Na2O"], axis=1)

df_normalised_oxides.head()

Pyrolite provides templates for a number of classical geochemistry diagrams. Let's use pyrolite to make a TAS plot.

In [None]:
from pyrolite.plot.templates import TAS # Middlemost(1994)
from matplotlib.patches import Patch

#Let's colour the data according to the sample material description n the dataframe
rocknames = df_normalised_oxides['sample_material'].unique()
named_colours = ["blue", "orange", "green", "red", "purple"] #these are matplotlib named colours. There are 5 of them, as there are 5 unique rock names in sample_material
color_to_rock=zip(rocknames, named_colours) #zip returns an iterator of tuples. 
color_to_rock_dictionary = dict(color_to_rock) #this converts the zip to a dictionary, with rock name as key and colour as the value

#Now we can use the TAS template provided by pyrolite
ax = TAS(
    linewidth = 0.5, 
    add_labels = True, 
    which_labels = 'volcanic'
)

ax.scatter(
    df_normalised_oxides["SiO2"],
    df_normalised_oxides["Na2O + K2O"],
    c= df_normalised_oxides["sample_material"].map(color_to_rock_dictionary) #we will use the color_to_rock dictionary to MAP the colours to the datapoints in the dataframe, using the rock name as a key
)

legend_elements = [
    Patch(facecolor=color, label=rock)
    for rock, color in color_to_rock_dictionary.items()
]

ax.legend(handles=legend_elements, title='Sample material')

plt.show()   


We can see that the rock descriptions don't always agree with the TAS classification. Pyrolite also provides the utility to do TAS classification, which will allow us to classify the data points according to which TAS field they fall into. This means we will be able to assign the TAS classification to the data points, in the form of at dataframe column.

In [None]:
from pyrolite.util.classification import TAS as TAS_clf

TAS_classifier = TAS_clf() #instantiate an instance of the TAS classifier
df_normalised_oxides["TAS"] = TAS_classifier.predict(df_normalised_oxides) # assign the classification to a new column called 'TAS'
df_normalised_oxides["Rocknames"] = df_normalised_oxides.TAS.apply(lambda x: TAS_classifier.fields.get(x, {"name": None})["name"][0])
df_normalised_oxides.head()

In [None]:
#Let's plot the classified data and colour by the new column containing the TAS classification
ax = TAS(linewidth = 0.5, add_labels = True)
sns.scatterplot(
    x=df_normalised_oxides["SiO2"], 
    y=df_normalised_oxides["Na2O + K2O"], 
    ax=ax, 
    hue = df_normalised_oxides['Rocknames'], 
    s = 70, 
    alpha = 0.5, 
    marker = 'o'
)
plt.show()  

Some of the rock samples do not fall in a TAS field and are thus classified as N/A. Let's remove them and replot the TAS diagram.

In [None]:
df_classified = df_normalised_oxides[df_normalised_oxides['Rocknames'] !='N/A']

ax = TAS(linewidth = 0.5, add_labels = True)
sns.scatterplot(
    x=df_classified["SiO2"], 
    y=df_classified["Na2O + K2O"], 
    ax=ax, 
    hue = df_classified['Rocknames'], 
    s = 70, 
    alpha = 0.5, 
    marker = 'o'
)
plt.show()  


We can define our own legend and color scheme on the TAS diagram. Matplotlib has a long list of [named colors](https://matplotlib.org/stable/gallery/color/named_colors.html), but you can also set colors as RGB values or HEX values. There is also a wide variety of marker types, as well as it being possible to create custom markers.

In [None]:
#Let's create our own colour-to-rocktype mapping (like a legend)
named_colours = [
    "blue", "orange", "green", "red", "purple",
    "brown", "pink", "gray", "olive", "cyan",
    "gold", "teal", "navy", "goldenred", 'darkseagreen'
]

unique_rocktypes = df_classified["Rocknames"].unique()
colour_map = dict(zip(unique_rocktypes, named_colours))

In [None]:
# Now that we have the mapping, we can create a legend
# Usually legends are fairly simple. In this case, because we're using a pyrolie function that calls matplotlib, it's a bit harder.
# We will therefore manually create the legend. In most cases, this isn't required. But it's useful to know it can be done!

from matplotlib.patches import Patch #patches are 2D shapes. This will allow us to draw rectangles filled with colour, to make a legend

ax = TAS(linewidth = 0.5, add_labels = True) # get the TAS plot template

ax.scatter( #plot the data on the templae
    x=df_classified["SiO2"], 
    y=df_classified["Na2O + K2O"], 
    c = df_classified["Rocknames"].map(colour_map), 
    s = 50, 
    alpha = 0.5, 
    marker = 'o'
)

legend_elements = [
    Patch(facecolor=colour, label=rocktype)
    for rocktype, colour in colour_map.items()
]

ax.legend(handles=legend_elements, title="TAS Rocktype")
plt.show()  

## Example 3.4: Ternary plots with pyrolite

Pyrolite has functionality to product ternary scatter plots. Simply provide the three fields to plot, and it will understand that you're creating a ternary diagram.

In [None]:
from matplotlib.lines import Line2D

df_classified.loc[:,"FeOt"] = df_classified.loc[:,"Fe2O3"] * 0.8998

fig,ax = plt.subplots()

ax = df_classified[["FeOt","Na2O + K2O", "MgO"]].pyroplot.scatter(
    marker="o",
    s=50,
    alpha=0.5,
    ax=ax,
    c = df_classified["Rocknames"].map(colour_map),
  )
ax.grid()
#ax.grid(axis="t", linestyle="--") #
#ax.set_ternary_lim(0.0, 0.5, 0, 1, 0.0, 1)  # top axis min  # top axis max  # left axis min  #left axis max  # right axis min  # right axis max
plt.tight_layout()
plt.show()

In Python, there is often more than one way to achieve an objective! Pyrolite is not the only way to plot a ternary diagram. Let's plot the ternary diagram using some alternative packages:  
- 'mpltern': In fact, pyrolite uses mpltern, which is built on matplotlib.
- 'plotly': Plotly creates interacive plots, and can also make ternary diagrams  



In [None]:
#Ternary plot with mpltern (this is what pyrolite uses, under the hood)

import matplotlib.pyplot as plt
import mpltern
ax = plt.subplot(projection="ternary", ternary_sum=100.0)

ax.scatter(
    df_classified["FeOt"],
    df_classified["Na2O + K2O"], 
    df_classified["MgO"],
    marker="o",
    s=50,
    alpha=0.5,
    c = df_classified["Rocknames"].map(colour_map),
  )
ax.set_tlabel("FeOt")
ax.set_llabel("Na2O + K2O")
ax.set_rlabel("MgO")

ax.grid()

plt.show()


In [None]:
# Interactive ternary plot using plotly
import plotly.express as px

fig = px.scatter_ternary(
    df_classified, 
    a="FeOt", 
    b="Na2O + K2O", 
    c="MgO",
    color='Rocknames',
    color_discrete_map = colour_map #I've added colour_map to set the colouring to match the previous examples. Without this, random colours would be assigned.
)

fig.update_traces(marker=dict(size=12, opacity=0.5)) # Here I'm setting the marker size to match that of the previous examples.

fig.show()

 <font color="green"><b>TRY IT YOURSELF:</b> Plot a ternary diagram of any 3 oxides, using pyrolite, mpltern, or plotly. </font>  

In [None]:
# Write your code here

## Example 3.5: Spidergrams with pyrolite

Let's plot a spidergram of the basalts, only. We'll use the 'subakalic basalt' classification assigned from the TAS diagram.  
Well normalise the REE using pyrolite's built in datasets. In this case, we'll use the Chondrite_SM89 model of chondrite from Sun & McDonough (1989).

In [None]:
#Filter out the subalkalic basalts
basalt = df_classified[df_classified['Rocknames'].isin(['Subalkalic \nBasalt'])]

#Subset to get REE's only (optional)
basalt_REE = basalt.pyrochem.REE
basalt_REE_normalised = basalt_REE.pyrochem.normalize_to("Chondrite_SM89")
basalt_REE_normalised.head()

We can immediately see that there is at least one row where all values are NAN. Let's remove rows where all values are NAN

In [None]:
basalt_REE_normalised.dropna(axis = 0, how='all', inplace = True)
basalt_REE_normalised.head()

We can also have a quick look at how many null values are in our columns

In [None]:
basalt_REE_normalised.info()

Time to plot the spidergram

In [None]:
fig, ax = plt.subplots()

basalt_REE_normalised.pyroplot.spider(ax=ax)

Let's plot the data again, but this time, we'll colour for MgO.  
This means we need to join the normalised REE dataframe to the oxide data we previously had.  
We can use pandas.merge to merge the dataframe, using dataframe index as key.

In [None]:
Basalt_REE_and_oxide = basalt_REE_normalised.merge(df_classified.pyrochem.oxides, right_index=True, left_index=True, how='inner')
Basalt_REE_and_oxide.head()

In [None]:
from matplotlib import colors, cm

ax = Basalt_REE_and_oxide.pyroplot.spider(
    cmap='plasma_r',
    alpha=0.5,
    color=Basalt_REE_and_oxide["MgO"]
)

cmap = cm.plasma
norm = colors.Normalize(
vmin=Basalt_REE_and_oxide["MgO"].min(),
vmax=Basalt_REE_and_oxide["MgO"].max())

fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label="MgO (wt%)")

 <font color="green"><b>TRY IT YOURSELF:</b> Plot a spidergram of the Basaltic Andesite samples and colour by K2O </font>  

In [None]:
# Write your code here

In [None]:
# Expand to see how it's done

# 1) filter the dataframe for the andesite samples. 
andesite = df_classified[df_classified['Rocknames'].isin(['Basaltic Andesite'])]

# 2) Select the REE data
andesite_REE = andesite.pyrochem.REE

# 3) Normalise the REE data
andesite_REE_normalised = andesite_REE.pyrochem.normalize_to("Chondrite_SM89")

# 4) Merge the K2O data onto the normalised REE data, using index as key
andesite_REE_and_oxide = andesite_REE_normalised.merge(df_classified[['K2O']], right_index=True, left_index=True, how='inner')

# 5) Plot the spidergram
from matplotlib import colors, cm

ax = andesite_REE_and_oxide.pyroplot.spider(
    cmap='plasma_r',
    alpha=0.5,
    color=andesite_REE_and_oxide["K2O"]
)

cmap = cm.plasma
norm = colors.Normalize(
vmin=Basalt_REE_and_oxide["K2O"].min(),
vmax=Basalt_REE_and_oxide["K2O"].max())

fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label="K2O (wt%)")

