# Sliding Window Subgraph Extraction for Graph Classification

This notebook demonstrates the sliding window approach for subgraph extraction and weakly supervised learning on graphs using Graph Attention Networks (GAT).

In [None]:
# Import standard libraries
import torch
import torch_geometric.datasets as datasets
from torch_geometric.loader import DataLoader
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import warnings

# Import custom modules
from utils import create_sliding_window_subgraphs, evaluate_with_attention
from models import GAT, train_model, test_model
from datasets import SubgraphDataset, create_dataset_splits
from visualization import plot_confusion_matrix, plot_training_curves, evaluate_and_visualize_top_k

warnings.filterwarnings('ignore')

## Dataset Loading and Analysis

In [None]:
# Load the MSRC_21 dataset
DATASET_PATH = 'dataset'
dataset = datasets.TUDataset(root=DATASET_PATH, name="MSRC_21")

print(f"Dataset: {dataset}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Number of node features: {dataset.num_node_features}")

# Calculate dataset statistics
labels = [data.y.item() for data in dataset]
total_nodes = sum(data.num_nodes for data in dataset)
total_edges = sum(data.num_edges for data in dataset)
avg_nodes = total_nodes / len(dataset)
avg_edges = total_edges / len(dataset)

print(f"\nDataset Statistics:")
print(f"Unique labels: {set(labels)}")
print(f"Average nodes per graph: {avg_nodes:.2f}")
print(f"Average edges per graph: {avg_edges:.2f}")

## Standard Graph Classification (Baseline)

In [None]:
# Create standard train/test splits for baseline comparison
train_dataset, test_dataset = create_dataset_splits(dataset, test_size=0.2, random_state=42)

# Create data loaders for full graphs
train_loader_full = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader_full = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Training graphs: {len(train_dataset)}")
print(f"Testing graphs: {len(test_dataset)}")

In [None]:
# Train baseline model on full graphs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

# Initialize baseline GAT model
baseline_model = GAT(
    num_features=dataset.num_node_features,
    num_classes=dataset.num_classes,
    hidden_channels=64,
    heads=8,
    dropout=0.6
).to(device)

baseline_optimizer = torch.optim.Adam(baseline_model.parameters(), lr=0.005, weight_decay=5e-4)

# Training loop for baseline
print("Training baseline model on full graphs...")
num_epochs = 50
baseline_train_acc = []
baseline_test_acc = []
baseline_losses = []

for epoch in range(num_epochs):
    # Train
    avg_loss, train_acc = train_model(baseline_model, train_loader_full, baseline_optimizer, device)
    
    # Test
    test_acc, _, _ = test_model(baseline_model, test_loader_full, device)
    
    baseline_train_acc.append(train_acc)
    baseline_test_acc.append(test_acc)
    baseline_losses.append(avg_loss)
    
    if epoch % 10 == 0 or epoch == num_epochs - 1:
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

print(f"\nBaseline Test Accuracy: {baseline_test_acc[-1]:.4f}")

## Sliding Window Subgraph Approach

In [None]:
# Sliding window parameters
window_params = {
    'window_size': 62,
    'step_size': 5
}

print(f"Sliding window parameters: {window_params}")

# Create subgraph datasets using sliding window approach
train_subgraph_dataset = SubgraphDataset(train_dataset, create_sliding_window_subgraphs, **window_params)
test_subgraph_dataset = SubgraphDataset(test_dataset, create_sliding_window_subgraphs, **window_params)

print(f"Training subgraphs: {len(train_subgraph_dataset.data_list)}")
print(f"Testing subgraphs: {len(test_subgraph_dataset.data_list)}")

# Create data loaders for subgraphs
train_loader_sub = DataLoader(train_subgraph_dataset.data_list, batch_size=32, shuffle=True)
test_loader_sub = DataLoader(test_subgraph_dataset.data_list, batch_size=32, shuffle=False)

print(f"Training batches: {len(train_loader_sub)}")
print(f"Testing batches: {len(test_loader_sub)}")

## Subgraph Model Training

In [None]:
# Initialize subgraph model
subgraph_model = GAT(
    num_features=dataset.num_node_features,
    num_classes=dataset.num_classes,
    hidden_channels=64,
    heads=8,
    dropout=0.6
).to(device)

subgraph_optimizer = torch.optim.Adam(subgraph_model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

print("Training subgraph model...")

In [None]:
# Training loop for subgraph model
num_epochs_sub = 50
subgraph_losses = []
subgraph_accuracies = []

for epoch in range(num_epochs_sub):
    avg_loss, accuracy = train_model(subgraph_model, train_loader_sub, subgraph_optimizer, device, criterion)
    subgraph_losses.append(avg_loss)
    subgraph_accuracies.append(accuracy)
    
    if epoch % 10 == 0 or epoch == num_epochs_sub - 1:
        print(f'Epoch {epoch + 1}/{num_epochs_sub}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')

print("Subgraph model training completed!")

## Evaluation with Attention-based Top-K Selection

In [None]:
# Evaluate using attention-based top-k subgraph selection
print("Evaluating with attention-based top-k subgraph selection...")

accuracy, predictions, true_labels = evaluate_with_attention(
    model=subgraph_model,
    dataset=test_dataset,
    subgraph_func=create_sliding_window_subgraphs,
    k=3,
    device=device,
    **window_params
)

print(f'Subgraph-based Test Accuracy (Top-K): {accuracy:.4f}')
print(f'Baseline Test Accuracy (Full Graph): {baseline_test_acc[-1]:.4f}')
print(f'Improvement: {accuracy - baseline_test_acc[-1]:.4f}')

# Generate class names
class_names = [f'Class {i}' for i in range(dataset.num_classes)]

# Print classification report
print("\nSubgraph-based Classification Report:")
print(classification_report(true_labels, predictions, target_names=class_names))

## Results Comparison and Visualization

In [None]:
# Plot confusion matrix for subgraph-based approach
plot_confusion_matrix(true_labels, predictions, class_names, 
                     title="Sliding Window Subgraph Classification Results")

In [None]:
# Compare training curves
plt.figure(figsize=(15, 5))

# Baseline accuracy
plt.subplot(1, 3, 1)
plt.plot(range(1, len(baseline_train_acc) + 1), baseline_train_acc, 'b-', label='Train Accuracy')
plt.plot(range(1, len(baseline_test_acc) + 1), baseline_test_acc, 'r-', label='Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Baseline (Full Graph) Accuracy')
plt.legend()
plt.grid(True)

# Subgraph training accuracy
plt.subplot(1, 3, 2)
plt.plot(range(1, len(subgraph_accuracies) + 1), subgraph_accuracies, 'g-', label='Training Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Subgraph Training Accuracy')
plt.legend()
plt.grid(True)

# Loss comparison
plt.subplot(1, 3, 3)
plt.plot(range(1, len(baseline_losses) + 1), baseline_losses, 'b-', label='Baseline Loss')
plt.plot(range(1, len(subgraph_losses) + 1), subgraph_losses, 'g-', label='Subgraph Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Comparison')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## Visualize Top-K Subgraph Selection

In [None]:
# Visualize attention-based subgraph selection for a sample graph
print("Visualizing top-k subgraph selection...")

evaluate_and_visualize_top_k(
    model=subgraph_model,
    dataset=test_dataset,
    subgraph_func=create_sliding_window_subgraphs,
    device=device,
    k=3,
    random_seed=42,
    **window_params
)

## Summary

This notebook demonstrated:

1. **Baseline Performance**: Standard GAT model on full graphs
2. **Sliding Window Approach**: Fixed-size subgraph extraction with overlapping windows
3. **Attention-based Selection**: Using GAT attention weights to identify most relevant subgraphs
4. **Performance Comparison**: Comparing subgraph-based vs full graph approaches
5. **Visualization**: Showing which subgraphs contribute most to final predictions

The sliding window approach provides a systematic way to create subgraphs while maintaining consistent sizes, which can be beneficial for batch processing and model stability.