In [1]:
import torch.backends.cudnn as cudnn
# cudnn.enabled = False

import os
import glob
import mne
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


###############################################################################
# 1) Dataset that loads ALL epochs from a single .fif file
#    Returns only (data_torch, labels_torch) to avoid the string/tuple confusion
###############################################################################
class FIFFileDataset(Dataset):
    """
    Returns (data_torch, labels_torch) for a single .fif file:
      - data_torch: (num_epochs, n_channels, n_times)
      - labels_torch: (num_epochs,)
    """
    def __init__(
        self,
        epilepsy_dir: str,
        pnes_dir: str,
        expected_channels=22,
        expected_length=6300
    ):
        super().__init__()
        
        self.file_info = []
        self.expected_channels = expected_channels
        self.expected_length = expected_length

        # Gather epilepsy files => label=1
        for f in glob.glob(os.path.join(epilepsy_dir, "*.fif")):
            self.file_info.append((f, 1))  # (file_path_str, label_int)

        # Gather PNES files => label=0
        for f in glob.glob(os.path.join(pnes_dir, "*.fif")):
            self.file_info.append((f, 0))  # (file_path_str, label_int)

    def __len__(self):
        return len(self.file_info)

    def __getitem__(self, idx):
        # Here, file_path is a string, label is an int
        file_path, label = self.file_info[idx]

        # Load epochs from this file
        epochs = mne.read_epochs(file_path, preload=True, verbose=False)
        data = epochs.get_data(copy=True)  # => (num_epochs, n_channels, n_times)

        if data.shape[0] == 0:
            raise ValueError(f"No epochs found in file: {file_path}")
        if data.shape[1] != self.expected_channels:
            raise ValueError(
                f"Expected {self.expected_channels} channels, got {data.shape[1]} in {file_path}"
            )

        # Convert to torch
        data_torch = torch.from_numpy(data).float()  # => (num_epochs, n_channels, n_times)

        # Pad or trim along time dimension
        current_len = data_torch.size(-1)
        if current_len > self.expected_length:
            data_torch = data_torch[:, :, :self.expected_length]
        elif current_len < self.expected_length:
            pad_amt = self.expected_length - current_len
            data_torch = F.pad(data_torch, (0, pad_amt), "constant", 0.0)

        num_epochs = data_torch.size(0)
        labels_torch = torch.full((num_epochs,), label, dtype=torch.long)

        return data_torch, labels_torch


###############################################################################
# 2) LSTM + CNN Model
###############################################################################
class LSTMCNNNet(nn.Module):
    """
    LSTM (per channel) + CNN (across channels) for EEG classification.
    Input shape: (batch_size, num_eeg_channels, seq_len).
    """
    def __init__(
        self,
        input_size=1,
        lstm_hidden_size=32,
        lstm_num_layers=1,
        num_classes=2,
        num_eeg_channels=22
    ):
        super().__init__()
        self.num_eeg_channels = num_eeg_channels
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_num_layers = lstm_num_layers

        # LSTM to process each channel's time series
        # input_size=1 => we treat the signal as 1D per channel, sequence length = time
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_num_layers,
            batch_first=True
        )

        # CNN across channels (after LSTM). 
        # Once we gather the final hidden states from all channels,
        # we will have shape (batch_size, num_channels, lstm_hidden_size).
        # We'll permute to (batch_size, lstm_hidden_size, num_channels)
        # and do a 1D conv across the channel dimension.
        self.cnn = nn.Sequential(
            nn.Conv1d(in_channels=lstm_hidden_size, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)  # squeezes the channel dimension down to size 1
        )

        # Final linear layer
        self.fc = nn.Linear(16, num_classes)

    def forward(self, x):
        """
        x: (batch_size, num_eeg_channels, seq_len)
        """
        bsz = x.size(0)

        # We will run the LSTM for each channel separately and collect the final hidden state.
        channel_feats = []
        for ch in range(self.num_eeg_channels):
            # Extract the (batch_size, seq_len) for this channel
            single_channel = x[:, ch, :]  # => (bsz, seq_len)
            single_channel = single_channel.unsqueeze(-1)  # => (bsz, seq_len, 1)

            # LSTM returns (lstm_out, (h, c))
            # lstm_out => (bsz, seq_len, hidden_size)
            # h => (num_layers, bsz, hidden_size) 
            lstm_out, (h, c) = self.lstm(single_channel)
            # We'll use the top layer's final hidden state => h[-1], shape=(bsz, hidden_size)
            final_h = h[-1]  # => (bsz, lstm_hidden_size)
            channel_feats.append(final_h)

        # Stack over channel dimension => (bsz, num_channels, lstm_hidden_size)
        channel_feats = torch.stack(channel_feats, dim=1)

        # Permute so that we can do a 1D conv across channels => (bsz, lstm_hidden_size, num_channels)
        channel_feats = channel_feats.permute(0, 2, 1)

        # CNN across channels
        out = self.cnn(channel_feats)  # => (bsz, 16, 1)
        out = out.squeeze(-1)          # => (bsz, 16)

        # Final classifier
        logits = self.fc(out)          # => (bsz, num_classes)
        return logits


###############################################################################
# 3) Main training script (with chunking in the training loop)
###############################################################################
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("[Info] Building dataset...")
    train_dataset = FIFFileDataset(
        epilepsy_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\Epilepsy\train",
        pnes_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\PNES\train",
        expected_channels=22,
        expected_length=6300
    )
    val_dataset = FIFFileDataset(
        epilepsy_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\Epilepsy\val",
        pnes_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\PNES\val",
        expected_channels=22,
        expected_length=6300
    )
    print(f"[Info] #Train files: {len(train_dataset)}, #Val files: {len(val_dataset)}")

    # batch_size=1 => one file per iteration
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    print("[Info] Building LSTM+CNN model...")
    model = LSTMCNNNet(
        input_size=1,         # each channel is a 1D signal
        lstm_hidden_size=32,
        lstm_num_layers=1,
        num_classes=2,
        num_eeg_channels=22
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # CHUNK SIZE: number of epochs to process per forward pass
    chunk_size = 32

    print("[Info] Starting training...")
    n_epochs = 10
    for epoch in range(n_epochs):
        print(f"\n[Epoch {epoch+1}] ============")
        model.train()
        total_loss = 0.0
        num_files_train = 0

        for batch_idx, (file_data, file_labels) in enumerate(train_loader, start=1):
            # file_data => (num_epochs_in_file, channels, length)
            # file_labels => (num_epochs_in_file,)

            file_data = file_data.squeeze(0)
            file_labels = file_labels.squeeze(0)

            num_epochs_in_file = file_data.size(0)
            print(f"[Train] File {batch_idx}/{len(train_loader)}, epochs={num_epochs_in_file}")

            # Process the file in sub-batches (to avoid OOM)
            for start in range(0, num_epochs_in_file, chunk_size):
                end = start + chunk_size
                sub_data = file_data[start:end].to(device)
                sub_labels = file_labels[start:end].to(device)

                print(f"   [Chunk] epochs {start}..{end-1}, shape={sub_data.shape}", end=" ")

                optimizer.zero_grad()
                logits = model(sub_data)  # => (chunk_size, num_classes)
                loss = criterion(logits, sub_labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                print(f"loss={loss.item():.4f}")

            num_files_train += 1

        avg_train_loss = total_loss / num_files_train if num_files_train > 0 else 0.0
        print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        val_loss, correct, total, num_files_val = 0.0, 0, 0, 0

        print("[Validate] ----------------")
        with torch.no_grad():
            for batch_idx, (file_data, file_labels) in enumerate(val_loader, start=1):
                file_data = file_data.squeeze(0)
                file_labels = file_labels.squeeze(0)

                num_epochs_in_file = file_data.size(0)
                print(f"[Val] File {batch_idx}/{len(val_loader)}, epochs={num_epochs_in_file}")
                
                file_loss = 0.0
                file_correct = 0
                file_total = 0

                # Evaluate in chunks
                for start in range(0, num_epochs_in_file, chunk_size):
                    end = start + chunk_size
                    sub_data = file_data[start:end].to(device)
                    sub_labels = file_labels[start:end].to(device)

                    print(f"   [Chunk] epochs {start}..{end-1}, shape={sub_data.shape}", end=" ")

                    logits = model(sub_data)
                    loss = criterion(logits, sub_labels)

                    file_loss += loss.item()
                    preds = torch.argmax(logits, dim=1)
                    file_correct += (preds == sub_labels).sum().item()
                    file_total += sub_labels.size(0)
                    print(f"loss={loss.item():.4f}")

                val_loss += file_loss
                correct += file_correct
                total += file_total
                num_files_val += 1

        if num_files_val > 0:
            avg_val_loss = val_loss / num_files_val
            accuracy = 100.0 * correct / total
        else:
            avg_val_loss = 0.0
            accuracy = 0.0
        print(f"[Epoch {epoch+1}] Val Loss: {avg_val_loss:.4f}, Val Acc: {accuracy:.2f}%")

    print("[Info] Training complete!")


if __name__ == "__main__":
    main()


[Info] Building dataset...
[Info] #Train files: 73, #Val files: 9
[Info] Building LSTM+CNN model...
[Info] Starting training...

[Train] File 1/73, epochs=532
   [Chunk] epochs 0..31, shape=torch.Size([32, 22, 6300]) loss=0.6835
   [Chunk] epochs 32..63, shape=torch.Size([32, 22, 6300]) loss=0.6729
   [Chunk] epochs 64..95, shape=torch.Size([32, 22, 6300]) loss=0.6644
   [Chunk] epochs 96..127, shape=torch.Size([32, 22, 6300]) loss=0.6536
   [Chunk] epochs 128..159, shape=torch.Size([32, 22, 6300]) loss=0.6489
   [Chunk] epochs 160..191, shape=torch.Size([32, 22, 6300]) loss=0.6427
   [Chunk] epochs 192..223, shape=torch.Size([32, 22, 6300]) loss=0.6215
   [Chunk] epochs 224..255, shape=torch.Size([32, 22, 6300]) loss=0.6112
   [Chunk] epochs 256..287, shape=torch.Size([32, 22, 6300]) loss=0.5915
   [Chunk] epochs 288..319, shape=torch.Size([32, 22, 6300]) loss=0.5927
   [Chunk] epochs 320..351, shape=torch.Size([32, 22, 6300]) loss=0.5851
   [Chunk] epochs 352..383, shape=torch.Size([

KeyboardInterrupt: 