In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class AttentionBlock(nn.Module):
    def __init__(self, input_dim):
        super(AttentionBlock, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, input_dim // 4),
            nn.ReLU(),
            nn.Linear(input_dim // 4, input_dim),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        attention_weights = self.attention(x)
        return x * attention_weights

class ResidualBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(ResidualBlock, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, input_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(input_dim)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.fc1(x)))
        out = self.dropout(out)
        out = self.bn2(self.fc2(out))
        out += identity
        return F.relu(out)

class EEGClassifier(nn.Module):
    def __init__(self, input_size, num_classes=4):
        super(EEGClassifier, self).__init__()
        
        # Initial dimension reduction
        self.fc_input = nn.Linear(input_size, 512)
        self.bn_input = nn.BatchNorm1d(512)
        
        # Attention mechanism
        self.attention1 = AttentionBlock(512)
        
        # Residual blocks
        self.res_block1 = ResidualBlock(512, 256)
        self.res_block2 = ResidualBlock(512, 256)
        
        # Second attention layer
        self.attention2 = AttentionBlock(512)
        
        # Final classification layers
        self.fc1 = nn.Linear(512, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, num_classes)
        
        self.dropout = nn.Dropout(0.4)
        
    def forward(self, x):
        # Initial processing
        x = F.relu(self.bn_input(self.fc_input(x)))
        
        # First attention mechanism
        x = self.attention1(x)
        
        # Residual blocks
        x = self.res_block1(x)
        x = self.res_block2(x)
        
        # Second attention mechanism
        x = self.attention2(x)
        
        # Final classification
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x
