In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, 1, kernel_size=1)
        
    def forward(self, x):
        # x shape: (batch, channels, time)
        attn = torch.sigmoid(self.conv(x))  # (batch, 1, time)
        return x * attn

class LocalSpikeDetectionBranch(nn.Module):
    def __init__(self, kernel_size, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=0,  # 'valid' padding to prevent boundary effects
            bias=True
        )
        self.bn = nn.BatchNorm1d(out_channels)
        self.pool = nn.MaxPool1d(kernel_size=4, stride=4)
        
    def forward(self, x):
        x = F.relu(self.bn(self.conv(x)))
        x = self.pool(x)
        return x

class EpilepsyDetectionModel(nn.Module):
    def __init__(
        self,
        num_channels=32,
        sampling_rate=300,  # Hz
        num_classes=2
    ):
        super().__init__()
        
        # Calculate kernel sizes based on sampling rate
        # For 70ms, 150ms, and 200ms spikes
        self.kernel_sizes = [
            int(0.07 * sampling_rate),  # 70ms
            int(0.15 * sampling_rate),  # 150ms
            int(0.20 * sampling_rate)   # 200ms
        ]
        
        # Local Feature Detection Branches
        self.branches = nn.ModuleList([
            LocalSpikeDetectionBranch(
                kernel_size=k,
                in_channels=num_channels,
                out_channels=128
            ) for k in self.kernel_sizes
        ])
        
        # Calculate the size of concatenated features
        # Need to account for the length reduction from valid convolutions
        self.feature_size = self._calculate_feature_size(1000, num_channels)
        
        # Spatial Integration
        self.channel_conv = nn.Conv1d(384, 256, kernel_size=1)
        self.spatial_attention = SpatialAttention(256)
        
        # Classification Head
        self.fc1 = nn.Linear(self.feature_size, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Conv1d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
            
    def _calculate_feature_size(self, input_length, num_channels):
        # Helper function to calculate feature size after convolutions and pooling
        x = torch.randn(1, num_channels, input_length)
        branch_outputs = []
        
        for branch in self.branches:
            branch_outputs.append(branch(x))
            
        x = torch.cat(branch_outputs, dim=1)
        x = self.channel_conv(x)
        x = self.spatial_attention(x)
        return x.numel() // x.shape[0]
    
    def forward(self, x):
        # x shape: (batch, channels, time)
        
        # Process each branch
        branch_outputs = []
        for branch in self.branches:
            branch_outputs.append(branch(x))
            
        # Concatenate branch outputs
        x = torch.cat(branch_outputs, dim=1)  # (batch, 384, time)
        
        # Spatial integration
        x = self.channel_conv(x)  # (batch, 256, time)
        x = self.spatial_attention(x)
        
        # Flatten for classification
        x = x.view(x.size(0), -1)
        
        # Classification
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return F.log_softmax(x, dim=1)

# Training utilities
class EpilepsyDataset(torch.utils.data.Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = torch.FloatTensor(data)
        self.labels = torch.LongTensor(labels)
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        
        if self.transform:
            x = self.transform(x)
            
        return x, y

class SpikeAugmentation:
    def __init__(self, sigma=0.01, time_shift=15, scale_range=(0.9, 1.1)):
        self.sigma = sigma
        self.time_shift = time_shift
        self.scale_range = scale_range
        
    def __call__(self, x):
        # Add Gaussian noise
        x = x + torch.randn_like(x) * self.sigma
        
        # Random time shift
        if self.time_shift > 0:
            shift = torch.randint(-self.time_shift, self.time_shift + 1, (1,))
            x = torch.roll(x, shifts=shift.item(), dims=-1)
            
        # Random scaling
        if self.scale_range[0] < self.scale_range[1]:
            scale = torch.empty(1).uniform_(*self.scale_range)
            x = x * scale
            
        return x

# Example usage:
def train_model():
    # Hyperparameters
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = EpilepsyDetectionModel().to(device)
    criterion = nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    
    # Create datasets and dataloaders
    transform = SpikeAugmentation()
    train_dataset = EpilepsyDataset(train_data, train_labels, transform=transform)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=32, 
        shuffle=True
    )
    
    # Training loop
    model.train()
    for epoch in range(100):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            
            loss.backward()
            optimizer.step()
        
        scheduler.step()