## This notebook is a demo for

1. Testing the convergence between the product-of-feature-matrix kernel (POFM) to the ground truth diffusion kernel.

In [21]:
import tensorflow as tf
import numpy as np
import gpflow
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import networkx as nx
from gpflow.utilities import print_summary
import tensorflow_probability as tfp
import seaborn as sns


In [22]:
import sys
import os
project_root = os.path.abspath("../..")
sys.path.append(project_root)

In [23]:
from efficient_graph_gp.graph_kernels import diffusion_kernel, feature_matrix_kernel, generate_noisy_samples
from efficient_graph_gp.gpflow_kernels import GraphDiffusionKernel
from utils import plot_network_graph, plot_gp_fit

In [24]:
# Parameters
num_nodes = 10
graph_type = 'random' # 'line', 'random'
np.random.seed(0)

In [25]:
if graph_type == 'line':
    adjacency_matrix = np.eye(num_nodes, k=1) + np.eye(num_nodes, k=-1)  # Circular adjacency matrix
elif graph_type == 'random':
    probability = 0.1  # Probability of edge creation
    G = nx.erdos_renyi_graph(num_nodes, probability, directed=False)  # Ensure the graph is undirected
    adjacency_matrix = nx.to_numpy_array(G)  # Convert to adjacency matrix

In [26]:
ground_truth = diffusion_kernel(adj_matrix=adjacency_matrix, beta=2.0)

In [27]:
pof_matrix = feature_matrix_kernel(adj_matrix=adjacency_matrix,
                                   max_expansion=10, kernel_type='diffusion',
                                   kernel_hyperparameters={'beta': 2.0})

In [29]:
# Plotting function
def plot_heatmaps(beta_sample, max_expansion_sample):
    ground_truth = diffusion_kernel(adjacency_matrix, beta_sample)
    pof_matrix = feature_matrix_kernel(
        adj_matrix=adjacency_matrix,
        max_expansion=max_expansion_sample,
        kernel_type='diffusion',
        kernel_hyperparameters={'beta': beta_sample}
    )
    
    # Create the plots
    plt.figure(figsize=(12, 6))
    
    # Heatmap for Ground Truth
    plt.subplot(1, 2, 1)
    sns.heatmap(ground_truth, annot=True, cmap='viridis', cbar=True)
    plt.title(f"Ground Truth (Beta={beta_sample})")
    plt.xlabel("Nodes")
    plt.ylabel("Nodes")
    
    # Heatmap for Product of Feature Matrix
    plt.subplot(1, 2, 2)
    sns.heatmap(pof_matrix, annot=True, cmap='viridis', cbar=True)
    plt.title(f"POF Matrix (Max Expansion={max_expansion_sample})")
    plt.xlabel("Nodes")
    plt.ylabel("Nodes")
    
    plt.tight_layout()
    plt.show()

# Interactive widgets
beta_slider = widgets.FloatSlider(value=3.0, min=0.1, max=10.0, step=0.1, description='Beta:')
max_expansion_slider = widgets.IntSlider(value=10, min=1, max=20, step=1, description='Max Expansion:')
ui = widgets.VBox([beta_slider, max_expansion_slider])
out = widgets.interactive_output(plot_heatmaps, {
    'beta_sample': beta_slider,
    'max_expansion_sample': max_expansion_slider
})
display(ui, out)

VBox(children=(FloatSlider(value=3.0, description='Beta:', max=10.0, min=0.1), IntSlider(value=10, description…

Output()