# BFS-Based Subgraph Extraction for Graph Classification

This notebook demonstrates the BFS (Breadth-First Search) approach for subgraph extraction and weakly supervised learning on graphs using Graph Attention Networks (GAT).

In [None]:
# Import standard libraries
import torch
import torch_geometric.datasets as datasets
from torch_geometric.loader import DataLoader
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import warnings

# Import custom modules
from utils import create_bfs_subgraphs, evaluate_with_attention
from models import GAT, train_model
from datasets import SubgraphDataset, create_dataset_splits
from visualization import plot_confusion_matrix, evaluate_and_visualize_top_k

warnings.filterwarnings('ignore')

## Dataset Loading and Setup

In [None]:
# Load the MSRC_21 dataset
DATASET_PATH = 'dataset'
dataset = datasets.TUDataset(root=DATASET_PATH, name="MSRC_21")

print(f"Dataset: {dataset}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Number of node features: {dataset.num_node_features}")

## BFS Subgraph Creation and Dataset Preparation

In [None]:
# Create train/test splits
train_graphs, test_graphs = create_dataset_splits(dataset, test_size=0.2, random_state=21)

# BFS parameters
bfs_params = {
    'depth_limit': 8,
    'min_nodes': 10,
    'min_edges': 8
}

# Create subgraph datasets
train_subgraph_dataset = SubgraphDataset(train_graphs, create_bfs_subgraphs, **bfs_params)
test_subgraph_dataset = SubgraphDataset(test_graphs, create_bfs_subgraphs, **bfs_params)

print(f"Training subgraphs: {len(train_subgraph_dataset.data_list)}")
print(f"Testing subgraphs: {len(test_subgraph_dataset.data_list)}")

# Create data loaders
train_loader = DataLoader(train_subgraph_dataset.data_list, batch_size=32, shuffle=True)
test_loader = DataLoader(test_subgraph_dataset.data_list, batch_size=32, shuffle=False)

print(f"Training batches: {len(train_loader)}")
print(f"Testing batches: {len(test_loader)}")

## Model Setup and Training

In [None]:
# Model parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

# Initialize GAT model
model = GAT(
    num_features=dataset.num_node_features,
    num_classes=dataset.num_classes,
    hidden_channels=64,
    heads=8,
    dropout=0.6
).to(device)

print(f"Model: {model}")

# Setup optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# Training loop
num_epochs = 20
train_losses = []
train_accuracies = []

print("Starting training...")
for epoch in range(num_epochs):
    avg_loss, accuracy = train_model(model, train_loader, optimizer, device, criterion)
    train_losses.append(avg_loss)
    train_accuracies.append(accuracy)
    
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')

print("Training completed!")

## Evaluation with Attention-based Subgraph Selection

In [None]:
# Evaluate using attention-based top-k subgraph selection
print("Evaluating with attention-based subgraph selection...")

accuracy, predictions, true_labels = evaluate_with_attention(
    model=model,
    dataset=test_graphs,
    subgraph_func=create_bfs_subgraphs,
    k=4,
    device=device,
    **bfs_params
)

print(f'Final Test Accuracy: {accuracy:.4f}')

# Generate class names
class_names = [f'Class {i}' for i in range(dataset.num_classes)]

# Print classification report
print("\nClassification Report:")
print(classification_report(true_labels, predictions, target_names=class_names))

## Results Visualization

In [None]:
# Plot confusion matrix
plot_confusion_matrix(true_labels, predictions, class_names, 
                     title="BFS-based Subgraph Classification Results")

In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs + 1), train_accuracies, 'b-', label='Training Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training Accuracy')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs + 1), train_losses, 'r-', label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## Visualize Top-K Subgraph Selection

In [None]:
# Visualize attention-based subgraph selection for a sample graph
print("Visualizing top-k subgraph selection...")

evaluate_and_visualize_top_k(
    model=model,
    dataset=test_graphs,
    subgraph_func=create_bfs_subgraphs,
    device=device,
    k=3,
    random_seed=42,
    **bfs_params
)