In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# For the GCN layers
# pip install torch-geometric
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse, to_undirected


# -------------------------------------------------------------------------
# 1) TCN Building Blocks
# -------------------------------------------------------------------------
class TemporalBlock(nn.Module):
    """
    A single TCN residual block with two dilated convolutions and a residual connection.
    """
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        dilation,
        padding,
        dropout=0.2
    ):
        super(TemporalBlock, self).__init__()
        # First conv
        self.conv1 = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
        )
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        # Second conv
        self.conv2 = nn.Conv1d(
            out_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
        )
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        # Residual connection (1x1 if in/out channels differ)
        self.downsample = (
            nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        """
        x: (batch_size, in_channels, seq_len)
        """
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.dropout1(out)

        out = self.conv2(out)
        out = self.relu2(out)
        out = self.dropout2(out)

        # Residual
        res = x if self.downsample is None else self.downsample(x)
        out = out + res
        out = self.relu(out)
        return out


class TemporalConvNet(nn.Module):
    """
    A TCN that stacks multiple TemporalBlock layers with exponentially increasing dilations.
    """
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        """
        num_inputs:  number of input channels to the TCN (e.g., 1 if each EEG channel is fed individually)
        num_channels: list with output channels for each TCN layer, e.g. [8, 16]
        kernel_size:  kernel size for all TCN layers
        dropout:      dropout probability
        """
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            # For TCN, padding = (kernel_size - 1) * dilation
            # so that the receptive field grows exponentially
            padding = (kernel_size - 1) * dilation_size

            block = TemporalBlock(
                in_channels,
                out_channels,
                kernel_size,
                stride=1,
                dilation=dilation_size,
                padding=padding,
                dropout=dropout,
            )
            layers.append(block)

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        """
        x: (batch_size, num_inputs, seq_len)
        returns: (batch_size, out_channels, seq_len)  # out_channels is num_channels[-1]
        """
        return self.network(x)


# -------------------------------------------------------------------------
# 2) Hybrid TCN + GCN for EEG
# -------------------------------------------------------------------------
class EEGEpilepsyNet(nn.Module):
    """
    A hybrid TCN + GCN model to extract epileptic features from EEG data.

    Input shape:
      x: (batch_size, 21, 6300)
      edge_index: (2, E) graph connectivity for the 21 channels

    Steps:
      1) For each of the 21 channels:
         - Pass its time-series (6300 samples) through a TCN
         - Obtain a channel-wise feature vector
      2) Stack these feature vectors => (batch_size, 21, feature_dim)
      3) Pass that to a GCN (via edge_index) to learn spatial/functional relationships
      4) Pool or reduce over nodes -> final graph-level embedding
      5) Classify => fully-connected layer
    """

    def __init__(
        self,
        adjacency: torch.Tensor = None,
        in_channels: int = 1,
        tcn_channels: list = [8, 16],
        hidden_dim: int = 32,
        num_classes: int = 2,
        kernel_size: int = 3,
        dropout: float = 0.2
    ):
        """
        Args:
            adjacency: (21 x 21) adjacency matrix for channels, can be None if you're passing edge_index directly.
            in_channels: input channels to TCN (1 if single EEG channel).
            tcn_channels: list, e.g. [8, 16], specifying hidden layer sizes in TCN.
            hidden_dim: hidden dimension in GCN layers.
            num_classes: number of output classes for final classification.
            kernel_size: kernel size for TCN.
            dropout: dropout rate in TCN blocks.
        """
        super(EEGEpilepsyNet, self).__init__()
        self.num_eeg_channels = 21

        # Convert adjacency to edge_index if provided
        if adjacency is not None:
            # Force adjacency to be symmetric
            adjacency = (adjacency + adjacency.t()) / 2
            adjacency[adjacency > 0] = 1
            edge_index, _ = dense_to_sparse(adjacency)
            # Make sure edges are undirected in PyG
            edge_index = to_undirected(edge_index)
        else:
            edge_index = None
        self.register_buffer("edge_index", edge_index)

        # Define a TCN to process each channel's time-series: shape (batch_size, 1, 6300)
        self.tcn = TemporalConvNet(
            num_inputs=in_channels,
            num_channels=tcn_channels,
            kernel_size=kernel_size,
            dropout=dropout
        )
        # We'll do an adaptive pool to get a fixed-size feature vector
        self.pool = nn.AdaptiveAvgPool1d(1)

        # The TCN output channels is the last element in tcn_channels
        self.cnn_features = tcn_channels[-1]

        # GCN layers
        self.gcn1 = GCNConv(self.cnn_features, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)

        # Final classification
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor = None):
        """
        x: (batch_size, 21, 6300)  -> 21 channels, each with 6300 time points
        edge_index: (2, E) or None. If None, we use self.edge_index
        Returns: (batch_size, num_classes)
        """
        batch_size = x.size(0)

        if edge_index is None:
            edge_index = self.edge_index
            if edge_index is None:
                raise ValueError(
                    "No edge_index found. Please provide adjacency or edge_index."
                )

        # 1) TCN feature extraction per channel
        channel_features = []
        for ch in range(self.num_eeg_channels):
            # x[:, ch, :] -> (batch_size, 6300)
            single_channel = x[:, ch, :].unsqueeze(1)  # (batch_size, 1, 6300)

            # Pass through TCN
            out = self.tcn(single_channel)  # (batch_size, tcn_channels[-1], seq_len)
            # Pool over time dimension
            out = self.pool(out)            # (batch_size, tcn_channels[-1], 1)
            out = out.squeeze(-1)          # (batch_size, tcn_channels[-1])

            channel_features.append(out)

        # Stack channel-wise features => (batch_size, 21, cnn_features)
        channel_features = torch.stack(channel_features, dim=1)

        # 2) GCN processing (one graph per sample in batch)
        graph_embeddings = []
        for b in range(batch_size):
            node_feats = channel_features[b]  # (21, cnn_features)

            # GCN layers
            g = self.gcn1(node_feats, edge_index)  # (21, hidden_dim)
            g = F.relu(g)
            g = self.gcn2(g, edge_index)           # (21, hidden_dim)
            g = F.relu(g)

            # 3) Pool over nodes => single embedding per EEG recording
            g = g.mean(dim=0)  # (hidden_dim,)

            graph_embeddings.append(g)

        # (batch_size, hidden_dim)
        graph_embeddings = torch.stack(graph_embeddings, dim=0)

        # 4) Classification
        logits = self.fc(graph_embeddings)  # (batch_size, num_classes)
        return logits


# -------------------------------------------------------------------------
# 3) Example Usage
# -------------------------------------------------------------------------
if __name__ == "__main__":
    # Suppose we have a random adjacency for 21 channels (21 x 21).
    # In practice, build from sensor layout or correlation metrics, etc.
    adjacency = torch.randint(0, 2, (21, 21)).float()

    # Create the TCN+GCN model
    model = EEGEpilepsyNet(adjacency=adjacency, num_classes=2)

    # Simulate a batch of EEG data: batch_size=4, 21 channels, 6300 samples
    x = torch.randn(4, 21, 6300)

    # Forward pass
    logits = model(x)  # (4, 2)
    print("Output logits shape:", logits.shape)


RuntimeError: The size of tensor a (6304) must match the size of tensor b (6300) at non-singleton dimension 2