# Geochemistry

We can harness the pandas capabilities of working with tabular data with the rest of the Python ecosystem. The library `pyrolite` developed by Morgan Williams (CSIRO) is very useful for working with geochemical data

As with anything, examples will make it clearer

In [None]:
# For data processing
import pandas as pd
import numpy as np
from pyrolite.geochem.norm import all_reference_compositions,get_reference_composition

# For visualization
import seaborn as sns
import pyrolite.plot
import matplotlib.pyplot as plt
import os


## Import and clean datasets
We're working with data from the Antrim Volcanic Plateau in the NT

In [None]:
Litho = pd.read_excel(r"https://raw.githubusercontent.com/pierosampaio/PythonWorkshop/refs/heads/main/GSA-WA%202025%20Data/AntrimData/Antrim%20DH%20geology.xls")
Maj = pd.read_excel(r"https://raw.githubusercontent.com/pierosampaio/PythonWorkshop/refs/heads/main/GSA-WA%202025%20Data/AntrimData/ANT1&2%20original%20majors.xls",
                     header = 1)
TE = pd.read_excel(r"https://raw.githubusercontent.com/pierosampaio/PythonWorkshop/refs/heads/main/GSA-WA%202025%20Data/AntrimData/Ant1&2%20original%20traces.xls")


TE = TE.rename({"ELEMENTS":"ID"}, axis = 1)
Maj_LOD = Maj.iloc[1,3:]
TE_LOD = TE.iloc[1,3:]
TE_Dup = TE.iloc[36:38,:]
TE_Standards = TE.iloc[40:43]

Maj = Maj.iloc[4:,:]
TE = TE.iloc[5:34,:]
Maj["CentrePoint"] = Maj.Interval.str.split("-").apply(lambda x: (float(x[0]) + float(x[1]))/2).round(2)
TE["CentrePoint"] = TE.Interval.str.split("-").apply(lambda x: (float(x[0]) + float(x[1]))/2).round(2)
Maj = Maj.rename({'AUSQUEST 22/11/02   251102':"ID"}, axis = 1)
TraceElements = TE.pyrochem.list_elements


 

In [None]:
Standards = pd.DataFrame(
    [  
        [
            340.,122.,2.8,12,7,18.2,14.2,2.,14.,10.6,4.3,
            58.,2.1,13.,57.,9.,15.,55.,12.7,1191.,0.9,2.6,
            1.4,2.3,0.8,8.,119.,14.8,93.,517.
        ]
    ],
    columns=TraceElements,
    index=["SY-4"]
)
Standards.index.name = "ID"

In [None]:
def RSD(el1,el2):
    return np.std([el1,el2])/np.mean([el1,el2]) * 100

def check_duplicates(dataset, duplicates, columns, ID_column = "ID", plot = True):

    Dup_IDs = duplicates["ID"].values
    Dup_checks = np.empty(shape = (len(Dup_IDs),len(columns)))
    for i,idx in enumerate(Dup_IDs):
        data = dataset.loc[dataset["ID"]==idx,columns]
        dups = duplicates.loc[duplicates["ID"]==idx,columns]
        Dup_checks[i] = list(map(RSD, data.values[0], dups.values[0]))\
        

    if plot:
        fig, ax = plt.subplots(figsize = (10,4))

        for i in range(len(Dup_checks)):
            ax.scatter(
                np.arange(Dup_checks.shape[1]),
                Dup_checks[i],
                label = Dup_IDs[i]
            )

        ax.axhline(
            5, ls = "--", color = "black"
        )

        ax.set_xticks(
            np.arange(TE.loc[:,"Ba":"Zr"].values.shape[1]),
            labels = TE.loc[:,"Ba":"Zr"].columns
        );

        ax.set(xlabel = "Element", ylabel = "RSD (%)");
        ax.legend();

def check_standards(Meas_standards, standards, columns, ID_column = "ID"):

    Std_ids = standards.index
 
    _, ax = plt.subplots()

    for i, idx in enumerate(Std_ids):

        ax.scatter(
            Meas_standards.loc[Meas_standards[ID_column]==idx,columns].values[0],
            standards.loc[standards.index==idx,columns].values[0],
            label = Std_ids[i]
        )
    
    ax.plot(
        [0,max(Meas_standards.loc[Meas_standards[ID_column]==idx,columns].values[0].max(),
               standards.loc[standards.index==idx,columns].values[0].max())],
        [0,max(Meas_standards.loc[Meas_standards[ID_column]==idx,columns].values[0].max(),
               standards.loc[standards.index==idx,columns].values[0].max())],
        color = "black", ls = ":"
    )

    ax.set(
        xlabel = "Measured", ylabel = "Reference"
    );
    ax.legend();



In [None]:
check_duplicates(TE,TE_Dup,TraceElements)

In [None]:
check_standards(TE_Standards, Standards, TraceElements)

## Merging the major and trace element datasets

In [None]:
df = Maj.merge(TE[["ID",*TE.pyrochem.list_elements]], on = "ID")
df

In [None]:
df.info()

We can see the element columns are all classified as containing object types. That is basically a catch-all bag of mixed data types. We want the data to be numeric. The issue is the existence of values below LOD, denoted as "<", which python cannot interpret as a number. There are multiple ways to overcome this

In [None]:
# Identifying data < LOD
df = df.replace(
    to_replace=[
        x for x in df.values.flatten() if str(x).startswith("<")
    ],
    value=np.nan
) # we'll replace values < LOD for nans; that is an option. One could also choose to treat them as half the LOD, etc.
# It is important to note that any choice will introduce some sort of bias to the data

# We can then convert the column types to floating point numbers 
df.pyrochem.compositional = df.pyrochem.compositional.astype("float")
df.pyrochem.compositional.info()



## Now we can start plotting our data
### Harker plot
$MgO$ vs. major elements

In [None]:
fig, ax = plt.subplots()

ax.scatter(
    df["MgO"], # x-variable
    df["Al2O3"], # y-variable
    marker="o",
    color="blue"
)

ax.set_xlabel("$MgO$ (wt%)")
ax.set_ylabel("$Al_{2}O_{3}$ (wt%)")

plt.show()

In [None]:
def Harker(df, element, ax, **kwargs):  #kwargs are the style changes that will be passed on to the ax.scatter() call

  x = df["MgO"],
  y = df[element],

  ax.scatter(
      x,y,
      **kwargs
  )
  ax.set(xlabel="MgO", ylabel=element)


fig, ax = plt.subplots()
Harker(df,"Al2O3",ax, color="red") # color is a kwarg

In [None]:
# We will use 8 major elements, which we will split up in 2 rows and 4 columns
fig, axes = plt.subplots(2, 4, figsize=(15, 6))


elements = ["Al2O3", "Fe2O3", "CaO", "SiO2",
            "TiO2", "Na2O", "K2O", "P2O5"]

for ax, element in zip(axes.flatten(), elements):
  Harker(df, element, ax)

plt.tight_layout()

#### Optional questions
- How can we use a different element as x-axis?
- Is there a way to make the x-label appear only on the lower plots?

## Calculating new variables:
With pandas we can not only visualize tabulat data, but also evaluate expressions and calculate new variables, among other things. For this example we will calculate Total Alkalis ($K_{2}O + Na_{2}O$) so we can later use this variable for classification in a TAS diagram. We can achieve this in multiple ways

In [None]:
#df["Na2O + K2O"] = df["Na2O"] + df["K2O"]
#df["Na2O + K2O"] = df.eval("Na2O + K2O")
df["Na2O + K2O"] = df.apply(lambda x: x["K2O"]+x["Na2O"], axis=1)
df["Na2O + K2O"].head()

In [None]:
REE = df.pyrochem.REE

REE["La/Sm"] = REE.eval("La/Sm")
# using the .eval() method to calculate a new ratio
# REE["La/Sm"] = REE["La"].values/REE["Sm"].values would also work


REE["Nd/10"] = REE["Nd"].apply(lambda x: x/10)
# using the .apply method and an anonymous function to calculate a new variable
# REE["Nd/10"] = REE.eval("Nd/10") also works
# REE["Nd/10"] = REE["Nd"].values/10 also works


REE["Gd/Yb"] = REE.eval("Gd/Yb")

Now we can use the TAS template provided by `pyrolite`

In [None]:
from pyrolite.plot.templates import TAS, pearceThNbYb, pearceTiNbYb

In [None]:
ax = TAS(linewidth = 0.5, add_labels = True)

ax.scatter(
    df["SiO2"],
    df["Na2O + K2O"]
)

plt.show()

In [None]:
## other templates
fig, (ax1, ax2) = plt.subplots(2,1, figsize = (6,8))

pearceThNbYb(ax1)
pearceTiNbYb(ax2)

In [None]:
from pyrolite.util.classification import TAS as TAS_clf


In [None]:
cm = TAS_clf()

fig, ax = plt.subplots(1)
cm.add_to_axes(ax, alpha=0.5, linewidth=0.5, zorder=-1, add_labels=True)
df[["SiO2", "Na2O + K2O"]].pyroplot.scatter(ax=ax, axlabels=False)
plt.show()

In [None]:

### This is a technicality that is not strictly necessary

def remove_newlines(data):
    # Iterate through the dictionary keys and values
    for key, value in data.items():
        # Check if the value has a "name" key and if the name is a list
        if isinstance(value['name'], list):
            # Remove '\n' from each string in the list
            value['name'] = [name.replace('\n', ' ') for name in value['name']]
        elif isinstance(value['name'], str):
            # If it's a string (in cases like 'nan'), remove '\n' from the string
            value['name'] = value['name'].replace('\n', ' ')

remove_newlines(cm.fields)

In [None]:
df["TAS"] = cm.predict(df)
df["Rocknames"] = df.TAS.apply(lambda x: cm.fields.get(x, {"name": None})["name"][0])
df["Rocknames"].sample(10)  # randomly check 10 sample rocknames

In [None]:
fig, ax = plt.subplots(1)
cm.add_to_axes(ax, alpha=0.5, linewidth=0.5, zorder=-1, add_labels=True)
df[["SiO2", "Na2O + K2O"]].pyroplot.scatter(ax=ax, c=df["TAS"]) # Just set a random color for each unique code

We can define our own legend and color scheme. `matplotlib` has a long list of [named colors](https://matplotlib.org/stable/gallery/color/named_colors.html), but you can also set colors as RGB values or HEX values. There is also a wide variety of marker types, as well as it being possible to create custom markers.

In [None]:
# First let's check which lithologies are present
df["Rocknames"].value_counts()

In [None]:
Colors = [
    "forestgreen",
    "slategray"
]
Color_dict = dict(zip(df["Rocknames"].unique(),Colors))
Color_dict

## Filtering:
We can also use pandas to filter datasets quickly using logic operators.

In [None]:
fig, ax = plt.subplots()

ax.grid(alpha=0.5)

# now we just need to loop over the rock names. Because the symbology is
# associated to the name we can use that to call the unique properties of
# each lithology using the dictionaries

for rock in df["Rocknames"].unique():
  ax.scatter(
      df.loc[df["Rocknames"]==rock,"MgO"], #filtering the dataset
      df.loc[df["Rocknames"]==rock,"CaO"],
      color=Color_dict[rock],
      marker="D",
      s=75,
      label=rock
  )

ax.set(xlabel = "MgO", ylabel = "CaO")

ax.legend(title = "Rock name", frameon = True)

In [None]:
# We can also use numerical filters and combine different filters using & (and), and | (or)
df_filtered = df.loc[(df["MgO"] >= 5) & (df["La"] >= 20)]
df_filtered

In [None]:
# SQL-type query
df.query("MgO >= 5")

In [None]:
# For columns with string entries we can check if the string
# contains a certain sub-string
df.loc[df["Rocknames"].str.contains("Trachy")]

In [None]:
from matplotlib.cm import get_cmap
import matplotlib.colors as mcolors

### Ternary plots with pyrolite

In [None]:
df["FeOt"] = df["Fe2O3"].values * 0.89998


fig,ax = plt.subplots()

for rock in df["Rocknames"].unique():
  df.loc[df["Rocknames"]==rock,["FeOt","Na2O + K2O","MgO"]].pyroplot.scatter(
      color=Color_dict[rock],
      marker="D",
      s=75,
      alpha=0.7,
      edgecolor="black",
      ax=ax
  )

fig.suptitle("AFM (Irvine and Baragar, 1973)")

plt.tight_layout()
plt.show()

### Spidergrams

In [None]:
# pyrolite also helps with normalization

df.pyrochem.REE.pyrochem.normalize_to("Chondrite_SM89").pyroplot.spider(alpha=0.3)

In [None]:
fig, ax = plt.subplots()

for rock in df["Rocknames"].unique():
  df.loc[df["Rocknames"]==rock].pyrochem.REE.pyrochem.normalize_to("Chondrite_SM89").pyroplot.spider(
      ax=ax,
      color=Color_dict[rock],
      marker="D",
      alpha = 0.5
  )

ax.set_ylim(0.9,500)

In [None]:
# We can adopt the same workflow for other incompatible elements
df["Ti"] = df[["TiO2"]].pyrochem.convert_chemistry(to=["Ti"]).values*1e4
df["P"] = df[["P2O5"]].pyrochem.convert_chemistry(to=["P"]).values*1e4
df["K"] = df[["K2O"]].pyrochem.convert_chemistry(to=["K"]).values*1e4



TraceElementList = [
    "Rb","Ba","Th","U","Nb","K","La","Ce","Pr",
    "Sr","P","Nd","Sm","Zr","Hf","Eu","Ti","Dy",
    "Y","Yb","Lu"
]

fig, ax = plt.subplots();

for rock in df["Rocknames"].unique():
  df.loc[df["Rocknames"]==rock,TraceElementList].pyrochem.normalize_to("NMORB_SM89").pyroplot.spider(
      ax=ax,
      color=Color_dict[rock],
      marker="D",
      unity_line=True,
      alpha = 0.5
  );

ax.set_ylim(0.1,300);

We also have the option to color according to a continuous variable

In [None]:
from matplotlib import colors, cm

ax = df.pyrochem.REE.pyrochem.normalize_to("Chondrite_SM89").pyroplot.spider(
    cmap='plasma_r',
    alpha=0.5,
    color=df["MgO"]
)

cmap = cm.plasma
norm = colors.Normalize(
    vmin=df["MgO"].min(),
    vmax=df["MgO"].max())

fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label="MgO (wt%)")

In [None]:
df.Rocknames

### Pairplot for quickly exploring correlations

In [None]:
df.pyrochem.compositional.columns

cols = ["MgO", "SO3", "Co", "Ni", "Cu", "Zn", "Rocknames"]

import seaborn as sns

sns.pairplot(df[cols], hue = "Rocknames", palette = Colors, corner = True)

As the drilling campaign was done for Ni-Cu exploration, we can evaluate the grades of these elements at each drill core

In [None]:
fig, axes = plt.subplots(1,len(df.Location.unique()), figsize = (6,10), sharey = True)



for ax, core in zip(axes, df.Location.unique()):
    dff = df.loc[df.Location == core].sort_values("CentrePoint")

    ax.plot(dff.Ni, dff.CentrePoint, marker="o", label="Ni (ppm)")
    ax.plot(dff.Cu, dff.CentrePoint, marker="o", label="Cu (ppm)")
    ax.axvline(100, ls="--", label="grade threshold", color="black")


    ax.set_title(core)
    ax.set_xlabel("Ni or Cu (ppm)")

# Legend
axes[0].legend(bbox_to_anchor=(0.01, 0.1))

axes[0].invert_yaxis()
axes[0].set_ylabel("Depth (m)")




We can also use python to visualize the stratigraphic chart

In [None]:
ANT002 = Litho.loc[Litho.hole == "ANT002"]
# Create column for top and base
ANT002 = ANT002.rename({"From":"Top","To":"Base"}, axis = 1)
ANT002["Thickness"] = ANT002.eval("Base - Top").round(2)

ANT002["summary geol"].unique()


In [None]:
import matplotlib.patches as mpatches

In [None]:

# Define colors for each unique lithology
unique_geol = ANT002["summary geol"].unique()
colors = [
    "sandybrown","lavender","navajowhite","darkseagreen","khaki",
    "powderblue","yellowgreen","slateblue"
]
color_map = {geol: colors[i % len(colors)] for i, geol in enumerate(unique_geol)}

# Plot the stratigraphic column
fig, (ax,ax2) = plt.subplots(1,2,figsize=(5, 12), sharey = True)

for i, row in ANT002.iterrows():
    ax.fill_betweenx([row["Top"], row["Base"]], x1=0, x2=1, color=color_map[row["summary geol"]])
    if row["summary geol"] == "basalt" and "Amyg" in row["Description"]:
        ax.axhline(row["Top"], color = "black", xmax = 0.7)
        ax.text(1.05, (row["Top"] + row["Base"]) / 2, "amg", va='top', fontsize=8)

# Invert y-axis to have the top at the top
ax.set_ylim(ANT002["Base"].max(), 0)
ax.set_xlim(0, 1.5)
ax.set_ylabel("Depth (m)")
ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
# Add legend
legend_patches = [mpatches.Patch(color=color_map[geol], label=geol) for geol in unique_geol]
ax2.legend(handles=legend_patches, bbox_to_anchor=(1.5, 1), loc='upper left')
ax.set_title("ANT002")


dff = df.loc[df.Location == "ANT002"].sort_values("CentrePoint")

ax2.plot(dff.Ni, dff.CentrePoint, marker="o", label="Ni (ppm)", color = "tab:red")

ax2.spines['left'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.tick_params(axis='y', which='both', bottom=False, top=False, labelbottom=False, labeltop = True)
ax2.set_title("Ni (ppm)")
ax2.grid(axis = "x")


plt.tight_layout()
plt.show()