In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# PyG imports for graph layers
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse, to_undirected

###############################################################################
# 1) TCN Building Blocks
###############################################################################
class TemporalBlock(nn.Module):
    """
    A single TCN block with two dilated convs + residual connection.
    We slice 'out' to ensure it matches 'x' along the time dimension.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, padding, dropout=0.2):
        super().__init__()
        self.conv1 = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=padding, dilation=dilation
        )
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(
            out_channels, out_channels, kernel_size,
            stride=stride, padding=padding, dilation=dilation
        )
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        # 1x1 convolution if in/out channels differ
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
        self.relu = nn.ReLU()

    def forward(self, x):
        """
        x shape: (batch_size, in_channels, seq_len)
        """
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.dropout1(out)

        out = self.conv2(out)
        out = self.relu2(out)
        out = self.dropout2(out)

        # --- Trim any extra frames to match x.shape[2] ---
        if out.size(2) > x.size(2):
            out = out[:, :, : x.shape[2]]

        # Residual
        res = x if self.downsample is None else self.downsample(x)

        # If the downsample changed channel dim, it won't change time dim,
        # but we still guard in case there's a mismatch in time as well.
        if res.size(2) < out.size(2):
            out = out[:, :, : res.size(2)]
        elif res.size(2) > out.size(2):
            res = res[:, :, : out.size(2)]

        out = out + res
        return self.relu(out)


class TemporalConvNet(nn.Module):
    """
    A TCN that stacks multiple TemporalBlock layers with exponentially increasing dilations.
    By default, uses kernel_size=2 for causal convolution.
    """
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        """
        Args:
          num_inputs:  input channels to TCN (e.g., 1 if single EEG channel)
          num_channels: list, e.g. [8, 16], specifying channels in each TCN block
          kernel_size:  usually 2 for TCN; bigger kernel sizes require trimming
          dropout:      dropout probability
        """
        super().__init__()
        layers = []
        for i in range(len(num_channels)):
            dilation_size = 2 ** i
            in_ch = num_inputs if i == 0 else num_channels[i - 1]
            out_ch = num_channels[i]
            # Common TCN practice: padding = (kernel_size - 1) * dilation
            padding = (kernel_size - 1) * dilation_size

            block = TemporalBlock(
                in_channels=in_ch,
                out_channels=out_ch,
                kernel_size=kernel_size,
                stride=1,
                dilation=dilation_size,
                padding=padding,
                dropout=dropout
            )
            layers.append(block)

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        """
        x: (batch_size, in_channels, seq_len)
        returns: (batch_size, num_channels[-1], seq_len) [possibly trimmed by a few frames]
        """
        return self.network(x)


###############################################################################
# 2) Hybrid TCN + GCN for EEG
###############################################################################
class EEGEpilepsyNet(nn.Module):
    """
    A hybrid TCN + GCN model to extract epileptic features from EEG data.
    Input shape:
      x: (batch_size, 21, seq_len)  e.g. seq_len=6300
    """
    def __init__(
        self,
        adjacency=None,         # (21 x 21) adjacency matrix for EEG channels
        in_channels=1,          # TCN input channels (1 for single-channel conv)
        tcn_channels=[8, 16],   # TCN hidden channels
        hidden_dim=32,          # GCN hidden dim
        num_classes=2,          # classification classes
        kernel_size=2,
        dropout=0.2
    ):
        super().__init__()
        self.num_eeg_channels = 21

        # Convert adjacency -> edge_index for PyG if provided
        if adjacency is not None:
            adjacency = (adjacency + adjacency.t()) / 2
            adjacency[adjacency > 0] = 1
            edge_index, _ = dense_to_sparse(adjacency)
            edge_index = to_undirected(edge_index)
        else:
            edge_index = None
        self.register_buffer("edge_index", edge_index)

        # TCN: processes each channel individually (shape: (bsz, 1, seq_len))
        self.tcn = TemporalConvNet(
            num_inputs=in_channels,
            num_channels=tcn_channels,
            kernel_size=kernel_size,
            dropout=dropout
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.cnn_features = tcn_channels[-1]  # last channel dimension from TCN

        # GCN layers
        self.gcn1 = GCNConv(self.cnn_features, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)

        # Final classifier
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index=None):
        """
        x: (batch_size, 21, seq_len)
        edge_index: (2, E) or None
        """
        bsz = x.size(0)

        if edge_index is None:
            edge_index = self.edge_index
            if edge_index is None:
                raise ValueError("No adjacency or edge_index provided.")

        # 1) TCN extraction per channel
        channel_feats = []
        for ch in range(self.num_eeg_channels):
            # single_channel: (bsz, 1, seq_len)
            single_channel = x[:, ch, :].unsqueeze(1)
            out = self.tcn(single_channel)        # => (bsz, tcn_channels[-1], seq_len')  (might be trimmed)
            out = self.pool(out).squeeze(-1)      # => (bsz, tcn_channels[-1])
            channel_feats.append(out)

        # Stack => (bsz, 21, tcn_channels[-1])
        channel_feats = torch.stack(channel_feats, dim=1)

        # 2) Apply GCN to each sample's 21 nodes
        graph_embs = []
        for i in range(bsz):
            node_feats = channel_feats[i]          # (21, tcn_channels[-1])
            g = self.gcn1(node_feats, edge_index)  # (21, hidden_dim)
            g = F.relu(g)
            g = self.gcn2(g, edge_index)           # (21, hidden_dim)
            g = F.relu(g)
            # Mean pooling over nodes => (hidden_dim,)
            graph_embs.append(g.mean(dim=0))

        graph_embs = torch.stack(graph_embs, dim=0)  # => (bsz, hidden_dim)
        return self.fc(graph_embs)


###############################################################################
# 3) Example Usage
###############################################################################
if __name__ == "__main__":
    # Choose device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Random adjacency for 21 channels
    adjacency = torch.randint(0, 2, (21, 21)).float().to(device)

    # Create model, move to GPU
    model = EEGEpilepsyNet(
        adjacency=adjacency,
        num_classes=2,
        kernel_size=2,   # using kernel_size=2 to minimize length inflation
        tcn_channels=[8, 16]
    ).to(device)

    # Dummy EEG input: (batch_size=4, 21 channels, 6300 timesteps)
    x = torch.randn(4, 21, 6300).to(device)

    # Forward pass
    logits = model(x)  # => (4, 2)
    print("Output logits shape:", logits.shape)


Output logits shape: torch.Size([4, 2])
