# Implemented By:
- Dimitri KACHLER
- Mirette MOAWAD
- Nader KHALIL

# Utility Functions


In [None]:
import time

# Convert string to number
def str_to_num(s):
  try:
    return int(s)
  except:
    return float(s)
  
# Sort array by first element then the second one
# O(n)  
def sortByAxis(a, i, j):
  a.sort(key = lambda item: (item[i], item[j]))

# Print array elements
# O(n)
def printArray(a):
  for i in a:
    print(i)

# O(n)
def addCols(col1, col2):
  # Add col1 to col2
  # Save result to col2
  # print(f'Adding cols {col1} {col2}')
  for i in col1:
    if (i in col2):
      col2.remove(i)
    else:
      col2.add(i)
  return col2

## Set File Path

In [None]:
file_path = './Filtration Files/Moebius_Band.txt'

## Reading Input

In [None]:
# Parses input from a file, extracting filtration, dimension, and vertices data
# returns it in a structured format.
# Input Formal: filtration dimension vertex_1 vertex_2 ... vertex_n
# O(n)

def read_input(filename):
    f = open(filename, "r")
    sample_input = []
    max_dim = -float('inf')
    for line in f:
        entry = []
        vertices = []
        splitted_line = line.split(' ')    
        if (splitted_line[-1][-1] == '\n'):
            splitted_line[-1] = splitted_line[-1][:-1]
        filtration = str_to_num(splitted_line[0])
        dim = str_to_num(splitted_line[1])
        vertices = splitted_line[2:]                    
        # vertices = list(vertices)
        vertices = sorted(list(map(int, vertices)))
        
        if(dim > max_dim):
            max_dim = dim
        entry.append(filtration)
        entry.append(dim)
        entry.append(vertices)
        sample_input.append(entry)
    return sample_input, max_dim, len(sample_input)


sample_input, max_dim, N = read_input(file_path)

## Sort by Filtration Value & Constructing Complex Dictionary

In [None]:
# Sort by filtration value in increasing order
sortByAxis(sample_input, 0, 1)

# key = list of vertices
# value = id
complex_dict = dict()

# complex[r] = r-skeleton = all simplices with dimension r
complex = [[] for i in range(max_dim + 1)]

# key = id
# value = filtration value
values_dict = dict()

for id, v_idx in enumerate(sample_input):
    value = v_idx[0]
    dim = v_idx[1]
    vertices = v_idx[2]
    complex_dict[tuple(vertices)] = id
    complex[dim].append((id, vertices))
    values_dict[id] = value


## Constructing the Boundary Matrix

In [None]:
# O(n^2)

# key = Simplex ID
# value = set of subsimplices composing boundary
# boundary_dict[i] = {j, k, l} is equivalent to 
# B[j, i] = 1, B[k, i] = 1, B[l, i] = 1
boundary_dict = dict()

# key = max value in boundary list
# value = Simplex ID
# pivot_dict[i] = j is equivalent to
# low(j) = i
pivot_dict = dict()

# For each dimension > 0
for r in range(max_dim, 0, -1):
    # Dimension of simplices composing the boundary
    next_dim = r - 1
    # For each r-simplex in the r-skeleton
    for simplex in complex[r]:
        # ID of simplex
        simplex_id = simplex[0]
        # Initialize boundary set with empty set
        boundary_set = set()
        # Initialize pivot = -1
        pivot = -1
        # For each subsimplex in simplex
        for v_idx in range(r + 1):
            # Remove subsimplex v_idx
            # [v0, v1, ..., v_idx ^, ..., vr]
            subsimplex = simplex[1][:v_idx] + simplex[1][v_idx + 1 :]
            # ID of subsimplex
            subsimplex_id = complex_dict[tuple(subsimplex)]
            # Add subsimplex to boundary set
            boundary_set.add(subsimplex_id)
            # Find pivot of simplex
            if (subsimplex_id > pivot):
                pivot = subsimplex_id
        # Add boundary_set to boundary_dict
        boundary_dict[simplex_id] = boundary_set
        if (pivot != -1):
            if (pivot not in pivot_dict):
                # Insert
                pivot_dict[pivot] = simplex_id

## Reduction to Row Echelon Form

In [None]:
# O(n^2)
def reduce_column(simplex_id, pivot):
    while (pivot in pivot_dict):
        # Reduce
        # Add column j to column j'
        j_ = pivot_dict[pivot]
        if (simplex_id == j_):
            break
        # print(f'Adding cols {j_} and {simplex_id}')
        boundary_dict[simplex_id] = addCols(boundary_dict[j_], boundary_dict[simplex_id])
        # If col is all zeros after addition
        if (len(boundary_dict[simplex_id]) == 0):
            break
        # Update new pivot
        pivot = max(boundary_dict[simplex_id])
    return pivot

In [None]:
# O(n^3)
def row_echelon_form():
    # O(n)
    for simplex_id, vertices in boundary_dict.items():
        if (len(vertices) == 0):
            continue
        pivot = max(vertices)
        pivot = reduce_column(simplex_id, pivot)
        if (pivot not in pivot_dict):
            # Insert
            pivot_dict[pivot] = simplex_id
tic = time.time()
row_echelon_form()
toc = time.time()
print('Time =', toc-tic)

## Extracting Intervals

In [None]:
# Extract persistence homology intervals from pivot_dict and values_dict
def extract_intervals(N, sample_input, pivot_dict, values_dict):
    intervals = []
    paired = set()
    # Paired intervals
    for pivot, col_num in pivot_dict.items():
        dim = sample_input[pivot][1]
        start = values_dict[pivot]
        end = values_dict[col_num]
        paired.add(pivot)
        paired.add(col_num)
        interval = tuple([dim, start, end])
        intervals.append(interval)
    # Unpaired intervals
    for i in range(N):
        if (i not in paired):
            dim = sample_input[i][1]
            intervals.append(tuple((dim, values_dict[i], float('inf'))))
    return intervals
intervals = extract_intervals(N, sample_input, pivot_dict, values_dict)
sortByAxis(intervals, 0, 1)

## Write Output

In [None]:
def write_output(intervals, filename):
    f = open(filename, 'w')
    for i in intervals:
        f.write(' '.join(map(str, i)))
        f.write('\n')
        # print(i[0], i[1], i[2])
    f.close()
write_output(intervals, 'output.txt')

## Draw Barcode

In [None]:
import matplotlib.pyplot as plt
maxi = -float('inf')
colors = ['blue', 'green', 'red', 'purple']
for i in intervals:
    if (i[2] > maxi and i[2] != float('inf')):
        maxi = i[2]
    if (i[1] > maxi):
        maxi = i[1]
for i in range(len(intervals)):
    interval = intervals[i]
    if (interval[2] != float('inf')):
        x = [interval[1], interval[2]]
        y = [i, i]
        if interval[0] > len(colors):
            color = 'purple'
        else:
            color = colors[interval[0]]
            
        plt.plot(x, y, color = color, label = 'unpaired')
    else:
        x = [interval[1], maxi + 1]
        y = [i, i]
        if interval[0] > len(colors):
            color = 'purple'
        else:
            color = colors[interval[0]]

        plt.plot(x, y, color = color, label = 'unpaired')
plt.xscale('log') # log scale
plt.title('Blue: H0, Green: H1, Red: H2')
plt.show()