In [None]:
import pandas as pd
from gudhi import SimplexTree

In [None]:
df = pd.read_csv('/Users/anibal/Downloads/champs-scalar-coupling/df.csv') 

In [None]:
df.head()

In [None]:
len(df)

### Getting the interesting molecules

In [None]:
st = SimplexTree()
interesting_molecules = {}
for i in range(1,len(df)):
    if df['molecule_name'][i] == df['molecule_name'][i-1]:
        
        st.insert([df['atom_index_1'][i-1],
                   df['atom_index_0'][i-1]], 
                   -df['scalar_coupling_constant'][i-1])
        
    else:
        st.expansion(3)
        barcode = st.persistence(homology_coeff_field=2)
        if len(barcode) > 2 and barcode[0][0]>1:
            interesting_molecules[df['molecule_name'][i-1]] = st
        
        st = SimplexTree()   

In [None]:
len(interesting_molecules)

In [None]:
st = SimplexTree()
interesting_molecules = {}
for i in range(1,len(df)):
    if df['molecule_name'][i] == df['molecule_name'][i-1]:
        
        st.insert([df['atom_index_1'][i-1],
                   df['atom_index_0'][i-1]], 
                  -df['scalar_coupling_constant'][i-1])    
    else:
        break
        
st.expansion(3)
for index, pair in enumerate(st.get_filtration()):
    st.assign_filtration(pair[0], index)
    
barcode = st.persistence(homology_coeff_field=2)
print(st.persistence())

filtration = tuple([tuple(pair[0]) for pair in st.get_filtration()])
coboundary = get_coboundary(filtration)
reduced, triangular = get_reduced_triangular(coboundary)

barcode = get_barcode(reduced, filtration)
barcode


In [None]:
rp = ((1,), (2,), (1, 2), (3,), (1, 3), (2, 3), (4,), (1, 4), (2, 4), (1, 2, 4), (3, 4), (2, 3, 4), (5,), (1, 5), (2, 5), (3, 5), (1, 3, 5), (2, 3, 5), (4, 5), (1, 4, 5), (6,), (1, 6), (2, 6), (1, 2, 6), (3, 6), (1, 3, 6), (4, 6), (3, 4, 6), (5, 6), (2, 5, 6), (4, 5, 6))
st = SimplexTree()
for idx, spx in enumerate(rp):
    st.insert(spx, idx)
    
barcode1 = [bar for bar in st.persistence(homology_coeff_field=2) if bar[1][1]-bar[1][0]>1 or bar[0]==2]

filtration = [tuple(pair[0]) for pair in st.get_filtration()]

coboundary = get_coboundary(filtration)
reduced, triangular = get_reduced_triangular(coboundary)

barcode = get_barcode(reduced, filtration)

print(barcode1)
barcode


### Main

In [None]:
import matplotlib.pyplot as plt

for molecule, st in interesting_molecules.items():
    # let us remove the filtration value and have strongly 
    # coupled pairs appearing first
    filtration = tuple([tuple(pair[0]) for pair in st.get_filtration()])
    
    #print(filtration)
    coboundary = get_coboundary(filtration)
    reduced, triangular = get_reduced_triangular(coboundary)

    barcode = get_barcode(reduced, filtration)
    #print(len(barcode) == len(st.persistence()))
    coho_reps = get_coho_reps(barcode, reduced, triangular)
    dimension_masks = get_dimension_masks(barcode)
    #print(dimension_masks)
    
    curve = steenrod_curve(1,1,barcode, coho_reps,
                           len(filtration),reduced, dimension_masks)
    
    if sum(curve) > 1:
        plot = plt.plot(range(len(filtration)), 
                        curve, 
                        label=f'{molecule[10:]} {sum(curve)}')
        plt.legend()
        plt.show()

In [None]:
def steenrod_curve(k,d,barcode,coho_reps,number_simplices,reduced,dimension_masks):
    try:
        mask = dimension_masks[d]
    except:
        return [0]*len(filtration)
    st_reps = np.empty((number_simplices, len(coho_reps[:,dimension_masks[d]].T)))
    for col, rep in enumerate(coho_reps[:,dimension_masks[d]].T):
        st_reps[:, col] = STSQ(k,rep,filtration).reshape(-1,)

    births = [triple[0] for triple in barcode if triple[2] == d] + [number_simplices]

    st_curve = [0]*births[0]
    steenrod_matrix = st_reps
    for i, b in enumerate(births[:-1]):
        for j in range(b, births[i+1]):
            steenrod_matrix[:,:i+1] = reduce_matrix(reduced[:,:j],steenrod_matrix[:,:i+1])
            st_curve.append(get_rank(steenrod_matrix[:,:i+1]))
            
    return st_curve

In [None]:
from itertools import combinations

import numpy as np

def get_coboundary(filtration):
    """returns the coboundary matrix with respect to the canonical basis 
    defined by the filtration
    
    Parameters
    ----------
    filtration : tuple 
    
    Returns
    -------
    self : numpy.ndarray
    """    
    
    n = len(filtration)
    spx_filtration_idx = {tuple(v): idx for idx, v in enumerate(filtration)}
    boundary = np.zeros((n, n), dtype=np.bool)
    for idx, spx in enumerate(filtration):
        faces_idxs = []
        try:
            faces_idxs = [spx_filtration_idx[spx[:j]+spx[j+1:]] 
                          for j in range(len(spx))]
        except KeyError:
            pass
        boundary[faces_idxs,idx] = True
    
    coboundary = np.flip(boundary, axis=[0,1]).transpose()
    
    return coboundary

def pivot(column):
    '''returns the position of the pivot of the given column 
    or None if the zero column is passed 
    
    Parameters
    ----------
    column : numpy.narray - shape  
    
    Returns
    -------
    self : numpy.ndarray
    '''
    try:
        return max(column.nonzero()[0])
    except ValueError:
        return None

def get_reduced_triangular(matrix):
    '''R = MV'''
    n = matrix.shape[1]
    reduced = np.array(matrix)
    triangular = np.eye(n, dtype=np.bool)
    for j in range(n):
        i = j
        while i > 0:
            i -= 1
            if not np.any(reduced[:,j]):
                break
            else:
                piv_j = pivot(reduced[:,j])
                piv_i = pivot(reduced[:,i])
                
                if piv_i == piv_j:
                    reduced[:,j] = np.logical_xor(reduced[:,i], reduced[:,j])
                    triangular[:,j] = np.logical_xor(triangular[:,i], triangular[:,j])
                    i = j
                    
    return reduced, triangular

def get_barcode(reduced, filtration):
    '''barcoded ordered by first coordinate'''
    dimensions = [len(spx)-1 for spx in reversed(filtration)]
    triples = []
    all_indices = []
    for j in range(len(filtration)):
        if np.any(reduced[:,j]):
            i = pivot(reduced[:,j])
            triples.append((i,j,dimensions[i]))
            all_indices += [i,j]
    
    for i in [i for i in range(len(filtration)) if i not in all_indices]:    
        if not np.any(reduced[:,i]):
            triples.append((i,np.inf,dimensions[i]))
    
    barcode = sorted([bar for bar in triples if bar[1]-bar[0]>1])
    
    return barcode

def get_coho_reps(barcode, reduced, triangular):
    coho_reps = []
    for pair in barcode:
        if pair[1] < np.inf:
            coho_reps.append(reduced[:,pair[1]])
        if pair[1] == np.inf:
            coho_reps.append(triangular[:,pair[0]])
    return np.transpose(np.array(coho_reps))

def get_dimension_masks(barcode):
    dimension_masks = {}    
    for idx, triple in enumerate(barcode):
        dimension = triple[2]
        try:
            dimension_masks[dimension] += [idx]
        except:
            dimension_masks[dimension]  = [idx]
            
    return dimension_masks

def vector_to_cochain(vector, filtration):
    cocycle = {filtration[-i-1] for i in vector.nonzero()[0]}
    return cocycle

def cochain_to_vector(cochain, filtration):
    simplex_to_index = lambda spx: len(filtration)-filtration.index(spx)-1
    nonzero_indices = [simplex_to_index(spx) for spx in cochain]
    vector = np.zeros(shape=(len(filtration),1), dtype=np.bool)
    vector[nonzero_indices] = True
    return vector

def STSQ(k, vector, filtration):
    
    # from vector to cochain
    cocycle = vector_to_cochain(vector, filtration)
    
    # bulk of the algorithm
    answer = set()
    for pair in combinations(cocycle, 2):
        a, b = set(pair[0]), set(pair[1])
        if ( len(a.union(b)) == len(a)+k and 
        tuple(sorted(a.union(b))) in filtration ):
            a_bar, b_bar = a.difference(b), b.difference(a)
            index = dict()
            for v in a_bar.union(b_bar):
                pos = sorted(a.union(b)).index(v)
                pos_bar = sorted(a_bar.union(b_bar)).index(v)
                index[v] = (pos + pos_bar)%2
            index_a = {index[v] for v in a_bar}
            index_b = {index[w] for w in b_bar}
            if (index_a == {0} and index_b == {1} 
            or  index_a == {1} and index_b == {0}):
                u = sorted(a.union(b))
                answer ^= {tuple(u)}
    
    # cochain to vector
    st_rep = cochain_to_vector(answer, filtration)
    
    return st_rep

def get_pivots(matrix):
    n = matrix.shape[1]
    pivots = []
    for i in range(n):
        pivots.append(pivot(matrix[:,i]))
    return pivots

def reduce_vector(reduced, vector):
    num_col = reduced.shape[1]
    i = -1
    while i >= -num_col:
        if not np.any(vector):
            break
        else:
            piv_v = pivot(vector)
            piv_i = pivot(reduced[:,i])

            if piv_i == piv_v:
                vector[:,0] = np.logical_xor(reduced[:,i], vector[:,0])
                i = -1
            i -= 1
    return vector

def reduce_matrix(reduced, matrix):
    num_vector = matrix.shape[1]

    for i in range(num_vector):
        reduced_vector = reduce_vector(reduced, matrix[:, i:i+1])
        reduced = np.concatenate([reduced, reduced_vector], axis=1)
    return reduced[:, -num_vector:]

def get_rank(matrix):
    rank = sum(np.apply_along_axis(np.any, 0, matrix).astype(np.int8))
    return rank

def steenrod_curve(k,d,barcode,coho_reps,number_simplices,reduced,dimension_masks):
    st_reps = np.empty((number_simplices, len(coho_reps[:,dimension_masks[d]].T)))
    for col, rep in enumerate(coho_reps[:,dimension_masks[d]].T):
        st_reps[:, col] = STSQ(k,rep,filtration).reshape(-1,)

    births = [triple[0] for triple in barcode if triple[2] == d] + [number_simplices]

    st_curve = [0]*births[0]
    steenrod_matrix = st_reps
    for i, b in enumerate(births[:-1]):
        for j in range(b, births[i+1]):
            steenrod_matrix[:,:i+1] = reduce_matrix(reduced[:,:j],steenrod_matrix[:,:i+1])
            st_curve.append(get_rank(steenrod_matrix[:,:i+1]))
            
    return st_curve