In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/

In [None]:
!pip install mne mne_connectivity torch-geometric -q

#!pip install PyWavelets

In [None]:
import gc
import numpy as np
import pandas as pd
import mne
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch.utils.data import Dataset, DataLoader
from mne.preprocessing import ICA
from scipy.signal import hilbert

class MemoryEfficientPLVProcessor:
    """
    Memory-efficient processor that handles TMS events and resting state data
    """
    def __init__(self, root_dir, output_dir='plv_results', montage='standard_1020'):
        self.root_dir = Path(root_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.montage = montage

        # PLV computation parameters
        self.freq_range = (0.1, 40)
        self.epoch_length = 5.0   # Default window length (shorter for TMS)
        self.overlap = 0

    def scan_subjects(self):
        """Scan for subject folders"""
        subjects = [d for d in self.root_dir.iterdir()
                   if d.is_dir() and d.name.startswith('sub-')]
        print(f"Found {len(subjects)} subjects")
        return subjects

    def _load_tms_events(self, file_path):
        """More robust TMS event loading"""
        try:
            events_path = file_path.parent / file_path.name.replace('_eeg.eeg', '_events.tsv')
            events = pd.read_csv(events_path, sep='\t')

            # Handle multiple TMS event types
            tms_events = events[events['trial_type'].str.contains('TMS|pulse|stim', case=False, na=False)]

            if len(tms_events) == 0:
                print(f"No TMS events found in {events_path}")
                return None

            # Convert to array with [onset, duration]
            return tms_events[['onset', 'duration']].values.astype(float)

        except Exception as e:
            print(f"TMS event loading failed: {e}")
            return None

    def load_and_preprocess_eeg(self, file_path):
        """Memory-optimized EEG loading and preprocessing"""
        try:
            if file_path.suffix == '.eeg':
                vhdr_path = file_path.with_suffix('.vhdr')
                if not vhdr_path.exists():
                    print(f"Missing .vhdr file for {file_path.name}")
                    return None

                # Load without preloading to save memory
                print(f"Loading {vhdr_path.name}...")
                raw = mne.io.read_raw_brainvision(vhdr_path, preload=False, verbose=False)

                # Set montage
                try:
                    montage = mne.channels.make_standard_montage(self.montage)
                    raw.set_montage(montage, match_case=False, verbose=False)
                except Exception as e:
                    print(f"Warning: Could not set montage: {e}")

                # Load data and apply filters
                print("Applying filters...")
                raw.load_data()
                raw.filter(0.1, 40.0, n_jobs=1, verbose=False)
                raw.notch_filter(50.0, n_jobs=1, verbose=False)

                # Memory-efficient ICA
                print("Running ICA...")
                n_components = min(15, len(raw.ch_names) - 2)  # Fewer components
                ica = ICA(
                    n_components=n_components,
                    max_iter=300,  # Fewer iterations
                    random_state=97,
                    verbose=False
                )

                # Fit on decimated data
                ica.fit(raw, reject_by_annotation=True, decim=3)

                # Automatic artifact detection
                if len(mne.pick_types(raw.info, eog=True)) > 0:
                    eog_inds, eog_scores = ica.find_bads_eog(raw)
                    ica.exclude = eog_inds

                # Apply ICA
                print("Applying ICA correction...")
                ica.apply(raw)

                return raw

            else:
                print(f"Unsupported file format: {file_path}")
                return None

        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            return None
    '''
    def compute_plv_for_window(self, data, sfreq):
        """Compute PLV matrix with robust error handling"""
        try:
            n_channels, n_samples = data.shape

            # Create temporary raw object
            info = mne.create_info(
                ch_names=[f'CH{i}' for i in range(n_channels)],
                sfreq=sfreq,
                ch_types='eeg'
            )
            raw_temp = mne.io.RawArray(data, info, verbose=False)

            # Create epochs
            events = np.array([[0, 0, 1]])
            epochs = mne.Epochs(
                raw_temp,
                events,
                tmin=0,
                tmax=(n_samples-1)/sfreq,
                baseline=None,
                preload=True,
                verbose=False
            )

            # Compute PLV
            from mne_connectivity import spectral_connectivity_epochs
            con = spectral_connectivity_epochs(
                epochs,
                method='plv',
                fmin=self.freq_range[0],
                fmax=self.freq_range[1],
                sfreq=sfreq,
                faverage=True,  # Average across frequencies
                verbose=False
            )

            # Get dense symmetric matrix
            plv_matrix = con.get_data(output='dense')[0]

            # Ensure matrix is properly shaped
            if plv_matrix.shape != (n_channels, n_channels):
                print(f"Warning: PLV shape {plv_matrix.shape}, expected {(n_channels, n_channels)}")
                plv_matrix = np.eye(n_channels)

            return plv_matrix

        except Exception as e:
            print(f"PLV computation error: {e}")
            return np.eye(data.shape[0])  # Return identity matrix as fallback
    '''

    def compute_plv_for_window(self, data, sfreq):
        analytic_signal = hilbert(data, axis=1)
        phases = np.angle(analytic_signal)
        phase_diff = phases[:, None, :] - phases[None, :, :]  # (n_chan, n_chan, n_samples)
        plv_matrix = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
        return plv_matrix


    def _process_sliding_windows(self, raw, subject_id, file_type):
      """Process with shape validation"""
      data = raw.get_data()
      sfreq = raw.info['sfreq']
      window_samples = int(self.epoch_length * sfreq)

      for i, start_idx in enumerate(range(0, data.shape[1] - window_samples, window_samples)):
          window_data = data[:, start_idx:start_idx+window_samples]

          # Skip empty/invalid windows
          if window_data.size == 0:
              print(f"Skipping empty window {i}")
              continue

          plv_matrix = self.compute_plv_for_window(window_data, sfreq)
          self.save_plv_matrix(plv_matrix, raw.ch_names, subject_id, i, file_type)

          if i % 5 == 0:
              gc.collect()

      return True

    def _process_tms_windows(self, raw, subject_id, tms_events):
        """Process TMS events with artifact exclusion"""
        sfreq = raw.info['sfreq']
        success = False

        for i, (onset, duration) in enumerate(tms_events):
            try:
                # Skip 100ms post-TMS to avoid artifact
                start = onset + 0.1  # 100ms after TMS
                end = start + 5.0    # 5-second window

                # Get data (auto-rejects bad segments)
                window = raw.copy().crop(start, end).get_data()

                if window.size == 0:
                    continue

                plv = self.compute_plv_for_window(window, sfreq)
                self.save_plv_matrix(plv, raw.ch_names, subject_id, i, 'tms')
                success = True

            except Exception as e:
                print(f"Failed TMS window {i}: {e}")

        return success

    def process_file_to_plv(self, file_path, subject_id, file_type='rest'):
        """Main processor with better TMS detection"""
        print(f"\nProcessing {file_path.name} as {file_type}...")

        try:
            raw = self.load_and_preprocess_eeg(file_path)
            if raw is None:
                return False

            # Force TMS detection for files containing 'tms' in name
            if 'tms' in file_path.name.lower():
                file_type = 'tms'

            if file_type == 'tms':
                tms_events = self._load_tms_events(file_path)
                if tms_events is not None:
                    return self._process_tms_windows(raw, subject_id, tms_events)
                else:
                    print("No valid TMS events found, processing as continuous data")

            # Default sliding window processing
            return self._process_sliding_windows(raw, subject_id, file_type)

        finally:
            if 'raw' in locals():
                del raw
            gc.collect()

    def save_plv_matrix(self, plv_matrix, ch_names, subject_id, window_idx, file_type):
        """Save PLV matrix with proper shape handling"""
        try:
            # Ensure the matrix is square
            if plv_matrix.shape[0] != plv_matrix.shape[1]:
                print(f"Warning: Non-square PLV matrix {plv_matrix.shape}, creating square matrix")
                n_channels = len(ch_names)
                plv_matrix = np.eye(n_channels)  # Fallback to identity matrix

            # Ensure matrix matches channel names length
            if plv_matrix.shape[0] != len(ch_names):
                print(f"Shape mismatch: PLV {plv_matrix.shape} vs channels {len(ch_names)}. Adjusting...")
                min_dim = min(plv_matrix.shape[0], len(ch_names))
                plv_matrix = plv_matrix[:min_dim, :min_dim]
                ch_names = ch_names[:min_dim]

            # Create and save DataFrame
            filename = f"sub-{subject_id}_type-{file_type}_window-{window_idx:04d}_plv.csv"
            filepath = self.output_dir / filename

            pd.DataFrame(plv_matrix,
                        columns=ch_names,
                        index=ch_names).to_csv(filepath)

            print(f"Saved: {filename}")

        except Exception as e:
            print(f"Error saving PLV matrix: {e}")
            # Save error log
            with open(self.output_dir / 'error_log.txt', 'a') as f:
                f.write(f"Error saving {filename}: {str(e)}\n")

    def process_all_subjects(self):
        """Process all subjects (unchanged)"""
        subjects = self.scan_subjects()
        processed_count = 0

        for subject_path in subjects:
            subject_id = subject_path.name.replace('sub-', '')
            eeg_dir = subject_path / 'eeg'

            if not eeg_dir.exists():
                continue

            # Process resting state
            for rest_file in eeg_dir.glob('*rest_eeg*.eeg'):
                if self.process_file_to_plv(rest_file, subject_id, 'rest'):
                    processed_count += 1

            # Process TMS files
            for tms_file in eeg_dir.glob('*tmseeg1_eeg.eeg'):
                if self.process_file_to_plv(tms_file, subject_id, 'tms'):
                    processed_count += 1

            gc.collect()
            print(f"Completed subject {subject_id}")

        print(f"Processed {processed_count} files")

# Rest of your code remains unchanged (PLVDatasetFromDisk and MemoryEfficientTMSClassifier classes)

if __name__ == "__main__":
    # Example usage
    processor = MemoryEfficientPLVProcessor(
        root_dir="EEGTMS",
        output_dir="plv_matrices_single"
    )
    processor.process_all_subjects()

    classifier = MemoryEfficientTMSClassifier("plv_matrices_single")
    classifier.load_dataset()
    classifier.train(epochs=50, learning_rate=0.001)

In [None]:
import gc
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import GATConv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import mne
from mne_connectivity import spectral_connectivity_epochs
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from mne.preprocessing import ICA
import warnings
warnings.filterwarnings('ignore')


In [None]:


class EEGDataLoader:
    """
    Loads EEG time series data organized in subject folders
    """
    def __init__(self, root_dir, montage='standard_1020'):
        self.root_dir = Path(root_dir)
        self.montage = montage
        self.subjects = []
        self.resting_data = {}
        self.tms_data = {}
        self.events_data = {}

    def scan_subjects(self):
        """Scan for subject folders following sub-xxx pattern"""
        self.subjects = [d for d in self.root_dir.iterdir()
                        if d.is_dir() and d.name.startswith('sub-')]
        print(f"Found {len(self.subjects)} subjects")
        return self.subjects

    def load_eeg_file(self, file_path):
        """Load .eeg EEG file using MNE and preprocess"""
        try:
            if file_path.suffix == '.eeg':
                # BrainVision requires .vhdr file; .eeg and .vmrk must be present in the same folder
                vhdr_path = file_path.with_suffix('.vhdr')
                if not vhdr_path.exists():
                    print(f"Missing corresponding .vhdr file for {file_path.name}")
                    return None

                # Load the raw data using the .vhdr file
                raw = mne.io.read_raw_brainvision(vhdr_path, preload=True, verbose=False)
            else:
                print(f"Unsupported file format: {file_path}")
                return None

            # Attempt to set standard montage
            try:
                montage = mne.channels.make_standard_montage(self.montage)
                raw.set_montage(montage, match_case=False, verbose=False)
            except Exception as montage_error:
                print(f"Warning: Could not set montage for {file_path}: {montage_error}")

            # Apply preprocessing
            raw = self.preprocess_eeg(raw)
            return raw

        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            return None

    def preprocess_eeg(self, raw, l_freq=0.1, h_freq=40.0, notch_freq=50.0,
                      apply_ica=True, n_components=0.95):
        """Bandpass filter and ICA motion correction for EEG"""
        # Bandpass filtering
        raw.filter(l_freq=l_freq, h_freq=h_freq, fir_design='firwin', verbose=False)

        # Notch filter (optional)
        if notch_freq:
            raw.notch_filter(freqs=notch_freq, fir_design='firwin', verbose=False)

        # Run ICA if requested
        if apply_ica:
            n_components = min(20, len(raw.ch_names) - 1)  # Use 20 components or channels-1 instead of 0.95
            ica = mne.preprocessing.ICA(
              n_components=n_components,
              #method='picard',  # More stable than default
              max_iter=500,
              random_state=97,
              verbose=False
            )

            ica.fit(raw, reject_by_annotation=True)

            # Automatic artifact detection
            if len(mne.pick_types(raw.info, eog=True)) > 0:
                  eog_inds, eog_scores = ica.find_bads_eog(raw)
                  ica.exclude = eog_inds

            # Apply ICA
            raw = ica.apply(raw)

        return raw

    def process_large_eeg(self, file_path, chunk_size=300):
        """Process EEG in chunks with robust time handling and memory management

        Args:
            file_path: Path to BrainVision EEG file (.vhdr)
            chunk_size: Duration of chunks in seconds

        Yields:
            Preprocessed Raw objects for each chunk

        Raises:
            ValueError: If file cannot be read or chunk_size is invalid
        """
        try:
            # 1. Load file without preloading
            raw = mne.io.read_raw_brainvision(file_path, preload=False, verbose=False)

            # 2. Calculate duration with floating-point precision
            total_samples = raw.n_times
            sfreq = raw.info['sfreq']
            total_duration = total_samples / sfreq

            # 3. Validate chunk size
            if chunk_size <= 0:
                raise ValueError(f"chunk_size must be >0 (got {chunk_size})")
            if chunk_size > total_duration:
                print(f"Warning: chunk_size {chunk_size}s > file duration {total_duration:.2f}s. Using full duration.")
                chunk_size = total_duration

            # 4. Process in chunks with safe cropping
            for start in np.arange(0, total_duration, chunk_size):
                try:
                    # Calculate end time with 1 sample buffer
                    end = min(start + chunk_size, total_duration - 1/sfreq)

                    # Handle floating-point precision issues
                    end = np.round(end, decimals=5)

                    # Crop and load chunk
                    with mne.utils.use_log_level('ERROR'):  # Suppress crop warnings
                        raw_chunk = raw.copy().crop(start, end, include_tmax=False).load_data()

                    # Process chunk
                    processed_chunk = self.preprocess_eeg(raw_chunk)

                    # Clear intermediate objects
                    del raw_chunk

                    yield processed_chunk

                except Exception as chunk_error:
                    print(f"Error processing chunk {start:.1f}-{end:.1f}s: {str(chunk_error)}")
                    continue

        except Exception as file_error:
            raise ValueError(f"Error processing {file_path}: {str(file_error)}")

        finally:
            # Ensure raw file is closed
            if 'raw' in locals():
                raw.close()
            gc.collect()

    def load_events_file(self, file_path):
        """Load events from TSV file"""
        try:
            events_df = pd.read_csv(file_path, sep='\t')
            return events_df
        except Exception as e:
            print(f"Error loading events {file_path}: {e}")
            return None
    '''
    def load_subject_data(self, subject_path):
        """Load all data for a single subject"""
        eeg_dir = subject_path / 'eeg'
        if not eeg_dir.exists():
            print(f"No EEG directory found for {subject_path.name}")
            return None, None, None

        resting_files = list(eeg_dir.glob('*rest_eeg*.eeg'))
        tms_files = list(eeg_dir.glob('*tmseeg1_eeg.eeg'))
        event_files = list(eeg_dir.glob('*.tsv'))

        resting_data = None
        tms_data = None
        events_data = None

        # Load resting state data
        if resting_files:
            print(resting_files[0])
            resting_data = self.load_eeg_file(resting_files[0])

        # Load TMS data
        if tms_files:
            print(tms_files[0])
            tms_data = self.load_eeg_file(tms_files[0])

        # Load events
        if event_files:
            events_data = self.load_events_file(event_files[0])

        return resting_data, tms_data, events_data
    '''
    def load_subject_data(self, subject_path):
        """Load all data for a single subject using chunked processing"""
        eeg_dir = subject_path / 'eeg'
        if not eeg_dir.exists():
            print(f"No EEG directory found for {subject_path.name}")
            return None, None, None

        resting_files = list(eeg_dir.glob('*rest_eeg*.eeg'))
        tms_files = list(eeg_dir.glob('*tmseeg1_eeg.eeg'))
        event_files = list(eeg_dir.glob('*.tsv'))

        # Process resting data in chunks
        resting_chunks = []
        if resting_files:
            print(resting_files)
            vhdr_path = resting_files[0].with_suffix('.vhdr')
            for chunk in self.process_large_eeg(vhdr_path):
                resting_chunks.append(chunk)

        # Process TMS data in chunks
        tms_chunks = []
        if tms_files:
            print(tms_files)
            vhdr_path = tms_files[0].with_suffix('.vhdr')
            for chunk in self.process_large_eeg(vhdr_path):
                tms_chunks.append(chunk)

        # Load events
        events_data = None
        if event_files:
            events_data = self.load_events_file(event_files[0])

        return resting_chunks, tms_chunks, events_data

    def load_all_subjects(self):
        """Load data for all subjects"""
        self.scan_subjects()

        for subject_path in self.subjects:
            subject_id = subject_path.name
            print(f"Loading {subject_id}...")

            resting, tms, events = self.load_subject_data(subject_path)

            if resting is not None:
                self.resting_data[subject_id] = resting
            if tms is not None:
                self.tms_data[subject_id] = tms
            if events is not None:
                self.events_data[subject_id] = events

        print(f"Loaded resting data for {len(self.resting_data)} subjects")
        print(f"Loaded TMS data for {len(self.tms_data)} subjects")
        print(f"Loaded events data for {len(self.events_data)} subjects")

class PLVGraphConstructor:
    """
    Constructs PLV-based connectivity graphs from EEG data using MNE-Connectivity
    """
    def __init__(self, freq_bands=None, threshold_percentile=90, epoch_length=1.0, output_dir='/content/drive/MyDrive/plv_results'):
        if freq_bands is None:
            self.freq_bands = {
                'theta': (4, 8),
                'alpha': (8, 13),
                'beta': (13, 30),
                'gamma': (30, 40)
            }
        else:
            self.freq_bands = freq_bands
        self.threshold_percentile = threshold_percentile
        self.epoch_length = epoch_length

        # Create output directory
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def save_plv_matrix(self, plv_matrix, subject_id, window_idx, band_name):
        """Save PLV matrix to CSV"""
        # Create filename
        filename = f"sub-{subject_id}_window-{window_idx}_{band_name}_plv.csv"
        filepath = self.output_dir / filename
        print("saving "+ filepath)
        # Convert to DataFrame and save
        df = pd.DataFrame(plv_matrix,
                         columns=self.ch_names,
                         index=self.ch_names)
        df.to_csv(filepath)
        return filepath

    def create_multi_band_edge_index(self, raw_data, subject_id=None, window_idx=None):
        """
        Create edge indices for multiple frequency bands and save PLV matrices
        """
        edge_indices = {}
        plv_matrices = {}

        # Store channel names for CSV headers
        self.ch_names = raw_data.ch_names

        for band_name, freq_band in self.freq_bands.items():
            try:
                adj_matrix, plv_matrix = self.create_plv_adjacency_matrix(raw_data, freq_band)

                # Save PLV matrix to CSV if subject/window info provided
                print("presaving")
                if subject_id is not None and window_idx is not None:
                    self.save_plv_matrix(plv_matrix, subject_id, window_idx, band_name)

                # Convert to edge index format
                edges = np.where(adj_matrix > 0)
                if len(edges[0]) > 0:
                    edge_index = np.stack([edges[0], edges[1]], axis=0)
                    edge_indices[band_name] = torch.tensor(edge_index, dtype=torch.long)
                else:
                    n_channels = adj_matrix.shape[0]
                    edge_index = np.array([[i, i+1] for i in range(n_channels-1)]).T
                    edge_indices[band_name] = torch.tensor(edge_index, dtype=torch.long)

                plv_matrices[band_name] = plv_matrix

            except Exception as e:
                print(f"Error processing band {band_name}: {e}")
                n_channels = len(raw_data.ch_names)
                edge_index = np.array([[i, i+1] for i in range(n_channels-1)]).T
                edge_indices[band_name] = torch.tensor(edge_index, dtype=torch.long)
                plv_matrices[band_name] = np.eye(n_channels)

        return edge_indices, plv_matrices

class TMSEEGDataset(Dataset):
    """
    Dataset for TMS-EEG classification with PLV-based connectivity
    """
    def __init__(self, eeg_data, events_data, window_size=1000, overlap=0.5, freq_bands=None):
        self.eeg_data = eeg_data
        self.events_data = events_data
        self.window_size = window_size
        self.overlap = overlap
        self.samples = []
        self.labels = []
        self.edge_indices = []
        self.plv_features = []

        if freq_bands is None:
            self.freq_bands = {
                'theta': (4, 8),
                'alpha': (8, 13),
                'beta': (13, 30),
                'gamma': (30, 40)
            }
        else:
            self.freq_bands = freq_bands

        self._prepare_samples()
    '''
    def _prepare_samples(self):
        """Prepare windowed samples with PLV-based connectivity using MNE-Connectivity"""
        graph_constructor = PLVGraphConstructor(freq_bands=self.freq_bands)

        for subject_id in self.eeg_data.keys():
            raw = self.eeg_data[subject_id]
            events = self.events_data.get(subject_id, None)

            # Get EEG data and sampling frequency
            data = raw.get_data()  # Shape: (n_channels, n_timepoints)
            sfreq = raw.info['sfreq']
            n_channels, n_timepoints = data.shape

            # Create sliding windows
            step_size = int(self.window_size * (1 - self.overlap))

            for start_idx in range(0, n_timepoints - self.window_size, step_size):
                end_idx = start_idx + self.window_size
                window_data = data[:, start_idx:end_idx]

                # Create a temporary Raw object for this window
                info = raw.info.copy()
                raw_window = mne.io.RawArray(window_data, info, verbose=False)

                # Compute PLV-based connectivity for this window using MNE-Connectivity
                edge_indices, plv_matrices = graph_constructor.create_multi_band_edge_index(raw_window)

                # Determine label based on events
                label = self._get_window_label(start_idx, end_idx, events, sfreq)

                self.samples.append(torch.tensor(window_data, dtype=torch.float32))
                self.labels.append(label)
                self.edge_indices.append(edge_indices)
                self.plv_features.append(plv_matrices)
    '''

    def _prepare_samples(self):
      """Prepare windowed samples from chunked data"""
      graph_constructor = PLVGraphConstructor(freq_bands=self.freq_bands)

      for subject_id in self.eeg_data.keys():
        # Extract subject number if using BIDS format (sub-XXX)
        sub_num = subject_id.split('-')[-1] if 'sub-' in subject_id else subject_id

        for chunk_idx, raw_chunk in enumerate(self.eeg_data[subject_id]):
            data = raw_chunk.get_data()
            sfreq = raw_chunk.info['sfreq']
            n_channels, n_timepoints = data.shape

            # Create sliding windows
            step_size = int(self.window_size * (1 - self.overlap))

            for window_idx, start_idx in enumerate(range(0, n_timepoints - self.window_size, step_size)):
                end_idx = start_idx + self.window_size
                window_data = data[:, start_idx:end_idx]

                # Create temporary Raw object
                info = raw_chunk.info.copy()
                raw_window = mne.io.RawArray(window_data, info, verbose=False)

                # Compute connectivity and save PLV matrices
                edge_indices, plv_matrices = graph_constructor.create_multi_band_edge_index(
                    raw_window,
                    subject_id=sub_num,
                    window_idx=window_idx
                )

                # Determine label
                label = self._get_window_label(start_idx, end_idx,
                                              self.events_data.get(subject_id, None),
                                              sfreq)

                self.samples.append(torch.tensor(window_data, dtype=torch.float32))
                self.labels.append(label)
                self.edge_indices.append(edge_indices)
                self.plv_features.append(plv_matrices)

    def _get_window_label(self, start_idx, end_idx, events, sfreq):
        """Determine label for a time window based on events"""
        if events is None:
            return 0  # Default to resting state

        start_time = start_idx / sfreq
        end_time = end_idx / sfreq

        # Check if any TMS event falls within this window
        for _, event in events.iterrows():
            event_time = event.get('onset', 0)
            if start_time <= event_time <= end_time:
                return 1  # TMS event present

        return 0  # No TMS event

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return {
            'x': self.samples[idx],
            'edge_indices': self.edge_indices[idx],
            'plv_features': self.plv_features[idx],
            'y': self.labels[idx]
        }

class MultiBandTemporalAttention(nn.Module):
    """
    Multi-band temporal attention mechanism incorporating PLV features
    """
    def __init__(self, input_dim, hidden_dim, num_heads=8, num_bands=4):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.num_bands = num_bands
        self.head_dim = hidden_dim // num_heads

        # Separate attention for each frequency band
        self.band_attentions = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads, dropout=0.1, batch_first=True)
            for _ in range(num_bands)
        ])

        # PLV feature integration
        self.plv_projection = nn.Linear(1, hidden_dim // 4)

        # Input projection
        self.input_projection = nn.Linear(input_dim, hidden_dim)

        # Output fusion
        self.fusion = nn.Linear(hidden_dim * num_bands + hidden_dim // 4, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, plv_features):
        # x shape: (batch_size, n_channels, seq_len)
        batch_size, n_channels, seq_len = x.shape

        # Project input to hidden dimension
        x_proj = self.input_projection(x.transpose(1, 2))  # (batch_size, seq_len, hidden_dim)

        # Process each frequency band
        band_outputs = []
        for i, attention in enumerate(self.band_attentions):
            # Apply attention for each band
            attended, _ = attention(x_proj, x_proj, x_proj)
            band_outputs.append(attended)

        # Concatenate band outputs
        multi_band_output = torch.cat(band_outputs, dim=-1)  # (batch_size, seq_len, hidden_dim * num_bands)

        # Process PLV features
        if plv_features is not None:
            # Aggregate PLV features across bands
            plv_vals = []
            for band_name in ['theta', 'alpha', 'beta', 'gamma']:
                if band_name in plv_features:
                    plv_matrix = plv_features[band_name]
                    # Take mean PLV as global connectivity measure
                    mean_plv = torch.tensor(np.mean(plv_matrix), dtype=torch.float32).view(1, 1)
                    plv_vals.append(mean_plv)

            if plv_vals:
                plv_tensor = torch.cat(plv_vals, dim=0).mean().view(1, 1, 1)  # (1, 1, 1)
                plv_tensor = plv_tensor.expand(batch_size, seq_len, 1)  # (batch_size, seq_len, 1)
                plv_features_proj = self.plv_projection(plv_tensor)  # (batch_size, seq_len, hidden_dim//4)

                # Combine multi-band output with PLV features
                combined = torch.cat([multi_band_output, plv_features_proj], dim=-1)
            else:
                combined = multi_band_output
        else:
            combined = multi_band_output

        # Fusion and normalization
        output = self.fusion(combined)
        output = self.layer_norm(output)
        output = self.dropout(output)

        return output.transpose(1, 2)  # Return to (batch_size, n_channels, hidden_dim)

class PLVSpatioTemporalGAT(nn.Module):
    """
    PLV-based Spatio-Temporal Graph Attention Network for TMS-EEG classification
    """
    def __init__(self, n_channels, seq_len, hidden_dim=64, num_heads=8, num_gat_layers=2, num_classes=2):
        super().__init__()
        self.n_channels = n_channels
        self.seq_len = seq_len
        self.hidden_dim = hidden_dim
        self.freq_bands = ['theta', 'alpha', 'beta', 'gamma']

        # Multi-band temporal attention with PLV integration
        self.temporal_attention = MultiBandTemporalAttention(
            seq_len, hidden_dim, num_heads, len(self.freq_bands)
        )

        # Separate GAT layers for each frequency band
        self.band_gat_layers = nn.ModuleDict()
        for band in self.freq_bands:
            self.band_gat_layers[band] = nn.ModuleList([
                GATConv(hidden_dim, hidden_dim, heads=num_heads, dropout=0.1, concat=False)
                for _ in range(num_gat_layers)
            ])

        # Cross-band fusion
        self.band_fusion = nn.Linear(hidden_dim * len(self.freq_bands), hidden_dim)

        # Global pooling and classification
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, x, edge_indices, plv_features, batch=None):
        # x shape: (batch_size, n_channels, seq_len)
        batch_size = x.size(0)

        # Apply multi-band temporal attention with PLV features
        x_temporal = self.temporal_attention(x, plv_features)  # (batch_size, n_channels, hidden_dim)

        # Process each frequency band separately
        band_outputs = []
        for band in self.freq_bands:
            if band in edge_indices and edge_indices[band].size(1) > 0:
                # Reshape for GAT processing
                x_band = x_temporal.view(batch_size * self.n_channels, self.hidden_dim)

                # Get edge index for this band
                edge_index = edge_indices[band]

                # Expand edge_index for batch processing
                if batch is None:
                    edge_index_batch = []
                    for i in range(batch_size):
                        offset = i * self.n_channels
                        edge_index_batch.append(edge_index + offset)
                    edge_index = torch.cat(edge_index_batch, dim=1)

                # Apply GAT layers for this band
                x_band_out = x_band
                for gat_layer in self.band_gat_layers[band]:
                    x_band_out = gat_layer(x_band_out, edge_index)
                    x_band_out = F.relu(x_band_out)
                    x_band_out = F.dropout(x_band_out, training=self.training)

                # Reshape back to batch format
                x_band_out = x_band_out.view(batch_size, self.n_channels, self.hidden_dim)
                band_outputs.append(x_band_out)
            else:
                # If no edges for this band, use temporal features directly
                band_outputs.append(x_temporal)

        # Fuse information from all frequency bands
        if len(band_outputs) > 1:
            x_fused = torch.cat(band_outputs, dim=-1)  # (batch_size, n_channels, hidden_dim * num_bands)
            x_fused = self.band_fusion(x_fused)  # (batch_size, n_channels, hidden_dim)
        else:
            x_fused = band_outputs[0]

        # Global pooling across channels
        x_fused = x_fused.transpose(1, 2)  # (batch_size, hidden_dim, n_channels)
        x_pooled = self.global_pool(x_fused).squeeze(-1)  # (batch_size, hidden_dim)

        # Classification
        logits = self.classifier(x_pooled)

        return logits

class TMSEEGClassifier:
    """
    Main classifier class that combines PLV-based data loading and model training
    """
    def __init__(self, root_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.root_dir = root_dir
        self.device = device
        self.data_loader = EEGDataLoader(root_dir)
        self.model = None
        self.scaler = StandardScaler()

    def load_data(self):
        """Load all EEG data"""
        print("Loading EEG data...")
        self.data_loader.load_all_subjects()

    def prepare_dataset(self, window_size=1000, overlap=0.5):
        """Prepare dataset for training with PLV-based connectivity"""
        print("Preparing dataset with PLV-based connectivity...")

        # Combine TMS and resting data
        combined_eeg = {}
        combined_events = {}

        # Add TMS data with events
        for subject_id in self.data_loader.tms_data.keys():
            combined_eeg[f"{subject_id}_tms"] = self.data_loader.tms_data[subject_id]
            combined_events[f"{subject_id}_tms"] = self.data_loader.events_data.get(subject_id, None)

        # Add resting data without events (background class)
        for subject_id in self.data_loader.resting_data.keys():
            combined_eeg[f"{subject_id}_rest"] = self.data_loader.resting_data[subject_id]
            combined_events[f"{subject_id}_rest"] = None

        # Create dataset with PLV-based connectivity
        self.dataset = TMSEEGDataset(combined_eeg, combined_events, window_size, overlap)

        print(f"Created dataset with {len(self.dataset)} samples")

        # Get data dimensions
        sample = self.dataset[0]
        self.n_channels = sample['x'].shape[0]
        self.seq_len = sample['x'].shape[1]

        print(f"Data dimensions: {self.n_channels} channels, {self.seq_len} time points")

    def create_model(self, hidden_dim=64, num_heads=8, num_gat_layers=2):
        """Create the PLV-based GAT model"""
        self.model = PLVSpatioTemporalGAT(
            n_channels=self.n_channels,
            seq_len=self.seq_len,
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            num_gat_layers=num_gat_layers
        ).to(self.device)

        print(f"Created PLV-based model with {sum(p.numel() for p in self.model.parameters())} parameters")

    def train(self, epochs=100, batch_size=32, learning_rate=0.001):
        """Train the PLV-based model"""
        if self.model is None:
            self.create_model()

        # Split dataset
        train_size = int(0.8 * len(self.dataset))
        val_size = len(self.dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            self.dataset, [train_size, val_size]
        )

        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        # Optimizer and loss
        optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        criterion = nn.CrossEntropyLoss()

        print("Starting PLV-based training...")

        for epoch in range(epochs):
            # Training
            self.model.train()
            train_loss = 0
            train_correct = 0
            train_total = 0

            for batch in train_loader:
                x = batch['x'].to(self.device)
                edge_indices = {k: v.to(self.device) for k, v in batch['edge_indices'][0].items()}
                plv_features = batch['plv_features'][0]
                y = torch.tensor(batch['y']).to(self.device)

                optimizer.zero_grad()
                logits = self.model(x, edge_indices, plv_features)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, predicted = torch.max(logits.data, 1)
                train_total += y.size(0)
                train_correct += (predicted == y).sum().item()

            # Validation
            self.model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for batch in val_loader:
                    x = batch['x'].to(self.device)
                    edge_indices = {k: v.to(self.device) for k, v in batch['edge_indices'][0].items()}
                    plv_features = batch['plv_features'][0]
                    y = torch.tensor(batch['y']).to(self.device)

                    logits = self.model(x, edge_indices, plv_features)
                    loss = criterion(logits, y)

                    val_loss += loss.item()
                    _, predicted = torch.max(logits.data, 1)
                    val_total += y.size(0)
                    val_correct += (predicted == y).sum().item()

            if epoch % 10 == 0:
                print(f"Epoch {epoch}/{epochs}")
                print(f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {100*train_correct/train_total:.2f}%")
                print(f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {100*val_correct/val_total:.2f}%")
                print("-" * 50)



In [None]:
# Example usage
if __name__ == "__main__":
    # Initialize classifier
    root_directory = "EEGTMS"  # Update this path
    classifier = TMSEEGClassifier(root_directory)

    # Load and prepare data
    classifier.load_data()
    classifier.prepare_dataset(window_size=1000, overlap=0.5)

    # Train model
    classifier.train(epochs=100, batch_size=16, learning_rate=0.001)
    print("PLV-based training completed!")