In [None]:
from _utils import *
import pandas as pd
from gudhi import SimplexTree
import pickle as pkl
import matplotlib.pyplot as plt
from joblib import Parallel, delayed

In [None]:
df = pd.read_csv('/Users/anibal/Downloads/champs-scalar-coupling/train.csv') 

In [None]:
df.head()

### Getting and filtering the interesting molecules

In [None]:
def get_interesting_molecules(df):
    
    st = SimplexTree()
    interesting_molecules = {}
    for i in range(1,len(df)):
        if df['molecule_name'][i] == df['molecule_name'][i-1]:

            st.insert([df['atom_index_1'][i-1],
                       df['atom_index_0'][i-1]], 
                      -df['scalar_coupling_constant'][i-1])

        else:
            st.expansion(3)
            st.make_filtration_non_decreasing()
            barcode = st.persistence(homology_coeff_field=2)
            if len(barcode) > 2:
                interesting_molecules[df['molecule_name'][i-1]] = \
                            tuple([tuple(pair[0]) for pair in st.get_filtration()])

            st = SimplexTree()

    pkl.dump(interesting_molecules, open("interesting_filtered_molecules.pkl", "wb"))

### Main

In [None]:
get_interesting_molecules(df)

In [None]:
def save_steenrod_curves():
    
    interesting_filtered_molecules = pkl.load(open("interesting_filtered_molecules.pkl", "rb"))

    curves = {}
    for name, filtration in interesting_filtered_molecules.items():

#         coboundary = get_coboundary(filtration)
#         reduced, triangular = get_reduced_triangular(coboundary)

#         barcode = get_barcode(reduced, filtration)
#         coho_reps = get_coho_reps(barcode, reduced, triangular, filtration)
#         steenrod_reps = get_steenrod_reps(1, coho_reps, filtration)
        max_k = 2
        checks = {}
        betti_curves = {}
        steenrod_curves = {}

        coboundary = get_coboundary(filtration)
        checks['coboundary square is zero'] = \
            check_square_zero(coboundary)

        reduced, triangular = get_reduced_triangular(coboundary)
        checks['reduced = coboundary*triangular'] = \
            check_factorization(coboundary, reduced, triangular)

        barcode = get_barcode(reduced, filtration)
        checks['H_* barcode = shifted H^* barcode'] = \
            check_duality(filtration)
        checks['homology barcode = gudhi barcode'] = \
            check_against_gudhi(filtration)

        betti_curves = get_betti_curves(barcode, filtration)

        coho_reps = get_coho_reps(barcode, reduced, 
                                  triangular, filtration)
        checks['coho_reps are cycles'] = \
            check_representatives(coboundary, coho_reps)

        st_reps = {}
        for k in range(1,max_k+1):
            st_reps[k] = get_steenrod_reps(k, coho_reps, filtration)
            checks[f'{k}-st_reps are cycles'] = \
                check_representatives(coboundary, coho_reps)

            steenrod_curves[k] = get_steenrod_curve(barcode, st_reps[k], 
                                                    filtration, reduced)
            checks[f'{k}-st_curve <= sum of bettis'] = \
                check_steenrod_lq_betti(steenrod_curves[k], betti_curves, filtration)
        
            curves[name] = tuple(steenrod_curves[k])

    pkl.dump(curves, open("steenrod_curves.pkl", "wb"))

In [None]:
save_steenrod_curves()

In [None]:
curves = pkl.load(open("steenrod_curves.pkl", "rb"))

In [None]:
super_interesting_filtered_molecules = {name: curve for name, curve in curves.items() if sum(curve) > 0}

In [None]:
pkl.dump(super_interesting_filtered_molecules, open("super_interesting_filtered_molecules.pkl", "wb"))

In [None]:
for name, curve in super_interesting_filtered_molecules.items():
    plot = plt.plot(range(len(curve)), 
                    curve, 
                    label=f'length={sum(curve)}')
    plt.legend()
    plt.suptitle(name)    
    plt.show()