In [1]:
import torch
import random
import torch.nn as nn
import torch.nn.functional as F

random.seed(2024)

In [2]:
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root="./datasets", name="Cora")
data = dataset[0]
train_mask = data.train_mask
test_mask = data.test_mask
val_mask = data.val_mask

In [None]:

import random

def random_walk(sample, num_nodes, edge_index, walk_len):
    """
    Perform random walks on a graph.

    Parameters:
    sample (int): Number of random walks for each node.
    num_nodes (int): Number of nodes in the graph.
    edge_index (torch.tensor): Edge indices of the graph (2, num_edges).
    walk_len (int): Length of each random walk.

    Returns:
    torch.tensor: Random walks (walk_length, num_nodes * sample, 1).
    """
    src_nodes = edge_index[0]
    des_nodes = edge_index[1]
    nodes = torch.arange(num_nodes).repeat(sample)
    out_neig = {node: [] for node in range(num_nodes)}

    for i in range(edge_index.size(-1)):
        out_neig[src_nodes[i].item()].append(des_nodes[i])

    total_samples = num_nodes * sample
    walks = -1 * torch.ones((walk_len, total_samples, 1), dtype=torch.int32)

    for i in range(total_samples):
        walks[0, i, 0] = nodes[i]
        current_node = walks[0, i, 0]
        for j in range(1, walk_len):
            neig = out_neig[current_node.item()]
            if len(neig) > 0:
                current_node = random.choice(neig)
                walks[j, i, 0] = current_node

    return walks

def uniqueness(walks, total_samples):
    """
    Calculate unique walks by assigning unique identifiers to nodes in each walk.

    Parameters:
    walks (torch.tensor): Random walks (walk_length, num_nodes * sample, 1).
    total_samples (int): Total number of samples (num_nodes * sample).

    Returns:
    torch.tensor: Unique walks (walk_length, num_nodes * sample, 1).
    """
    unique_walks = torch.empty_like(walks)
    for i in range(total_samples):
        c = 0
        om = {}
        for j, current_node in enumerate(walks[:, i, 0]):
            if current_node.item() not in om:
                om[current_node.item()] = c
                c += 1
            unique_walks[j, i, 0] = om[current_node.item()]

    return unique_walks

In [None]:

class RumLayer(nn.Module):
    def __init__(
        self,
        num_nodes,
        sample,
        x_input_dim,
        hidden_state_dim,
        walk_len,
        rnd_walk: callable = random_walk,
        uniq_walk=uniqueness,
        **kwargs,
    ):
        """
        Initialize the RumLayer.

        Parameters:
        num_nodes (int): Number of nodes.
        sample (int): Number of random walks for each node.
        x_input_dim (int): Input feature dimension for each node.
        hidden_state_dim (int): Hidden state dimension for the RNN.
        walk_len (int): Length of each random walk.
        rnd_walk (callable): Function to perform random walks.
        uniq_walk (callable): Function to calculate unique walks.
        """
        super().__init__()
        self.rnn_walk = nn.GRU(2, hidden_state_dim, bidirectional=True)
        self.rnn_x = nn.GRU(x_input_dim, hidden_state_dim)

        self.num_nodes = num_nodes
        self.sample = sample
        self.x_input_dim = x_input_dim
        self.hidden_state_dim = hidden_state_dim
        self.total_samples = sample * num_nodes
        self.random_walk = rnd_walk
        self.uniq_walks = uniq_walk
        self.walk_len = walk_len

    def forward(self, x, edge_index):
        """
        Forward pass for the RumLayer.

        Parameters:
        x (torch.tensor): Input features of shape (num_nodes, x_input_dim).
        edge_index (torch.tensor): Edge indices of the graph.

        Returns:
        torch.tensor: Output embeddings of shape (num_nodes, hidden_state_dim).
        """
        walks = self.random_walk(self.sample, self.num_nodes, edge_index, self.walk_len)
        x = x[walks.squeeze(-1)]

        uniq_walks = self.uniq_walks(walks, self.total_samples)
        uniq_walks = uniq_walks / uniq_walks.size(0)
        uniq_walks = uniq_walks * torch.pi * 2

        uniq_walks_sin, uniq_walks_cos = torch.sin(uniq_walks), torch.cos(uniq_walks)
        uniq_walks = torch.cat([uniq_walks_sin, uniq_walks_cos], dim=-1)

        _, h_walk = self.rnn_walk(uniq_walks)
        h_walk = torch.mean(h_walk, dim=0, keepdim=True)

        _, h = self.rnn_x(x, h_walk)
        h = h.view(self.sample, self.num_nodes, self.hidden_state_dim)
        h = torch.mean(h, dim=0)

        return h

In [None]:


class RUMModel(torch.nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: int,
        depth: int,
        num_nodes: int,
        sample: int,
        walk_len: int,
        activation: callable = torch.nn.ELU(),
    ):
        """
        Initialize the RUMModel.

        Parameters:
        in_features (int): Number of input features per node.
        out_features (int): Number of output classes.
        hidden_features (int): Hidden state dimension for the RNN.
        depth (int): Number of RumLayers in the model.
        num_nodes (int): Number of nodes in the graph.
        sample (int): Number of random walks for each node.
        walk_len (int): Length of each random walk.
        activation (callable): Activation function.
        """
        super().__init__()
        torch.manual_seed(2024)
        self.layers = nn.ModuleList()
        self.fc_in = nn.Linear(in_features, hidden_features, bias=True)
        for _ in range(depth):
            self.layers.append(
                RumLayer(
                    x_input_dim=hidden_features,
                    hidden_state_dim=hidden_features,
                    num_nodes=num_nodes,
                    sample=sample,
                    walk_len=walk_len,
                )
            )
        self.fc_out = nn.Linear(hidden_features, out_features, bias=True)

        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.depth = depth
        self.activation = activation

    def forward(self, x, edge_index):
        """
        Forward pass for the RUMModel.

        Parameters:
        x (torch.tensor): Input features of shape (num_nodes, in_features).
        edge_index (torch.tensor): Edge indices of the graph.

        Returns:
        torch.tensor: Output embeddings of shape (num_nodes, out_features).
        """
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.fc_in(x)
        x = self.activation(x)
        for rum in self.layers:
            skip_connection = x
            x = rum(x, edge_index)
            x = x + skip_connection
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.fc_out(x)
        x = F.log_softmax(x, dim=-1)
        return x

In [None]:
def train(model, criterion, optimizer, data, train_mask):
    """
    Train the model for one epoch.

    Returns:
    float: The training loss.
    """
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[train_mask], data.y[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

In [None]:
def test(model, data, train_mask, val_mask, test_mask):
    """
    Evaluate the model on the training, validation, and test sets.

    Returns:
    list: A list containing the accuracy for the training, validation, and test sets.
    """
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=-1)

        acc = []
        for mask in [train_mask, val_mask, test_mask]:
            correct = (pred[mask] == data.y[mask]).sum().item()
            total = mask.sum().item()
            acc.append(correct / total)

    return acc

In [None]:
import matplotlib.pyplot as plt

import json

samples_range = [1, 2, 3, 4, 5, 6, 7, 8]
walk_len_range = [1, 2, 3, 4, 5, 6, 7, 8]
results = {}
criterion = nn.CrossEntropyLoss()

# Iterate through different samples and walk lengths
for sample in samples_range:
    for walk_len in walk_len_range:
        # Initialize model
        model = RUMModel(
            in_features=dataset.num_node_features,
            out_features=dataset.num_classes,
            hidden_features=64,
            depth=1,
            num_nodes=data.num_nodes,
            sample=sample,
            walk_len=walk_len,
        )

        # Define optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)

        best_test_acc = 0
        best_val_acc = 0
        patience = 25
        counter = 0

        # Train model with early stopping after 100 epochs
        for epoch in range(1, 101):
            train_loss = train(model, criterion, optimizer, data, train_mask)

            # Test model
            train_acc, val_acc, test_acc = test(model, data, train_mask, val_mask, test_mask)

            # Check if validation accuracy has improved
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
                counter = 0
            else:
                counter += 1

            # Early stopping condition
            if counter >= patience:
                print(f"Early stopping at epoch {epoch} for sample {sample}, walk_len {walk_len}.")
                break

        # Store the best test accuracy for this sample and walk_len combination
        results[(sample, walk_len)] = best_test_acc

        # Print the best accuracy for the current (sample, walk_len)
        print(f"Best test accuracy for sample {sample}, walk_len {walk_len}: {best_test_acc:.4f}")

# Save results to a JSON file
results_file = "test_accuracy_results.json"
json_serializable_results = {str(key): value for key, value in results.items()}

with open(results_file, "w") as f:
    json.dump(json_serializable_results, f, indent=4)
print(f"Results saved to {results_file}")

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt

# Define the colormap
colors = mpl.colormaps["coolwarm"]

# Prepare data for plotting
test_accuracies = [
    [results.get((sample, walk_len), 0) for walk_len in walk_len_range]
    for sample in samples_range
]

# Plot results with the new color map
plt.figure(figsize=(10, 6))
for i, sample in enumerate(samples_range):
    plt.plot(
        walk_len_range,
        test_accuracies[i],
        label=f"Sample: {sample}",
        color=colors(i / len(samples_range)),
    )

plt.xlabel("Walk Length")
plt.ylabel("Best Test Accuracy")
plt.title("Best Test Classification Accuracy of Cora with Varying Sample Size and Walk Length")
plt.legend()
plt.grid(True)

# Save the plot as an image file
plt.savefig("test_accuracy_plot_colored.png")

plt.show()

print("Line plot saved to test_accuracy_plot_colored.png")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Prepare data for plotting
X, Y = np.meshgrid(walk_len_range, samples_range)
Z = np.array([
    [results.get((sample, walk_len), 0) for walk_len in walk_len_range]
    for sample in samples_range
])

# Create a 3D figure
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection="3d")

# Customize plot surface
surf = ax.plot_surface(
    X, Y, Z,
    cmap="coolwarm",  # Colormap for better visualization
    edgecolor="k",  # Add black edges for clarity
    linewidth=0.5,  # Thin edges for a neat look
    alpha=0.8  # Slight transparency for depth
)

# Add a color bar
cbar = fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10)
cbar.set_label("Test Accuracy", fontsize=12)

# Set axis labels
ax.set_xlabel("Walk Length", fontsize=14, labelpad=10)
ax.set_ylabel("Sample Size", fontsize=14, labelpad=10)
ax.set_zlabel("Test Accuracy", fontsize=14, labelpad=10)

# Set title
ax.set_title("Best Test Accuracy with Varying Sample Size and Walk Length", fontsize=16, loc="center")

# Adjust viewing angle for better visibility
ax.view_init(elev=30, azim=135)

# Save the plot as an image file
plot_file = "test_accuracy_mesh_plot.png"
plt.savefig(plot_file, dpi=300, bbox_inches="tight")
print(f"Plot saved to {plot_file}")

plt.show()