In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score
from scipy import signal
import os
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

# Configuration
class Config:
    # Data parameters
    sampling_rate = 256
    window_size = 10 * sampling_rate  # 10 seconds
    stride = 1 * sampling_rate  # 1 second
    num_channels = 18  # CHB-MIT typical
    
    # Model architecture
    encoder_dim = 64
    ltc_neurons1 = 64
    ltc_neurons2 = 48
    attention_heads = 2
    attention_dim = 32
    
    # Training
    batch_size = 32
    learning_rate = 1e-3
    weight_decay = 1e-4
    epochs = 100
    dropout_rate = 0.1
    
    # LTC parameters
    min_time_constant = 0.01  # 10ms
    max_time_constant = 10.0  # 10 seconds
    
    # Loss weights
    lambda_pred = 0.7
    mu_reg = 1e-5
    gamma_lead = 2.0
    
    # Pre-ictal window (5 minutes for prediction)
    preictal_window = 300  # seconds

config = Config()

Using device: cpu


In [7]:
class SpectralFrontend(nn.Module):
    """Learnable filterbank spectral front-end"""
    def __init__(self, num_channels, out_channels=16):
        super(SpectralFrontend, self).__init__()
        self.num_channels = num_channels
        self.out_channels = out_channels
        
        # Depthwise convolutions for each channel
        self.filterbanks = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(1, out_channels, kernel_size=128, stride=32, padding=64, bias=False),
                nn.BatchNorm1d(out_channels),
                nn.GELU()
            ) for _ in range(num_channels)
        ])
        
    def forward(self, x):
        # x shape: (batch, channels, time)
        batch_size, num_channels, time_steps = x.shape
        
        # Apply filterbank to each channel
        processed_channels = []
        for i in range(num_channels):
            channel_data = x[:, i:i+1, :]  # (batch, 1, time)
            processed = self.filterbanks[i](channel_data)  # (batch, out_channels, time')
            processed_channels.append(processed)
        
        # Concatenate all channels
        output = torch.cat(processed_channels, dim=1)  # (batch, channels*out_channels, time')
        return output

# Test spectral frontend
def test_spectral_frontend():
    frontend = SpectralFrontend(config.num_channels).to(device)
    x = torch.randn(2, config.num_channels, config.window_size).to(device)
    output = frontend(x)
    print(f"Spectral frontend: input {x.shape} -> output {output.shape}")
    return frontend

frontend = test_spectral_frontend()

Spectral frontend: input torch.Size([2, 18, 2560]) -> output torch.Size([2, 288, 81])


In [8]:
class LTCLayer(nn.Module):
    """Liquid Time-Constant Continuous-time Layer"""
    def __init__(self, input_dim, hidden_dim, sparsity=0.2):
        super(LTCLayer, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # Time constants (learnable per neuron)
        self.log_tau = nn.Parameter(torch.randn(hidden_dim) * 0.1)
        
        # Input weights
        self.W_x = nn.Linear(input_dim, hidden_dim, bias=False)
        
        # Sparse recurrent weights
        self.W_h = self._create_sparse_weights(hidden_dim, hidden_dim, sparsity)
        
        # Bias
        self.bias = nn.Parameter(torch.zeros(hidden_dim))
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(hidden_dim)
        
    def _create_sparse_weights(self, in_dim, out_dim, sparsity):
        weight = torch.zeros(out_dim, in_dim)
        num_nonzero = int(in_dim * out_dim * (1 - sparsity))
        indices = torch.randperm(in_dim * out_dim)[:num_nonzero]
        weight.view(-1)[indices] = torch.randn(num_nonzero) * 0.1
        return nn.Parameter(weight)
    
    def get_time_constants(self):
        """Get time constants in actual seconds"""
        tau = torch.sigmoid(self.log_tau) * (config.max_time_constant - config.min_time_constant) + config.min_time_constant
        return tau
    
    def forward(self, x, initial_state=None):
        # x shape: (batch, time, features)
        batch_size, seq_len, input_dim = x.shape
        
        # Initialize hidden state
        if initial_state is None:
            h = torch.zeros(batch_size, self.hidden_dim, device=x.device)
        else:
            h = initial_state
        
        # Get time constants and compute alpha
        tau = self.get_time_constants()  # (hidden_dim,)
        dt = 1.0 / (config.sampling_rate / 32)  # Account for downsampling in frontend
        alpha = torch.exp(-dt / tau)  # (hidden_dim,)
        
        hidden_states = []
        time_constants = []
        
        for t in range(seq_len):
            # LTC dynamics
            input_proj = self.W_x(x[:, t, :])  # (batch, hidden_dim)
            recurrent = h @ self.W_h.t()  # (batch, hidden_dim)
            
            # Update equation
            h = alpha * h + (1 - alpha) * torch.tanh(input_proj + recurrent + self.bias)
            h = self.layer_norm(h)
            
            hidden_states.append(h.unsqueeze(1))
            time_constants.append(tau.unsqueeze(0).expand(batch_size, -1).unsqueeze(1))
        
        hidden_sequence = torch.cat(hidden_states, dim=1)  # (batch, seq_len, hidden_dim)
        time_constant_sequence = torch.cat(time_constants, dim=1)  # (batch, seq_len, hidden_dim)
        
        return hidden_sequence, time_constant_sequence

# Test LTC layer
def test_ltc_layer():
    ltc = LTCLayer(64, config.ltc_neurons1).to(device)
    x = torch.randn(2, 80, 64).to(device)  # (batch, time, features)
    hidden_seq, tau_seq = ltc(x)
    print(f"LTC layer: input {x.shape} -> output {hidden_seq.shape}")
    print(f"Time constants shape: {tau_seq.shape}")
    print(f"Mean time constant: {ltc.get_time_constants().mean().item():.3f}s")
    return ltc

ltc_layer = test_ltc_layer()

LTC layer: input torch.Size([2, 80, 64]) -> output torch.Size([2, 80, 64])
Time constants shape: torch.Size([2, 80, 64])
Mean time constant: 5.062s


In [9]:
class TemporalAttentionGate(nn.Module):
    """Lightweight temporal and channel attention"""
    def __init__(self, input_dim, num_heads=2, attention_dim=32):
        super(TemporalAttentionGate, self).__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.attention_dim = attention_dim
        
        # Multi-head attention for temporal patterns
        self.temporal_attention = nn.MultiheadAttention(
            embed_dim=input_dim,
            num_heads=num_heads,
            dropout=config.dropout_rate,
            batch_first=True
        )
        
        # Channel-frequency gating
        self.channel_gate = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            nn.GELU(),
            nn.Linear(input_dim // 2, input_dim),
            nn.Sigmoid()
        )
        
        self.layer_norm1 = nn.LayerNorm(input_dim)
        self.layer_norm2 = nn.LayerNorm(input_dim)
        
    def forward(self, x):
        # x shape: (batch, time, features)
        
        # Temporal attention
        attended, attention_weights = self.temporal_attention(x, x, x)
        x = self.layer_norm1(x + attended)
        
        # Channel gating
        channel_weights = self.channel_gate(x.mean(dim=1))  # (batch, features)
        channel_weights = channel_weights.unsqueeze(1)  # (batch, 1, features)
        x = x * channel_weights
        
        x = self.layer_norm2(x)
        return x, attention_weights, channel_weights

# Test attention gate
def test_attention_gate():
    attention = TemporalAttentionGate(config.ltc_neurons2).to(device)
    x = torch.randn(2, 80, config.ltc_neurons2).to(device)
    output, attn_weights, channel_weights = attention(x)
    print(f"Attention gate: input {x.shape} -> output {output.shape}")
    print(f"Attention weights: {attn_weights.shape}")
    print(f"Channel weights: {channel_weights.shape}")
    return attention

attention_gate = test_attention_gate()

Attention gate: input torch.Size([2, 80, 48]) -> output torch.Size([2, 80, 48])
Attention weights: torch.Size([2, 80, 80])
Channel weights: torch.Size([2, 1, 48])


In [10]:
class LightLTCSeizNet(nn.Module):
    """Complete LTC-based seizure prediction model"""
    def __init__(self, config):
        super(LightLTCSeizNet, self).__init__()
        self.config = config
        
        # Spectral front-end
        self.spectral_frontend = SpectralFrontend(config.num_channels)
        frontend_output_dim = config.num_channels * 16
        
        # Channel encoder - FIXED: removed groups parameter
        self.channel_encoder = nn.Sequential(
            nn.Conv1d(frontend_output_dim, config.encoder_dim, kernel_size=7, 
                     stride=1, padding=3),  # Removed groups parameter
            nn.BatchNorm1d(config.encoder_dim),
            nn.GELU(),
            nn.Conv1d(config.encoder_dim, config.encoder_dim, kernel_size=1),
            nn.BatchNorm1d(config.encoder_dim),
            nn.GELU(),
        )
        
        # LTC blocks
        self.ltc1 = LTCLayer(config.encoder_dim, config.ltc_neurons1)
        self.ltc2 = LTCLayer(config.ltc_neurons1, config.ltc_neurons2)
        
        # Residual connection
        self.residual_proj = nn.Linear(config.encoder_dim, config.ltc_neurons2)
        
        # Attention gate
        self.attention_gate = TemporalAttentionGate(config.ltc_neurons2)
        
        # Readout heads
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.global_max_pool = nn.AdaptiveMaxPool1d(1)
        
        # Detection head (instant seizure)
        self.detect_head = nn.Sequential(
            nn.Linear(config.ltc_neurons2 * 2, 64),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
        # Prediction head (pre-ictal)
        self.pred_head = nn.Sequential(
            nn.Linear(config.ltc_neurons2 * 2, 64),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
        # Initialize weights
        self._initialize_weights()
        
        print(f"Total parameters: {sum(p.numel() for p in self.parameters()):,}")
        
    def _initialize_weights(self):
        """Initialize weights properly"""
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        # x shape: (batch, channels, time)
        batch_size, channels, time_steps = x.shape
        
        # Spectral front-end
        x_spectral = self.spectral_frontend(x)  # (batch, channels*16, time')
        
        # Channel encoder
        x_encoded = self.channel_encoder(x_spectral)  # (batch, encoder_dim, time')
        x_encoded = x_encoded.transpose(1, 2)  # (batch, time', encoder_dim)
        
        # LTC blocks
        h1, tau1 = self.ltc1(x_encoded)
        h2, tau2 = self.ltc2(h1)
        
        # Residual connection
        residual = self.residual_proj(x_encoded)
        # Ensure residual matches the temporal dimension of h2
        residual = residual[:, :h2.size(1), :]
        h2 = h2 + residual
        
        # Attention gate
        h_attended, attention_weights, channel_weights = self.attention_gate(h2)
        
        # Pooling
        avg_pool = self.global_pool(h_attended.transpose(1, 2)).squeeze(-1)
        max_pool = self.global_max_pool(h_attended.transpose(1, 2)).squeeze(-1)
        pooled = torch.cat([avg_pool, max_pool], dim=1)
        
        # Readout
        p_detect = self.detect_head(pooled)
        p_pred = self.pred_head(pooled)
        
        return {
            'detection': p_detect,
            'prediction': p_pred,
            'attention_weights': attention_weights,
            'channel_weights': channel_weights,
            'time_constants1': tau1,
            'time_constants2': tau2
        }

# Test complete model
def test_complete_model():
    model = LightLTCSeizNet(config).to(device)
    x = torch.randn(2, config.num_channels, config.window_size).to(device)
    outputs = model(x)
    
    print(f"Model outputs:")
    print(f"  Detection: {outputs['detection'].shape}")
    print(f"  Prediction: {outputs['prediction'].shape}")
    print(f"  Time constants 1: {outputs['time_constants1'].shape}")
    print(f"  Time constants 2: {outputs['time_constants2'].shape}")
    
    return model

model = test_complete_model()

Total parameters: 212,602
Model outputs:
  Detection: torch.Size([2, 1])
  Prediction: torch.Size([2, 1])
  Time constants 1: torch.Size([2, 81, 64])
  Time constants 2: torch.Size([2, 81, 48])


In [11]:
import mne
import glob
from torch.utils.data import Dataset, DataLoader

class CHBMITDataset(Dataset):
    """Real CHB-MIT Scalp EEG Dataset Loader"""
    def __init__(self, data_path, window_size=2560, stride=256, mode='train', 
                 subjects=None, preictal_window=300, sampling_rate=256):
        self.data_path = data_path
        self.window_size = window_size
        self.stride = stride
        self.mode = mode
        self.preictal_window = preictal_window
        self.sampling_rate = sampling_rate
        
        # Discover available subjects and files
        self.subjects = self._discover_subjects()
        if subjects:
            self.subjects = [s for s in self.subjects if s in subjects]
        
        # Load all data and annotations
        self.samples = self._load_all_data()
        
        print(f"Loaded {len(self.samples)} samples for {mode} mode across {len(self.subjects)} subjects")
        
    def _discover_subjects(self):
        """Discover available subjects in CHB-MIT dataset"""
        subjects = []
        chb_dir = os.path.join(self.data_path, 'chb-mit-scalp-eeg-database-1.0.0')
        if not os.path.exists(chb_dir):
            # If the structure is different, try to find subject directories
            chb_dir = self.data_path
            
        for item in os.listdir(chb_dir):
            if item.startswith('chb') and os.path.isdir(os.path.join(chb_dir, item)):
                subjects.append(item)
        return sorted(subjects)
    
    def _load_all_data(self):
        """Load all EEG data and create labeled windows"""
        all_samples = []
        
        for subject in self.subjects:
            subject_dir = os.path.join(self.data_path, 'chb-mit-scalp-eeg-database-1.0.0', subject)
            if not os.path.exists(subject_dir):
                print(f"Subject directory not found: {subject_dir}")
                continue
                
            # Find all EDF files for this subject
            edf_files = glob.glob(os.path.join(subject_dir, "*.edf"))
            summary_file = os.path.join(subject_dir, f"{subject}-summary.txt")
            
            # Load seizure annotations from summary file
            seizure_times = self._load_seizure_annotations(summary_file, subject)
            
            for edf_file in edf_files:
                try:
                    # Load EDF file
                    raw = mne.io.read_raw_edf(edf_file, preload=False, verbose=False)
                    
                    # Resample to target sampling rate if needed
                    if raw.info['sfreq'] != self.sampling_rate:
                        raw = raw.resample(self.sampling_rate)
                    
                    # Get data (only EEG channels)
                    eeg_channels = [ch for ch in raw.ch_names if 'EEG' in ch]
                    if len(eeg_channels) < config.num_channels:
                        print(f"Warning: Only {len(eeg_channels)} EEG channels found in {edf_file}")
                        continue
                        
                    # Select first config.num_channels for consistency
                    eeg_channels = eeg_channels[:config.num_channels]
                    raw.pick_channels(eeg_channels)
                    
                    # Get the actual data
                    data, times = raw[:, :]
                    data = data.astype(np.float32)
                    
                    # Normalize each channel
                    data = (data - np.mean(data, axis=1, keepdims=True)) / (np.std(data, axis=1, keepdims=True) + 1e-8)
                    
                    # Create labeled windows for this file
                    file_samples = self._create_windows_for_file(data, times, seizure_times, subject, edf_file)
                    all_samples.extend(file_samples)
                    
                except Exception as e:
                    print(f"Error loading {edf_file}: {e}")
                    continue
                    
        return all_samples
    
    def _load_seizure_annotations(self, summary_file, subject):
        """Load seizure annotations from summary file"""
        seizure_times = []
        
        if not os.path.exists(summary_file):
            print(f"Summary file not found: {summary_file}")
            return seizure_times
            
        try:
            with open(summary_file, 'r') as f:
                lines = f.readlines()
            
            current_file = None
            for line in lines:
                line = line.strip()
                if line.startswith('File Name:'):
                    current_file = line.split(': ')[1]
                elif line.startswith('Seizure Start Time:'):
                    start_time = int(line.split(': ')[1].split()[0])
                elif line.startswith('Seizure End Time:'):
                    end_time = int(line.split(': ')[1].split()[0])
                    if current_file:
                        seizure_times.append({
                            'file': current_file,
                            'start': start_time,
                            'end': end_time
                        })
        except Exception as e:
            print(f"Error reading summary file {summary_file}: {e}")
            
        return seizure_times
    
    def _create_windows_for_file(self, data, times, seizure_times, subject, edf_file):
        """Create labeled windows for a single EEG file"""
        samples = []
        num_samples = data.shape[1]
        file_name = os.path.basename(edf_file)
        
        # Find seizures in this file
        file_seizures = [sz for sz in seizure_times if sz['file'] == file_name]
        
        for start_idx in range(0, num_samples - self.window_size + 1, self.stride):
            end_idx = start_idx + self.window_size
            window_data = data[:, start_idx:end_idx]
            
            # Convert to time in seconds
            window_center_sec = times[start_idx + self.window_size // 2]
            
            # Create labels
            detection_label, prediction_label, timestamp = self._create_labels(
                window_center_sec, file_seizures
            )
            
            samples.append({
                'eeg_data': window_data,
                'detection_label': detection_label,
                'prediction_label': prediction_label,
                'timestamp': timestamp,
                'subject': subject,
                'file': file_name
            })
            
        return samples
    
    def _create_labels(self, window_center_sec, file_seizures):
        """Create detection and prediction labels for a window"""
        detection_label = 0
        prediction_label = 0
        timestamp = 0  # Time to nearest seizure
        
        if not file_seizures:
            # No seizures in this file, all interictal
            return 0, 0, 1000  # Large positive timestamp for interictal
            
        for seizure in file_seizures:
            seizure_start = seizure['start']
            seizure_end = seizure['end']
            preictal_start = seizure_start - self.preictal_window
            
            # Check if window contains seizure (ictal)
            if seizure_start <= window_center_sec <= seizure_end:
                detection_label = 1
                timestamp = window_center_sec - seizure_start  # Negative during seizure
                break
            
            # Check if window is in pre-ictal period
            elif preictal_start <= window_center_sec < seizure_start:
                prediction_label = 1
                timestamp = window_center_sec - seizure_start  # Negative, approaches 0
                break
            
            # For inter-ictal, compute time to nearest seizure
            else:
                dist_to_seizure = min(
                    abs(window_center_sec - seizure_start),
                    abs(window_center_sec - seizure_end)
                )
                timestamp = dist_to_seizure if window_center_sec < seizure_start else -dist_to_seizure
        
        return detection_label, prediction_label, timestamp
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Convert to tensor
        eeg_data = torch.FloatTensor(sample['eeg_data']).contiguous()
        
        # Apply data augmentation in training
        if self.mode == 'train':
            eeg_data = self._augment_data(eeg_data)
        
        return {
            'eeg': eeg_data,
            'detection_label': torch.FloatTensor([sample['detection_label']]),
            'prediction_label': torch.FloatTensor([sample['prediction_label']]),
            'timestamp': torch.FloatTensor([sample['timestamp']]),
            'subject': sample['subject'],
            'file': sample['file']
        }
    
    def _augment_data(self, eeg_data):
        """Apply data augmentation"""
        # Channel dropout
        if torch.rand(1) < 0.3:
            num_drop = int(config.num_channels * 0.2)
            channels_to_drop = torch.randperm(config.num_channels)[:num_drop]
            eeg_data[channels_to_drop] = 0
        
        # Add Gaussian noise
        if torch.rand(1) < 0.5:
            noise = torch.randn_like(eeg_data) * 0.1
            eeg_data = eeg_data + noise
        
        # Add EMG-like noise (bursts)
        if torch.rand(1) < 0.3:
            burst_noise = torch.randn_like(eeg_data) * 0.2
            mask = torch.rand_like(eeg_data) < 0.1  # 10% of samples get bursts
            eeg_data = eeg_data + burst_noise * mask
        
        # Amplitude scaling
        if torch.rand(1) < 0.3:
            scale = torch.rand(1) * 0.4 + 0.8  # 0.8 to 1.2
            eeg_data = eeg_data * scale
            
        return eeg_data

# Create real dataloaders
def create_real_dataloaders(data_path):
    """Create real dataloaders with CHB-MIT data"""
    # Use first 15 subjects for training, next 5 for validation
    all_subjects = [f'chb{i:02d}' for i in range(1, 24)]  # CHB-MIT has subjects chb01 to chb24
    
    # For demo, use a small subset
    train_subjects = all_subjects[:3]
    val_subjects = all_subjects[3:4]
    
    try:
        train_dataset = CHBMITDataset(
            data_path=data_path,
            window_size=config.window_size,
            stride=config.stride,
            mode='train',
            subjects=train_subjects
        )
        
        val_dataset = CHBMITDataset(
            data_path=data_path,
            window_size=config.window_size,
            stride=config.stride,
            mode='val',
            subjects=val_subjects
        )
        
        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
        
        return train_loader, val_loader
        
    except Exception as e:
        print(f"Error creating real dataloaders: {e}")
        print("Creating mock dataloaders for demonstration...")
        return create_mock_dataloaders()

def create_mock_dataloaders():
    """Create mock dataloaders for demonstration when real data is not available"""
    class MockDataset(Dataset):
        def __init__(self, num_samples=1000):
            self.num_samples = num_samples
            
        def __len__(self):
            return self.num_samples
            
        def __getitem__(self, idx):
            # Generate realistic EEG-like data
            eeg_data = torch.randn(config.num_channels, config.window_size) * 0.1
            
            # Add some rhythmic components to make it more realistic
            t = torch.arange(config.window_size).float() / config.sampling_rate
            for i in range(config.num_channels):
                freq = torch.rand(1) * 30 + 2  # 2-32 Hz
                rhythm = torch.sin(2 * np.pi * freq * t) * 0.05
                eeg_data[i] += rhythm
            
            # Random labels (20% positive for demo)
            detection_label = torch.FloatTensor([1.0]) if torch.rand(1) < 0.2 else torch.FloatTensor([0.0])
            prediction_label = torch.FloatTensor([1.0]) if torch.rand(1) < 0.15 else torch.FloatTensor([0.0])
            timestamp = torch.FloatTensor([torch.randn(1).item() * 100])
            
            return {
                'eeg': eeg_data,
                'detection_label': detection_label,
                'prediction_label': prediction_label,
                'timestamp': timestamp,
                'subject': 'mock',
                'file': 'mock.edf'
            }
    
    train_dataset = MockDataset(1000)
    val_dataset = MockDataset(200)
    
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    
    return train_loader, val_loader

# Initialize dataloaders
data_path = "."  # Change this to your CHB-MIT data path
train_loader, val_loader = create_real_dataloaders(data_path)

ModuleNotFoundError: No module named 'mne'