In [4]:
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import mne

# For GCN
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse, to_undirected


class EEGEpochDataset(Dataset):
    """
    Reads .fif EPOCH files from two folders: epilepsy_dir, pnes_dir.
    Each file might contain multiple epochs.
    Each epoch is returned as one training sample.
    """
    def __init__(
        self, 
        epilepsy_dir: str, 
        pnes_dir: str,
        expected_channels=22,
        expected_length=6300
    ):
        super().__init__()
        self.epoch_data = []
        self.labels = []
        self.expected_channels = expected_channels
        self.expected_length = expected_length

        # Collect epilepsy epoch files => label = 1
        for f in glob.glob(os.path.join(epilepsy_dir, "*.fif")):
            self._load_epoch_file(f, label=1)

        # Collect PNES epoch files => label = 0
        for f in glob.glob(os.path.join(pnes_dir, "*.fif")):
            self._load_epoch_file(f, label=0)

    def _load_epoch_file(self, file_path, label):
        # Load as epochs
        epochs = mne.read_epochs(file_path, preload=True, verbose=False)
        data = epochs.get_data(copy=True)  # shape (n_epochs, n_channels, n_times)

        # Check channels
        if data.shape[1] != self.expected_channels:
            raise ValueError(f"Expected {self.expected_channels} channels, got {data.shape[1]} in {file_path}")

        # For each epoch in this file, store (epoch_data, label)
        for i in range(data.shape[0]):
            single_epoch = data[i]  # shape (n_channels, n_times)
            # Possibly trim or pad to expected_length
            if single_epoch.shape[1] > self.expected_length:
                single_epoch = single_epoch[:, :self.expected_length]
            elif single_epoch.shape[1] < self.expected_length:
                pad_amount = self.expected_length - single_epoch.shape[1]
                single_epoch_t = torch.from_numpy(single_epoch)
                single_epoch_t = F.pad(single_epoch_t, (0, pad_amount), "constant", 0.0)
                single_epoch = single_epoch_t.numpy()

            self.epoch_data.append(single_epoch)  # shape (21, 6300)
            self.labels.append(label)

    def __len__(self):
        return len(self.epoch_data)

    def __getitem__(self, idx):
        eeg_np = self.epoch_data[idx]
        label = self.labels[idx]
        eeg_tensor = torch.from_numpy(eeg_np).float()  # (21, 6300)
        return eeg_tensor, torch.tensor(label, dtype=torch.long)



###############################################################################
# 2) TCN Building Blocks (TemporalBlock, TemporalConvNet)
###############################################################################
class TemporalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, padding, dropout=0.2):
        super().__init__()
        self.conv1 = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=padding, dilation=dilation
        )
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(
            out_channels, out_channels, kernel_size,
            stride=stride, padding=padding, dilation=dilation
        )
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        # Residual if channel dims differ
        self.downsample = (
            nn.Conv1d(in_channels, out_channels, 1) 
            if in_channels != out_channels else None
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.dropout1(out)

        out = self.conv2(out)
        out = self.relu2(out)
        out = self.dropout2(out)

        # Trim if needed to match x's time dimension
        if out.size(2) > x.size(2):
            out = out[:, :, : x.shape[2]]

        res = x if self.downsample is None else self.downsample(x)
        if res.size(2) > out.size(2):
            res = res[:, :, :out.size(2)]

        out = out + res
        return self.relu(out)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super().__init__()
        layers = []
        for i in range(len(num_channels)):
            dilation_size = 2 ** i
            in_ch = num_inputs if i == 0 else num_channels[i - 1]
            out_ch = num_channels[i]
            padding = (kernel_size - 1) * dilation_size
            block = TemporalBlock(
                in_ch, out_ch, kernel_size, stride=1,
                dilation=dilation_size, padding=padding,
                dropout=dropout
            )
            layers.append(block)
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


###############################################################################
# 3) Hybrid TCN + GCN Model
###############################################################################
class EEGEpilepsyNet(nn.Module):
    """
    A hybrid TCN + GCN model for binary classification: Epilepsy vs PNES.
    Input shape: (batch_size, 21, seq_len)
    """
    def __init__(
        self,
        adjacency=None,
        in_channels=1,
        tcn_channels=[8, 16],
        hidden_dim=32,
        num_classes=2,
        kernel_size=2,
        dropout=0.2,
        num_eeg_channels=21
    ):
        super().__init__()
        self.num_eeg_channels = num_eeg_channels

        # Convert adjacency to edge_index for PyG
        if adjacency is not None:
            adjacency = (adjacency + adjacency.t()) / 2
            adjacency[adjacency > 0] = 1
            edge_index, _ = dense_to_sparse(adjacency)
            edge_index = to_undirected(edge_index)
        else:
            edge_index = None
        self.register_buffer("edge_index", edge_index)

        # TCN (per-channel)
        self.tcn = TemporalConvNet(
            num_inputs=in_channels,
            num_channels=tcn_channels,
            kernel_size=kernel_size,
            dropout=dropout
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.cnn_features = tcn_channels[-1]

        # GCN layers
        self.gcn1 = GCNConv(self.cnn_features, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)

        # Final classifier
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index=None):
        """
        x: (batch_size, 21, seq_len)
        """
        bsz = x.size(0)
        if edge_index is None:
            edge_index = self.edge_index
            if edge_index is None:
                raise ValueError("No adjacency or edge_index provided.")

        # TCN feature extraction per channel
        channel_feats = []
        for ch in range(self.num_eeg_channels):
            # single_channel => (bsz, 1, seq_len)
            single_channel = x[:, ch, :].unsqueeze(1)
            out = self.tcn(single_channel)           # => (bsz, tcn_channels[-1], seq_len')
            out = self.pool(out).squeeze(-1)         # => (bsz, tcn_channels[-1])
            channel_feats.append(out)

        # Stack => (bsz, 21, cnn_features)
        channel_feats = torch.stack(channel_feats, dim=1)

        # GCN over channels
        graph_embs = []
        for i in range(bsz):
            node_feats = channel_feats[i]            # (21, cnn_features)
            g = self.gcn1(node_feats, edge_index)    # => (21, hidden_dim)
            g = F.relu(g)
            g = self.gcn2(g, edge_index)             # => (21, hidden_dim)
            g = F.relu(g)
            graph_embs.append(g.mean(dim=0))         # => (hidden_dim,)

        graph_embs = torch.stack(graph_embs, dim=0)   # => (bsz, hidden_dim)
        return self.fc(graph_embs)


"""
        


        
"""

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Suppose your epoched files for epilepsy and PNES are located here:
    train_dataset = EEGEpochDataset(
        epilepsy_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\Epilepsy\train",
        pnes_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\PNES\train",
        expected_channels=22,
        expected_length=6300
    )
    val_dataset = EEGEpochDataset(
        epilepsy_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\Epilepsy\val",
        pnes_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\PNES\val",
        expected_channels=22,
        expected_length=6300
    )

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)

    # adjacency (dummy example)
    adjacency = torch.randint(0, 2, (21, 21)).float()

    # Your TCN+GCN model
    model = EEGEpilepsyNet(
        adjacency=adjacency,
        in_channels=1,
        tcn_channels=[8, 16],
        hidden_dim=32,
        num_classes=2,
        kernel_size=2,
        dropout=0.2,
        num_eeg_channels=22
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(10):
        model.train()
        total_loss = 0.0
        for batch_eeg, batch_labels in train_loader:
            batch_eeg, batch_labels = batch_eeg.to(device), batch_labels.to(device)
            optimizer.zero_grad()
            logits = model(batch_eeg)
            loss = criterion(logits, batch_labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"[Epoch {epoch}] Train Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for batch_eeg, batch_labels in val_loader:
                batch_eeg, batch_labels = batch_eeg.to(device), batch_labels.to(device)
                logits = model(batch_eeg)
                loss = criterion(logits, batch_labels)
                val_loss += loss.item()

                preds = torch.argmax(logits, dim=1)
                correct += (preds == batch_labels).sum().item()
                total += len(batch_labels)

        avg_val_loss = val_loss / len(val_loader)
        accuracy = 100.0 * correct / total
        print(f"    Val Loss: {avg_val_loss:.4f},  Val Acc: {accuracy:.2f}%")

if __name__ == "__main__":
    main()


KeyboardInterrupt: 