In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv  # PyTorch Geometric for GCN layers

class CNNGCNModel(nn.Module):
    def __init__(self, num_classes, label_graph_adj, cnn_output_dim=256, gcn_hidden_dim=128):
        """
        Combines CNN and GCN for multi-label classification.
        
        Parameters:
        - num_classes (int): Number of output labels/classes
        - label_graph_adj (torch.Tensor): Adjacency matrix of the label graph
        - cnn_output_dim (int): Dimensionality of CNN output features
        - gcn_hidden_dim (int): Dimensionality of GCN hidden layer
        """
        super(CNNGCNModel, self).__init__()
        
        # CNN Component (Example: simple CNN for feature extraction)
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # Input channels = 3 (RGB images)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, cnn_output_dim),  # Assuming input images are 32x32
            nn.ReLU()
        )
        
        # GCN Component
        self.gcn1 = GCNConv(num_classes, gcn_hidden_dim)
        self.gcn2 = GCNConv(gcn_hidden_dim, cnn_output_dim)
        
        # Classification layer
        self.classifier = nn.Linear(cnn_output_dim, num_classes)
        
        # Label graph adjacency matrix
        self.label_graph_adj = label_graph_adj

    def forward(self, x, label_features):
        """
        Forward pass of the CNN-GCN model.
        
        Parameters:
        - x (torch.Tensor): Input image tensor [batch_size, 3, H, W]
        - label_features (torch.Tensor): Initial label embeddings [num_labels, num_classes]
        
        Returns:
        - logits (torch.Tensor): Predicted label scores [batch_size, num_classes]
        """
        # CNN forward pass
        cnn_features = self.cnn(x)  # [batch_size, cnn_output_dim]
        
        # GCN forward pass
        label_embeddings = self.gcn1(label_features, self.label_graph_adj)
        label_embeddings = F.relu(label_embeddings)
        label_embeddings = self.gcn2(label_embeddings, self.label_graph_adj)
        label_embeddings = F.relu(label_embeddings)  # [num_labels, cnn_output_dim]
        
        # Combine CNN features and GCN label embeddings
        logits = torch.matmul(cnn_features, label_embeddings.T)  # [batch_size, num_labels]
        return logits


# Example Usage
if __name__ == "__main__":
    # Define parameters
    num_classes = 10  # Number of labels/classes
    cnn_output_dim = 256
    gcn_hidden_dim = 128
    batch_size = 32
    img_size = 32  # Example image size (H=W=32)
    
    # Input data
    x = torch.randn(batch_size, 3, img_size, img_size)  # Example RGB images
    label_graph_adj = torch.eye(num_classes)  # Example: Identity adjacency matrix (no edges)
    label_features = torch.eye(num_classes)  # Example: Identity matrix as initial embeddings
    
    # Initialize model
    model = CNNGCNModel(num_classes, label_graph_adj, cnn_output_dim, gcn_hidden_dim)
    
    # Forward pass
    logits = model(x, label_features)
    print("Logits shape:", logits.shape)  # Expected: [batch_size, num_classes]
