In [1]:
import random
import itertools
import functools
import time
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix


import gudhi as gd
import torch
torch.set_default_dtype(torch.float64) 
import torch.nn as nn

def rand(shape, low, high):
    """Tensor of random numbers, uniformly distributed on [low, high]."""
    return torch.rand(shape) * (high - low) + low

def count_parameters(model):
    """Return the number of trainable parameters of a model (the total number of scalars)."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [2]:
# I WANT TO COMPARE CLOUDS WITH 1K POINTS WITH THE SAME STARTING POINT
def generate_orbits(n_points_per_orbit = 1000, params = [2.5, 3.5, 4.0, 4.1, 4.3], same_init_point = True):
    # create point clouds 
    ORBITS = np.zeros([len(params), n_points_per_orbit, 2])
    xcur_0, ycur_0 = np.random.rand(), np.random.rand() # not necesary to save the first one
    for id_pc, param in enumerate(params): # id_point_cloud
        if same_init_point:
            xcur, ycur = xcur_0, ycur_0 # not necesary to save the first one
        else:
            xcur, ycur =np.random.rand(), np.random.rand()
        for id_pt in range(n_points_per_orbit): # id_point
            xcur = (xcur + param * ycur * (1. - ycur)) % 1
            ycur = (ycur + param * xcur * (1. - xcur)) % 1
            ORBITS[id_pc, id_pt, :] = [xcur, ycur]
    return ORBITS

# function from [len(params), n_points, 2] to alpha-complex and persistence diagram
# to create PDs we need to: points -> skeleton(ac) -> simplex(st) -> persistence(pers)
# for each element of the dataset we'll have len(params) PDs to be compared
def extract_PD(cloud, id_class):
    """extract a dict 

    Args:
        cloud (_type_): array [1000,2] composing th ewhole point cloud
        id_class (_type_): index about the class of membership

    Returns:
        dict: with keys ['persist_0','persist_1','id_class']
    """
    # for every point cloud we create a dictionary storing the label and its persistence
    # usage of dictionary to store each other possible data linked to the point clous
    ac = gd.AlphaComplex(points=cloud)
    st = ac.create_simplex_tree()
    pers = st.persistence()
    #? EXTENDED PERSISTENCE?
    #! TRANSPOSE TO HAVE THEN [BATCH SIZE, 2, NUM POINTS]
    pers_0 = np.array(st.persistence_intervals_in_dimension(0)).transpose()
    pers_1 = np.array(st.persistence_intervals_in_dimension(1)).transpose()
    pers_dict = {
        'cloud': cloud, #* UNCOMMENT THE LINE IF YOU WANT TO VISUALIZE POINT CLOUDS LATER
        # 'skeleton': ac, # no more used
        # 'complex': st, # used for bottleneck distance
        'persist_0': pers_0[:,:-1], # removing the last barcode, the one with inf
        'persist_1': pers_1, # here we should never have inf, since [0,1]^2 is compact/bounded  
        # 'persist': pers, # actual PD
        'id_class': id_class # label for classification
    }
    return pers_dict

# 
def gaussian_transformation(pd):
    # Applied in the model after preproc size
    # I'm embedding the pair [birth, death] in R^15 
    embs = np.apply_along_axis(gamma_p, axis=1, arr=pd)
    return torch.tensor(embs)

def gamma_p(p):
    # params of gaussian_transformation
    ts = torch.tensor([[0., 0.], \
          [0.25, 0.], [0.25, 0.25], \
          [0.5, 0.], [0.5, 0.25], [0.5, 0.5], \
          [0.75, 0.], [0.75, 0.25], [0.75, 0.5], [0.75, 0.75], \
          [1., 0.], [1., 0.25], [1., 0.5], [1., 0.75], [1., 1.]])
    sigma = 0.2
    # single point computaions for gaussian transformation
    squared_distances = torch.pow(ts - p, 2).sum(dim=1)
    emb = -squared_distances/(2*sigma**2)
    emb = torch.exp(emb)
    return emb

def preproc_prom(tens, prom):
    diffs = tens[1] - tens[0]
    sorted_diffs = torch.argsort(diffs, descending=True)
    sorted_tens = tens[:,sorted_diffs]
    return sorted_tens[:,:prom]

In [3]:
points = torch.rand(6, 2)
points = torch.tensor([[0.,0.], [1.,0.]])
gaussian_transformation(points)#.shape

tensor([[1.0000e+00, 4.5783e-01, 2.0961e-01, 4.3937e-02, 2.0116e-02, 1.9305e-03,
         8.8383e-04, 4.0465e-04, 3.8833e-05, 7.8115e-07, 3.7267e-06, 1.7062e-06,
         1.6374e-07, 3.2937e-09, 1.3888e-11],
        [3.7267e-06, 8.8383e-04, 4.0465e-04, 4.3937e-02, 2.0116e-02, 1.9305e-03,
         4.5783e-01, 2.0961e-01, 2.0116e-02, 4.0465e-04, 1.0000e+00, 4.5783e-01,
         4.3937e-02, 8.8383e-04, 3.7267e-06]])

In [4]:
### FULL DATA GENERATION 
# (~2 mins) ---------------------------------------

# hyper params
n_points = 1000
params = [2.5, 3.5, 4.0, 4.1, 4.3]
same_init_point = True
n_seq_per_dataset = [700, 300] # I want [i, len(params), n_points, 2]

batch_size = 128
extended_pers = False
k_pd_preproc = 500

# init list fo persistence diagrams
pds_train = []
pds_test = []

# TRAIN 
for i in tqdm(range(n_seq_per_dataset[0]), desc='Create TRAIN Point Clouds'):
    ORBS = generate_orbits(n_points, params, same_init_point) # CREATE THE 5 POINT CLOUDS
    for j in range(ORBS.shape[0]):
        ij_pers = extract_PD(ORBS[j,:,:], j) # EXTRACT PDs
        pds_train.append(ij_pers) # STORE IN THE LIST pds_train

train_batched_data = [] # BATCHING DATA FOR THE NN
batching = len(pds_train)//batch_size
for i in range(batching):
    train_batched_data.append(pds_train[i*batch_size:(i+1)*batch_size])
# check if we have to add the last batch
if batching*batch_size != len(pds_train):
    train_batched_data.append(pds_train[batching*batch_size:])
print(f'{len(train_batched_data) = }\n')
    

# TEST
for i in tqdm(range(n_seq_per_dataset[1]), desc='Create TEST Point Clouds'):
    ORBS = generate_orbits(n_points, params, same_init_point) # CREATE THE 5 POINT CLOUDS
    for j in range(ORBS.shape[0]):
        ij_pers = extract_PD(ORBS[j,:,:], j) # EXTRACT PDs
        pds_test.append(ij_pers) # STORE IN THE LIST pds_test

test_batched_data = [] # BATCHING DATA FOR THE NN
batching = len(pds_test)//batch_size
for i in range(batching):
    test_batched_data.append(pds_test[i*batch_size:(i+1)*batch_size])
# check if we have to add the last batch
if batching*batch_size != len(pds_test):
    test_batched_data.append(pds_test[batching*batch_size:])
print(f'{len(test_batched_data) = }')

Create TRAIN Point Clouds: 100%|██████████| 700/700 [00:37<00:00, 18.66it/s]


len(train_batched_data) = 28



Create TEST Point Clouds: 100%|██████████| 300/300 [00:16<00:00, 17.95it/s]

len(test_batched_data) = 12





In [5]:
class PersImage_KTH(nn.Module):
    def __init__(self, hidden_size:int = 10, alpha_0:bool = True, alpha_1:bool = True, prom:int = 500, top_k:int = 5, using_len_p1:bool = False):
        super().__init__()
        self.a0 = alpha_0
        self.a1 = alpha_1
        self.prom = prom
        self.top_k = top_k
        self.num_classes = 5
        self.using_len_p1 = using_len_p1

        self.ds_0_a = nn.Linear(15,25)
        self.relu_0 = torch.nn.ReLU()
        self.ds_0_b = nn.Linear(25,hidden_size)
        # self.ds_0_c = DeepSetLayer(10,5)

        self.ds_1_a = nn.Linear(15,25)
        self.relu_1 = torch.nn.ReLU()
        self.ds_1_b = nn.Linear(25,hidden_size)
        # self.ds_1_c = DeepSetLayer(10,5)

        if using_len_p1:
            self.linear_dim_H1 = nn.Linear(1,self.top_k)
            self.linear_labels = nn.Linear(self.top_k*(hidden_size*2+1), self.num_classes)
        else:
            self.linear_labels = nn.Linear(self.top_k*(hidden_size*2), self.num_classes)
        
        self.name = f'{hidden_size}_{alpha_0}_{alpha_1}_{prom}_{top_k}_{using_len_p1}' # list of params


    def forward(self, batch_pers_0, batch_pers_1):
        labels = False
        for p0, p1 in zip(batch_pers_0, batch_pers_1):# one PD at the time due to different cardinality betweeen different H1 barcodes
            # I want to select the self.prom longest barcodes
            # print(f'{p0.max(), p0.min() = }')
            p0 = torch.tensor(p0)*100 #? rescaling
            # print(f'{p0.max(), p0.min() = }')
            p0 = preproc_prom(p0, self.prom)
            p0 = torch.transpose(p0, 0, 1)
            p0 = gaussian_transformation(p0)
            p0 = self.ds_0_a(p0)
            p0 = self.ds_0_b(p0)
            p0, _ = torch.topk(p0, self.top_k, dim=0)

            # same for p1
            # print(f'{p1.max(), p1.min() = }')
            p1 = torch.tensor(p1)*100 # not always with same len of p0
            # print(f'{p1.max(), p1.min() = }')
            p1_shape = p1.shape[1] # number of elemnts in H1 persistence
            if p1_shape<self.top_k:
                dim_to_add = self.top_k - p1_shape
                aux_zeros = torch.zeros(p1.size(0), dim_to_add)
                p1 = torch.cat((p1, aux_zeros), dim=1)

            p1 = preproc_prom(p1, self.prom)
            p1 = torch.transpose(p1, 0, 1)
            p1 = gaussian_transformation(p1)
            p1 = self.ds_1_a(p1)
            p1 = self.ds_1_b(p1)
            try:
                p1, _ = torch.topk(p1, self.top_k, dim=0)
            except RuntimeError:
                raise ValueError('')
            
            if self.using_len_p1 == True:
                emb_len = self.linear_dim_H1(torch.tensor([p1_shape]))
                concat = torch.cat((p0.view(-1), p1.view(-1), emb_len.view(-1)))
            else:
                concat = torch.cat((p0.view(-1), p1.view(-1)))

            labs = self.linear_labels(concat).unsqueeze(0)

            if isinstance(labels,bool):
                labels = labs
            else:
                labels = torch.cat((labels, labs), dim = 0)
                
        return labels

In [6]:
def test_model(model, test_data, if_plot):
    # Set the model to evaluation mode
    model.eval()

    target_labs = np.array([])
    pred_labs = np.array([])
    correct = 0
    total = 0

    # Disable gradient calculation
    with torch.no_grad():
        # for batch in tqdm(test_data):
        for batch in test_data:
            batch_in_pd0 = [sample['persist_0'] for sample in batch] # get tersor of persistence
            batch_in_pd1 = [sample['persist_1'] for sample in batch] # get tersor of persistence
            batch_target = torch.tensor([sample['id_class'] for sample in batch]) # get target labels

            # Forward pass
            outputs = model(batch_in_pd0, batch_in_pd1)
            # Get predicted labels
            _, predicted = torch.max(outputs.data, 1)
            # Total number of labels
            total += batch_target.size(0)
            # Total correct predictions
            correct += (predicted == batch_target).sum().item()
            target_labs = np.append(target_labs, batch_target)
            pred_labs = np.append(pred_labs, predicted)

    # Calculate accuracy
    accuracy = 100 * correct / total
    # print('Accuracy on the test set: {:.2f}%'.format(accuracy))

    if if_plot:
        cm = confusion_matrix(np.array(target_labs), np.array(pred_labs))
        classes = [1,2,3,4,5]
        # Plot confusion matrix
        plt.figure(figsize=(6, 4))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
        plt.xlabel('Predicted labels')
        plt.ylabel('True labels')
        plt.title('Confusion Matrix')
        plt.show()
        
    return accuracy

In [8]:
from torch.optim.lr_scheduler import StepLR

model = PersImage_KTH(hidden_size = 25)

# Learning rate and loss function.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# scheduler = StepLR(optimizer, step_size=70, gamma=0.7)
loss_function = nn.CrossEntropyLoss()
epochs = 100
recorded_loss = torch.zeros(epochs)

best_acc = 0.0 

for epoch in range(epochs):
    model.train()

    total_loss = 0.0
    for batch in tqdm(train_batched_data):
    # for batch in train_batched_data:

        batch_in_pd0 = [sample['persist_0'] for sample in batch] # get tersor of persistence

        batch_in_pd1 = [sample['persist_1'] for sample in batch] # get tersor of persistence

        batch_target = torch.tensor([sample['id_class'] for sample in batch]) # get target labels
        # print(f'\n{len(batch_in_pd0) = }\n{len(batch_in_pd1) = }\n{batch_target.shape = }')

        optimizer.zero_grad()
        result = model(batch_in_pd0, batch_in_pd1)
        loss = loss_function(result, batch_target)
        total_loss += loss
        loss.backward()
        optimizer.step()
        # scheduler.step()
    
    recorded_loss[epoch] = total_loss
    print(f"Epoch {epoch+1}/{epochs}, loss {total_loss}")

    test_acc = test_model(model, test_batched_data, if_plot=False)
    if best_acc < test_acc:
        if test_acc > 70.0:
            torch.save(model.state_dict(), f'./Perslay_models/try_acc_{test_acc}.pth')
        print(f'--> IMPROVEMENT from {best_acc} to {test_acc}')
        best_acc = test_acc
    else:
        print(f'> Test Accuracy = {test_acc} [best = {best_acc}]')

fig, ax = plt.subplots()
ax.plot(recorded_loss)
#ax.set_ylim([0, 1])
plt.show()
print(f"Final loss is {recorded_loss[-1]}")

100%|██████████| 28/28 [04:01<00:00,  8.61s/it]


Epoch 1/100, loss 33.020670508921015
--> IMPROVEMENT from 0.0 to 69.93333333333334


100%|██████████| 28/28 [03:57<00:00,  8.48s/it]


Epoch 2/100, loss 20.68373385773978
--> IMPROVEMENT from 69.93333333333334 to 71.53333333333333


100%|██████████| 28/28 [04:02<00:00,  8.66s/it]


Epoch 3/100, loss 18.348203870119885
--> IMPROVEMENT from 71.53333333333333 to 73.33333333333333


100%|██████████| 28/28 [03:54<00:00,  8.37s/it]


Epoch 4/100, loss 17.603771232340065
--> IMPROVEMENT from 73.33333333333333 to 74.26666666666667


100%|██████████| 28/28 [03:55<00:00,  8.40s/it]


Epoch 5/100, loss 17.290659854285096
--> IMPROVEMENT from 74.26666666666667 to 74.86666666666666


100%|██████████| 28/28 [03:55<00:00,  8.43s/it]


Epoch 6/100, loss 16.912011134119332
> Test Accuracy = 74.86666666666666 [best = 74.86666666666666]


100%|██████████| 28/28 [03:56<00:00,  8.44s/it]


Epoch 7/100, loss 16.342555233053886
--> IMPROVEMENT from 74.86666666666666 to 76.73333333333333


100%|██████████| 28/28 [03:56<00:00,  8.46s/it]


Epoch 8/100, loss 15.799397456784622
--> IMPROVEMENT from 76.73333333333333 to 77.66666666666667


100%|██████████| 28/28 [03:58<00:00,  8.52s/it]


Epoch 9/100, loss 15.509009952675711
--> IMPROVEMENT from 77.66666666666667 to 78.4


100%|██████████| 28/28 [03:58<00:00,  8.54s/it]


Epoch 10/100, loss 15.19938412157462
--> IMPROVEMENT from 78.4 to 79.0


100%|██████████| 28/28 [10:27<00:00, 22.41s/it] 


Epoch 11/100, loss 15.008400173156614
--> IMPROVEMENT from 79.0 to 79.26666666666667


  0%|          | 0/28 [00:00<?, ?it/s]

: 

In [None]:
train_acc = test_model(model, train_batched_data, if_plot=True); print(f'{train_acc = }')
test_acc = test_model(model, test_batched_data, if_plot=True); print(f'{test_acc = }')