In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import Linear, HeteroConv, SAGEConv

class HeteroGNN(torch.nn.Module):
    def __init__(self, in_channels_host, in_channels_flow, dim_h, dim_out, num_layers):
        super().__init__()

        # Define input channels for different node types
        self.in_channels_host = in_channels_host
        self.in_channels_flow = in_channels_flow

        # Initial linear projections for each node type
        self.host_proj = Linear(in_channels_host, dim_h)
        self.flow_proj = Linear(in_channels_flow, dim_h)

        # Convolution layers
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('host', 'to', 'flow'): SAGEConv((-1, -1), dim_h, add_self_loops=False),
                ('flow', 'to', 'host'): SAGEConv((-1, -1), dim_h, add_self_loops=False),
            }, aggr='sum')
            self.convs.append(conv)

        # Final classification layer
        self.lin = Linear(dim_h, dim_out)

    def forward(self, x_dict, edge_index_dict):
        # Initial projection of node features
        x_dict['host'] = self.host_proj(x_dict['host'])
        x_dict['flow'] = self.flow_proj(x_dict['flow'])

        # Graph convolution layers
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}

        # Final classification on flow nodes
        return self.lin(x_dict['flow'])

def get_model(train_graphs):
    # Dynamically determine input channels based on first graph
    in_channels_host = train_graphs[0]['host'].x.size(1)
    in_channels_flow = train_graphs[0]['flow'].x.size(1)
    
    # Determine number of unique labels
    num_classes = len(torch.unique(train_graphs[0]['flow'].y))
    
    model = HeteroGNN(
        in_channels_host=in_channels_host, 
        in_channels_flow=in_channels_flow,
        dim_h=64,  # Hidden dimension
        dim_out=num_classes,  # Output classes
        num_layers=3  # Number of graph convolution layers
    )
    
    return model

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, global_mean_pool
from torch_geometric.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def train_and_evaluate_graph_classification(
    train_graphs, 
    test_graphs, 
    model_type='edge_flow', 
    epochs=100, 
    patience=10, 
    learning_rate=0.001,
    batch_size=32
):
    """
    Comprehensive training and evaluation for graph classification
    
    Args:
        train_graphs (list): List of training graph data
        test_graphs (list): List of test graph data
        model_type (str): 'edge_flow' or 'node_flow'
        epochs (int): Maximum training epochs
        patience (int): Early stopping patience
        learning_rate (float): Initial learning rate
        batch_size (int): Training batch size
    
    Returns:
        dict: Training and evaluation results
    """
    # Determine device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create data loaders
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)
    
    # Determine number of classes and features
    num_classes = len(torch.unique(train_graphs[0].y))
    num_features = train_graphs[0].x.shape[1]
    
    # Select model based on graph type
    if model_type == 'edge_flow':
        model = EdgeFlowGNN(num_features, num_classes).to(device)
    else:
        model = NodeFlowGNN(num_features, num_classes).to(device)
    
    # Optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=5, 
        verbose=True
    )
    
    # Loss function
    criterion = torch.nn.CrossEntropyLoss()
    
    # Training and validation tracking
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    
    def train_epoch():
        model.train()
        total_loss = 0
        total_correct = 0
        total_graphs = 0
        
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # Forward pass depends on model type
            if model_type == 'edge_flow':
                out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            else:
                out = model(batch.x, batch.edge_index, batch.batch)
            
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_correct += (out.argmax(dim=1) == batch.y).sum().item()
            total_graphs += batch.y.size(0)
        
        avg_loss = total_loss / len(train_loader)
        accuracy = total_correct / total_graphs
        return avg_loss, accuracy
    
    def validate():
        model.eval()
        total_loss = 0
        total_correct = 0
        total_graphs = 0
        
        with torch.no_grad():
            for batch in test_loader:
                batch = batch.to(device)
                
                # Forward pass depends on model type
                if model_type == 'edge_flow':
                    out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
                else:
                    out = model(batch.x, batch.edge_index, batch.batch)
                
                loss = criterion(out, batch.y)
                total_loss += loss.item()
                total_correct += (out.argmax(dim=1) == batch.y).sum().item()
                total_graphs += batch.y.size(0)
        
        avg_loss = total_loss / len(test_loader)
        accuracy = total_correct / total_graphs
        return avg_loss, accuracy
    
    # Main training loop
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch()
        val_loss, val_acc = validate()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), f'best_{model_type}_model.pt')
        else:
            patience_counter += 1
        
        print(f'Epoch {epoch+1}: '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        
        # Early stopping condition
        if patience_counter >= patience:
            print("Early stopping triggered")
            break
    
    # Detailed Evaluation
    model.load_state_dict(torch.load(f'best_{model_type}_model.pt'))
    model.eval()
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            
            # Forward pass depends on model type
            if model_type == 'edge_flow':
                out = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            else:
                out = model(batch.x, batch.edge_index, batch.batch)
            
            preds = out.argmax(dim=1).cpu().numpy()
            labels = batch.y.cpu().numpy()
            
            all_preds.extend(preds)
            all_labels.extend(labels)
    
    # Classification Report
    class_report = classification_report(all_labels, all_preds)
    print("Classification Report:\n", class_report)
    
    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'{model_type.replace("_", " ").title()} Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.savefig(f'{model_type}_confusion_matrix.png')
    
    # Visualization of training metrics
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Loss Curves')
    plt.legend()
    
    plt.subplot(1,2,2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.title('Accuracy Curves')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(f'{model_type}_training_metrics.png')
    
    return {
        'best_val_loss': best_val_loss,
        'classification_report': class_report,
        'confusion_matrix': cm,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'val_accuracies': val_accuracies
    }

# Example usage function
def run_graph_classification_experiments(train_edge_graphs, train_node_graphs, 
                                         test_edge_graphs, test_node_graphs):
    """
    Run classification experiments for both edge and node flow graphs
    
    Args:
        train_edge_graphs (list): Training edge flow graphs
        train_node_graphs (list): Training node flow graphs
        test_edge_graphs (list): Test edge flow graphs
        test_node_graphs (list): Test node flow graphs
    """
    # Edge Flow Graph Experiment
    print("Edge Flow Graph Classification:")
    edge_flow_results = train_and_evaluate_graph_classification(
        train_edge_graphs, 
        test_edge_graphs, 
        model_type='edge_flow',
        epochs=50,
        patience=10
    )
    
    # Node Flow Graph Experiment
    print("\nNode Flow Graph Classification:")
    node_flow_results = train_and_evaluate_graph_classification(
        train_node_graphs, 
        test_node_graphs, 
        model_type='node_flow',
        epochs=50,
        patience=10
    )
    
    return {
        'edge_flow_results': edge_flow_results,
        'node_flow_results': node_flow_results
    }

# Note: To use this, you would typically load your saved graphs and pass them to run_graph_classification_experiments
results = run_graph_classification_experiments(
    train_edge_flow, 
    train_node_flow, 
    test_edge_flow, 
    test_node_flow
)