In [1]:
from _utils import *
import pandas as pd
from gudhi import SimplexTree
import pickle as pkl
import matplotlib.pyplot as plt
from joblib import Parallel, delayed

In [2]:
df = pd.read_csv('/Users/anibal/Downloads/champs-scalar-coupling/train.csv') 

In [3]:
df.head()

Unnamed: 0,id,molecule_name,atom_index_0,atom_index_1,type,scalar_coupling_constant
0,0,dsgdb9nsd_000001,1,0,1JHC,84.8076
1,1,dsgdb9nsd_000001,1,2,2JHH,-11.257
2,2,dsgdb9nsd_000001,1,3,2JHH,-11.2548
3,3,dsgdb9nsd_000001,1,4,2JHH,-11.2543
4,4,dsgdb9nsd_000001,2,0,1JHC,84.8074


### Getting and filtering the interesting molecules

In [4]:
def get_interesting_molecules(df):
    
    st = SimplexTree()
    interesting_molecules = {}
    for i in range(1,len(df)):
        if df['molecule_name'][i] == df['molecule_name'][i-1]:

            st.insert([df['atom_index_1'][i-1],
                       df['atom_index_0'][i-1]], 
                      -df['scalar_coupling_constant'][i-1])

        else:
            st.expansion(3)
            st.make_filtration_non_decreasing()
            barcode = st.persistence(homology_coeff_field=2)
            if len(barcode) > 2:
                interesting_molecules[df['molecule_name'][i-1]] = \
                            tuple([tuple(pair[0]) for pair in st.get_filtration()])

            st = SimplexTree()

    pkl.dump(interesting_molecules, open("interesting_filtered_molecules.pkl", "wb"))

In [5]:
#get_interesting_molecules(df)

### Main

In [6]:
interesting_filtered_molecules = pkl.load(open("interesting_filtered_molecules-Copy1.pkl", "rb"))

len(interesting_filtered_molecules)

83358

In [7]:
def save_steenrod_curves():
    
    interesting_filtered_molecules = pkl.load(open("interesting_filtered_molecules.pkl", "rb"))

    curves = {}
    for name, filtration in interesting_filtered_molecules.items():

#         coboundary = get_coboundary(filtration)
#         reduced, triangular = get_reduced_triangular(coboundary)

#         barcode = get_barcode(reduced, filtration)
#         coho_reps = get_coho_reps(barcode, reduced, triangular, filtration)
#         steenrod_reps = get_steenrod_reps(1, coho_reps, filtration)
        max_k = 2
        checks = {}
        betti_curves = {}
        steenrod_curves = {}

        coboundary = get_coboundary(filtration)
        checks['coboundary square is zero'] = \
            check_square_zero(coboundary)

        reduced, triangular = get_reduced_triangular(coboundary)
        checks['reduced = coboundary*triangular'] = \
            check_factorization(coboundary, reduced, triangular)

        barcode = get_barcode(reduced, filtration)
        checks['H_* barcode = shifted H^* barcode'] = \
            check_duality(filtration)
        checks['homology barcode = gudhi barcode'] = \
            check_against_gudhi(filtration)

        betti_curves = get_betti_curves(barcode, filtration)

        coho_reps = get_coho_reps(barcode, reduced, 
                                  triangular, filtration)
        checks['coho_reps are cycles'] = \
            check_representatives(coboundary, coho_reps)

        st_reps = {}
        for k in range(1,max_k+1):
            st_reps[k] = get_steenrod_reps(k, coho_reps, filtration)
            checks[f'{k}-st_reps are cycles'] = \
                check_representatives(coboundary, coho_reps)

            steenrod_curves[k] = get_steenrod_curve(barcode, st_reps[k], 
                                                    filtration, reduced)
            checks[f'{k}-st_curve <= sum of bettis'] = \
                check_steenrod_lq_betti(steenrod_curves[k], betti_curves, filtration)
        
            curves[name] = tuple(steenrod_curves[k])

    pkl.dump(curves, open("steenrod_curves.pkl", "wb"))

In [8]:
save_steenrod_curves()

{(37.0, inf, 3)}
{(33.0, inf, 3)}
{(61.0, inf, 3), (60.0, inf, 3), (59.0, inf, 3)}
{(79.0, inf, 3), (78.0, inf, 3), (77.0, inf, 3)}
{(60.0, inf, 2), (61.0, inf, 2), (59.0, inf, 2)}
{(102.0, inf, 3), (101.0, inf, 3), (100.0, inf, 3), (99.0, inf, 3)}
{(71.0, inf, 3), (70.0, inf, 3), (72.0, inf, 3)}
{(65.0, inf, 3), (67.0, inf, 3), (66.0, inf, 3)}
{(71.0, inf, 3), (70.0, inf, 3), (69.0, inf, 3)}
{(60.0, inf, 3), (59.0, inf, 3)}
{(52.0, inf, 3), (51.0, inf, 3)}
{(129.0, inf, 3), (127.0, inf, 3), (128.0, inf, 3), (130.0, inf, 3)}
{(162.0, inf, 3), (151.0, inf, 3), (153.0, inf, 3), (163.0, inf, 3), (152.0, inf, 3), (154.0, inf, 3), (164.0, inf, 3)}
{(143.0, inf, 3), (141.0, inf, 3), (142.0, inf, 3), (140.0, inf, 3)}
{(71.0, inf, 3), (73.0, inf, 3), (72.0, inf, 3)}
{(78.0, inf, 3), (77.0, inf, 3), (76.0, inf, 3)}
{(71.0, inf, 3), (73.0, inf, 3), (72.0, inf, 3)}
{(83.0, inf, 3), (82.0, inf, 3), (81.0, inf, 3)}
{(75.0, inf, 3), (74.0, inf, 3)}
{(114.0, inf, 3), (123.0, inf, 3), (113.0, inf, 3),

{(134.0, inf, 3), (137.0, inf, 3), (135.0, inf, 3), (136.0, inf, 3)}
{(103.0, inf, 3), (102.0, inf, 3), (104.0, inf, 3)}
{(125.0, inf, 3), (127.0, inf, 3), (126.0, inf, 3)}
{(129.0, inf, 3), (128.0, inf, 3), (130.0, inf, 3)}
{(92.0, inf, 3), (91.0, inf, 3), (90.0, inf, 3), (89.0, inf, 3)}
{(85.0, inf, 3), (87.0, inf, 3), (86.0, inf, 3), (88.0, inf, 3)}
{(71.0, inf, 3), (70.0, inf, 3), (72.0, inf, 3)}
{(65.0, inf, 3), (67.0, inf, 3), (66.0, inf, 3)}
{(71.0, inf, 3), (70.0, inf, 3), (72.0, inf, 3)}
{(147.0, inf, 3), (148.0, inf, 3), (150.0, inf, 3), (149.0, inf, 3)}
{(109.0, inf, 3), (111.0, inf, 3), (110.0, inf, 3)}
{(92.0, inf, 3), (91.0, inf, 3), (90.0, inf, 3)}
{(95.0, inf, 3), (96.0, inf, 3), (93.0, inf, 3), (94.0, inf, 3)}
{(76.0, inf, 3), (75.0, inf, 3), (74.0, inf, 3)}
{(71.0, inf, 3), (70.0, inf, 3), (72.0, inf, 3)}
{(83.0, inf, 3), (82.0, inf, 3), (81.0, inf, 3)}
{(119.0, inf, 3), (122.0, inf, 3), (121.0, inf, 3), (120.0, inf, 3)}
{(87.0, inf, 3), (86.0, inf, 3), (85.0, inf, 3)

{(113.0, inf, 3), (110.0, inf, 3), (112.0, inf, 3), (111.0, inf, 3)}
{(108.0, inf, 3), (107.0, inf, 3), (106.0, inf, 3), (105.0, inf, 3)}
{(71.0, inf, 3), (73.0, inf, 3), (72.0, inf, 3)}
{(68.0, inf, 3), (67.0, inf, 3), (66.0, inf, 3)}
{(162.0, inf, 3), (161.0, inf, 3), (151.0, inf, 3), (159.0, inf, 3), (148.0, inf, 3), (160.0, inf, 3), (150.0, inf, 3), (149.0, inf, 3)}
{(114.0, inf, 3), (113.0, inf, 3), (112.0, inf, 3), (115.0, inf, 3)}
{(79.0, inf, 3), (81.0, inf, 3), (80.0, inf, 3)}
{(109.0, inf, 3), (108.0, inf, 3), (107.0, inf, 3), (106.0, inf, 3)}
{(119.0, inf, 3), (118.0, inf, 3), (117.0, inf, 3), (116.0, inf, 3)}
{(71.0, inf, 3), (73.0, inf, 3), (72.0, inf, 3)}
{(96.0, inf, 3), (95.0, inf, 3), (94.0, inf, 3)}
{(71.0, inf, 3), (70.0, inf, 3)}
{(109.0, inf, 3), (108.0, inf, 3), (107.0, inf, 3)}
{(129.0, inf, 3), (119.0, inf, 3), (127.0, inf, 3), (120.0, inf, 3), (128.0, inf, 3), (118.0, inf, 3)}
{(113.0, inf, 3), (120.0, inf, 3), (122.0, inf, 3), (112.0, inf, 3), (121.0, inf, 3),

{(125.0, inf, 3), (127.0, inf, 3), (128.0, inf, 3), (126.0, inf, 3)}
{(170.0, inf, 3), (166.0, inf, 3), (169.0, inf, 3), (167.0, inf, 3), (168.0, inf, 3)}
{(124.0, inf, 3), (123.0, inf, 3), (122.0, inf, 3), (121.0, inf, 3)}
{(176.0, inf, 3), (175.0, inf, 3), (178.0, inf, 3), (177.0, inf, 3), (174.0, inf, 3)}
{(108.0, inf, 3), (107.0, inf, 3), (106.0, inf, 3), (105.0, inf, 3)}
{(108.0, inf, 3), (107.0, inf, 3), (106.0, inf, 3), (105.0, inf, 3)}
{(149.0, inf, 3), (151.0, inf, 3), (152.0, inf, 3), (150.0, inf, 3), (153.0, inf, 3)}
{(155.0, inf, 3), (151.0, inf, 3), (152.0, inf, 3), (154.0, inf, 3), (153.0, inf, 3)}
{(123.0, inf, 3), (119.0, inf, 3), (120.0, inf, 3), (122.0, inf, 3), (121.0, inf, 3)}
{(79.0, inf, 3), (81.0, inf, 3), (80.0, inf, 3)}
{(162.0, inf, 3), (161.0, inf, 3), (163.0, inf, 3), (160.0, inf, 3)}
{(146.0, inf, 3), (145.0, inf, 3), (144.0, inf, 3)}
{(112.0, inf, 3), (111.0, inf, 3), (110.0, inf, 3)}
{(192.0, inf, 3), (190.0, inf, 3), (194.0, inf, 3), (193.0, inf, 3), (19

{(84.0, inf, 3), (83.0, inf, 3), (82.0, inf, 3)}
{(71.0, inf, 3), (70.0, inf, 3), (69.0, inf, 3)}
{(71.0, inf, 3), (73.0, inf, 3), (72.0, inf, 3)}
{(125.0, inf, 3), (127.0, inf, 3), (126.0, inf, 3)}
{(114.0, inf, 3), (113.0, inf, 3), (115.0, inf, 3)}
{(91.0, inf, 3), (90.0, inf, 3), (89.0, inf, 3)}
{(82.0, inf, 3), (81.0, inf, 3), (80.0, inf, 3)}
{(71.0, inf, 3), (70.0, inf, 3), (69.0, inf, 3)}
{(79.0, inf, 3), (81.0, inf, 3), (80.0, inf, 3)}
{(73.0, inf, 3), (75.0, inf, 3), (74.0, inf, 3)}
{(64.0, inf, 3), (65.0, inf, 3), (63.0, inf, 3)}
{(94.0, inf, 3), (93.0, inf, 3), (92.0, inf, 3)}
{(87.0, inf, 3), (86.0, inf, 3), (85.0, inf, 3)}
{(79.0, inf, 3), (81.0, inf, 3), (80.0, inf, 3)}
{(16.0, inf, 1), (21.0, inf, 1), (20.0, inf, 1), (19.0, inf, 1), (15.0, inf, 1)}
{(15.0, inf, 1), (14.0, inf, 1)}
{(13.0, inf, 1), (12.0, inf, 1)}
{(109.0, inf, 3), (108.0, inf, 3), (110.0, inf, 3)}
{(123.0, inf, 3), (122.0, inf, 3), (121.0, inf, 3)}
{(101.0, inf, 3), (100.0, inf, 3), (99.0, inf, 3)}
{(103.

{(101.0, inf, 3), (100.0, inf, 3), (99.0, inf, 3)}
{(77.0, inf, 3), (76.0, inf, 3), (75.0, inf, 3)}
{(71.0, inf, 3), (70.0, inf, 3), (69.0, inf, 3)}
{(77.0, inf, 3), (76.0, inf, 3), (75.0, inf, 3)}
{(79.0, inf, 3), (78.0, inf, 3), (77.0, inf, 3)}
{(79.0, inf, 3), (81.0, inf, 3), (80.0, inf, 3)}
{(141.0, inf, 3), (139.0, inf, 3), (142.0, inf, 3), (140.0, inf, 3)}
{(162.0, inf, 3), (161.0, inf, 3), (163.0, inf, 3)}
{(206.0, inf, 3), (204.0, inf, 3), (207.0, inf, 3), (205.0, inf, 3)}
{(199.0, inf, 3), (200.0, inf, 3), (202.0, inf, 3), (201.0, inf, 3)}
{(172.0, inf, 3), (175.0, inf, 3), (173.0, inf, 3), (174.0, inf, 3)}
{(137.0, inf, 3), (135.0, inf, 3), (136.0, inf, 3)}
{(129.0, inf, 3), (128.0, inf, 3), (130.0, inf, 3)}
{(176.0, inf, 3), (175.0, inf, 3), (178.0, inf, 3), (177.0, inf, 3)}
{(157.0, inf, 3), (155.0, inf, 3), (158.0, inf, 3), (156.0, inf, 3)}
{(137.0, inf, 3), (136.0, inf, 3), (138.0, inf, 3)}
{(157.0, inf, 3), (155.0, inf, 3), (156.0, inf, 3)}
{(212.0, inf, 3), (211.0, inf,

KeyboardInterrupt: 

In [9]:
curves = pkl.load(open("steenrod_curves.pkl", "rb"))

FileNotFoundError: [Errno 2] No such file or directory: 'steenrod_curves.pkl'

In [None]:
super_interesting_filtered_molecules = {name: curve for name, curve in curves.items() if sum(curve) > 0}

In [None]:
pkl.dump(super_interesting_filtered_molecules, open("super_interesting_filtered_molecules.pkl", "wb"))

In [None]:
for name, curve in super_interesting_filtered_molecules.items():
    plot = plt.plot(range(len(curve)), 
                    curve, 
                    label=f'length={sum(curve)}')
    plt.legend()
    plt.suptitle(name)    
    plt.show()