In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.join('../'))))

import tqdm
import numpy
import torch
import wandb
import pandas
import joblib
import itertools
import scipy.stats
import torchvision
import gtda.images
import gtda.diagrams
import gtda.homology
import sklearn.pipeline
import sklearn.ensemble
import sklearn.metrics
import tqdm.contrib.itertools
import sklearn.decomposition

import lib.topology

In [2]:
BINARIZATION_THESHOLD = 0.4

height_filtration_directions = [
    [ -1, -1 ], [ 1, 1 ], [ 1, -1 ], [ -1, 1 ],
    [ 0, -1 ], [ 0, 1 ], [ -1, 0 ], [ 1, 0 ]
]


radial_filtration_centers = list(itertools.product([ 7, 14, 21 ], [ 7, 14, 21 ]))
radial_filtration_metrics = [ "euclidean", "manhattan", "cosine" ]

density_filtration_metrics = [ "euclidean" , "manhattan", "cosine" ]
density_filtration_radiuses = [ 1, 5, 15 ]

FILTRATIONS = [
    *[ [ gtda.images.HeightFiltration, { 'direction': numpy.array(direction), 'n_jobs': -1 } ] for direction in height_filtration_directions ],
    *[
        [ gtda.images.RadialFiltration, { 'center': numpy.array(center), 'metric': metric, 'n_jobs': -1 } ]
        for center in radial_filtration_centers
        for metric in radial_filtration_metrics
    ],
    [ gtda.images.DilationFiltration, { 'n_jobs': -1 } ],
    [ gtda.images.ErosionFiltration, { 'n_jobs': -1 } ],
    [ gtda.images.SignedDistanceFiltration, { 'n_jobs': -1 } ],
    *[
        [ gtda.images.DensityFiltration, { 'radius': radius, 'metric': metric, 'n_jobs': -1 } ]
        for metric in density_filtration_metrics
        for radius in density_filtration_radiuses
    ]
]

In [3]:
train = torchvision.datasets.MNIST('mnist', train = True, download = True)
test = torchvision.datasets.MNIST('mnist', train = False, download = True)

train_images = numpy.array([ item[0] for item in train ])
train_labels = numpy.array([ item[1] for item in train ])

test_images = numpy.array([ item[0] for item in test ])
test_labels = numpy.array([ item[1] for item in test ])

In [4]:
def make_filtrations(images: numpy.ndarray):
    images_bin = gtda.images.Binarizer(threshold = BINARIZATION_THESHOLD).fit_transform(images)
    filtrations = [
        filtration[0](**filtration[1]).fit_transform(images_bin)
        for filtration in tqdm.tqdm(FILTRATIONS, desc = 'filtrations')
    ]
    return [ images, images_bin ] + filtrations


def make_point_clouds(images: numpy.ndarray):
    def make_point_cloud(image, threshold):
        point_cloud = [ ]
        for i in range(image.shape[0]):
            for j in range(image.shape[1]):
                if image[i][j] < threshold:
                    continue
                point_cloud.append([ i, j, image[i][j] ])
        return numpy.array(point_cloud)

    def _make_point_clouds(imgs, threshold):
        imgs = numpy.swapaxes(numpy.flip(imgs, axis = 1), 1, 2)
        return [ make_point_cloud(image, threshold) for image in tqdm.tqdm(imgs, desc = 'point_clouds') ]

    images_bin = gtda.images.Binarizer(threshold = BINARIZATION_THESHOLD).fit_transform(images)
    point_cloud = gtda.images.ImageToPointCloud().fit_transform(images_bin)
    return [ point_cloud, _make_point_clouds(images, 50) ]

In [5]:
def make_filtration_diagrams(images: numpy.ndarray):
    filtrations = make_filtrations(images)
    all_filtrations = [ ]
    for diagrams in zip(*filtrations):
        all_filtrations.extend(diagrams)

    print('Making filtration diagrams')
    persistence = gtda.homology.CubicalPersistence(homology_dimensions = [ 0, 1 ], n_jobs = -1)
    return persistence.fit_transform(all_filtrations)

def make_point_cloud_diagrams(images: numpy.ndarray):
    point_clouds = make_point_clouds(images)
    all_point_clouds = [ ]
    for diagrams in zip(*point_clouds):
        all_point_clouds.extend(diagrams)
        
    print('Making point cloud diagrams')
    persistence = gtda.homology.VietorisRipsPersistence(homology_dimensions = [ 0, 1, 2 ], n_jobs = -1)
    return persistence.fit_transform(all_point_clouds)

In [6]:
diagrams = make_filtration_diagrams(train_images[:4000])
diagrams.shape

filtrations: 100%|██████████| 47/47 [00:15<00:00,  3.03it/s]


Making filtration diagrams


(196000, 71, 3)

### lifetime_features

In [312]:
def calc_stats_bulk(data: numpy.ndarray):
    return numpy.ma.hstack([
        numpy.ma.max(data, axis = 1, keepdims = True),
        numpy.ma.mean(data, axis = 1, keepdims = True),
        numpy.ma.std(data, axis = 1, keepdims = True),
        numpy.ma.sum(data, axis = 1, keepdims = True),
        numpy.ma.median(data, axis = 1, keepdims = True),
        numpy.ma.sqrt(numpy.ma.sum(data ** 2, axis = 1, keepdims = True))
    ]).filled(0)

def calc_lifetime_features(diagrams: numpy.ndarray, eps: float = 0.0) -> pandas.DataFrame:
    birth, death, dim = diagrams[:, :, 0], diagrams[:, :, 1], diagrams[:, :, 2]
    bd2 = (birth + death) / 2.0
    life = death - birth
    
    mask = (life < eps)
    bd2 = numpy.ma.array(bd2, mask = mask)
    life = numpy.ma.array(life, mask = mask)

    bd2_features = [ calc_stats_bulk(bd2) ]
    life_features = [ calc_stats_bulk(life) ]
    for d in range(0, int(numpy.max(dim)) + 1):
        mask = (dim != d)
        dim_bd2 = numpy.ma.array(bd2, mask = mask)
        dim_life = numpy.ma.array(life, mask = mask)
        bd2_features.append(calc_stats_bulk(dim_bd2))
        life_features.append(calc_stats_bulk(dim_life))

    return numpy.hstack([ *life_features, *bd2_features ])

q = pandas.DataFrame(calc_lifetime_features(diagrams, eps = 0.1))
q.shape

(245000, 36)

### betti_features

In [347]:
def calc_stats(data: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
    assert len(data.shape) == 1
    if data.shape == (0,): data = numpy.array([ 0 ])

    stats = numpy.array([
        numpy.max(data), numpy.mean(data), numpy.std(data), numpy.sum(data),
        # numpy.percentile(data, 25),
        numpy.median(data), # numpy.percentile(data, 75),
        # scipy.stats.kurtosis(data), scipy.stats.skew(data), numpy.linalg.norm(data, ord = 1),
        numpy.linalg.norm(data, ord = 2)
    ])
    names = [ "max", "mean", "std", "sum", "median", "norm-2" ]

    return pandas.DataFrame([ numpy.nan_to_num(stats) ], columns = [ f"{prefix} {name}" for name in names ])

def calc_batch_stats(data: numpy.ndarray, homology_dimensions: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
    def process_batch(batch: numpy.ndarray) -> pandas.DataFrame:
        features = [ ]
        for dim, vec in zip(homology_dimensions, batch):
            features.append(calc_stats(vec, prefix = f'{prefix} dim-{dim}'))
        return pandas.concat(features, axis = 1)

    features = joblib.Parallel(return_as = 'generator', n_jobs = -1)(
        joblib.delayed(process_batch)(batch) for batch in data
    )
    features = tqdm.tqdm(features, total = len(data), desc = prefix)
    return pandas.concat(features, axis = 0)
    
def calc_betti_features(diagrams: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
    betti_curve = gtda.diagrams.BettiCurve(n_bins = 100, n_jobs = -1)
    betti_derivative = gtda.curves.Derivative()
    betti_curves = betti_curve.fit_transform(diagrams)
    betti_curves = betti_derivative.fit_transform(betti_curves)
    return calc_batch_stats(betti_curves, betti_curve.homology_dimensions_, f'{prefix} betti')
    
q = calc_betti_features(diagrams)
q.shape

 betti: 100%|██████████| 245000/245000 [00:34<00:00, 7190.65it/s]


(245000, 12)

In [352]:
def calc_betti_features(diagrams: numpy.ndarray) -> pandas.DataFrame:
    betti_curve = gtda.diagrams.BettiCurve(n_bins = 100, n_jobs = -1)
    betti_curves = betti_curve.fit_transform(diagrams)

    betti_derivative = gtda.curves.Derivative()
    betti_curves = betti_derivative.fit_transform(betti_curves)
    
    return numpy.hstack([
        calc_stats_bulk(betti_curves[:, d, :])
        for d in betti_curve.homology_dimensions_
    ])
    
e = pandas.DataFrame(calc_betti_features(diagrams))
e.shape

(245000, 12)

In [349]:
numpy.abs(q.to_numpy() - e.to_numpy()).max() < 1e-8

True

### landscape_features

In [353]:
def calc_landscape_features(diagrams: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
    persistence_landscape = gtda.diagrams.PersistenceLandscape(n_layers = 1, n_bins = 100, n_jobs = -1)
    landscape = persistence_landscape.fit_transform(diagrams)
    return calc_batch_stats(landscape, persistence_landscape.homology_dimensions_, f'{prefix} landscape')

q = calc_landscape_features(diagrams)
q.shape

 landscape: 100%|██████████| 245000/245000 [00:36<00:00, 6673.54it/s]


(245000, 12)

In [357]:
def calc_landscape_features(diagrams: numpy.ndarray) -> pandas.DataFrame:
    persistence_landscape = gtda.diagrams.PersistenceLandscape(n_layers = 1, n_bins = 100, n_jobs = -1)
    landscape = persistence_landscape.fit_transform(diagrams)
    
    return numpy.hstack([
        calc_stats_bulk(landscape[:, d, :])
        for d in persistence_landscape.homology_dimensions_
    ])
    
e = pandas.DataFrame(calc_landscape_features(diagrams))
e.shape

(245000, 12)

In [358]:
numpy.abs(q.to_numpy() - e.to_numpy()).max() < 1e-8

True

### FeatureCalculator

In [19]:
import os
import random

import tqdm
import numpy
import pandas
import joblib
import scipy.stats
import gtda.curves
import gtda.diagrams

def determine_filtering_epsilon(diagrams: numpy.ndarray, percentile: int):
    life = (diagrams[:, :, 1] - diagrams[:, :, 0]).flatten()
    return numpy.percentile(life[life != 0], percentile)

def apply_filtering(diagrams: numpy.ndarray, eps: float):
    filtering = gtda.diagrams.Filtering(epsilon = eps)
    return filtering.fit_transform(diagrams)


def set_random_seed(seed: int):
    random.seed(seed)
    numpy.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)


AMPLITUDE_METRICS = [
    { "id": "bottleneck", "metric": "bottleneck", "metric_params": { } },

    { "id": "wasserstein-1", "metric": "wasserstein", "metric_params": { "p": 1 } },
    { "id": "wasserstein-2", "metric": "wasserstein", "metric_params": { "p": 2 } },

    { "id": "betti-1", "metric": "betti", "metric_params": { "p": 1, 'n_bins': -1 } },
    { "id": "betti-2", "metric": "betti", "metric_params": { "p": 2, 'n_bins': -1 } },
    
    { "id": "landscape-1-1", "metric": "landscape", "metric_params": { "p": 1, "n_layers": 1, 'n_bins': -1 } },
    { "id": "landscape-1-2", "metric": "landscape", "metric_params": { "p": 1, "n_layers": 2, 'n_bins': -1 } },
    { "id": "landscape-2-1", "metric": "landscape", "metric_params": { "p": 2, "n_layers": 1, 'n_bins': -1 } },
    { "id": "landscape-2-2", "metric": "landscape", "metric_params": { "p": 2, "n_layers": 2, 'n_bins': -1 } },

    { "id": "silhouette-1-1", "metric": "silhouette", "metric_params": { "p": 1, "power": 1, 'n_bins': -1 } },
    { "id": "silhouette-1-2", "metric": "silhouette", "metric_params": { "p": 1, "power": 2, 'n_bins': -1 } },
    { "id": "silhouette-2-1", "metric": "silhouette", "metric_params": { "p": 2, "power": 1, 'n_bins': -1 } },
    { "id": "silhouette-2-2", "metric": "silhouette", "metric_params": { "p": 2, "power": 2, 'n_bins': -1 } },

    { "id": "heat-1-1.6", "metric": "heat", "metric_params": { "p": 1, "sigma": 1.6, 'n_bins': -1 } },
    { "id": "heat-1-3.2", "metric": "heat", "metric_params": { "p": 1, "sigma": 3.2, 'n_bins': -1 } },
    { "id": "heat-2-1.6", "metric": "heat", "metric_params": { "p": 2, "sigma": 1.6, 'n_bins': -1 } },
    { "id": "heat-2-3.2", "metric": "heat", "metric_params": { "p": 2, "sigma": 3.2, 'n_bins': -1 } },

    { "id": "persistence_image-1-1.6", "metric": "persistence_image", "metric_params": { "p": 1, "sigma": 1.6, 'n_bins': -1 } },
    { "id": "persistence_image-1-3.2", "metric": "persistence_image", "metric_params": { "p": 1, "sigma": 3.2, 'n_bins': -1 } },
    { "id": "persistence_image-2-1.6", "metric": "persistence_image", "metric_params": { "p": 2, "sigma": 1.6, 'n_bins': -1 } },
    { "id": "persistence_image-2-3.2", "metric": "persistence_image", "metric_params": { "p": 2, "sigma": 3.2, 'n_bins': -1 } }
]

class FeatureCalculator:
    def __init__(
        self,
        n_jobs: int = -1,
        verbose: bool = True,
        random_state: int = 42,
        filtering_percentile: int = 10
    ):
        self.n_jobs = n_jobs
        self.verbose = verbose
        self.random_state = random_state
        self.filtering_percentile = filtering_percentile


    def calc_stats(self, data: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
        assert len(data.shape) == 1
        if data.shape == (0,): data = numpy.array([ 0 ])

        stats = numpy.array([
            numpy.max(data), numpy.mean(data), numpy.std(data), numpy.sum(data),
            # numpy.percentile(data, 25),
            numpy.median(data), # numpy.percentile(data, 75),
            # scipy.stats.kurtosis(data), scipy.stats.skew(data), numpy.linalg.norm(data, ord = 1),
            numpy.linalg.norm(data, ord = 2)
        ])
        names = [ "max", "mean", "std", "sum", "percentile-25", "median", "percentile-75", "kurtosis", "skew", "norm-1", "norm-2" ]
        names = [ "max", "mean", "std", "sum", "median", "norm-2" ]

        return pandas.DataFrame([ numpy.nan_to_num(stats) ], columns = [ f"{prefix} {name}" for name in names ])

    def calc_batch_stats(self, data: numpy.ndarray, homology_dimensions: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
        def process_batch(batch: numpy.ndarray) -> pandas.DataFrame:
            features = [ ]
            for dim, vec in zip(homology_dimensions, batch):
                features.append(self.calc_stats(vec, prefix = f'{prefix} dim-{dim}'))
            return pandas.concat(features, axis = 1)

        features = joblib.Parallel(return_as = 'generator', n_jobs = self.n_jobs)(
            joblib.delayed(process_batch)(batch) for batch in data
        )
        if self.verbose:
            features = tqdm.tqdm(features, total = len(data), desc = prefix)
        return pandas.concat(features, axis = 0)
    

    def calc_betti_features(self, diagrams: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
        if self.verbose:
            print('Calculating Betti features')
        betti_curve = gtda.diagrams.BettiCurve(n_bins = 100, n_jobs = self.n_jobs)
        betti_derivative = gtda.curves.Derivative()
        betti_curves = betti_curve.fit_transform(diagrams)
        betti_curves = betti_derivative.fit_transform(betti_curves)
        return self.calc_batch_stats(betti_curves, betti_curve.homology_dimensions_, f'{prefix} betti')

    def calc_landscape_features(self, diagrams: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
        if self.verbose:
            print('Calculating landscape features')
        persistence_landscape = gtda.diagrams.PersistenceLandscape(n_layers = 1, n_bins = 100, n_jobs = self.n_jobs)
        landscape = persistence_landscape.fit_transform(diagrams)
        return self.calc_batch_stats(landscape, persistence_landscape.homology_dimensions_, f'{prefix} landscape')

    def calc_silhouette_features(self, diagrams: numpy.ndarray, prefix: str = "", powers: int = [ 1, 2 ]) -> pandas.DataFrame:
        if isinstance(powers, int):
            silhouette = gtda.diagrams.Silhouette(power = powers, n_bins = 100, n_jobs = self.n_jobs)
            silhouettes = silhouette.fit_transform(diagrams)
            return self.calc_batch_stats(silhouettes, silhouette.homology_dimensions_, f'{prefix} silhouette-{powers}')
        else:
            if self.verbose:
                print('Calculating silhouette features')
            features = [ ]
            for power in powers:
                features.append(self.calc_silhouette_features(diagrams, prefix, power))
            return pandas.concat(features, axis = 1)
    

    def calc_entropy_features(self, diagrams: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
        if self.verbose:
            print('Calculating entropy features')
        entropy = gtda.diagrams.PersistenceEntropy(normalize = True, nan_fill_value = 0, n_jobs = self.n_jobs)
        features = entropy.fit_transform(diagrams)
        names = [ f'{prefix} entropy dim-{dim}' for dim in entropy.homology_dimensions_ ]
        return pandas.DataFrame(features, columns = names)
    
    def calc_number_of_points_features(self, diagrams: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
        if self.verbose:
            print('Calculating number of points features')
        number_of_points = gtda.diagrams.NumberOfPoints(n_jobs = self.n_jobs)
        features = number_of_points.fit_transform(diagrams)
        names = [ f'{prefix} numberofpoints dim-{dim}' for dim in number_of_points.homology_dimensions_ ]
        return pandas.DataFrame(features, columns = names)
    
    def calc_amplitude_features(self, diagrams: numpy.ndarray, prefix: str = "", **metric) -> pandas.DataFrame:
        if len(metric) == 0:
            features = [ ]
            metrics = AMPLITUDE_METRICS
            if self.verbose:
                print('Calculating amplitude features')
                metrics = tqdm.tqdm(metrics, desc = f'{prefix} amplitudes')
            for metric in metrics:
                features.append(self.calc_amplitude_features(diagrams, prefix, **metric))
            return pandas.concat(features, axis = 1)
        else:
            metric_params = metric['metric_params'].copy()
            if metric_params.get('n_bins', None) == -1:
                metric_params['n_bins'] = 100
            amplitude = gtda.diagrams.Amplitude(metric = metric['metric'], metric_params = metric_params, n_jobs = self.n_jobs)
            features = amplitude.fit_transform(diagrams)
            return pandas.concat([
                pandas.DataFrame(features, columns = [ f'{prefix} amplitude-{metric["id"]} dim-{dim}' for dim in amplitude.homology_dimensions_ ]),
                pandas.DataFrame(numpy.linalg.norm(features, axis = 1, ord = 1).reshape(-1, 1), columns = [ f'{prefix} amplitude-{metric["id"]} norm-1' ]),
                pandas.DataFrame(numpy.linalg.norm(features, axis = 1, ord = 2).reshape(-1, 1), columns = [ f'{prefix} amplitude-{metric["id"]} norm-2' ])
            ], axis = 1)
    
    def calc_lifetime_features(self, diagrams: numpy.ndarray, prefix: str = "", eps: float = 0.0) -> pandas.DataFrame:
        if len(diagrams.shape) == 3:
            if self.verbose:
                print('Calculating lifetime features')
            features = joblib.Parallel(return_as = 'generator', n_jobs = self.n_jobs)(
                joblib.delayed(self.calc_lifetime_features)(diag, prefix, eps)
                for diag in diagrams
            )
            if self.verbose:
                features = tqdm.tqdm(features, total = len(diagrams), desc = f'{prefix} lifetime')
            return pandas.concat(features, axis = 0)

        birth, death, dim = diagrams[:, 0], diagrams[:, 1], diagrams[:, 2]
        life = death - birth
        
        birth, death, dim = birth[life >= eps], death[life >= eps], dim[life >= eps]
        bd2 = (birth + death) / 2.0
        life = death - birth

        bd2_features = [ self.calc_stats(bd2, f'{prefix} bd2 all') ]
        life_features = [ self.calc_stats(life, f'{prefix} life all') ]
        for d in numpy.unique(diagrams[:, 2]).astype(int):
            bd2_features.append(self.calc_stats(bd2[dim == d], f'{prefix} bd2 dim-{d}'))
            life_features.append(self.calc_stats(life[dim == d], f'{prefix} life dim-{d}'))
        return pandas.concat([ *life_features, *bd2_features ], axis = 1)
        

    def calc_features(self, diagrams: numpy.ndarray, prefix: str = "") -> pandas.DataFrame:
        set_random_seed(self.random_state)
        
        eps = determine_filtering_epsilon(diagrams, self.filtering_percentile)
        diagrams = apply_filtering(diagrams, eps)
        if self.verbose:
            print('Filtered diagrams:', diagrams.shape)
        return pandas.concat([
           self.calc_betti_features           (diagrams, prefix     ).reset_index(drop = True),
           self.calc_landscape_features       (diagrams, prefix     ).reset_index(drop = True),
           self.calc_silhouette_features      (diagrams, prefix     ).reset_index(drop = True),
           self.calc_entropy_features         (diagrams, prefix     ).reset_index(drop = True),
           self.calc_number_of_points_features(diagrams, prefix     ).reset_index(drop = True),
           self.calc_amplitude_features       (diagrams, prefix     ).reset_index(drop = True),
           self.calc_lifetime_features        (diagrams, prefix, eps).reset_index(drop = True)
        ], axis = 1)


target_features = FeatureCalculator(n_jobs = -1).calc_features(diagrams)
target_features
# 21.3s

Filtered diagrams: (4900, 40, 3)
Calculating Betti features


 betti: 100%|██████████| 4900/4900 [00:54<00:00, 89.92it/s]


Calculating landscape features


 landscape: 100%|██████████| 4900/4900 [00:26<00:00, 186.10it/s]


Calculating silhouette features


 silhouette-1: 100%|██████████| 4900/4900 [00:26<00:00, 181.57it/s]
 silhouette-2: 100%|██████████| 4900/4900 [00:26<00:00, 182.22it/s]


Calculating entropy features
Calculating number of points features
Calculating amplitude features


 amplitudes: 100%|██████████| 21/21 [00:05<00:00,  4.05it/s]


Calculating lifetime features


 lifetime: 100%|██████████| 4900/4900 [00:53<00:00, 92.05it/s] 


Unnamed: 0,betti dim-0 max,betti dim-0 mean,betti dim-0 std,betti dim-0 sum,betti dim-0 median,betti dim-0 norm-2,betti dim-1 max,betti dim-1 mean,betti dim-1 std,betti dim-1 sum,...,bd2 dim-0 std,bd2 dim-0 sum,bd2 dim-0 median,bd2 dim-0 norm-2,bd2 dim-1 max,bd2 dim-1 mean,bd2 dim-1 std,bd2 dim-1 sum,bd2 dim-1 median,bd2 dim-1 norm-2
0,0.0,0.0,0.000000,0.0,0.0,0.000000,1.0,0.0,0.651339,0.0,...,0.000000e+00,0.000000,0.000000,0.000000,252.5,216.75,42.278590,1300.5,230.25,540.932759
1,0.0,0.0,0.000000,0.0,0.0,0.000000,1.0,0.0,0.142134,0.0,...,0.000000e+00,0.000000,0.000000,0.000000,0.5,0.50,0.000000,0.5,0.50,0.500000
2,2.0,0.0,0.284268,0.0,0.0,2.828427,0.0,0.0,0.000000,0.0,...,0.000000e+00,41.012193,20.506097,29.000000,0.0,0.00,0.000000,0.0,0.00,0.000000
3,1.0,0.0,0.142134,0.0,0.0,1.414214,0.0,0.0,0.000000,0.0,...,0.000000e+00,21.920310,21.920310,21.920310,0.0,0.00,0.000000,0.0,0.00,0.000000
4,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,...,2.512148e-15,38.890873,19.445436,27.500000,0.0,0.00,0.000000,0.0,0.00,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4895,0.0,0.0,0.000000,0.0,0.0,0.000000,1.0,0.0,0.142134,0.0,...,0.000000e+00,0.500000,0.500000,0.500000,24.5,19.50,5.000000,39.0,19.50,28.469282
4896,2.0,0.0,0.284268,0.0,0.0,2.828427,1.0,0.0,0.142134,0.0,...,1.545603e+00,41.000000,14.500000,23.822258,43.0,43.00,0.000000,43.0,43.00,43.000000
4897,0.0,0.0,0.000000,0.0,0.0,0.000000,2.0,0.0,0.376051,0.0,...,0.000000e+00,0.000000,0.000000,0.000000,8.0,6.90,1.319091,34.5,7.50,15.708278
4898,0.0,0.0,0.000000,0.0,0.0,0.000000,2.0,0.0,0.376051,0.0,...,0.000000e+00,0.000000,0.000000,0.000000,8.0,6.90,1.319091,34.5,7.50,15.708278


In [7]:
import os
import random
import typing

import tqdm
import numpy
import pandas
import joblib
import scipy.stats
import gtda.curves
import gtda.diagrams

def determine_filtering_epsilon(diagrams: numpy.ndarray, percentile: int):
    life = (diagrams[:, :, 1] - diagrams[:, :, 0]).flatten()
    return numpy.percentile(life[life != 0], percentile)

def apply_filtering(diagrams: numpy.ndarray, eps: float):
    filtering = gtda.diagrams.Filtering(epsilon = eps)
    return filtering.fit_transform(diagrams)


def set_random_seed(seed: int):
    random.seed(seed)
    numpy.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)


AMPLITUDE_METRICS = [
    { "id": "bottleneck", "metric": "bottleneck", "metric_params": { } },

    { "id": "wasserstein-1", "metric": "wasserstein", "metric_params": { "p": 1 } },
    { "id": "wasserstein-2", "metric": "wasserstein", "metric_params": { "p": 2 } },

    { "id": "betti-1", "metric": "betti", "metric_params": { "p": 1, 'n_bins': -1 } },
    { "id": "betti-2", "metric": "betti", "metric_params": { "p": 2, 'n_bins': -1 } },
    
    { "id": "landscape-1-1", "metric": "landscape", "metric_params": { "p": 1, "n_layers": 1, 'n_bins': -1 } },
    { "id": "landscape-1-2", "metric": "landscape", "metric_params": { "p": 1, "n_layers": 2, 'n_bins': -1 } },
    { "id": "landscape-2-1", "metric": "landscape", "metric_params": { "p": 2, "n_layers": 1, 'n_bins': -1 } },
    { "id": "landscape-2-2", "metric": "landscape", "metric_params": { "p": 2, "n_layers": 2, 'n_bins': -1 } },

    { "id": "silhouette-1-1", "metric": "silhouette", "metric_params": { "p": 1, "power": 1, 'n_bins': -1 } },
    { "id": "silhouette-1-2", "metric": "silhouette", "metric_params": { "p": 1, "power": 2, 'n_bins': -1 } },
    { "id": "silhouette-2-1", "metric": "silhouette", "metric_params": { "p": 2, "power": 1, 'n_bins': -1 } },
    { "id": "silhouette-2-2", "metric": "silhouette", "metric_params": { "p": 2, "power": 2, 'n_bins': -1 } },

    { "id": "heat-1-1.6", "metric": "heat", "metric_params": { "p": 1, "sigma": 1.6, 'n_bins': -1 } },
    { "id": "heat-1-3.2", "metric": "heat", "metric_params": { "p": 1, "sigma": 3.2, 'n_bins': -1 } },
    { "id": "heat-2-1.6", "metric": "heat", "metric_params": { "p": 2, "sigma": 1.6, 'n_bins': -1 } },
    { "id": "heat-2-3.2", "metric": "heat", "metric_params": { "p": 2, "sigma": 3.2, 'n_bins': -1 } },

    { "id": "persistence_image-1-1.6", "metric": "persistence_image", "metric_params": { "p": 1, "sigma": 1.6, 'n_bins': -1 } },
    { "id": "persistence_image-1-3.2", "metric": "persistence_image", "metric_params": { "p": 1, "sigma": 3.2, 'n_bins': -1 } },
    { "id": "persistence_image-2-1.6", "metric": "persistence_image", "metric_params": { "p": 2, "sigma": 1.6, 'n_bins': -1 } },
    { "id": "persistence_image-2-3.2", "metric": "persistence_image", "metric_params": { "p": 2, "sigma": 3.2, 'n_bins': -1 } }
]

class FeatureCalculator:
    def __init__(
        self,
        n_jobs: int = -1,
        verbose: bool = True,
        random_state: int = 42,
        filtering_percentile: int = 10
    ):
        self.n_jobs = n_jobs
        self.verbose = verbose
        self.random_state = random_state
        self.filtering_percentile = filtering_percentile


    def calc_stats_bulk(self, data: numpy.ndarray) -> numpy.ndarray:
        return numpy.ma.hstack([
            numpy.ma.max(data, axis = 1, keepdims = True),
            numpy.ma.mean(data, axis = 1, keepdims = True),
            numpy.ma.std(data, axis = 1, keepdims = True),
            numpy.ma.sum(data, axis = 1, keepdims = True),
            numpy.ma.median(data, axis = 1, keepdims = True),
            numpy.ma.sqrt(numpy.ma.sum(data ** 2, axis = 1, keepdims = True))
        ]).filled(0)

    def calc_batch_stats(self, data: numpy.ndarray, homology_dimensions: typing.Iterable[int]) -> numpy.ndarray:
        return numpy.hstack([ self.calc_stats_bulk(data[:, d, :]) for d in homology_dimensions ])
    

    def calc_betti_features(self, diagrams: numpy.ndarray) -> numpy.ndarray:
        if self.verbose:
            print('Calculating Betti features')

        betti_curve = gtda.diagrams.BettiCurve(n_bins = 100, n_jobs = self.n_jobs)
        betti_curves = betti_curve.fit_transform(diagrams)

        betti_derivative = gtda.curves.Derivative()
        betti_curves = betti_derivative.fit_transform(betti_curves)
        
        return self.calc_batch_stats(betti_curves, betti_curve.homology_dimensions_)
        
    def calc_landscape_features(self, diagrams: numpy.ndarray) -> numpy.ndarray:
        if self.verbose:
            print('Calculating landscape features')
        persistence_landscape = gtda.diagrams.PersistenceLandscape(n_layers = 1, n_bins = 100, n_jobs = self.n_jobs)
        landscape = persistence_landscape.fit_transform(diagrams)
        return self.calc_batch_stats(landscape, persistence_landscape.homology_dimensions_)

    def calc_silhouette_features(self, diagrams: numpy.ndarray, powers: typing.Union[int, typing.List[int]] = [ 1, 2 ]) -> numpy.ndarray:
        if isinstance(powers, int):
            silhouette = gtda.diagrams.Silhouette(power = powers, n_bins = 100, n_jobs = self.n_jobs)
            silhouettes = silhouette.fit_transform(diagrams)
            return self.calc_batch_stats(silhouettes, silhouette.homology_dimensions_)
        else:
            if self.verbose:
                print('Calculating silhouette features')
            return numpy.hstack([ self.calc_silhouette_features(diagrams, power) for power in powers ])

    def calc_entropy_features(self, diagrams: numpy.ndarray) -> numpy.ndarray:
        if self.verbose:
            print('Calculating entropy features')
        entropy = gtda.diagrams.PersistenceEntropy(normalize = True, nan_fill_value = 0, n_jobs = self.n_jobs)
        return entropy.fit_transform(diagrams)
    
    def calc_number_of_points_features(self, diagrams: numpy.ndarray) -> numpy.ndarray:
        if self.verbose:
            print('Calculating number of points features')
        number_of_points = gtda.diagrams.NumberOfPoints(n_jobs = self.n_jobs)
        return number_of_points.fit_transform(diagrams)
    
    def calc_amplitude_features(self, diagrams: numpy.ndarray, **metric) -> numpy.ndarray:
        if len(metric) == 0:
            metrics = tqdm.tqdm(AMPLITUDE_METRICS, desc = 'amplitudes') if self.verbose else AMPLITUDE_METRICS
            return numpy.hstack([ self.calc_amplitude_features(diagrams, **metric) for metric in metrics ])
        
        metric_params = metric['metric_params'].copy()
        if metric_params.get('n_bins', None) == -1:
            metric_params['n_bins'] = 100

        amplitude = gtda.diagrams.Amplitude(metric = metric['metric'], metric_params = metric_params, n_jobs = self.n_jobs)
        features = amplitude.fit_transform(diagrams)
        return numpy.hstack([
            features,
            numpy.linalg.norm(features, axis = 1, ord = 1).reshape(-1, 1),
            numpy.linalg.norm(features, axis = 1, ord = 2).reshape(-1, 1),
        ])

    def calc_lifetime_features(self, diagrams: numpy.ndarray, eps: float = 0.0) -> numpy.ndarray:
        if self.verbose:
            print('Calculating lifetime features')
        birth, death, dim = diagrams[:, :, 0], diagrams[:, :, 1], diagrams[:, :, 2]
        bd2 = (birth + death) / 2.0
        life = death - birth
        
        mask = (life < eps)
        bd2 = numpy.ma.array(bd2, mask = mask)
        life = numpy.ma.array(life, mask = mask)

        bd2_features = [ self.calc_stats_bulk(bd2) ]
        life_features = [ self.calc_stats_bulk(life) ]
        for d in range(0, int(numpy.max(dim)) + 1):
            mask = (dim != d)
            dim_bd2 = numpy.ma.array(bd2, mask = mask)
            dim_life = numpy.ma.array(life, mask = mask)
            bd2_features.append(self.calc_stats_bulk(dim_bd2))
            life_features.append(self.calc_stats_bulk(dim_life))

        return numpy.hstack([ *life_features, *bd2_features ])

    def calc_features(self, diagrams: numpy.ndarray) -> numpy.ndarray:
        set_random_seed(self.random_state)
        
        eps = determine_filtering_epsilon(diagrams, self.filtering_percentile)
        diagrams = apply_filtering(diagrams, eps)
        if self.verbose:
            print('Filtered diagrams:', diagrams.shape)
        return numpy.hstack([
           self.calc_betti_features           (diagrams     ),
           self.calc_landscape_features       (diagrams     ),
           self.calc_silhouette_features      (diagrams     ),
           self.calc_entropy_features         (diagrams     ),
           self.calc_number_of_points_features(diagrams     ),
           self.calc_amplitude_features       (diagrams     ),
           self.calc_lifetime_features        (diagrams, eps)
        ])


my_features = pandas.DataFrame(FeatureCalculator(n_jobs = -1).calc_features(diagrams))
my_features
# 2.3s

Filtered diagrams: (196000, 71, 3)
Calculating Betti features
Calculating landscape features
Calculating silhouette features
Calculating entropy features
Calculating number of points features


amplitudes: 100%|██████████| 21/21 [02:59<00:00,  8.53s/it]


Calculating lifetime features


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,162,163,164,165,166,167,168,169,170,171
0,0.0,0.0,0.000000,0.0,0.0,0.000000,1.0,0.0,0.651339,0.0,...,0.000000e+00,0.000000,0.000000,0.000000,252.5,216.75,42.27859,1300.5,230.25,540.932759
1,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,...,0.000000e+00,0.000000,0.000000,0.000000,0.5,0.50,0.00000,0.5,0.50,0.500000
2,2.0,0.0,0.284268,0.0,0.0,2.828427,0.0,0.0,0.000000,0.0,...,0.000000e+00,41.012193,20.506097,29.000000,0.0,0.00,0.00000,0.0,0.00,0.000000
3,1.0,0.0,0.142134,0.0,0.0,1.414214,0.0,0.0,0.000000,0.0,...,0.000000e+00,21.920310,21.920310,21.920310,0.0,0.00,0.00000,0.0,0.00,0.000000
4,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,0.0,0.000000,0.0,...,2.512148e-15,38.890873,19.445436,27.500000,0.0,0.00,0.00000,0.0,0.00,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
195995,0.0,0.0,0.000000,0.0,0.0,0.000000,1.0,0.0,0.142134,0.0,...,0.000000e+00,0.000000,0.000000,0.000000,22.0,22.00,0.00000,22.0,22.00,22.000000
195996,2.0,0.0,0.348155,0.0,0.0,3.464102,1.0,0.0,0.142134,0.0,...,3.198307e+00,127.500000,22.250000,52.637914,73.0,73.00,0.00000,73.0,73.00,73.000000
195997,1.0,0.0,0.142134,0.0,0.0,1.414214,4.0,0.0,0.651339,0.0,...,0.000000e+00,2.500000,2.500000,2.500000,7.5,6.60,0.80000,66.0,6.50,21.023796
195998,1.0,0.0,0.142134,0.0,0.0,1.414214,4.0,0.0,0.651339,0.0,...,0.000000e+00,2.500000,2.500000,2.500000,7.5,6.60,0.80000,66.0,6.50,21.023796


In [20]:
numpy.abs(target_features.to_numpy() - my_features.to_numpy()).max() < 1e-8

True