In [5]:
import json, copy, cv2
import networkx as nx
import sys,glob
import matplotlib.image as mpimg
import numpy as np
import gudhi as gd
from gudhi.representations import vector_methods
import ripser
import persim
from TDA_filtrations import level_set_flooding, save_BD, image_to_pointcloud
from custom_functions import *
import multiprocessing as mp
import skimage.measure

In [6]:
###### To select a dataset to analyze, uncomment all code under the dataset's name

### STARE Expert #1
dataset = "KAGGLE"
# nefi_output_folder = "../Data/Dataset_1/NEFI_graphs/*/"
image_folder = "/DATA/kavi/SA-UNET/CHASE_Pretrained_VSI/"
write_folder = "/DATA/kavi/TDA_yt/"

#  nums = np.arange(1, 2412)
import os
nums=os.listdir(image_folder)
repeat_files = os.listdir(write_folder)
for i in range(len(nums)):
  nums[i] = nums[i].split('.')[0]

nums = [_ for _ in nums if _ + '_VR_persistence_0.npy' not in repeat_files]

print(len(nums))
print(nums)

# data_name = "DS1_"
# file_name = "_left"
# nefi_outputs = glob.glob(f"{nefi_output_folder}*.txt")

35
['13144_left', '12844_right', '11790_left', '16863_right', '10120_right', '12857_right', '11874_left', '13678_right', '18804_right', '17029_left', '10126_left', '1452_right', '15387_right', '1514_right', '10239_left', '11547_left', '19213_left', '11966_left', '13125_right', '11929_left', '11726_right', '15981_right', '1887_right', '19494_right', '15810_right', '17048_left', '12227_left', '17111_right', '14842_right', '18353_right', '11013_left', '14380_right', '10819_right', '1660_right', '11203_right']


## VR filtration

In [11]:
#Define weighting for persistence images for ripser
def weight_ramp(x):
    
    if np.any(np.isinf(x)):
        weight = 1.0
    else:
        weight = x[1]/185
    
    return weight

def VR_filtration(num):
    
    if "all" in dataset:
        num_str = num
    else:
        num_str = f"{str(num).zfill(4)}"
    
    #load in image
    # image_loc = f"{image_folder}{file_name}{num_str}.png"
    image_loc = f"{image_folder}{num_str}.jpeg"
    image = mpimg.imread(image_loc)
    
    if dataset == "HRF":
        #downsample for HRF to ease computation
        image = skimage.measure.block_reduce(image,(3,3),np.max)
    
    #saving
    filename_header = write_folder+num_str+"_VR"
    
    #convert image to pointcloud
    try:
        pointcloud = image_to_pointcloud(image[:,:,0])
    except:
        pointcloud = image_to_pointcloud(image)    
    pointcloud = np.array(pointcloud)


    #initialize averaged PIs for \each descriptor vector
    im0_ripser_ramp = np.zeros((2500,))
    im1_ripser_ramp = np.zeros((2500,))
    im0_ripser_ones = np.zeros((2500,))
    im1_ripser_ones = np.zeros((2500,))
    
    np.random.seed(10)
    
    #shuffle pointcloud
    np.random.shuffle(pointcloud)

    #Run VR on subsampled pointcloud
    dgms = ripser.ripser(pointcloud, n_perm = 2000)['dgms']
    
    #Save the persistence diagram
    save_BD(dgms, filename = f"{filename_header}_persistence_{0}")
    

print(f"Computing VR filtration for {dataset}")
# pool = mp.Pool(mp.cpu_count())
# results = pool.map(VR_filtration, nums)
# pool.close()

for i, num in enumerate(nums):
    print(f'{i}: {num}')
    VR_filtration(num)

Computing VR filtration for KAGGLE
0: 13144_left
1: 12844_right
2: 11790_left
3: 16863_right
4: 10120_right
5: 12857_right
6: 11874_left
7: 13678_right
8: 18804_right
9: 17029_left
10: 10126_left
11: 1452_right
12: 15387_right
13: 1514_right
14: 10239_left
15: 11547_left
16: 19213_left
17: 11966_left
18: 13125_right
19: 11929_left
20: 11726_right
21: 15981_right
22: 1887_right
23: 19494_right
24: 15810_right
25: 17048_left
26: 12227_left
27: 17111_right
28: 14842_right
29: 18353_right
30: 11013_left
31: 14380_right
32: 10819_right
33: 1660_right
34: 11203_right


## Compute the Radial inward and Radial outward filtrations


In [None]:
def radial_filtrations(num):
    
    if "all" in dataset:
        num_str = num
    else:
        num_str = f"{str(num).zfill(4)}"
        
    if dataset == "HRF":
        max_rad = 3000
    else:
        max_rad = 700
    
    #find nefi output file
    nefi_output = [s for s in nefi_outputs if num_str in s]
    #ensure there is only one location in this list
    assert len(nefi_output)==1
    #read in graph G
    graph_in = nx.read_multiline_adjlist(nefi_output[0],delimiter='|')

    #compute both radial inward and radial outward filtrations
    for direction in ['inward','outward']:
    
        filename_header = write_folder+data_name+file_name+num_str+"_"+direction
    
        diag = radius_filtration(graph_in,max_rad=max_rad,filename_save = filename_header+"_persistence",direction=direction)
    
        b0,b1,r = betti_curve(diag,r0=0,r1=40,filename_save = filename_header+"_Betti")
        
        PI_o, PI_r = Persist_im(diag=diag, inf_val = 40,sigma = 1.0, filename_save = [filename_header+"_PIO",
                                                                                      filename_header+"_PIR"])

print(f"Computing radial filtrations for {dataset}")          
pool = mp.Pool(mp.cpu_count())
results = pool.map(radial_filtrations, nums)
pool.close()

## Compute the Flooding filtration

In [None]:
def flood_filtration(num):

    if dataset == "all":
        num_str = num
    else:
        num_str = num #.split('_')[0]
       # print('num',num)
       # num_str = f"{str(num).zfill(4)}"
      #  print('num_str', num_str)
    
    #load in image
    image_loc = f"{image_folder}{num}.jpeg"
    # print('file name', file_name)
    # print('image loc', image_loc)
    image = mpimg.imread(image_loc)
    print(image.shape)
    
    if dataset == "HRF":
        #downsample to ease computation
        image = skimage.measure.block_reduce(image,(3,3),np.max)
    

    filename_header = write_folder+data_name+num_str+"_flooding"
    print('111')
    
    try:
       # print(image[448:672,448:672,0].shape)
        diag = level_set_flooding(image[:,:, 0],iter_num=35,steps=2,filename = filename_header+"_persistence")

    except:
        diag = level_set_flooding(image, iter_num=35,steps=2,filename = filename_header+"_persistence")

   # b0,b1,r = betti_curve(diag,r0=0,r1=35,filename_save = filename_header+"_Betti")

    
   PI_o, PI_r = Persist_im(diag=diag, sigma = 1.0, inf_val = 35,filename_save = [filename_header+"_PIO",filename_header+"_PIR"])
   print(type(PI_o))

    print('working on:' ,filename_header)

print(f"Computing flooding filtration for {dataset}")      
# pool = mp.Pool(mp.cpu_count())
# results = pool.map(flood_filtration, nums)
# pool.close()
for i in range(len(nums)):
  flood_filtration(nums[i])
