In [2]:
from __future__ import print_function

%matplotlib inline
import matplotlib.pyplot as plt

import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.metrics import accuracy_score

import warnings
warnings.filterwarnings('ignore')

from ripser import lower_star_img
from ripser import Rips
vr = Rips()
from gtda.homology import VietorisRipsPersistence

import persim
import diagram2vec

from scipy.ndimage import gaussian_filter

from sklearn.datasets import make_circles
from sklearn.manifold import MDS

from gtda.diagrams import PersistenceEntropy, PersistenceImage, BettiCurve

import pickle
from tqdm import tqdm

import torch
from torch.nn import Linear
from torch.nn.functional import relu

from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

from sklearn.model_selection import cross_val_score

Rips(maxdim=1, thresh=inf, coeff=2, do_cocycles=False, n_perm = None, verbose=True)


In [3]:
W = 300
sigma1 = 4
sigma2 = 2
t = 0.01

def generate(N, S, W=300, sigma1=4, sigma2=2, t=0.01, bins=64):

    z = np.zeros((N, S, 2))
    for n in range(N):
        z[n, 0] = np.random.uniform(0, W, size=(2))
        for s in range(S-1):
            d_1 = np.random.normal(0, sigma1)
            d_2 = np.random.normal(0, sigma1)
            z[n, s+1, 0] = (z[n, s, 0] + d_1) % W
            z[n, s+1, 1] = (z[n, s, 1] + d_2) % W

    z_r = z.reshape(N*S, 2)
    H, _, _ = np.histogram2d(z_r[:,0], z_r[:,1], bins=bins)
    
    G = gaussian_filter(H, sigma2)
    G[G < t] = 0
    
    return G

In [3]:
count = 50000
classes_count = 2

images = np.zeros((classes_count * count, 64, 64))

# class A
N = 100
S = 30

for n in tqdm(range(count)):
    images[n] = generate(N, S)
    
# class B
N = 250
S = 10

for n in tqdm(range(count)):
    images[n+count] = generate(N, S)

100%|██████████| 50000/50000 [02:45<00:00, 301.53it/s]
100%|██████████| 50000/50000 [02:23<00:00, 348.96it/s]


In [4]:
class PorusDataset(Dataset):
    def __init__(self, diagrams, labels):
        self.diagrams = diagrams
        self.labels = labels
        
    def __len__(self):
        return len(self.diagrams)
    
    def __getitem__(self, idx):
        return self.diagrams[idx], self.labels[idx]

In [5]:
# baseline
# cedt
# cedt x ...
# conv
# multiconv
# dir

In [6]:
import torch
import torchvision
import gudhi as gd
from gudhi.wasserstein import wasserstein_distance

from scipy.io import loadmat
from sklearn.model_selection import train_test_split
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import math
import matplotlib.pyplot as plt
from scipy.ndimage import distance_transform_edt

from ripser import lower_star_img
from ripser import Rips

import persim
import diagram2vec

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F

import gudhi as gd
from gudhi.wasserstein import wasserstein_distance as dist_w

from IPython.display import clear_output
from scipy.ndimage import maximum_filter


POT (Python Optimal Transport) package is not installed. Try to run $ conda install -c conda-forge pot ; or $ pip install POT


In [38]:
def diagram(image, device, sublevel=True):
    # get height and square image
    h = int(np.sqrt(image.shape[0]))
    image_sq = image.reshape((h,h))

    # create complex
    cmplx = gd.CubicalComplex(dimensions=(h, h), top_dimensional_cells=(sublevel*image))

    # get pairs of critical simplices
    cmplx.compute_persistence()
    critical_pairs = cmplx.cofaces_of_persistence_pairs()
    
    # get essential critical pixel
    bpx0_essential = critical_pairs[1][0][0] // h, critical_pairs[1][0][0] % h

    # get critical pixels corresponding to critical simplices
    try:
        bpx0 = [critical_pairs[0][0][i][0] for i in range(len(critical_pairs[0][0]))]
        dpx0 = [critical_pairs[0][0][i][1] for i in range(len(critical_pairs[0][0]))]
    except IndexError:
        bpx0 = []
        dpx0 = []
        
    try:
        bpx1 = [critical_pairs[0][1][i][0] for i in range(len(critical_pairs[0][1]))]
        dpx1 = [critical_pairs[0][1][i][1] for i in range(len(critical_pairs[0][1]))]
    except IndexError:
        bpx1 = []
        dpx1 = []
    

    flat_image = image_sq.flatten()
    pd0_essential = torch.tensor([[image_sq[bpx0_essential], torch.max(image)]])

    if (len(bpx0)!=0):
        pdb0 = flat_image[bpx0][:, None]
        pdd0 = flat_image[dpx0][:, None]
        pd0 = torch.Tensor(torch.hstack([pdb0, pdd0]))
        pd0 = torch.vstack([pd0, pd0_essential.to(device)])
    else:
        pd0 = pd0_essential

    if (len(bpx1)!=0):
        pdb1 = flat_image[bpx1][:, None]
        pdd1 = flat_image[dpx1][:, None]
        pd1 = torch.Tensor(torch.hstack([pdb1, pdd1]))
    else:
        pd1 = torch.zeros((1, 2))
    
    return pd0, pd1


def process_by_direction(img, alpha):
    X = (math.cos(alpha) - (np.arange(0, img.shape[0]) - (img.shape[0] / 2 - 0.5)) / (img.shape[0] * math.sqrt(2))) * math.cos(alpha) / 2
    Y = (math.sin(alpha) - (np.arange(0, img.shape[1]) - (img.shape[1] / 2 - 0.5)) / (img.shape[1] * math.sqrt(2))) * math.sin(alpha) / 2
    direction_filter = X.reshape(-1, 1) + Y.reshape(1, -1)
    return np.maximum(direction_filter, img)


def process_image(img, filter_params, device):
    w = int(np.sqrt(img.flatten().shape[0]))
    imgs = [process_by_direction(img.reshape(w, w), alpha) for alpha in filter_params]
    diagrams = []
    for i, img in enumerate(imgs):
        res = diagram(torch.Tensor(img.flatten()), device=device)
        for j in range(len(res)):
            if not res[j].shape[0]:
                diagrams.append(torch.zeros(0, 4))
            else:
                diagrams.append(torch.concatenate([res[j], torch.Tensor([[j, filter_params[i]] for _ in range(res[j].shape[0])])], axis=1))

    diagrams = torch.concatenate(diagrams)
    return diagrams


def process_by_conv(img, conv, device):
    w = int(np.sqrt(img.flatten().shape[0]))
    img = conv(torch.Tensor(img).reshape(1, w, w)).detach()
    diagrams = []
    for i in range(img.shape[0]):
        res = diagram(img[i].flatten(), device=device)
        for j in range(len(res)):
            if not res[j].shape[0]:
                diagrams.append(torch.zeros(0, 4))
            else:
                diagrams.append(torch.concatenate([res[j], torch.Tensor([[j, i] for _ in range(res[j].shape[0])])], axis=1))
    diagrams = torch.concatenate(diagrams)
    return diagrams


def process_baseline(img, device):
    diagrams = []
    res = diagram(torch.Tensor(img.flatten()), device=device)
    for j in range(len(res)):
        if not res[j].shape[0]:
            diagrams.append(torch.zeros(0, 4))
        else:
            diagrams.append(torch.concatenate([res[j], torch.Tensor([[j, 1] for _ in range(res[j].shape[0])])], axis=1))
    diagrams = torch.concatenate(diagrams)
    return diagrams


def process_cedt(img, device):
    img /= img.max()
    edt = torch.Tensor(distance_transform_edt(img > 0.5))
    cedt = edt * (img > 0.5) - edt * (img <= 0.5)
    return process_baseline(cedt, device=device)


def process_cedt_thickening(img, window_size, device):
    img /= img.max()
    img = maximum_filter(img, size=window_size)
    edt = torch.Tensor(distance_transform_edt(img > 0.5))
    cedt = edt * (img > 0.5) - edt * (img <= 0.5)
    return process_baseline(cedt, device=device)


In [39]:
N = 250
S = 228
image = generate(N, S, W=300, sigma1=4, sigma2=2, t=0.01, bins=64)

In [40]:
t = image[:63, :63]

In [41]:
diagram(torch.Tensor(t).flatten(), device="cpu")[0].shape


torch.Size([30, 2])

In [43]:
# dir
print(process_image(torch.Tensor(t).flatten(), [0, np.pi / 2], device="cpu").shape)

# conv
conv = nn.Conv2d(1, 4, kernel_size=3)
print(process_by_conv(torch.Tensor(t).flatten(), conv, device="cpu").shape)

# baseline
print(process_baseline(torch.Tensor(t).flatten(), device="cpu").shape)

# cedt
print(process_cedt(torch.Tensor(t).flatten(), device="cpu").shape)

# thickening cedt
print(process_cedt_thickening(torch.Tensor(t).flatten(), 3, device="cpu").shape)


torch.Size([118, 4])
torch.Size([247, 4])
torch.Size([59, 4])
torch.Size([25, 4])
torch.Size([18, 4])
