In [1]:
import torch
from torch_geometric.data import HeteroData, Data
from typing import List, Union

def convert_single_graph(homogeneous_graph: Data, source_node_idx: int = 0, add_source_self_loop: bool = True) -> HeteroData:
    """
    Convert a single homogeneous graph to a heterogeneous graph with two node types:
    - 'source': News source node
    - 'user': All other nodes
    
    And two edge types:
    - ('source', 'to', 'user'): Edges from source to users
    - ('user', 'to', 'user'): Edges between users
    - ('source', 'to', 'source'): Self-loop for source node (optional)
    
    Args:
        homogeneous_graph: A PyTorch Geometric Data object
        source_node_idx: Index of the source node in the graph, default is 0
        add_source_self_loop: Whether to add a self-loop to the source node, default is False
        
    Returns:
        A HeteroData object
    """
    hetero_graph = HeteroData()
    
    # Get total number of nodes
    num_nodes = homogeneous_graph.num_nodes
    
    # Extract features for source node
    source_features = homogeneous_graph.x[source_node_idx:source_node_idx+1]
    
    # Extract features for user nodes (all nodes except source)
    user_indices = torch.cat([
        torch.arange(0, source_node_idx), 
        torch.arange(source_node_idx + 1, num_nodes)
    ])
    user_features = homogeneous_graph.x[user_indices]
    
    # Add node features to the heterogeneous graph
    hetero_graph['source'].x = source_features
    hetero_graph['user'].x = user_features
    
    # Create a mapping from original node indices to new node indices
    node_mapping = {}
    node_mapping[source_node_idx] = ('source', 0)  # Source node maps to index 0 in 'source' type
    
    # Map all other nodes to 'user' type
    user_counter = 0
    for i in range(num_nodes):
        if i != source_node_idx:
            node_mapping[i] = ('user', user_counter)
            user_counter += 1
    
    # Process edges
    edge_index = homogeneous_graph.edge_index
    
    # Source-to-user edges and User-to-user edges
    source_to_user_edges = []
    user_to_user_edges = []
    
    for i in range(edge_index.shape[1]):
        src, dst = edge_index[0, i].item(), edge_index[1, i].item()
        
        src_type, src_idx = node_mapping[src]
        dst_type, dst_idx = node_mapping[dst]
        
        if src_type == 'source' and dst_type == 'user':
            # Source to user edge
            source_to_user_edges.append((src_idx, dst_idx))
        elif src_type == 'user' and dst_type == 'user':
            # User to user edge
            user_to_user_edges.append((src_idx, dst_idx))
        # We ignore user-to-source edges as mentioned in the requirements
    
    # Add edges to the heterogeneous graph
    if source_to_user_edges:
        src_indices, dst_indices = zip(*source_to_user_edges)
        hetero_graph['source', 'to', 'user'].edge_index = torch.tensor(
            [src_indices, dst_indices], dtype=torch.long
        )
        
    
    if user_to_user_edges:
        src_indices, dst_indices = zip(*user_to_user_edges)
        hetero_graph['user', 'to', 'user'].edge_index = torch.tensor(
            [src_indices, dst_indices], dtype=torch.long
        )
    
    # Add self-loop to source node if requested
    if add_source_self_loop:
        hetero_graph['source', 'to', 'source'].edge_index = torch.tensor(
            [[0], [0]], dtype=torch.long
        )
    
    # Copy graph-level targets if they exist
    if hasattr(homogeneous_graph, 'y'):
        hetero_graph['source'].y = homogeneous_graph.y
    
    return hetero_graph

def convert_to_heterogeneous(homogeneous_dataset, source_node_idx=0, add_source_self_loop=True):
    """
    Convert a homogeneous UPFD dataset to a heterogeneous dataset.
    
    Args:
        homogeneous_dataset: A PyTorch Geometric UPFD dataset
        source_node_idx: Index of the source node in each graph, default is 0
        add_source_self_loop: Whether to add a self-loop to the source node, default is False
        
    Returns:
        A list of HeteroData objects
    """
    # Simply apply convert_single_graph to each graph in the dataset
    hetero_dataset = [
        convert_single_graph(graph, source_node_idx, add_source_self_loop) 
        for graph in homogeneous_dataset
    ]
    
    return hetero_dataset

def get_edge_type(edge_index, source_indices=[0]):
    """
    Generate edge type tensor based on source node indices.
    
    This function creates a tensor of edge types by assigning different types to edges
    based on whether the source node is in the specified list of source indices.
    
    Args:
        edge_index (torch.Tensor): The edge index tensor of shape [2, num_edges]
            where edge_index[0] contains source nodes and edge_index[1] contains 
            target nodes.
        source_indices (list, optional): List of node indices to be considered as 
            source nodes. Edges originating from these nodes will be assigned type 0,
            while all other edges will be assigned type 1. Defaults to [0].
    
    Returns:
        torch.Tensor: A tensor of shape [num_edges] containing the edge types.
            Edges from nodes in source_indices have type 0, others have type 1.
    
    Example:
        >>> edge_index = torch.tensor([[0, 1, 2, 0], [1, 2, 3, 3]])
        >>> edge_type = get_edge_type(edge_index, source_indices=[0])
        >>> print(edge_type)
        tensor([0, 1, 1, 0])
    """
    edge_type = []
    for src, tgt in edge_index.t().tolist():
        if src in source_indices:
            edge_type.append(0)
        else:
            edge_type.append(1)
    return torch.tensor(edge_type)

from torch_geometric.data import Batch
def to_hetero_batch(batch, add_source_self_loop=True):
    data_list = batch.to_data_list()
    data_list = convert_to_heterogeneous(data_list, 0, add_source_self_loop)
    for data in data_list:
        print(data)
    
    batch = Batch.from_data_list(data_list)
    
    return batch

In [2]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, ModuleList, ReLU, Sequential
from typing import Callable, Dict, List, Optional, Tuple, Union, Final

from torch_geometric.nn import GATConv, GATv2Conv, HANConv, RGCNConv, global_add_pool, global_mean_pool, global_max_pool
from torch_geometric.nn.models.basic_gnn import BasicGNN
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.data import HeteroData, Batch


#######################
# GAT (Graph Attention Network) Models
#######################

class GAT(BasicGNN):
    """
    Base GAT model that outputs node embeddings.
    """
    supports_edge_weight: Final[bool] = False
    supports_edge_attr: Final[bool] = True
    
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        num_layers: int,
        dropout: float = 0.0,
        v2: bool = False,
        heads: int = 8,
        concat: bool = True,
        **kwargs,
    ):
        # Store these attributes before the parent class constructor
        self.v2 = v2
        self.heads = heads
        self.concat = concat
        self.out_dim = out_channels
        
        super().__init__(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            out_channels=out_channels,
            num_layers=num_layers,
            dropout=dropout,
            **kwargs,
        )
      
    def init_conv(self, in_channels: Union[int, Tuple[int, int]],
                  out_channels: int, **kwargs) -> MessagePassing:
        
        # Use the stored attributes
        v2 = kwargs.pop('v2', self.v2)
        heads = kwargs.pop('heads', self.heads)
        concat = kwargs.pop('concat', self.concat)

        # Do not use concatenation in case the layer `GATConv` layer maps to
        # the desired output channels (out_channels != None and jk != None):
        if getattr(self, '_is_conv_to_out', False):
            concat = False

        if concat and out_channels % heads != 0:
            raise ValueError(f"Ensure that the number of output channels of "
                             f"'GATConv' (got '{out_channels}') is divisible "
                             f"by the number of heads (got '{heads}')")

        if concat:
            out_channels = out_channels // heads

        Conv = GATConv if not v2 else GATv2Conv
        return Conv(in_channels, out_channels, heads=heads, concat=concat,
                    dropout=self.dropout.p, **kwargs)


class GATForGraphClassification(torch.nn.Module):
    """
    Graph classification model based on Graph Attention Networks.
    """
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        num_classes: int,
        num_layers: int,
        dropout: float = 0.0,
        pooling: str = 'mean',
        v2: bool = False,
        heads: int = 8,
        concat: bool = True,
        **kwargs,
    ):
        super().__init__()
        
        # Create the base GAT model for node embeddings
        self.gat = GAT(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            out_channels=hidden_channels,
            num_layers=num_layers,
            dropout=dropout,
            v2=v2,
            heads=heads,
            concat=concat,
            **kwargs,
        )
        
        # Set up the pooling function
        if pooling == 'add':
            self.pool = global_add_pool
        elif pooling == 'mean':
            self.pool = global_mean_pool
        elif pooling == 'max':
            self.pool = global_max_pool
        else:
            raise ValueError(f"Pooling type {pooling} not supported.")
        
        # Classification layer
        self.classifier = Linear(hidden_channels, num_classes)
        self.dropout = dropout
        
        # Store the embedding dimension for ensemble methods
        self.output_dim = hidden_channels
      
    def forward(self, x, edge_index, batch=None, edge_attr=None):
        """
        Forward pass for graph classification.
        """
        # Get graph embeddings
        embeddings = self.get_embedding(x, edge_index, batch, edge_attr)
        
        # Apply final classification layer
        x = F.dropout(embeddings, p=self.dropout, training=self.training)
        x = self.classifier(x)
        
        return x
    
    def get_embedding(self, x, edge_index, batch=None, edge_attr=None):
        """
        Get graph-level embeddings for use in classification or ensemble methods.
        """
        # Get node embeddings from the base GAT model
        x = self.gat(x, edge_index, edge_attr=edge_attr)
        
        # Pool node features to graph-level representation
        if batch is None:
            # If no batch is provided, assume a single graph
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        # Apply pooling to get graph-level representation
        x = self.pool(x, batch)
        
        return x


#######################
# HAN (Heterogeneous Graph Attention Network) Models
#######################

class HAN(torch.nn.Module):
    """
    Base Heterogeneous Graph Attention Network (HAN) model that outputs node embeddings.
    """
    def __init__(self, 
                 in_channels: Union[int, Dict[str, int]],
                 hidden_channels: int,
                 out_channels: int, 
                 heads: int = 8, 
                 metadata: Optional[Tuple] = None, 
                 dropout: float = 0.6,
                 num_layers: int = 1):
        super().__init__()
        self.num_layers = num_layers
        self.out_dim = out_channels
        
        # HANConv does not support multiple layers natively
        self.han_conv = HANConv(
            in_channels=in_channels, 
            out_channels=hidden_channels, 
            heads=heads,
            dropout=dropout, 
            metadata=metadata
        )
        
        # For multi-layer networks, we create additional transformation layers
        self.transforms = ModuleList()
        if num_layers > 1:
            # First transformation after HANConv
            self.transforms.append(torch.nn.Linear(hidden_channels, hidden_channels))
            
            # Additional transformation layers if requested
            for _ in range(num_layers - 2):
                self.transforms.append(torch.nn.Linear(hidden_channels, hidden_channels))
        
        # Final transformation to output dimension
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        """
        Forward pass that returns node embeddings for each node type.
        """
        # Get node embeddings from HANConv
        x = self.han_conv(x_dict, edge_index_dict)
        
        # Apply additional transformation layers if specified
        if self.num_layers > 1:
            for i, transform in enumerate(self.transforms):
                # Apply transformation to each node type's embeddings
                for node_type in x.keys():
                    if x[node_type] is not None:
                        x[node_type] = transform(x[node_type])
                        x[node_type] = F.relu(x[node_type])
                        x[node_type] = F.dropout(x[node_type], p=0.5, training=self.training)
        
        # Apply final linear transformation to each node type's embeddings
        for node_type in x.keys():
            if x[node_type] is not None:
                x[node_type] = self.lin(x[node_type])
        
        return x




#######################
# RGCN (Relational Graph Convolutional Network) Models
#######################

class RGCN(torch.nn.Module):
    """
    Base Relational Graph Convolutional Network (RGCN) model that outputs node embeddings.
    """
    def __init__(self, 
                 in_channels: int, 
                 hidden_channels: int, 
                 out_channels: int,
                 num_relations: int, 
                 num_bases: Optional[int] = None,
                 num_layers: int = 2,
                 dropout: float = 0.5):
        super().__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout
        self.out_dim = out_channels
        
        # Create RGCN layers
        self.convs = ModuleList()
        
        # First layer
        self.convs.append(
            RGCNConv(
                in_channels=in_channels,
                out_channels=hidden_channels,
                num_relations=num_relations,
                num_bases=num_bases
            )
        )
        
        # Middle layers (if any)
        for _ in range(num_layers - 2):
            self.convs.append(
                RGCNConv(
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    num_relations=num_relations,
                    num_bases=num_bases
                )
            )
        
        # Last layer
        if num_layers > 1:
            self.convs.append(
                RGCNConv(
                    in_channels=hidden_channels,
                    out_channels=out_channels,
                    num_relations=num_relations,
                    num_bases=num_bases
                )
            )
    
    def forward(self, x, edge_index, edge_type):
        """
        Forward pass that returns node embeddings.
        """
        # Apply RGCN layers
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index, edge_type)
            if i < len(self.convs) - 1:  # Apply activation to all but the last layer
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)
                
        return x


class RGCNForGraphClassification(torch.nn.Module):
    """
    Graph classification model based on Relational Graph Convolutional Networks.
    """
    def __init__(self, 
                 in_channels: int, 
                 hidden_channels: int, 
                 num_classes: int,
                 num_relations: int, 
                 num_bases: Optional[int] = None,
                 num_layers: int = 2,
                 dropout: float = 0.5,
                 pooling: str = 'mean'):
        super().__init__()
        
        # Create the base RGCN model for node embeddings
        self.rgcn = RGCN(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            out_channels=hidden_channels,  # Use same dimension for simplicity
            num_relations=num_relations,
            num_bases=num_bases,
            num_layers=num_layers,
            dropout=dropout
        )
        
        # Set up the pooling function
        if pooling == 'add':
            self.pool = global_add_pool
        elif pooling == 'mean':
            self.pool = global_mean_pool
        elif pooling == 'max':
            self.pool = global_max_pool
        else:
            raise ValueError(f"Pooling type {pooling} not supported.")
        
        # Classification layer
        self.classifier = Linear(hidden_channels, num_classes)
        
        # Dropout
        self.dropout = dropout
        
        # Store the embedding dimension for ensemble methods
        self.output_dim = hidden_channels
    
    def forward(self, x, edge_index, edge_type, batch):
        """
        Forward pass for graph classification.
        """
        # Get graph embeddings
        embeddings = self.get_embedding(x, edge_index, edge_type, batch)
        
        # Apply final classifier
        x = F.dropout(embeddings, p=self.dropout, training=self.training)
        x = self.classifier(x)
        
        return x
    
    def get_embedding(self, x, edge_index, edge_type, batch):
        """
        Get graph-level embeddings for use in classification or ensemble methods.
        """
        # Get node embeddings from the base RGCN model
        x = self.rgcn(x, edge_index, edge_type)
        
        # Global pooling (from node-level to graph-level representation)
        x = self.pool(x, batch)
        
        return x


#######################
# Ensemble Models
#######################

class EnsembleGraphClassifier(torch.nn.Module):
    """
    Ensemble model that combines multiple graph neural networks for graph classification.
    """
    def __init__(self, 
                 models: List[torch.nn.Module],
                 ensemble_method: str = 'voting',
                 num_classes: int = 2,
                 hidden_dim: int = 64,
                 dropout: float = 0.5):
        super().__init__()
        self.models = torch.nn.ModuleList(models)
        self.ensemble_method = ensemble_method
        self.num_classes = num_classes
        
        if ensemble_method == 'concat':
            # Calculate total embedding dimension from all models
            total_dim = sum(model.output_dim for model in models)
            self.classifier = torch.nn.Linear(total_dim, num_classes)
            
        elif ensemble_method == 'transform':
            # Fixed-size hidden layer regardless of number of models
            total_dim = sum(model.output_dim for model in models)
            self.transform = torch.nn.Linear(total_dim, hidden_dim)
            self.classifier = torch.nn.Linear(hidden_dim, num_classes)
            
        self.dropout = dropout
    
    def forward(self, data):
        """
        Forward pass for ensemble graph classification.
        """
        if self.ensemble_method == 'voting':
            # Get logits from each model
            all_logits = []
            for model in self.models:
                if hasattr(model, 'forward_data'):
                    # Use specialized data handling if available
                    logits = model.forward_data(data)
                else:
                    # Extract appropriate inputs based on model type
                    if isinstance(model, GATForGraphClassification):
                        logits = model(data.x, data.edge_index, data.batch, getattr(data, 'edge_attr', None))
                    elif isinstance(model, HANForGraphClassification):
                        if not isinstance(data, HeteroData):
                            data_list = Batch.to_data_list(data)
                            data_list = convert_to_heterogeneous(data_list)
                            heter_data = Batch.from_data_list(data_list)
                        else:
                            heter_data = data
                        heter_data.to(data.x.device)
                        logits = model(heter_data.x_dict, heter_data.edge_index_dict)
                    elif isinstance(model, RGCNForGraphClassification):
                        # Generate edge_type if not present
                        if not hasattr(data, 'edge_type'):
                            edge_type = get_edge_type(data.edge_index, 
                                    source_indices=data.ptr[:-1].tolist() if hasattr(data, 'ptr') else [0])
                        else:
                            edge_type = data.edge_type
                        logits = model(data.x, data.edge_index, edge_type, data.batch)
                    else:
                        raise TypeError(f"Unsupported model type: {type(model)}")
                
                all_logits.append(logits)
            
            # For training, we need to return logits, not predictions
            # Average the logits from all models
            if self.training:
                all_logits_stacked = torch.stack(all_logits, dim=0)
                return torch.mean(all_logits_stacked, dim=0)
            else:
                # For evaluation/inference, we can do voting on the predicted classes
                all_preds = [torch.argmax(logit, dim=1) for logit in all_logits]
                all_preds = torch.stack(all_preds, dim=0)
                # Get the most common prediction (mode) for each sample
                final_preds_values, _ = torch.mode(all_preds, dim=0)
                
                # Convert predictions back to one-hot format for consistency
                batch_size = final_preds_values.size(0)
                final_logits = torch.zeros(batch_size, self.num_classes, device=data.x.device)
                for i in range(batch_size):
                    final_logits[i, final_preds_values[i]] = 1.0
                
                return final_logits
            
        elif self.ensemble_method == 'average':
            # Average logits from all models
            all_logits = []
            for model in self.models:
                if hasattr(model, 'forward_data'):
                    # Use specialized data handling if available
                    logits = model.forward_data(data)
                else:
                    # Extract appropriate inputs based on model type
                    if isinstance(model, GATForGraphClassification):
                        logits = model(data.x, data.edge_index, data.batch, getattr(data, 'edge_attr', None))
                    elif isinstance(model, HANForGraphClassification):
                        if not isinstance(data, HeteroData):
                            data_list = Batch.to_data_list(data)
                            data_list = convert_to_heterogeneous(data_list)
                            heter_data = Batch.from_data_list(data_list)
                        else:
                            heter_data = data
                        heter_data.to(data.x.device)
                        logits = model(heter_data.x_dict, heter_data.edge_index_dict)
                    elif isinstance(model, RGCNForGraphClassification):
                        # Generate edge_type if not present
                        if not hasattr(data, 'edge_type'):
                            edge_type = get_edge_type(data.edge_index, 
                                    source_indices=data.ptr[:-1].tolist() if hasattr(data, 'ptr') else [0])
                        else:
                            edge_type = data.edge_type
                        logits = model(data.x, data.edge_index, edge_type, data.batch)
                    else:
                        raise TypeError(f"Unsupported model type: {type(model)}")
                
                all_logits.append(logits)
            
            # Stack and average logits
            all_logits = torch.stack(all_logits, dim=0)
            avg_logits = torch.mean(all_logits, dim=0)
            return avg_logits
            
        elif self.ensemble_method == 'concat' or self.ensemble_method == 'transform':
            # Get embeddings from each model
            all_embeddings = []
            for model in self.models:
                if hasattr(model, 'get_embedding_data'):
                    # Use specialized data handling if available
                    embed = model.get_embedding_data(data)
                else:
                    # Extract appropriate inputs based on model type
                    if isinstance(model, GATForGraphClassification):
                        embed = model.get_embedding(data.x, data.edge_index, data.batch, getattr(data, 'edge_attr', None))
                    elif isinstance(model, HANForGraphClassification):
                        if not isinstance(data, HeteroData):
                            data_list = Batch.to_data_list(data)
                            data_list = convert_to_heterogeneous(data_list)
                            heter_data = Batch.from_data_list(data_list)
                        else:
                            heter_data = data
                        heter_data.to(data.x.device)
                        logits = model(heter_data.x_dict, heter_data.edge_index_dict)
                    elif isinstance(model, RGCNForGraphClassification):
                        # Generate edge_type if not present
                        if not hasattr(data, 'edge_type'):
                            edge_type = get_edge_type(data.edge_index, 
                                    source_indices=data.ptr[:-1].tolist() if hasattr(data, 'ptr') else [0])
                        else:
                            edge_type = data.edge_type
                        embed = model.get_embedding(data.x, data.edge_index, edge_type, data.batch)
                    else:
                        raise TypeError(f"Unsupported model type: {type(model)}")
                
                all_embeddings.append(embed)
            
            # Concatenate embeddings
            combined = torch.cat(all_embeddings, dim=1)
            
            if self.ensemble_method == 'transform':
                combined = F.relu(self.transform(combined))
                combined = F.dropout(combined, p=self.dropout, training=self.training)
            
            # Apply classifier
            combined = F.dropout(combined, p=self.dropout, training=self.training)
            logits = self.classifier(combined)
            return logits

In [31]:
class HANForGraphClassification(torch.nn.Module):
    """
    Graph classification model based on Heterogeneous Graph Attention Networks.
    """
    def __init__(self, 
                 in_channels: Union[int, Dict[str, int]],
                 hidden_channels: int,
                 out_channels: int, 
                 num_classes: int = 2,
                 heads: int = 8, 
                 metadata: Optional[Tuple] = None, 
                 dropout: float = 0.6,
                 num_layers: int = 1):
        super().__init__()
        
        # Create the base HAN model for node embeddings
        # The issue is here - you initialize 'self.han', but try to use 'self.han_conv' in forward
        # Either rename this to self.han_conv or fix the forward method
        self.han_conv = HANConv(
            in_channels=in_channels,
            out_channels=hidden_channels,
            heads=heads,
            dropout=dropout,
            metadata=metadata
        )
        
        # Linear layer for dimensionality reduction after pooling
        # Will be initialized during forward pass once we know the input dimension
        self.lin = None
        self.out_channels = out_channels
        
        # Classification layer
        self.classifier = torch.nn.Linear(out_channels, num_classes)
        self.dropout = dropout
        
        # Store the embedding dimension for ensemble methods
        self.output_dim = out_channels
    
    def forward(self, x_dict, edge_index_dict, batch=None):
        """
        Forward pass for heterogeneous graph classification.
        """
        # Get node embeddings from HANConv (this line needs to match your initialization)
        x = self.get_embedding(x_dict, edge_index_dict, batch)
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.classifier(x)
        
        return x
    
    def get_embedding(self, x_dict, edge_index_dict, batch=None):
        """
        Get graph-level embeddings for use in classification or ensemble methods.
        """
        # Get node embeddings from the base HAN model
        node_embeddings_dict = self.han_conv(x_dict, edge_index_dict)
        
        # Average pooling for each node type
        pooled_embeddings = {}
        if batch is None:
            # If no batch is provided, assume a single graph
            batch = {node_type: torch.zeros(embeddings.size(0), dtype=torch.long, device=embeddings.device)
                     for node_type, embeddings in node_embeddings_dict.items() if embeddings is not None}
        for node_type, embeddings in node_embeddings_dict.items():
            if embeddings is not None:
                # Average pooling for nodes of the same type
                pooled = global_mean_pool(embeddings, batch[node_type])
                pooled_embeddings.setdefault(node_type, []).append(pooled)
        
        if not pooled_embeddings:
            raise ValueError("No node embeddings were produced by the model")
        
        embeddings_by_batch = []
        for source, user in zip(pooled_embeddings['source'], pooled_embeddings['user']):
            embeddings_by_batch += (source + user) / 2
        # Concatenate all pooled embeddings from different node types
        for i,  embed in enumerate(embeddings_by_batch):
            print(f"embed at {i}: {embed.shape}")
        x = torch.stack(embeddings_by_batch, dim=0)
        print(f"Concatenated pooled embeddings shape: {x.shape}")
        # Initialize the linear layer if not done yet
        if self.lin is None:
            lin_input_dim = x.size(1)
            self.lin = torch.nn.Linear(lin_input_dim, self.out_channels).to(x.device)
        
        # Apply linear layer
        x = self.lin(x)
        x = F.relu(x)
        
        return x


In [41]:
from torch_geometric.datasets import UPFD
# different feature types can be selected: content(profile + spacy; dim: 310), profile(dim: 10), spacy(dim: 300)
# splits: train, test, val
# name: politifact, gossipcop
dataset = UPFD('data/upfd', name="politifact", feature='bert', split="train")

from torch_geometric.loader import DataLoader
# Create a DataLoader for the dataset
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [5]:
batch = next(iter(train_loader))
# print(batch)


In [6]:
data_list = Batch.to_data_list(batch)
hetero_dataset = [
    convert_single_graph(graph) 
    for graph in data_list
]

for data in hetero_dataset:
    print(data)

HeteroData(
  source={
    x=[1, 768],
    y=[1],
  },
  user={ x=[84, 768] },
  (source, to, user)={ edge_index=[2, 51] },
  (user, to, user)={ edge_index=[2, 33] },
  (source, to, source)={ edge_index=[2, 1] }
)
HeteroData(
  source={
    x=[1, 768],
    y=[1],
  },
  user={ x=[24, 768] },
  (source, to, user)={ edge_index=[2, 23] },
  (user, to, user)={ edge_index=[2, 1] },
  (source, to, source)={ edge_index=[2, 1] }
)


In [42]:
def to_hetero_batch(batch, add_source_self_loop=True):
    data_list = batch.to_data_list()
    data_list = convert_to_heterogeneous(data_list, 0, add_source_self_loop)
    batch = Batch.from_data_list(data_list)
    batch.batch = {
        'source': batch['source'].batch,
        'user': batch['user'].batch
    }
    
    return batch
hetero_batch = to_hetero_batch(batch)
hetero_batch.batch

{'source': tensor([0, 1]),
 'user': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}

In [25]:
print(hetero_batch['user'].batch.shape)

torch.Size([490])


In [43]:
han_cls = HANForGraphClassification(
  in_channels=dataset.num_features,
  hidden_channels=64,
  out_channels=64,
  num_classes=2,
  heads=8,
  dropout=0.5,
  metadata=(['source', 'user'], [('source', 'to', 'user'), ('user', 'to', 'user'), ('source', 'to', 'source')]))


# print(hetero_batch)
# print(f"user batch shape: {hetero_batch['user'].batch.shape}")
# print(f"source batch shape: {hetero_batch['source'].batch.shape}")
out = han_cls(hetero_batch.x_dict, hetero_batch.edge_index_dict, batch=hetero_batch.batch)
print(out.shape)
# embed = han_cls.get_embedding(hetero_batch.x_dict, hetero_batch.edge_index_dict, batch={'user': hetero_batch['user'].batch, 'source': hetero_batch['source'].batch})

embed at 0: torch.Size([64])
embed at 1: torch.Size([64])
Concatenated pooled embeddings shape: torch.Size([2, 64])
torch.Size([2, 2])
