In [1]:
from _utils import *
import pandas as pd
from gudhi import SimplexTree
import pickle as pkl
import matplotlib.pyplot as plt
from joblib import Parallel, delayed

In [2]:
df = pd.read_csv('/Users/anibal/Downloads/champs-scalar-coupling/train.csv') 

In [3]:
df.head()

Unnamed: 0,id,molecule_name,atom_index_0,atom_index_1,type,scalar_coupling_constant
0,0,dsgdb9nsd_000001,1,0,1JHC,84.8076
1,1,dsgdb9nsd_000001,1,2,2JHH,-11.257
2,2,dsgdb9nsd_000001,1,3,2JHH,-11.2548
3,3,dsgdb9nsd_000001,1,4,2JHH,-11.2543
4,4,dsgdb9nsd_000001,2,0,1JHC,84.8074


### Getting and filtering the interesting molecules

In [None]:
def get_interesting_molecules(df):
    
    st = SimplexTree()
    interesting_molecules = {}
    for i in range(1,len(df)):
        if df['molecule_name'][i] == df['molecule_name'][i-1]:

            st.insert([df['atom_index_1'][i-1],
                       df['atom_index_0'][i-1]], 
                      df['scalar_coupling_constant'][i-1])

        else:
            st.expansion(3)
            barcode = st.persistence(homology_coeff_field=2)
            if len(barcode) > 3 and barcode[0][0]>1:
                interesting_molecules[df['molecule_name'][i-1]] = \
                            tuple([tuple(pair[0]) for pair in st.get_filtration()])

            st = SimplexTree()

    pkl.dump(interesting_molecules, open("interesting_filtered_molecules.pkl", "wb"))

In [None]:
def filter_interesting_molegules():
    with open('interesting_molecules.txt') as f:
        interesting_names = f.readlines()
        interesting_names = [name[:-1] for name in interesting_names]
    
    interesting_filtered_molecules = {}
    for name in interesting_names:
        molecule_df = df[df['molecule_name'] == name]
        st = SimplexTree()
        for index, row in molecule_df.iterrows():
            st.insert([row['atom_index_1'],
                       row['atom_index_0']], 
                      -row['scalar_coupling_constant'])

        st.expansion(3)
        filtration = tuple([tuple(pair[0]) for pair in st.get_filtration()])

        interesting_filtered_molecules[name] = filtration
        
    pkl.dump(interesting_filtered_molecules, open("interesting_filtered_molecules.pkl", "wb"))

### Main

In [None]:
get_interesting_molecules(df)

In [None]:
def save_steenrod_curves():
    
    interesting_filtered_molecules = pkl.load(open("interesting_filtered_molecules.pkl", "rb"))

    curves = {}
    for name, filtration in interesting_filtered_molecules.items():

        coboundary = get_coboundary(filtration)
        reduced, triangular = get_reduced_triangular(coboundary)

        barcode = get_barcode(reduced, filtration)
        coho_reps = get_coho_reps(barcode, reduced, triangular, filtration)
        steenrod_reps = get_steenrod_reps(1, coho_reps, filtration)

        curves[name] = tuple(get_steenrod_curve(barcode, steenrod_reps, filtration, reduced))

    pkl.dump(curves, open("steenrod_curves.pkl", "wb"))

In [None]:
save_steenrod_curves()

steenrod_curves = pkl.load(open("steenrod_curves.pkl", "rb"))
super_interesting_filtered_molecules = {name:curve for name, curve in steenrod_curves.items() if sum(curve)>0}

len(super_interesting_filtered_molecules)