# Fatbox for Analogue Modelling - Fault extraction from PIV-derived elevation, batch - tuto 2/6

This example describes how to extract a 2-D fault network from an analogue model simulating orthogonal continental rifting, as done in Tutorial B1, but for all timesteps in our dataset.

The analog models are based on the study of Molnar et al. (2017)

### Load packages
To run the toolbox, we will need a couple of packages including the toolbox itself. So let's install it:

In [1]:
from google.colab import drive
drive.mount('/content/drive')
!pwd

Mounted at /content/drive
/content


In [2]:
!pip install earthpy
!pip install cv-algorithms
!pip install vtk

Collecting earthpy
  Downloading earthpy-0.9.4-py3-none-any.whl.metadata (9.2 kB)
Collecting rasterio (from earthpy)
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio->earthpy)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio->earthpy)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio->earthpy)
  Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl.metadata (6.5 kB)
Downloading earthpy-0.9.4-py3-none-any.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m34.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7

Now we can load the python packages that we need:

In [3]:
import numpy as np
import networkx as nx
import cv2
import pickle

import copy

import math
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1 import make_axes_locatable

from skimage import data, io, filters, measure, feature, color, morphology
from skimage.morphology import skeletonize
from skimage.util import invert, img_as_ubyte

from scipy.spatial import distance_matrix
from scipy import interpolate
from scipy.interpolate import griddata
from scipy.signal import fftconvolve, savgol_filter
from scipy import ndimage as ndi

from numpy import array

from ipywidgets import Layout, interactive, widgets
from tqdm import tqdm


import earthpy as et
import earthpy.spatial as es
import earthpy.plot as ep

from pathlib import Path #The gestion of Path with pathlib allows for universal use.
import os
path_folder=Path('/content/drive/MyDrive/Fatbox')
path_modules=path_folder/'modules'
#print(path_folder) #make sure path_folder = '/Fatbox/modules'
os.chdir(path_modules) # make modules as working directory
#type pwd in console and make sure it is '/Fatbox/modules'

import preprocessing
import metrics
import plots
import utils
import structural_analysis
import edits


save_plots=False
loc_plots=Path(path_folder/'tutorials'/'analog'/'plots'/'B2')
my_dpi=100

if not (path_folder/'tutorials'/'analog'/'plots'/'B2').is_dir():
    (path_folder/'tutorials'/'analog'/'plots'/'B2').mkdir(parents=True,exist_ok=True)


### Load data
In this example, we have a dataset consisting of 16 DEM numpy arrays obtained from cleaning and cropping the data exported from the PIV software

In [4]:
path_input=Path(path_folder/'tutorials'/'analog'/'data_analog')

dems = [] #all dem arrays will be stored here

for n in range(1,17): #because there are 16 arrays, dem1.npy through dem16.npy
  data = np.load((path_input/str('dem' + str(n) + '.npy')))
  data = data[:,:-1] #added this line because the canny edge detection algorithm was generating an unreal 'fault' in the right hand side of the arrays
  dems.append(data)



## 1. Fault Extraction

First let's plot our raw data to see what are we looking at, for all timesteps:

In [5]:
#First we define a few variables to calculate a hillshade for our DEMs
az = 300 #azimuth of the light for the hillshade
alt = 1 #altitute (degrees) of the light for the hillshade

def f(time):
    plt.figure(figsize=(6,10)) #set up plot
    mydem = dems[time] #get the dem values for each timestep
    hillshade = es.hillshade(mydem, azimuth=az, altitude=alt) #create hillshade for better visualisation

    plt.title('DEM - Raw')
    plt.imshow(mydem, cmap='gist_earth', vmin=-8, vmax=2) #try 'cubehelix' as an alternative
    plt.colorbar()
    plt.imshow(hillshade, cmap='Greys', alpha=0.3)
    plt.show()

interactive_plot = interactive(f, time=widgets.IntSlider(min=0, max=15, step=1, layout=Layout(width='700px')))
output = interactive_plot.children[-1]
output.layout.width = '800px'
interactive_plot

interactive(children=(IntSlider(value=0, description='time', layout=Layout(width='700px'), max=15), Output(lay…

Before moving on to the image analysis and fault detection, we can apply a Gaussian blur filter to smooth the surface a little bit in order to get better results (especially considering that PIV data is usually exported with certain irregularities) by reducing its noise:

In [6]:
def gaussian_blur(in_array, size):
    # expand in_array to fit edge of kernel
    padded_array = np.pad(in_array, size, 'symmetric')
    # build kernel
    x, y = np.mgrid[-size:size + 1, -size:size + 1]
    g = np.exp(-(x**2 / float(size) + y**2 / float(size)))
    g = (g / g.sum()).astype(in_array.dtype)
    # do the Gaussian blur
    return fftconvolve(padded_array, g, mode='valid')

In [7]:
blurfact = 12 #the higher this number is, the more we will smooth the DEM

blurdems = [] #create a list where blurred dems will be stored

for n in range(0,16): #because there are 16 arrays
  blurdem = gaussian_blur(dems[n], blurfact)
  blurdems.append(blurdem)

We plot the smoothed DEMs to see if hillshade effect works:

In [8]:
def f(time):
    plt.figure(figsize=(6,10)) #set up plot
    mydem = blurdems[time] #get the dem values for each timestep
    hillshade = es.hillshade(mydem, azimuth=az, altitude=alt) #create hillshade for better visualisation, same variables as before

    plt.title('DEM - Blur')
    plt.imshow(mydem, cmap='gist_earth', vmin=-8, vmax=2) #try 'cubehelix' as an alternative?
    plt.colorbar()
    plt.imshow(hillshade, cmap='Greys', alpha=0.4)
    plt.show()

interactive_plot = interactive(f, time=widgets.IntSlider(min=0, max=15, step=1, layout=Layout(width='700px')))
output = interactive_plot.children[-1]
output.layout.width = '800px'
interactive_plot

interactive(children=(IntSlider(value=0, description='time', layout=Layout(width='700px'), max=15), Output(lay…

Next step is to apply a canny edge detection, which basically consists of detecting a change in colors in an image (in this case, that change will be equivalent to a fault):

In [9]:
#These parameters can and should be modified depending on the dataset.
#You can manually vary them and compare the results!
sigma = 2.7 #lower number, more faults detected but more noise too
lowth = 0.15
highth = 0.2

def f(time):

    rawdem = dems[time] #get the raw dem for each timestep
    blurdem = blurdems[time] #get the blurred dem for each timestep
    rawhillshade = es.hillshade(rawdem, azimuth=az, altitude=alt) #create hillshade for better visualisation
    hillshade = es.hillshade(blurdem, azimuth=az, altitude=alt) #create hillshade for better visualisation
    faults = feature.canny(blurdem, sigma=sigma, low_threshold=lowth, high_threshold=highth)
    faultsraw = feature.canny(rawdem, sigma=sigma, low_threshold=lowth, high_threshold=highth)

    fig, axs = plt.subplots(1, 4, figsize=(18, 12), sharex=True, sharey=True) # set up plot

    axs[0].imshow(rawdem, cmap='gist_earth', vmin=-8, vmax=2) #plot raw DEM
    axs[0].set_title('DEM - Original', fontsize=12)
    axs[0].imshow(rawhillshade, cmap='Greys', alpha=0.4)

    axs[1].imshow(faultsraw, cmap='Greys') #plot what we detect from applying canny edge detection to raw DEM
    axs[1].set_title('Detected edges', fontsize=12)

    axs[2].imshow(blurdem,cmap='gist_earth', vmin=-8, vmax=2) #plot smoothed DEM
    axs[2].set_title('DEM - Smoothed', fontsize=12)
    axs[2].imshow(hillshade, cmap='Greys', alpha=0.4)

    axs[3].imshow(faults, cmap='Greys') #plot what we detect from applying canny edge detection to smoothed DEM
    axs[3].set_title('Detected edges', fontsize=12)

    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()


interactive_plot = interactive(f, time=widgets.IntSlider(min=0, max=15, step=1, layout=Layout(width='700px')))
output = interactive_plot.children[-1]
output.layout.width = '800px'
interactive_plot

interactive(children=(IntSlider(value=0, description='time', layout=Layout(width='700px'), max=15), Output(lay…

______
From the figure we can easily notice what is the effect of applying the blur filter to our DEM data.

While technically it involves slightly modifying the elevation data, the result after applying the canny edge detection shows that faults are identified in the exact same position as for the original DEM.

We therefore keep this step as it help us clean the data before converting the lines into a graph using Fatbox.

Note that until now, the canny edge detection is only performed within function f, which was defined to plot, compare and understand how it works. In the next series of steps, we will incorporate the canny edge detection into a for loop so that we can store the soon-to-be faults into a list.
______

## Converting 'image' of faults to network (nodes and edges)

Another important point is that while the fault traces look like if they were 1-pixel thick lines, they are not.
We must perform a skeletonization to make sure we are reducing everything to nodes (1 pixel) in the next step.

**Following cells do the following:**

1) Create an empty list where the graphs resulting from each skeletonised image will be stored

2) Loop through all timesteps to do the following actions in one single loop:

Skeletonize >> lines into points >> then points are turned into nodes of graph G

3) Plot graphs for each timestep to understand what is going on


In [10]:
# First we need to define a set of auxiliary functions to calculate strike difference,
# which is used as a parameter to clean the network

# function 1
def mystrike(x1, y1, x2, y2):
  if (x2-x1)<0:
    strike = math.degrees(math.atan2((x2-x1),(y2-y1))) + 360
  else:
    strike = math.degrees(math.atan2((x2-x1),(y2-y1)))

  #Scale to [0, 180]
  if strike<=180:
    return strike
  else:
    return strike - 180

#function 2
def calculate_mystrike(G, non):
    """ Compute strike of fault network

    Parameters
    ----------
    G : nx.graph
        Graph containing edges
    non: int
        Number of neighbors

    Returns
    -------
    G : nx.graph
        Graph containing edges with 'strike' attribute
    """

    # Assertions
    assert isinstance(G, nx.Graph), 'G is not a NetworkX graph'

    for node in tqdm(G, desc='Calculate mystrike'):


        neighbors = nx.single_source_shortest_path_length(G, node, cutoff=non)


        neighbors = sorted(neighbors.items())

        first = neighbors[0][0]
        last = neighbors[-1][0]

        # print(node)
        # print(neighbors)
        # print(first, last)



        x1 = G.nodes[first]['pos'][0]
        y1 = G.nodes[first]['pos'][1]

        x2 = G.nodes[last]['pos'][0]
        y2 = G.nodes[last]['pos'][1]


        G.nodes[node]['strike'] = mystrike(x1, y1, x2, y2)


    for edge in G.edges:
        G.edges[edge]['strike'] = (G.nodes[edge[0]]['strike'] + G.nodes[edge[0]]['strike'])/2

    return G

#function 3
def calculate_diff_strike(G, non):
    """ Compute strike difference between nodes of fault network

    Parameters
    ----------
    G : nx.graph
        Graph containing edges
    non: int
        Number of neighbors

    Returns
    -------
    G : nx.graph
        Graph containing nodes with 'strike' attribute
    """


    # Assertions
    assert isinstance(G, nx.Graph), 'G is not a NetworkX graph'


    for node in G:

        neighbors = nx.single_source_shortest_path_length(G, node, cutoff=non)
        #print(neighbors)
        #print(len(neighbors))
        strikes = [G.nodes[node]['strike'] for node in neighbors.keys()]
        #print(strikes)
        #print(len(strikes))
        if len(neighbors) > 1:
          G.nodes[node]['strike_diff'] = np.max(np.diff(strikes))
        else:
          G.nodes[node]['strike_diff'] = np.nan


    return G

# function 4

def remove_nodes_between(G, attribute, low, high):
    """ Remove node with attribute between two values

    Parameters
    ----------
    G : nx.graph
        Graph
    attribute : str
        Attribute
    low : float
        lower value
    high : float
        higher value

    Returns
    -------
    G : nx.graph
        Graph
    """
    removals = []
    # Assertions
    assert isinstance(G, nx.Graph), "G is not a NetworkX graph"

    # Calculation
    for node in G.nodes:
        if G.nodes[node][attribute] >= low and G.nodes[node][attribute] <= high:
            removals.append(node)

    G.remove_nodes_from(removals)

    return G

Now we run an **'all-in-one'** loop in which we will convert all the original DEMs to relatively clean fault networks:

In [11]:
G_dem = [] #all graphs derived from DEM arrays will be stored here

#define factors and parameters again if they weren't defined earlier

#blurfact = 7 #the higher this number is, the more blurry the image will be
#sigma = 3 #2.5 original value, worked fine, don't erase
#lowth = 0.15 #0.15 original value, worked fine, don't erase
#highth = 0.2 #0.2 original value, worked fine, don't erase

non1 = 2 #number of neighbour edges taken into account to calculate edge strike
non2 = 12 #number of neighbour edges taken into account to calculate edge strike difference
f_thr = 45 #fault length threshold

for n in range(0,16):

  print('')
  print('Timestep ' + str(n+1) + '/16')
  print('Experiment time: ' + str((n+3)*30) + ' minutes') #Start printing a little summary for each timestep

  blurdem = gaussian_blur(dems[n], blurfact) #grab dem data and apply blur filter
  faults = feature.canny(blurdem, sigma=sigma, low_threshold=lowth, high_threshold=highth) #run canny edge detection
  faultsx = img_as_ubyte(faults) #convert faults arrays to another format, otherwise the skeleton fx won't work
  skeleton = preprocessing.skeleton_guo_hall(faultsx) #skeletonize the image we created after applying the canny edge detection
  ret, markers = cv2.connectedComponents(skeleton) #get connected components from the skeletonised image
  points = preprocessing.array_to_points(skeleton) #convert lines to points

  G = nx.Graph() #create empty graph where nodes obtained from the skeletonized image will be stored

  node = 0
  for comp in range(1,ret):
    points = np.transpose(np.vstack((np.where(markers==comp))))

    for point in points:
        G.add_node(node)
        G.nodes[node]['pos'] = (point[1], point[0])
        G.nodes[node]['component'] = comp
        node += 1

  for comp in range(1,ret):

    points = [G.nodes[node]['pos'] for node in G if G.nodes[node]['component']==comp]
    nodes  = [node for node in G if G.nodes[node]['component']==comp]

    dm = distance_matrix(points, points)

    for o in range(len(points)):

        for m in range(len(points)):
            if dm[o,m]<1.5 and o != m:
                G.add_edge(nodes[o],nodes[m])

  for node in list(G.nodes()):
    if G.degree(node) == 3 or G.degree(node) == 4 or G.degree(node) == 5: #remove nodes with 3, 4, 5 edges
        edges = list(G.edges(node))
        G.add_edge(edges[0][1], edges[1][1])
        G.remove_node(node)

  G = calculate_mystrike(G, non1) #we calculate the strike for each node/edge
  G = calculate_diff_strike(G, non2) #we calculate the difference in strike between one edge and the neighbour edge at distance=non2 (number of neighbours)
  G = remove_nodes_between(G, attribute='strike_diff', low=60, high=120) #remove all nodes that have a strike difference between 'low' and 'high'

  G = edits.label_components(G)

  for node in G:
    G.nodes[node]['fault'] = G.nodes[node]['component'] #make component ID and fault ID match

  def get_fault_labels(G):
    labels=set()
    for nodex in G:
        labels.add(G.nodes[nodex]['fault'])
    return sorted(list(labels))

  fault_labels = metrics.get_fault_labels(G)

  f_len = metrics.calculate_fault_lengths(G, mode='get') #get fault lengths for cleaning them by length in the following step
  f_id_clean = np.argwhere(f_len<f_thr) #get faults with a length smaller than the threshold

  #print a little summary of each timestep to see how things are working out
  print("Total number of faults: " + str(len(f_len)))
  print("Faults shorter than " + str(f_thr) + " : " + str(len(f_id_clean)))
  print("Remaining faults: " + str(len(f_len)-len(f_id_clean)))

  for e in range(len(f_id_clean)): #number of times we are going to loop this, which is the number of faults that meet our condition
    fid = f_id_clean.item(e) #get a temporary variable, fid = fault id
    for node in G.copy():
      if G.nodes[node]['fault'] == fid: #if the fault id's match
        G.remove_node(node) #remove all nodes associated to that fault id

  G = edits.label_components(G)

  G_dem.append(G) #after finishing this loop, adds the new graph to the G_dem list

def f(time):
    H = G_dem[time] #grab graph corresponding to each timestep
    H = edits.label_components(H) #first we create the labels of the components in our graph
    mydem = blurdems[time]
    hillshade = es.hillshade(mydem, azimuth=az, altitude=alt) #create hillshade for better visualisation

    fig, axs = plt.subplots(1, 1, figsize=(6,10)) #set up plot

    axs.set_title('Clean network + DEM Hillshade')
    plt.imshow(hillshade, cmap='Greys', alpha=0.2)
    plt.imshow(mydem, cmap='gist_earth', vmin=-8, vmax=2, alpha=0.2)
    axs.set_ylim([1000,0])
    axs.set_xlim([0,400])
    plots.plot_components(H, node_size=1, ax=axs, label=True) #then we plot them
    plt.show()

interactive_plot = interactive(f, time=widgets.IntSlider(min=0, max=15, step=1, layout=Layout(width='700px')))
output = interactive_plot.children[-1]
output.layout.width = '800px'
interactive_plot


Timestep 1/16
Experiment time: 90 minutes


Calculate mystrike: 100%|██████████| 2406/2406 [00:00<00:00, 98042.31it/s]


Total number of faults: 183
Faults shorter than 45 : 176
Remaining faults: 7

Timestep 2/16
Experiment time: 120 minutes


Calculate mystrike: 100%|██████████| 2530/2530 [00:00<00:00, 61601.09it/s]


Total number of faults: 187
Faults shorter than 45 : 182
Remaining faults: 5

Timestep 3/16
Experiment time: 150 minutes


Calculate mystrike: 100%|██████████| 2501/2501 [00:00<00:00, 98164.48it/s]


Total number of faults: 190
Faults shorter than 45 : 185
Remaining faults: 5

Timestep 4/16
Experiment time: 180 minutes


Calculate mystrike: 100%|██████████| 2600/2600 [00:00<00:00, 109332.89it/s]


Total number of faults: 198
Faults shorter than 45 : 194
Remaining faults: 4

Timestep 5/16
Experiment time: 210 minutes


Calculate mystrike: 100%|██████████| 3087/3087 [00:00<00:00, 102656.98it/s]


Total number of faults: 230
Faults shorter than 45 : 226
Remaining faults: 4

Timestep 6/16
Experiment time: 240 minutes


Calculate mystrike: 100%|██████████| 3791/3791 [00:00<00:00, 93281.12it/s]


Total number of faults: 232
Faults shorter than 45 : 216
Remaining faults: 16

Timestep 7/16
Experiment time: 270 minutes


Calculate mystrike: 100%|██████████| 4995/4995 [00:00<00:00, 84157.00it/s]


Total number of faults: 283
Faults shorter than 45 : 260
Remaining faults: 23

Timestep 8/16
Experiment time: 300 minutes


Calculate mystrike: 100%|██████████| 5774/5774 [00:00<00:00, 106819.06it/s]


Total number of faults: 251
Faults shorter than 45 : 220
Remaining faults: 31

Timestep 9/16
Experiment time: 330 minutes


Calculate mystrike: 100%|██████████| 6455/6455 [00:00<00:00, 105947.02it/s]


Total number of faults: 225
Faults shorter than 45 : 200
Remaining faults: 25

Timestep 10/16
Experiment time: 360 minutes


Calculate mystrike: 100%|██████████| 7336/7336 [00:00<00:00, 109912.35it/s]


Total number of faults: 217
Faults shorter than 45 : 196
Remaining faults: 21

Timestep 11/16
Experiment time: 390 minutes


Calculate mystrike: 100%|██████████| 7963/7963 [00:00<00:00, 104255.02it/s]


Total number of faults: 224
Faults shorter than 45 : 197
Remaining faults: 27

Timestep 12/16
Experiment time: 420 minutes


Calculate mystrike: 100%|██████████| 8673/8673 [00:00<00:00, 59505.21it/s]


Total number of faults: 271
Faults shorter than 45 : 248
Remaining faults: 23

Timestep 13/16
Experiment time: 450 minutes


Calculate mystrike: 100%|██████████| 9145/9145 [00:00<00:00, 111666.25it/s]


Total number of faults: 246
Faults shorter than 45 : 218
Remaining faults: 28

Timestep 14/16
Experiment time: 480 minutes


Calculate mystrike: 100%|██████████| 9555/9555 [00:00<00:00, 62367.74it/s]


Total number of faults: 263
Faults shorter than 45 : 239
Remaining faults: 24

Timestep 15/16
Experiment time: 510 minutes


Calculate mystrike: 100%|██████████| 10093/10093 [00:00<00:00, 63226.78it/s]


Total number of faults: 255
Faults shorter than 45 : 226
Remaining faults: 29

Timestep 16/16
Experiment time: 540 minutes


Calculate mystrike: 100%|██████████| 10640/10640 [00:00<00:00, 87534.90it/s]


Total number of faults: 310
Faults shorter than 45 : 283
Remaining faults: 27


interactive(children=(IntSlider(value=0, description='time', layout=Layout(width='700px'), max=15), Output(lay…

**And here we can see a difference with the previous raw plots!**

Faults are broken down into smaller segments first, and then the fault length threshold filter is applied, ending up in a cleaner network.
Some minor tweaks will improve the issue with certain faults that are not exactly being detected, especially at early stages.

.

________

**But there is still some work to do...**

If we pay attention to the fault IDs, we will realise that they are not correlated within timesteps. Let's solve this.

## 2. Correlating between timesteps

To correlate faults across time steps, we want to check if a fault from the time step 0 is within a fault from time step 1 and vice versa. This allows us to correlate faults even if they merge or split up.

*(See [Tutorial B1] for an intro on how to do it with this dataset)*

To correlate faults across time steps, we want to check how similar each fault from time step 0 is to each fault from time step 1. This will allow us to correlate fault even if they merge or split up between time steps. To do this, we first calculate the similarity and then correlate faults if their similarity is above a certain threshold:

In [12]:
def get_nodes(G):
    labels = metrics.get_fault_labels(G)
    point_set=[]
    for label in labels:
        G_fault = metrics.get_fault(G, label)
        points = []
        for node in G_fault:
            points.append(G_fault.nodes[node]['pos'])
        point_set.append(points)
    return point_set


def compute_similarity(set_A, set_B):
      distances = np.zeros((len(set_A), len(set_B)))
      for n, pt_0 in enumerate(set_A):
          for m, pt_1 in enumerate(set_B):
              distances[n,m] = math.sqrt((pt_0[0]-pt_1[0])**2 + (pt_0[1]-pt_1[1])**2)
      return np.mean(np.min(distances, axis=1))


def correlation_slow(G_0, G_1, R):
    # A function which labels the faults in G_1 according to G_0 using the
    # minimum radius R

    # Get labels and nodes
    fault_labels_0 = metrics.get_fault_labels(G_0)
    fault_labels_1 = metrics.get_fault_labels(G_1)

    nodes_0 = get_nodes(G_0)
    nodes_1 = get_nodes(G_1)

    # Compute similarities
    smf = np.zeros((len(fault_labels_0), len(fault_labels_1)))
    smb = np.zeros((len(fault_labels_1), len(fault_labels_0)))

    for n in tqdm(range(len(fault_labels_0)), desc='   Compute similarities'):
        for m in range(len(fault_labels_1)):
            smf[n,m] = compute_similarity(nodes_0[n], nodes_1[m])
            smb[m,n] = compute_similarity(nodes_1[m], nodes_0[n])

    # Determine correlations
    correlations = set()
    for n in tqdm(range(len(fault_labels_0)), desc='   Find correlations'):
        for m in range(len(fault_labels_1)):
            if smf[n,m] < R:
                correlations.add((fault_labels_0[n], fault_labels_1[m]))
            if smb[m,n] < R:
                correlations.add((fault_labels_0[n], fault_labels_1[m]))

    return correlations, smf, smb

In [None]:
def G_to_pts(G):
    labels = metrics.get_fault_labels(G)
    point_set=[]
    for label in labels:
        G_fault = metrics.get_fault(G, label)
        points = []
        for node in G_fault:
            points.append(G_fault.nodes[node]['pos'])
        point_set.append(points)
    return point_set

def is_A_in_B(set_A, set_B, R):
      distances = np.zeros((len(set_A), len(set_B)))
      for n, pt_0 in enumerate(set_A):
          for m, pt_1 in enumerate(set_B):
              distances[n,m] = math.sqrt((pt_0[0]-pt_1[0])**2 + (pt_0[1]-pt_1[1])**2)
      if np.mean(np.min(distances, axis=1)) > R:
          return False
      else:
          return True

And let's relabel the faults in time step 1 to match time step 0:

In [13]:
def relabel(G_1, correlations):

    # A function, which relabels G_1 using the correlations
    for node in G_1:
        G_1.nodes[node]['correlated']=0

    lengths = [metrics.total_length(metrics.get_fault(G_0, correlation[0]), calculate=True) for correlation in correlations]
    lengths, correlations = zip(*sorted(zip(lengths, correlations)))


    for node in G_1:
        for correlation in correlations:
            if G_1.nodes[node]['component'] == correlation[1]:
                G_1.nodes[node]['fault'] = correlation[0]
                G_1.nodes[node]['correlated'] = 1

    max_comp = max(get_fault_labels(G_1))

    G_1_sub = nx.subgraph(G_1, [node for node in G_1 if G_1.nodes[node]['correlated']==0])
    for label, cc in enumerate(sorted(nx.connected_components(G_1_sub))):
        for n in cc:
            G_1.nodes[n]['fault'] = label+max_comp+1

    return G_1

We apply the fault correlation for all timesteps in our dataset:

In [14]:
R = 12 #Modify this value and check best results

for n, time in enumerate(range(0,len(G_dem)-1)): #loop through all timesteps, stop one before the end because it would create a size error
  print('Timesteps ' + str(time+1) + '<->' + str(time+2) + '/16')
  G_0 = G_dem[time]
  G_1 = G_dem[time+1]

  if n == 0: # Only do this for the first correlation
    for node in G_0:
      G_0.nodes[node]['fault'] = G_0.nodes[node]['component']

  for node in G_1:
    G_1.nodes[node]['fault'] = G_1.nodes[node]['component']

  correlations, smf, smb = correlation_slow(G_0, G_1, R=R)
  G_1 = relabel(G_1, correlations)

  G_dem[time+1] = G_1

Timesteps 1<->2/16


   Compute similarities: 100%|██████████| 7/7 [00:00<00:00, 58.86it/s]
   Find correlations: 100%|██████████| 7/7 [00:00<00:00, 10298.19it/s]


Timesteps 2<->3/16


   Compute similarities: 100%|██████████| 5/5 [00:00<00:00, 68.63it/s]
   Find correlations: 100%|██████████| 5/5 [00:00<00:00, 20702.39it/s]


Timesteps 3<->4/16


   Compute similarities: 100%|██████████| 5/5 [00:00<00:00, 67.99it/s]
   Find correlations: 100%|██████████| 5/5 [00:00<00:00, 48210.39it/s]


Timesteps 4<->5/16


   Compute similarities: 100%|██████████| 4/4 [00:00<00:00, 80.71it/s]
   Find correlations: 100%|██████████| 4/4 [00:00<00:00, 32140.26it/s]


Timesteps 5<->6/16


   Compute similarities: 100%|██████████| 4/4 [00:00<00:00, 13.38it/s]
   Find correlations: 100%|██████████| 4/4 [00:00<00:00, 8797.70it/s]


Timesteps 6<->7/16


   Compute similarities: 100%|██████████| 16/16 [00:04<00:00,  3.69it/s]
   Find correlations: 100%|██████████| 16/16 [00:00<00:00, 25682.69it/s]


Timesteps 7<->8/16


   Compute similarities: 100%|██████████| 23/23 [00:06<00:00,  3.44it/s]
   Find correlations: 100%|██████████| 23/23 [00:00<00:00, 31038.93it/s]


Timesteps 8<->9/16


   Compute similarities: 100%|██████████| 29/29 [00:19<00:00,  1.50it/s]
   Find correlations: 100%|██████████| 29/29 [00:00<00:00, 34812.48it/s]


Timesteps 9<->10/16


   Compute similarities: 100%|██████████| 25/25 [00:26<00:00,  1.06s/it]
   Find correlations: 100%|██████████| 25/25 [00:00<00:00, 36033.54it/s]


Timesteps 10<->11/16


   Compute similarities: 100%|██████████| 20/20 [00:33<00:00,  1.69s/it]
   Find correlations: 100%|██████████| 20/20 [00:00<00:00, 33989.50it/s]


Timesteps 11<->12/16


   Compute similarities: 100%|██████████| 26/26 [00:40<00:00,  1.55s/it]
   Find correlations: 100%|██████████| 26/26 [00:00<00:00, 37295.45it/s]


Timesteps 12<->13/16


   Compute similarities: 100%|██████████| 21/21 [00:45<00:00,  2.18s/it]
   Find correlations: 100%|██████████| 21/21 [00:00<00:00, 27430.83it/s]


Timesteps 13<->14/16


   Compute similarities: 100%|██████████| 27/27 [00:56<00:00,  2.09s/it]
   Find correlations: 100%|██████████| 27/27 [00:00<00:00, 47522.54it/s]


Timesteps 14<->15/16


   Compute similarities: 100%|██████████| 22/22 [01:05<00:00,  2.97s/it]
   Find correlations: 100%|██████████| 22/22 [00:00<00:00, 32129.07it/s]


Timesteps 15<->16/16


   Compute similarities: 100%|██████████| 27/27 [01:08<00:00,  2.53s/it]
   Find correlations: 100%|██████████| 27/27 [00:00<00:00, 33229.52it/s]


If we plot all the timesteps again, we should see matching fault IDs over time:

In [15]:
def f(time):
    G = G_dem[time] #grab graph corresponding to each timestep
    mydem = blurdems[time]
    hillshade = es.hillshade(mydem, azimuth=az, altitude=alt) #create hillshade for better visualisation

    fig, axs = plt.subplots(1,1, figsize=(6,10)) #set up plot

    axs.set_title('DEM + re-labeled components')
    axs.imshow(mydem, cmap='gist_earth', vmin=-8, vmax=2, alpha=0.4) #plot raw DEM
    plt.imshow(hillshade, cmap='Greys', alpha=0.2)
    axs.set_ylim([1000,0])
    axs.set_xlim([0,400])
    plots.plot_faults(G, node_size=1, ax=axs, label=True)
    plt.show()

interactive_plot = interactive(f, time=widgets.IntSlider(min=0, max=15, step=1, layout=Layout(width='700px')))
output = interactive_plot.children[-1]
output.layout.width = '800px'
interactive_plot

interactive(children=(IntSlider(value=0, description='time', layout=Layout(width='700px'), max=15), Output(lay…

**NICE!**


.
____________

.

This detailed tutorial of how to batch convert elevation data into a fault network comes to an end here. But we can still do a lot of things with our dataset and Fatbox functions.

.

See [Tutorial B3] to learn how to calculate and plot fault displacement vs length over time

In [None]:
# #store fault network for using in Tutorial B5

# with open('dem_graphs.pkl', 'wb') as f:
#     pickle.dump(G_dem, f) #, protocol=pickle.HIGHEST_PROTOCOL)

# with open('dems.pkl', 'wb') as p:
#     pickle.dump(blurdems, p) #, protocol=pickle.HIGHEST_PROTOCOL)