In [1]:
import torch
import torch.nn as nn
import math

class ConvEmbed(nn.Module):
    """
    A convolutional embedding layer that takes raw EEG (channels x time)
    and produces an embedding for each time point.
    """
    def __init__(self, in_channels=21, embed_dim=64, kernel_size=3, stride=1, padding=1):
        super(ConvEmbed, self).__init__()
        # Simple 1D conv to mix channels and produce embeddings
        self.conv = nn.Conv1d(in_channels, embed_dim, kernel_size=kernel_size, 
                              stride=stride, padding=padding)
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        # x: (batch, channels, time)
        x = self.conv(x)  # (batch, embed_dim, time)
        # Transpose to (batch, time, embed_dim)
        x = x.permute(0, 2, 1)
        # Layer norm
        x = self.norm(x)
        return x

class PositionalEncoding(nn.Module):
    """
    Standard sinusoidal positional encoding as described in "Attention Is All You Need".
    """
    def __init__(self, d_model, max_len=1000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # shape: (1, max_len, d_model)

    def forward(self, x):
        # x: (batch, time, d_model)
        time_steps = x.size(1)
        # Add positional embedding
        x = x + self.pe[:, :time_steps, :].to(x.device)
        return x

class AttentionPooling(nn.Module):
    """
    Attention pooling layer: computes weights over time dimension to get a single vector.
    """
    def __init__(self, d_model):
        super(AttentionPooling, self).__init__()
        self.query = nn.Parameter(torch.randn(d_model))
        
    def forward(self, x):
        # x: (batch, time, d_model)
        # Compute attention weights based on similarity to a learnable query vector
        # Expand query to (batch, 1, d_model)
        q = self.query.unsqueeze(0).unsqueeze(1)  # (1,1,d_model)
        # Scores: (batch, time)
        scores = torch.matmul(x, q.transpose(-1, -2)).squeeze(-1)
        attn = torch.softmax(scores, dim=1).unsqueeze(1)  # (batch,1,time)
        # Weighted sum:
        pooled = torch.matmul(attn, x).squeeze(1)  # (batch,d_model)
        return pooled

class EEGTransformerModel(nn.Module):
    def __init__(self, 
                 in_channels=21, 
                 seq_len=400,      # length of time sequence
                 embed_dim=64, 
                 num_heads=4, 
                 num_layers=2, 
                 num_classes=3,
                 dropout=0.1):
        super(EEGTransformerModel, self).__init__()
        
        # Input projection
        self.embedding = ConvEmbed(in_channels, embed_dim)
        
        # Positional encoding
        self.pos_enc = PositionalEncoding(embed_dim, max_len=seq_len)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, 
                                                   nhead=num_heads, 
                                                   dim_feedforward=embed_dim*4, 
                                                   dropout=dropout, 
                                                   batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Attention pooling instead of taking only the last token
        self.attn_pool = AttentionPooling(embed_dim)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, num_classes)
        )
        
    def forward(self, x):
        # x: (batch, channels, time)
        x = self.embedding(x)               # (batch, time, embed_dim)
        x = self.pos_enc(x)                 # (batch, time, embed_dim)
        x = self.transformer_encoder(x)     # (batch, time, embed_dim)
        x = self.attn_pool(x)               # (batch, embed_dim)
        x = self.classifier(x)              # (batch, num_classes)
        return x

# Example usage:
if __name__ == "__main__":
    # Suppose we have a batch of EEG data: batch=16, channels=21, time=400
    batch_size = 16
    channels = 21
    seq_len = 400  # e.g., 2 seconds at 200 Hz
    num_classes = 3
    
    model = EEGTransformerModel(in_channels=channels, seq_len=seq_len, num_classes=num_classes)
    dummy_input = torch.randn(batch_size, channels, seq_len)  # (B, C, T)
    output = model(dummy_input)
    print("Output shape:", output.shape)  # Expect (16, 3)


Output shape: torch.Size([16, 3])
