In [None]:
import torch_geometric
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx
import torch_geometric.datasets as datasets
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool,GINConv
import random
from torch_geometric.data import Data, DataLoader, InMemoryDataset


import warnings
warnings.filterwarnings('ignore')

In [None]:
DATASET_PATH = 'dataset'
dataset = torch_geometric.datasets.TUDataset(root=DATASET_PATH, name="MSRC_21")

In [None]:
MIN_NODES = 10  # Minimum number of nodes a subgraph must have
MIN_EDGES = 8  # Minimum number of edges a subgraph must have
def create_bfs_subgraphs(G, original_features, graph_label, depth_limit=8):
    visited = set()  # Set to store visited nodes
    subgraphs = []  # List to store subgraphs from each BFS traversal

    # BFS implementation
    def bfs(G, start_node, depth_limit):
        bfs_nodes = set()  # To store nodes visited in this BFS
        bfs_edges = set()  # To store edges traversed in this BFS
        queue = [(start_node, 0)]  # (node, depth)
        visited.add(start_node)  # Mark the start node as visited

        while queue:
            node, depth = queue.pop(0)
            if depth < depth_limit:
                for neighbor in G.neighbors(node):
                    if neighbor not in visited:
                        visited.add(neighbor)
                        queue.append((neighbor, depth + 1))
                        bfs_edges.add((node, neighbor))
                        bfs_nodes.add(neighbor)
            bfs_nodes.add(node)

        return bfs_nodes, bfs_edges

    # Perform BFS iteratively from random unvisited nodes until all nodes are visited
    while len(visited) < len(G.nodes):
        unvisited_nodes = list(set(G.nodes) - visited)
        if not unvisited_nodes:
            break

        start_node = random.choice(unvisited_nodes)

        # Run BFS from the chosen node
        bfs_nodes, bfs_edges = bfs(G, start_node, depth_limit)

        # **New Check for Minimum Nodes and Edges**
        if len(bfs_nodes) < MIN_NODES or len(bfs_edges) < MIN_EDGES:
            continue  # Skip this subgraph if it doesn't meet the criteria

        # Map original indices to subgraph indices
        node_indices = {node: i for i, node in enumerate(bfs_nodes)}

        # Create edge index in subgraph format
        subgraph_edges = []
        for u, v in bfs_edges:
            if u in bfs_nodes and v in bfs_nodes:
                subgraph_edges.append((node_indices[u], node_indices[v]))

        edge_index = torch.tensor(subgraph_edges, dtype=torch.long).t().contiguous()
        features = original_features[list(bfs_nodes)]  # Extract node features for subgraph

        # Create Data object for the subgraph, including `y` and original indices
        data = Data(
            x=features,
            edge_index=edge_index,
            y=torch.tensor([graph_label.item()], dtype=torch.long),
            original_node_indices=torch.tensor(list(bfs_nodes), dtype=torch.long),
        )
        if data.edge_index.size() == 0:  # Check number of edges
            print("Skipping subgraph with empty edge_index")
        else:
            subgraphs.append(data)

    return subgraphs


    



class SubgraphDataset(InMemoryDataset):
    def __init__(self, dataset):
        super(SubgraphDataset, self).__init__(root=DATASET_PATH)
        self.data_list = []
        self.labels = []

        for graph in dataset:
            G = nx.from_edgelist(graph.edge_index.t().tolist())
            #print(graph.y)
            subgraphs = create_bfs_subgraphs(G, graph.x,graph_label=graph.y)  # Pass original node features
            self.data_list.extend(subgraphs)
            #print(subgraphs[0].y)
            #print(graph.y)
            self.labels.extend([graph.y] * len(subgraphs))  # Add label for each subgraph
            
        self.data, self.slices = self.collate(self.data_list)

    def get_labels(self):
        return torch.tensor(self.labels)



train_graphs, test_graphs = train_test_split(dataset, test_size=0.2, random_state=21)


train_subgraph_dataset = SubgraphDataset(train_graphs)
test_subgraph_dataset = SubgraphDataset(test_graphs)
print(len(train_subgraph_dataset),len(test_subgraph_dataset))


train_loader = DataLoader(train_subgraph_dataset.data_list, batch_size=32, shuffle=True)
test_loader = DataLoader(test_subgraph_dataset.data_list, batch_size=32, shuffle=False)
print(len(train_loader),len(test_loader))



In [None]:
class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads=8, dropout=0.6):
        super(GAT, self).__init__()
        torch.manual_seed(42)


        self.conv1 = GATConv(dataset.num_node_features, hidden_channels, heads=heads, dropout=dropout)
        
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=1, concat=False, dropout=dropout)
        self.lin = torch.nn.Linear(hidden_channels, dataset.num_classes)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        
        x, attn_weights = self.conv2(x, edge_index, return_attention_weights=True)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        
        x = global_mean_pool(x, data.batch)

        
        x = self.lin(x)

        return F.log_softmax(x, dim=1), attn_weights


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device : ',device)


hidden_channels = 64
model = GAT(hidden_channels=hidden_channels).to(device)
print(model)


optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)


In [None]:

hidden_channels = 64
heads = 8
dropout = 0.6
model = GAT(hidden_channels=hidden_channels, heads=heads, dropout=dropout)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


def move_to_device(batch, device):
    batch = batch.to(device)
    return batch

# Weakly supervised training
model.train()
for epoch in range(20):  # Number of epochs
    total_loss = 0
    correct = 0
    total = 0

    for batch in train_loader:
        batch = move_to_device(batch, device)

        optimizer.zero_grad()

        out, attn_weights = model(batch)
        loss = criterion(out, batch.y)  # Compute loss using graph labels
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        # Accuracy calculation
        _, predicted = torch.max(out, dim=1)
        correct += (predicted == batch.y).sum().item()
        total += batch.y.size(0)
    
    avg_loss = total_loss / len(train_loader)
    accuracy = correct / total
    print(f'Epoch {epoch + 1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import networkx as nx

# Evaluation function to select top-k subgraphs based on attention weights
def evaluate_model_with_attention(model, dataset, k=4):
    model.eval()
    all_predictions = []
    all_true_labels = []
    correct = 0
    total = 0

    with torch.no_grad():
        for graph in dataset:
            # Convert the graph to a NetworkX object
            G = nx.from_edgelist(graph.edge_index.t().tolist())
            
            # Generate BFS subgraphs
            subgraphs = create_bfs_subgraphs(G, graph.x, graph_label=graph.y)
            subgraph_outputs = []
            subgraph_attention_scores = []

            for subgraph in subgraphs:
                subgraph = subgraph.to(device)
                
                # Skip subgraphs with empty edge_index
                if subgraph.edge_index.size() == 0:  # Check number of edges
                    print("Skipping subgraph with empty edge_index")
                    continue
                
                try:
                    # Pass the subgraph to the model
                    output, attn_weights = model(subgraph)
                except IndexError as e:
                    print(f"Error processing subgraph: {e}")
                    continue
                
                # Extract attention weights
                if isinstance(attn_weights, (tuple, list)):
                    attention_tensor = attn_weights[-1]
                else:
                    attention_tensor = attn_weights
                
                # Compute a single attention score for the subgraph
                attention_score = attention_tensor.mean().item()
                subgraph_outputs.append(output.unsqueeze(0))
                subgraph_attention_scores.append(attention_score)

            # Skip if no valid subgraph outputs are available
            if not subgraph_outputs:
                continue

            subgraph_outputs = torch.cat(subgraph_outputs, dim=0)
            subgraph_attention_scores = torch.tensor(subgraph_attention_scores)

            # Select top-k subgraphs based on attention scores
            current_k = min(k, len(subgraph_outputs))
            if current_k == 0:
                continue

            top_k_values, top_k_indices = subgraph_attention_scores.topk(current_k, dim=0, largest=True, sorted=True)
            top_k_subgraphs = subgraph_outputs[top_k_indices]

            # Aggregate the top-k subgraph outputs (mean aggregation)
            final_prediction = top_k_subgraphs.mean(dim=0)

            # Apply softmax to get probabilities
            final_prediction = torch.softmax(final_prediction, dim=1)

            # Apply argmax to find the predicted class
            final_prediction_class = final_prediction.argmax(dim=1).item()  # Convert to scalar
            true_label = graph.y.item()

            all_predictions.append(final_prediction_class)
            all_true_labels.append(true_label)

            if final_prediction_class == true_label:
                correct += 1
            total += 1

    accuracy = correct / total if total > 0 else 0
    print(correct)
    print(f'Accuracy: {accuracy:.4f}')

    
    cm = confusion_matrix(all_true_labels, all_predictions)
    report = classification_report(all_true_labels, all_predictions, target_names=class_names)

    
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

    print(report)




class_names = [f'Class {i}' for i in range(dataset.num_classes)]
evaluate_model_with_attention(model, test_graphs)
