In this notebook we'll show how we can compute the quantitative indices for CIFAR10 dataset

In [35]:
import os
os.chdir('../..')

In [36]:
import torch
import torchvision
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from eXNN.InnerNeuralTopology import quantitative_indices as qi
from gtda.homology import VietorisRipsPersistence

In [37]:
# Define the transformation to apply to the data
# prepare data
_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
tfm = transforms.Compose([transforms.ToTensor(), _normalize])

In [38]:
# Download the CIFAR10 training and test datasets
train_ds = CIFAR10(root='./.cache', train=True, download=True, transform=tfm)
test_ds = CIFAR10(root='./.cache', train=False, download=True, transform=tfm)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./.cache/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./.cache/cifar-10-python.tar.gz to ./.cache
Files already downloaded and verified


In [39]:
# Create the data loaders
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=128, shuffle=True)

In [40]:
# Select the first 256 images from the training set
train_data = []
train_targets = []
for i in range(256):
    data, target = train_ds[i]
    train_data.append(data.view(-1))
    train_targets.append(target)
train_data = torch.stack(train_data, dim=0)
train_targets = torch.tensor(train_targets)

In [41]:
# Compute the diagram
VR = VietorisRipsPersistence(homology_dimensions=[0, 1, 2])
diag = VR.fit_transform([train_data])

In [42]:
lengths = qi.compute_length(diag)
lengths

[24.416135787963867,
 30.209699630737305,
 30.791522979736328,
 31.37993621826172,
 32.0010986328125,
 33.28579330444336,
 33.411617279052734,
 34.021934509277344,
 34.056697845458984,
 34.104881286621094,
 34.343692779541016,
 35.0025634765625,
 35.13337707519531,
 35.27815246582031,
 35.6840705871582,
 36.08675003051758,
 36.5817756652832,
 37.078163146972656,
 37.111454010009766,
 37.15925979614258,
 37.249916076660156,
 37.29237365722656,
 37.521728515625,
 37.90203857421875,
 38.235862731933594,
 38.62734603881836,
 38.68802261352539,
 38.71816635131836,
 38.779117584228516,
 39.02437973022461,
 39.17852020263672,
 39.280128479003906,
 39.39548110961914,
 39.441280364990234,
 39.472023010253906,
 39.552799224853516,
 39.796409606933594,
 39.80344009399414,
 39.80546188354492,
 39.81997299194336,
 39.97366714477539,
 40.374942779541016,
 40.45478820800781,
 40.460105895996094,
 40.56178665161133,
 40.74037170410156,
 40.76991271972656,
 40.78469467163086,
 41.107295989990234,
 41.4

In [43]:
longest_interval = qi.compute_longest_interval(lengths)
longest_interval

76.79570007324219

In [44]:
length_mean = qi.compute_length_mean(lengths)
length_mean

39.442246466326566

In [45]:
length_median = qi.compute_length_median(lengths)
length_median

45.99239540100098

In [46]:
length_stdev = qi.compute_length_stdev(lengths)
length_stdev

21.706520091252777

In [47]:
length_sum = qi.compute_length_sum(lengths)
length_sum

12858.172348022461

In [48]:
two_to_one_ratio = qi.compute_two_to_one_ratio(lengths)
two_to_one_ratio

3.0817804769235613

In [49]:
three_to_one_ratio = qi.compute_three_to_one_ratio(lengths)
three_to_one_ratio

2.9910391203497015

In [50]:
entropy = qi.compute_entropy(lengths, length_sum)
entropy

5.556769799760856

In [51]:
normed_entropy = qi.compute_normed_entropy(entropy, length_sum)
normed_entropy

0.5872886819332289

In [52]:
births = qi.compute_births(diag)
births

array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.  

In [53]:
deaths = qi.compute_deaths(diag)
deaths

array([24.41613579, 30.20969963, 30.79152298, 31.37993622, 32.00109863,
       33.2857933 , 33.41161728, 34.02193451, 34.05669785, 34.10488129,
       34.34369278, 35.00256348, 35.13337708, 35.27815247, 35.68407059,
       36.08675003, 36.58177567, 37.07816315, 37.11145401, 37.1592598 ,
       37.24991608, 37.29237366, 37.52172852, 37.90203857, 38.23586273,
       38.62734604, 38.68802261, 38.71816635, 38.77911758, 39.02437973,
       39.1785202 , 39.28012848, 39.39548111, 39.44128036, 39.47202301,
       39.55279922, 39.79640961, 39.80344009, 39.80546188, 39.81997299,
       39.97366714, 40.37494278, 40.45478821, 40.4601059 , 40.56178665,
       40.7403717 , 40.76991272, 40.78469467, 41.10729599, 41.47017288,
       41.83136749, 41.91716003, 42.1076622 , 42.26543808, 42.36532974,
       42.42207718, 42.68595123, 42.82120514, 42.86104965, 42.90550232,
       43.17183304, 43.25719452, 43.26025009, 43.37685394, 43.43218231,
       43.49118805, 43.50811768, 43.51413727, 43.59746933, 43.75

In [57]:
birth_dim0 = qi.compute_birth(diag, dim=0)
# birth_dim1 = qi.compute_birth(diag, dim=1)
# birth_dim2 = qi.compute_birth(diag, dim=2)
print(birth_dim0)
# print(birth_dim1)
# print(birth_dim2)

0.0


In [55]:
death_dim0 = qi.compute_death(diag, dim=0)
death_dim1 = qi.compute_death(diag, dim=1)
death_dim2 = qi.compute_death(diag, dim=2)
print(death_dim0)
print(death_dim1)
print(death_dim2)

0.0
30.209699630737305
0.0


In [29]:
snr = qi.compute_snr(births, deaths)
snr

-4.358496521031304

In [30]:
birth_mean = qi.compute_births_mean(births)
birth_mean

50.59675119552144

In [31]:
birth_stdev = qi.compute_births_stdev(births)
birth_stdev

9.049507387697483

In [32]:
death_mean = qi.compute_deaths_mean(deaths)
death_mean

11.154504729194874

In [33]:
death_stdev = qi.compute_deaths_stdev(deaths)
death_stdev

21.41260576251542