In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# For the GCN layers
# pip install torch-geometric
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse


class EEGEpilepsyNet(nn.Module):
    """
    A hybrid 1D CNN + GCN model to extract epileptic features from EEG data.

    Input shape:
      x: (batch_size, num_channels=21, num_samples=6300)
      edge_index: (2, E) - graph connectivity for the 21 channels

    Steps:
      1) For each channel (21 total):
         - pass its 1D time series (6300 samples) through CNN layers
         - obtain a channel-wise feature vector
      2) Stack these feature vectors to get a node feature matrix of shape (21, feature_dim)
      3) Pass that to a GCN (using the provided edge_index) to learn spatial/functional relationships
      4) Optionally pool or reduce over nodes to get a graph-level embedding
      5) Classify via fully connected layer

    Arguments:
      - adjacency (torch.Tensor): 21x21 adjacency matrix or None
        (Used here only if you prefer to construct edge_index inside the model)
      - in_channels (int): Number of input channels to CNN (for EEG time series it's typically 1)
      - cnn_features (int): Output feature dimension from the CNN for each channel
      - hidden_dim (int): Hidden dimension for GCN
      - num_classes (int): Number of output classes
      - kernel_size (int): Kernel size for CNN
    """

    def __init__(
        self,
        adjacency: torch.Tensor = None,
        in_channels: int = 1,
        cnn_features: int = 16,
        hidden_dim: int = 32,
        num_classes: int = 2,
        kernel_size: int = 3
    ):
        super(EEGEpilepsyNet, self).__init__()

        self.num_eeg_channels = 21

        # If an adjacency (21 x 21) is provided, convert to edge_index once here
        if adjacency is not None:
            edge_index, _ = dense_to_sparse(adjacency)
        else:
            # If no adjacency is provided, you must pass edge_index each forward pass
            edge_index = None
        self.register_buffer("edge_index", edge_index)

        # Example 1D CNN layers for each channel
        # You can stack more layers or use larger filters if needed.
        self.conv1 = nn.Conv1d(in_channels, 8, kernel_size=kernel_size, padding=1)
        self.conv2 = nn.Conv1d(8, cnn_features, kernel_size=kernel_size, padding=1)
        self.pool = nn.AdaptiveAvgPool1d(1)

        # Two-layer GCN
        self.gcn1 = GCNConv(cnn_features, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)

        # Final classification layer
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor = None):
        """
        Forward pass.

        Args:
            x: EEG input of shape (batch_size, 21, 6300)
            edge_index: (2, E) - optional if not stored in self.edge_index
        Returns:
            logits of shape (batch_size, num_classes)
        """
        batch_size = x.size(0)

        # If the model has a built-in edge_index, we can fall back on it if not provided
        if edge_index is None:
            edge_index = self.edge_index
            if edge_index is None:
                raise ValueError(
                    "No edge_index found. Please provide adjacency or edge_index."
                )

        # 1) CNN feature extraction per channel
        channel_features = []
        for ch in range(self.num_eeg_channels):
            # Extract single channel: shape (batch_size, 6300)
            single_channel = x[:, ch, :].unsqueeze(1)  # -> (batch_size, 1, 6300)
            out = F.relu(self.conv1(single_channel))   # -> (batch_size, 8, 6300)
            out = F.relu(self.conv2(out))              # -> (batch_size, cnn_features, 6300)
            out = self.pool(out)                       # -> (batch_size, cnn_features, 1)
            out = out.squeeze(-1)                      # -> (batch_size, cnn_features)
            channel_features.append(out)

        # Stack channel-wise features into shape: (batch_size, 21, cnn_features)
        channel_features = torch.stack(channel_features, dim=1)

        # 2) GCN for graph-based feature extraction
        # PyG expects node features as (num_nodes, num_features).
        # We have them in (batch_size, 21, cnn_features).
        # We'll process each sample in the batch individually, then combine results.

        graph_embeddings = []
        for b in range(batch_size):
            node_feats = channel_features[b]  # -> (21, cnn_features)

            # Pass node_feats through GCN
            g = self.gcn1(node_feats, edge_index)  # -> (21, hidden_dim)
            g = F.relu(g)
            g = self.gcn2(g, edge_index)           # -> (21, hidden_dim)
            g = F.relu(g)

            # 3) Pool the node embeddings to get a graph-level embedding
            # A simple approach: average over the 21 nodes
            g = g.mean(dim=0)  # -> (hidden_dim,)

            graph_embeddings.append(g)

        # Stack each graph embedding in batch: (batch_size, hidden_dim)
        graph_embeddings = torch.stack(graph_embeddings, dim=0)

        # 4) Classification
        logits = self.fc(graph_embeddings)  # -> (batch_size, num_classes)
        return logits


if __name__ == "__main__":
    # -------------------------------------------------------------------------
    # Example Usage
    # -------------------------------------------------------------------------

    # Suppose we have a random adjacency for 21 channels (21 x 21).
    # In practice, you'd build this from known sensor layout or correlation metrics, etc.
    adjacency = torch.randint(0, 2, (21, 21)).float()
    # Make adjacency symmetric
    adjacency = (adjacency + adjacency.t()) / 2
    adjacency[adjacency > 0] = 1

    # Create the network
    model = EEGEpilepsyNet(adjacency=adjacency, num_classes=2)

    # Simulate a batch of EEG data
    # batch_size = 4, 21 channels, 6300 samples per channel
    x = torch.randn(4, 21, 6300)

    # Forward pass
    logits = model(x)  # shape: (4, 2)
    print("Output logits shape:", logits.shape)
