In [1]:
import os
import numpy as np

# import matplotlib
# import matplotlib.pyplot as plt
# import matplotlib.tri 
# import mpl_toolkits.mplot3d
# My conda does not recognise this ???


from rdkit import Chem
from rdkit.Chem import Draw
import rdkit.Chem.AllChem as Chem
import rdkit.Chem.AllChem as AllChem

from gtda.homology import VietorisRipsPersistence

import structures as st # My own module

In [2]:
def persistence_diagrams(coords):
    # Track connected components, loops, and voids
    homology_dimensions = [0, 1, 2]

    # Collapse edges to speed up H2 persistence calculation!
    persistence = VietorisRipsPersistence(
        metric="euclidean",
        homology_dimensions=homology_dimensions,
        n_jobs=1,
        collapse_edges=True,
    )
    
    reshaped_coords=coords[None, :, :]
    diagrams_basic = persistence.fit_transform(reshaped_coords)
    return coords, diagrams_basic

"""Persistent-homology–related plotting functions and classes."""
# License: GNU AGPLv3

import numpy as np
import plotly.graph_objs as gobj


def plot_diagram(diagram, homology_dimensions=None, plotly_params=None):
    """Plot a single persistence diagram.

    Parameters
    ----------
    diagram : ndarray of shape (n_points, 3)
        The persistence diagram to plot, where the third dimension along axis 1
        contains homology dimensions, and the first two contain (birth, death)
        pairs to be used as coordinates in the two-dimensional plot.

    homology_dimensions : list of int or None, optional, default: ``None``
        Homology dimensions which will appear on the plot. If ``None``, all
        homology dimensions which appear in `diagram` will be plotted.

    plotly_params : dict or None, optional, default: ``None``
        Custom parameters to configure the plotly figure. Allowed keys are
        ``"traces"`` and ``"layout"``, and the corresponding values should be
        dictionaries containing keyword arguments as would be fed to the
        :meth:`update_traces` and :meth:`update_layout` methods of
        :class:`plotly.graph_objects.Figure`.

    Returns
    -------
    fig : :class:`plotly.graph_objects.Figure` object
        Figure representing the persistence diagram.

    """
    # TODO: increase the marker size
    if homology_dimensions is None:
        homology_dimensions = np.unique(diagram[:, 2])
        
    # Remove points with zero persistence (0 birth and 0 death)
    diagram = diagram[diagram[:, 0] != diagram[:, 1]]
    
    # Removes the homology dimension column
    diagram_no_dims = diagram[:, :2]
    
    # Checks if anything is infinite
    posinfinite_mask = np.isposinf(diagram_no_dims)
    neginfinite_mask = np.isneginf(diagram_no_dims)
    
    # If there are things within this diagram it will find the max and min where there is no infinite
    if diagram_no_dims.size:
        max_val = np.max(np.where(posinfinite_mask, -np.inf, diagram_no_dims))
        min_val = np.min(np.where(neginfinite_mask, np.inf, diagram_no_dims))
    else:
        # Dummy values if diagram is empty
        max_val = 1
        min_val = 0
        
    # Sets the range of the plot
    parameter_range = max_val - min_val
    extra_space_factor = 0.02
    
    # Checks if anything is infinite in the positive direction and if so adds a point to represent it
    has_posinfinite_death = np.any(posinfinite_mask[:, 1])
    if has_posinfinite_death:
        posinfinity_val = max_val + 0.1 * parameter_range
        extra_space_factor += 0.1
    
    # Sets the range of the plot based on the minimum and maximum values of the diagram
    extra_space = extra_space_factor * parameter_range
    min_val_display = min_val - extra_space
    max_val_display = max_val + extra_space
    
    # Makes dotted line through the plot
    fig = gobj.Figure()
    fig.add_trace(gobj.Scatter(
        x=[min_val_display, max_val_display],
        y=[min_val_display, max_val_display],
        mode="lines",
        line={"dash": "dash", "width": 1, "color": "black"},
        showlegend=False,
        hoverinfo="none"
        ))

    for dim in homology_dimensions:
        name = f"H{int(dim)}" if dim != np.inf else "Any homology dimension"
        subdiagram = diagram[diagram[:, 2] == dim]
        unique, inverse, counts = np.unique(
            subdiagram, axis=0, return_inverse=True, return_counts=True
            )
        hovertext = [
            f"{tuple(unique[unique_row_index][:2])}" +
            (
                f", multiplicity: {counts[unique_row_index]}"
                if counts[unique_row_index] > 1 else ""
            )
            for unique_row_index in inverse
            ]
        y = subdiagram[:, 1]
        if has_posinfinite_death:
            y[np.isposinf(y)] = posinfinity_val
        fig.add_trace(gobj.Scatter(
            x=subdiagram[:, 0], y=y, mode="markers + text", textposition="top center",
            hoverinfo="text", hovertext=hovertext, name=name
        ))

    fig.update_layout(
        width=500,
        height=500,
        xaxis1={
            "title": "Birth",
            "side": "bottom",
            "type": "linear",
            "range": [min_val_display, max_val_display],
            "autorange": False,
            "ticks": "outside",
            "showline": True,
            "zeroline": True,
            "linewidth": 1,
            "linecolor": "black",
            "mirror": False,
            "showexponent": "all",
            "exponentformat": "e"
            },
        yaxis1={
            "title": "Death",
            "side": "left",
            "type": "linear",
            "range": [min_val_display, max_val_display],
            "autorange": False, "scaleanchor": "x", "scaleratio": 1,
            "ticks": "outside",
            "showline": True,
            "zeroline": True,
            "linewidth": 1,
            "linecolor": "black",
            "mirror": False,
            "showexponent": "all",
            "exponentformat": "e"
            },
        plot_bgcolor="white"
        )

    # Add a horizontal dashed line for points with infinite death
    if has_posinfinite_death:
        fig.add_trace(gobj.Scatter(
            x=[min_val_display, max_val_display],
            y=[posinfinity_val, posinfinity_val],
            mode="lines",
            line={"dash": "dash", "width": 0.5, "color": "black"},
            showlegend=True,
            name=u"\u221E",
            hoverinfo="none"
        ))

    # Update traces and layout according to user input
    if plotly_params:
        fig.update_traces(plotly_params.get("traces", None))
        fig.update_layout(plotly_params.get("layout", None))

    return fig

In [3]:
simple_cubic = st.Structure(2,2,2, 1.0, 1.0, 1.0, 90, 90, 90, False, False, False)
st.plotting_struct(simple_cubic)

coords, diagrams_basic = persistence_diagrams(simple_cubic)

figure = plot_diagram(diagrams_basic[0])
figure.show()

In [4]:
from gtda.diagrams import BettiCurve

BC = BettiCurve()

X_betti_curves = BC.fit_transform(diagrams_basic)

print(X_betti_curves)

BC.plot(X_betti_curves)

[[[7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7
   7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7
   7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 0]
  [5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
   5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
   5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 0]
  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
   0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
   0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]]


In [5]:
body_cubic = st.Structure(2,2,2, 1.0, 1.0, 1.0, 90, 90, 90, False, False, True)
st.plotting_struct(body_cubic)
coords, diagrams_basic = persistence_diagrams(body_cubic)

figure = plot_diagram(diagrams_basic[0])
figure.show()

In [6]:
face_cubic = st.Structure(2,2,2, 1.0, 1.0, 1.0, 90, 90, 90, True, True, False)
st.plotting_struct(face_cubic)
coords, diagrams_basic = persistence_diagrams(face_cubic)

figure = plot_diagram(diagrams_basic[0])
figure.show()

In [7]:
hexagonal = st.Structure(2,2,2, 1.0, 1.0, 1.0, 90, 90, 120, False, False, False)
st.plotting_struct(hexagonal)
coords, diagrams_basic = persistence_diagrams(hexagonal)

figure = plot_diagram(diagrams_basic[0])
figure.show()

In [8]:
diagrams_basic

array([[[0.        , 1.        , 0.        ],
        [0.        , 1.        , 0.        ],
        [0.        , 1.        , 0.        ],
        [0.        , 1.        , 0.        ],
        [0.        , 1.        , 0.        ],
        [0.        , 1.        , 0.        ],
        [0.        , 1.11803401, 0.        ],
        [1.11803401, 1.5       , 1.        ],
        [1.        , 1.41421354, 1.        ],
        [1.        , 1.41421354, 1.        ],
        [0.        , 0.        , 2.        ]]])