In [None]:
import torch.backends.cudnn as cudnn
# cudnn.enabled = False

import os
import glob
import mne
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse, to_undirected


###############################################################################
# 1) Dataset that loads ALL epochs from a single .fif file
#    Returns only (data_torch, labels_torch) to avoid the string/tuple confusion
###############################################################################
class FIFFileDataset(Dataset):
    """
    Returns (data_torch, labels_torch) for a single .fif file:
      - data_torch: (num_epochs, n_channels, n_times)
      - labels_torch: (num_epochs,)
    """
    def __init__(
        self,
        epilepsy_dir: str,
        pnes_dir: str,
        expected_channels=22,
        expected_length=6300
    ):
        super().__init__()
        
        self.file_info = []
        self.expected_channels = expected_channels
        self.expected_length = expected_length

        # Gather epilepsy files => label=1
        for f in glob.glob(os.path.join(epilepsy_dir, "*.fif")):
            self.file_info.append((f, 1))  # (file_path_str, label_int)

        # Gather PNES files => label=0
        for f in glob.glob(os.path.join(pnes_dir, "*.fif")):
            self.file_info.append((f, 0))  # (file_path_str, label_int)

    def __len__(self):
        return len(self.file_info)

    def __getitem__(self, idx):
        # Here, file_path is a string, label is an int
        file_path, label = self.file_info[idx]

        # Load epochs from this file
        # (Keeping prints minimal)
        epochs = mne.read_epochs(file_path, preload=True, verbose=False)
        data = epochs.get_data(copy=True)  # => (num_epochs, n_channels, n_times)

        if data.shape[0] == 0:
            raise ValueError(f"No epochs found in file: {file_path}")
        if data.shape[1] != self.expected_channels:
            raise ValueError(
                f"Expected {self.expected_channels} channels, got {data.shape[1]} in {file_path}"
            )

        # Convert to torch
        data_torch = torch.from_numpy(data).float()  # => (num_epochs, n_channels, n_times)

        # Pad or trim along time dimension
        current_len = data_torch.size(-1)
        if current_len > self.expected_length:
            data_torch = data_torch[:, :, :self.expected_length]
        elif current_len < self.expected_length:
            pad_amt = self.expected_length - current_len
            data_torch = F.pad(data_torch, (0, pad_amt), "constant", 0.0)

        num_epochs = data_torch.size(0)
        labels_torch = torch.full((num_epochs,), label, dtype=torch.long)

        return data_torch, labels_torch


###############################################################################
# 2) TCN building blocks
###############################################################################
class TemporalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, padding, dropout=0.2):
        super().__init__()
        self.conv1 = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=padding, dilation=dilation
        )
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(
            out_channels, out_channels, kernel_size,
            stride=stride, padding=padding, dilation=dilation
        )
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.downsample = (
            nn.Conv1d(in_channels, out_channels, 1) 
            if in_channels != out_channels else None
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.dropout1(out)

        out = self.conv2(out)
        out = self.relu2(out)
        out = self.dropout2(out)

        if out.size(2) > x.size(2):
            out = out[:, :, :x.shape[2]]

        res = x if self.downsample is None else self.downsample(x)
        if res.size(2) > out.size(2):
            res = res[:, :, :out.size(2)]
        out = out + res
        return self.relu(out)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super().__init__()
        layers = []
        for i in range(len(num_channels)):
            dilation_size = 2 ** i
            in_ch = num_inputs if i == 0 else num_channels[i - 1]
            out_ch = num_channels[i]
            padding = (kernel_size - 1) * dilation_size
            block = TemporalBlock(
                in_ch, out_ch, kernel_size, stride=1,
                dilation=dilation_size, padding=padding,
                dropout=dropout
            )
            layers.append(block)
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


###############################################################################
# 3) Hybrid TCN + GCN Model
###############################################################################
class EEGEpilepsyNet(nn.Module):
    """
    TCN + GCN for EEG classification.
    Input shape: (batch_size, num_eeg_channels, seq_len).
    """
    def __init__(
        self,
        adjacency=None,
        in_channels=1,
        tcn_channels=[8, 16],
        hidden_dim=32,
        num_classes=2,
        kernel_size=2,
        dropout=0.2,
        num_eeg_channels=22
    ):
        super().__init__()
        self.num_eeg_channels = num_eeg_channels

        # Convert adjacency => edge_index for PyG
        if adjacency is not None:
            adjacency = (adjacency + adjacency.t()) / 2
            adjacency[adjacency > 0] = 1
            edge_index, _ = dense_to_sparse(adjacency)
            edge_index = to_undirected(edge_index)
        else:
            edge_index = None
        self.register_buffer("edge_index", edge_index)

        # TCN part
        self.tcn = TemporalConvNet(
            num_inputs=in_channels,
            num_channels=tcn_channels,
            kernel_size=kernel_size,
            dropout=dropout
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.cnn_features = tcn_channels[-1]

        # GCN part
        self.gcn1 = GCNConv(self.cnn_features, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)

        # Final classifier
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index=None):
        """
        x: (batch_size, num_eeg_channels, seq_len)
        """
        bsz = x.size(0)
        if edge_index is None:
            edge_index = self.edge_index
            if edge_index is None:
                raise ValueError("No adjacency or edge_index provided.")

        # TCN feature extraction per channel
        channel_feats = []
        for ch in range(self.num_eeg_channels):
            single_channel = x[:, ch, :].unsqueeze(1)  # => (bsz, 1, seq_len)
            out = self.tcn(single_channel)             # => (bsz, tcn_out, seq_len')
            out = self.pool(out).squeeze(-1)           # => (bsz, tcn_out)
            channel_feats.append(out)

        # Stack => (bsz, num_eeg_channels, cnn_features)
        channel_feats = torch.stack(channel_feats, dim=1)

        # GCN over channels
        graph_embs = []
        for i in range(bsz):
            node_feats = channel_feats[i]              # => (num_eeg_channels, cnn_features)
            g = self.gcn1(node_feats, edge_index)      # => (num_eeg_channels, hidden_dim)
            g = F.relu(g)
            g = self.gcn2(g, edge_index)               # => (num_eeg_channels, hidden_dim)
            g = F.relu(g)
            graph_embs.append(g.mean(dim=0))           # => (hidden_dim,)

        graph_embs = torch.stack(graph_embs, dim=0)     # => (bsz, hidden_dim)
        return self.fc(graph_embs)


###############################################################################
# 4) Main with CHUNKING in the training loop (less verbose, no basename calls)
###############################################################################
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("[Info] Building dataset...")
    train_dataset = FIFFileDataset(
        epilepsy_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\Epilepsy\train",
        pnes_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\PNES\train",
        expected_channels=22,
        expected_length=6300
    )
    val_dataset = FIFFileDataset(
        epilepsy_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\Epilepsy\val",
        pnes_dir=r"C:\Users\mhfar\OneDrive\Desktop\New folder\20second_MNE_2CLASS\PNES\val",
        expected_channels=22,
        expected_length=6300
    )
    print(f"[Info] #Train files: {len(train_dataset)}, #Val files: {len(val_dataset)}")

    # batch_size=1 => one file per iteration
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    # Dummy adjacency
    adjacency = torch.randint(0, 2, (22, 22)).float()

    print("[Info] Building model...")
    model = EEGEpilepsyNet(
        adjacency=adjacency,
        in_channels=1,
        tcn_channels=[8, 16],
        hidden_dim=32,
        num_classes=2,
        kernel_size=2,
        dropout=0.2,
        num_eeg_channels=22
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # CHUNK SIZE: number of epochs to process per forward pass
    chunk_size = 16

    print("[Info] Starting training...")
    n_epochs = 10
    for epoch in range(n_epochs):
        print(f"\n[Epoch {epoch+1}] ============")
        model.train()
        total_loss = 0.0
        num_files_train = 0

        for batch_idx, (file_data, file_labels) in enumerate(train_loader, start=1):
            # file_data => (num_epochs_in_file, channels, length)
            # file_labels => (num_epochs_in_file,)

            file_data = file_data.squeeze(0)
            file_labels = file_labels.squeeze(0)

            num_epochs_in_file = file_data.size(0)
            print(f"[Train] File {batch_idx}/{len(train_loader)}, epochs={num_epochs_in_file}")

            # Process the file in sub-batches (to avoid OOM)
            for start in range(0, num_epochs_in_file, chunk_size):
                end = start + chunk_size
                sub_data = file_data[start:end].to(device)
                sub_labels = file_labels[start:end].to(device)

                print(f"   [Chunk] epochs {start}..{end-1}, shape={sub_data.shape}", end=" ")

                optimizer.zero_grad()
                logits = model(sub_data)
                loss = criterion(logits, sub_labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                print(f"loss={loss.item():.4f}")

            num_files_train += 1

        avg_train_loss = total_loss / num_files_train if num_files_train > 0 else 0.0
        print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        val_loss, correct, total, num_files_val = 0.0, 0, 0, 0

        print("[Validate] ----------------")
        with torch.no_grad():
            for batch_idx, (file_data, file_labels) in enumerate(val_loader, start=1):
                file_data = file_data.squeeze(0)
                file_labels = file_labels.squeeze(0)

                num_epochs_in_file = file_data.size(0)
                print(f"[Val] File {batch_idx}/{len(val_loader)}, epochs={num_epochs_in_file}")
                
                file_loss = 0.0
                file_correct = 0
                file_total = 0

                # Evaluate in chunks
                for start in range(0, num_epochs_in_file, chunk_size):
                    end = start + chunk_size
                    sub_data = file_data[start:end].to(device)
                    sub_labels = file_labels[start:end].to(device)

                    print(f"   [Chunk] epochs {start}..{end-1}, shape={sub_data.shape}", end=" ")

                    logits = model(sub_data)
                    loss = criterion(logits, sub_labels)

                    file_loss += loss.item()
                    preds = torch.argmax(logits, dim=1)
                    file_correct += (preds == sub_labels).sum().item()
                    file_total += sub_labels.size(0)
                    print(f"loss={loss.item():.4f}")

                val_loss += file_loss
                correct += file_correct
                total += file_total
                num_files_val += 1

        if num_files_val > 0:
            avg_val_loss = val_loss / num_files_val
            accuracy = 100.0 * correct / total
        else:
            avg_val_loss = 0.0
            accuracy = 0.0
        print(f"[Epoch {epoch+1}] Val Loss: {avg_val_loss:.4f}, Val Acc: {accuracy:.2f}%")

    print("[Info] Training complete!")


if __name__ == "__main__":
    main()
