In [8]:
import warnings
warnings.filterwarnings("ignore")

In [9]:
import torch
import numpy as np
from torch_geometric.datasets import TUDataset
import rephine_extension
from gtda.diagrams import PersistenceImage, PersistenceLandscape, BettiCurve
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
from sklearn.preprocessing import FunctionTransformer

In [10]:
def pad_diagrams(diagrams, pad_value=0.0):
    # Cada diagrama deverá ter 3 colunas: [birth, death, homology_dimension]
    max_points = max(diag.shape[0] for diag in diagrams)
    padded_diagrams = []
    for diag in diagrams:
        n_points = diag.shape[0]
        if n_points < max_points:
            pad = np.full((max_points - n_points, 3), pad_value)
            diag_padded = np.vstack([diag, pad])
        else:
            diag_padded = diag
        padded_diagrams.append(diag_padded)
    return np.stack(padded_diagrams) 
dataset = TUDataset(root='data/TUDataset', name='MUTAG')

diagrams = [] 
labels = [] 

for idx, g in enumerate(dataset):
    g.edge_index = g.edge_index.long()
    n_nodes = g.num_nodes
    n_edges = g.edge_index.shape[1]
    
    if n_nodes == 0 or n_edges == 0:
        continue

    vertex_slices = torch.tensor([0, n_nodes], dtype=torch.long)
    edge_slices = torch.tensor([0, n_edges], dtype=torch.long)

    filtered_v = torch.zeros((1, n_nodes), dtype=torch.float32)
    filtered_e_vals = [0.1 * i for i in range(n_edges)]
    filtered_e = torch.tensor([filtered_e_vals], dtype=torch.float32)

    pers_indices = rephine_extension.compute_rephine(
        filtered_v,
        filtered_e,
        g.edge_index.T.contiguous(),
        vertex_slices,
        edge_slices
    )

    pers = pers_indices[0, :, 0] 
    valid = pers >= 0 
    if valid.sum() == 0:
        diag = np.empty((0, 3))
    else:
        birth = filtered_v[0][valid].numpy()            
        death = filtered_e[0][pers[valid].long()].numpy() 
        diag_2d = np.stack([birth, death], axis=1) 
        homology_dim = np.zeros((diag_2d.shape[0], 1)) 
        diag = np.concatenate([diag_2d, homology_dim], axis=1) 
    
    diagrams.append(diag)
    labels.append(g.y.item())

In [11]:
X_all = pad_diagrams(diagrams, pad_value=0.0)
print("X_all.shape:", X_all.shape) 
y_all = np.array(labels)

indices = np.arange(X_all.shape[0])
train_idx, test_idx = train_test_split(indices, test_size=0.3, random_state=42)

X_train = X_all[train_idx]
X_test = X_all[test_idx]
y_train = y_all[train_idx]
y_test = y_all[test_idx]

def avaliar_kernels_topologicos(X_train, X_test, y_train, y_test):
    flatten_transformer = FunctionTransformer(lambda x: x.reshape(x.shape[0], -1))
    
    kernels = {
        "Persistence Image": PersistenceImage(),
        "Persistence Landscape": PersistenceLandscape(),
        "Betti Curve": BettiCurve()
    }
    
    resultados = {}
    for nome, kernel in kernels.items():
        print(f"\n=== Avaliando Kernel: {nome} ===")
        pipe = Pipeline([
            ("topo", kernel), # transforma diagramas em uma representação (possivelmente 3D ou 4D)
            ("flatten", flatten_transformer), # achata para 2D
            ("svm", SVC(kernel="rbf"))
        ])
        pipe.fit(X_train, y_train)
        y_pred = pipe.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        print(f"Acurácia: {acc:.4f}")
        print(classification_report(y_test, y_pred, digits=4))
        resultados[nome] = {"acuracia": acc, "modelo": pipe, "y_pred": y_pred}
    return resultados

resultados = avaliar_kernels_topologicos(X_train, X_test, y_train, y_test)

X_all.shape: (188, 28, 3)

=== Avaliando Kernel: Persistence Image ===
Acurácia: 0.6842
              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000        18
           1     0.6842    1.0000    0.8125        39

    accuracy                         0.6842        57
   macro avg     0.3421    0.5000    0.4063        57
weighted avg     0.4681    0.6842    0.5559        57


=== Avaliando Kernel: Persistence Landscape ===
Acurácia: 0.9123
              precision    recall  f1-score   support

           0     0.8421    0.8889    0.8649        18
           1     0.9474    0.9231    0.9351        39

    accuracy                         0.9123        57
   macro avg     0.8947    0.9060    0.9000        57
weighted avg     0.9141    0.9123    0.9129        57


=== Avaliando Kernel: Betti Curve ===
Acurácia: 0.8772
              precision    recall  f1-score   support

           0     0.8235    0.7778    0.8000        18
           1     0.9000    