In [1]:
import os
import scipy.io as sio
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import torch.nn as nn
import torch.optim as optim

import os
import torch
# import scipy.io as sio
from torch.utils.data import Dataset
from collections import defaultdict
import numpy as np
class MatFileDataset(Dataset):
    def __init__(self, directory):
        self.features = []  # To store 4-channel fractal features
        self.signals = []   # To store 4-channel raw signals
        self.labels = []    # To store labels
        self.fractal_feature_length = None
        self._load_data(directory)
        self._validate_labels()

    def _load_data(self, directory):
        # Temporary storage for grouping channels
        grouped_data = defaultdict(lambda: defaultdict(list))

        # Step 1: Read all files and group by participant and window
        for filename in os.listdir(directory):
            if filename.endswith(".mat"):
                filepath = os.path.join(directory, filename)
                mat_data = sio.loadmat(filepath, struct_as_record=False, squeeze_me=True)

                all_window_features = mat_data.get("all_window_features")
                if all_window_features is None:
                    continue

                # Extract the channel name from the file name
                channel_name = filename.split('_')[-2]  # Assumes "TP10" or similar is always at the second last position

                for participant_index, participant_data in enumerate(all_window_features):
                    if participant_data is None:
                        continue

                    for window_index, window in enumerate(participant_data):
                        # Extract labels first for filtering
                        before_label = getattr(window, "before_label", None)
                        after_label = getattr(window, "after_label", None)
                        
                        # Filter condition: both labels must exist and be <=3
                        if (before_label is None or 
                            after_label is None or 
                            not (0 <= before_label <= 3) or 
                            not (0 <= after_label <= 3)):
                            continue

                        # Extract features and raw signals
                        fractal_features = []
                        raw_signals = []

                        if hasattr(window, "raw_window_signal") and window.raw_window_signal is not None:
                            raw_signals = window.raw_window_signal.flatten()
                        if hasattr(window, "Dq") and window.Dq is not None:
                            fractal_features = window.Dq.flatten()
                        
                        # Store data temporarily, grouped by participant and window
                        grouped_data[(participant_index, window_index)]["channels"].append(channel_name)
                        grouped_data[(participant_index, window_index)]["fractal"].append(torch.tensor(fractal_features, dtype=torch.float32))
                        grouped_data[(participant_index, window_index)]["signal"].append(torch.tensor(raw_signals, dtype=torch.float32))
                        grouped_data[(participant_index, window_index)]["label"] = (int(before_label), int(after_label))
                        
                        # Set fractal feature length from the first valid feature
                        if self.fractal_feature_length is None and len(fractal_features) > 0:
                            self.fractal_feature_length = len(fractal_features)

        # Step 2: Combine data for all channels
        for (participant_index, window_index), data in grouped_data.items():
            channels = data["channels"]
            if len(channels) != 4:  # Ensure all 4 channels are present
                continue
            
            # Stack fractal features and signals along a new dimension (channel dimension)
            fractal_features = torch.stack(data["fractal"], dim=0)  # Shape: (4, fractal_feature_length)
            raw_signals = torch.stack(data["signal"], dim=0)        # Shape: (4, signal_length)

            self.features.append(fractal_features)
            self.signals.append(raw_signals)
            self.labels.append(data["label"])

    def _validate_labels(self):
        """Ensure all labels are valid integers 0-3"""
        valid_before = all(0 <= lbl[0] <= 3 for lbl in self.labels)
        valid_after = all(0 <= lbl[1] <= 3 for lbl in self.labels)
        if not (valid_before and valid_after):
            invalid = [
                (i, lbl) for i, lbl in enumerate(self.labels)
                if not (0 <= lbl[0] <= 3 and 0 <= lbl[1] <= 3)
            ]
            raise ValueError(f"Invalid labels found at indices: {invalid[:10]}")

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        x = {
            'fractal': self.features[idx],  # Shape: (4, fractal_feature_length)
            'signal': self.signals[idx],    # Shape: (4, signal_length)
        }
        y = {
            "before_label": torch.tensor(self.labels[idx][0], dtype=torch.long),
            "after_label": torch.tensor(self.labels[idx][1], dtype=torch.long),
        }
        return x, y


In [3]:
dataset = MatFileDataset("/Users/athenasaghi/VSProjects/CognitiveFatigueDetection/Prediction/")
print(f"Loaded {len(dataset)} valid samples")

Loaded 28908 valid samples


In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EEGTransformerClassifier(nn.Module):
    def __init__(self, input_channels, seq_length, num_classes, num_heads=4, d_model=64, num_layers=2, dim_feedforward=128, dropout=0.1):
        super(EEGTransformerClassifier, self).__init__()
        self.embedding = nn.Linear(input_channels, d_model)  # Embed input_channels to d_model
        self.positional_encoding = nn.Parameter(torch.zeros(1, seq_length, d_model))
        self.num_classes = num_classes
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=num_heads, 
            dim_feedforward=dim_feedforward, 
            dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Sequential(
            nn.Linear(seq_length * d_model, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes * 2)
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)  # Shape: [batch_size, seq_length, input_channels]
        x = self.embedding(x)  # Shape: [batch_size, seq_length, d_model]
        x = x + self.positional_encoding[:, :x.size(1), :]
        x = x.permute(1, 0, 2)  # Shape: [seq_length, batch_size, d_model]
        x = self.transformer_encoder(x)
        x = x.permute(1, 0, 2)  # Shape: [batch_size, seq_length, d_model]
        x = x.flatten(1)  # Shape: [batch_size, seq_length * d_model]
        x = self.fc(x)
        x = x.view(-1, 2, self.num_classes)  # Shape: [batch_size, 2, num_classes]
        return x

In [20]:
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# batch_size = 32
# input_channels = 4
# seq_length = 41
# num_classes = 4
# learning_rate = 0.001

# num_epochs = 10

# criterion = nn.CrossEntropyLoss()
# model = EEGTransformerClassifier(input_channels, seq_length, num_classes)
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# for epoch in range(num_epochs):

#     for batch in dataloader:
#         x, y = batch
#         # print(x['fractal'].shape, x['signal'].shape)
#         # print(y['before_label'].shape, y['after_label'].shape)
#         if np.isinf(x['fractal']).any():
#             max_finite = torch.max(x['fractal'][~torch.isinf(x['fractal'])])
#             x['fractal'][torch.isinf(x['fractal'])] = max_finite
#         if np.isinf(x['fractal']).any():
#             raise ValueError("Infinite values found in fractal features")
#         if np.isnan(x['fractal']).any():
#             mean_finite = torch.mean(x['fractal'][~torch.isnan(x['fractal'])])
#             x['fractal'][torch.isnan(x['fractal'])] = mean_finite
        

#         inputs = x['fractal']
#         outputs = model(inputs)
#         loss_1 = criterion(outputs[:, 0, :], y['before_label'])
#         loss_2 = criterion(outputs[:, 1, :], y['after_label'])
#         loss = loss_1 + loss_2
#         loss.backward()
#         optimizer.step()

#         with torch.no_grad():
#             pred_1 = outputs[:, 0, :].argmax(dim=1)
#             pred_2 = outputs[:, 1, :].argmax(dim=1)
#             acc_1 = (pred_1 == y['before_label']).float().mean().item()
#             acc_2 = (pred_2 == y['after_label']).float().mean().item()

#     print(f"Epoch {epoch + 1}/{num_epochs}, Loss_1: {loss_1.item():.4f}, Loss_2: {loss_2.item():.4f}, Total Loss: {loss.item():.4f}, Acc_1: {acc_1:.4f}, Acc_2: {acc_2:.4f}")


In [21]:
# from torch.utils.data import DataLoader, random_split
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.optim as optim

# # Dataset split
# train_size = int(0.7 * len(dataset))
# val_size = int(0.15 * len(dataset))
# test_size = len(dataset) - train_size - val_size
# train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# batch_size = 16


# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# input_channels = 4
# seq_length = 41
# num_classes = 4
# learning_rate = 0.01

# num_epochs = 10

# criterion = nn.CrossEntropyLoss()
# model = EEGTransformerClassifier(input_channels, seq_length, num_classes)
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# def process_batch(loader, model, criterion, optimizer=None, train_mode=False):
#     total_loss = 0
#     total_acc_1 = 0
#     total_acc_2 = 0
#     if train_mode:
#         model.train()
#     else:
#         model.eval()

#     with torch.set_grad_enabled(train_mode):
#         for batch in loader:
#             x, y = batch
#             if np.isinf(x['fractal']).any():
#                 max_finite = torch.max(x['fractal'][~torch.isinf(x['fractal'])])
#                 x['fractal'][torch.isinf(x['fractal'])] = max_finite
#             if np.isnan(x['fractal']).any():
#                 mean_finite = torch.mean(x['fractal'][~torch.isnan(x['fractal'])])
#                 x['fractal'][torch.isnan(x['fractal'])] = mean_finite

#             inputs = x['fractal']
#             outputs = model(inputs)
#             loss_1 = criterion(outputs[:, 0, :], y['before_label'])
#             loss_2 = criterion(outputs[:, 1, :], y['after_label'])
#             loss = loss_1 + loss_2

#             if train_mode:
#                 optimizer.zero_grad()
#                 loss.backward()
#                 optimizer.step()

#             total_loss += loss.item()
#             pred_1 = outputs[:, 0, :].argmax(dim=1)
#             pred_2 = outputs[:, 1, :].argmax(dim=1)
#             total_acc_1 += (pred_1 == y['before_label']).float().mean().item()
#             total_acc_2 += (pred_2 == y['after_label']).float().mean().item()

#     avg_loss = total_loss / len(loader)
#     avg_acc_1 = total_acc_1 / len(loader)
#     avg_acc_2 = total_acc_2 / len(loader)

#     return avg_loss, avg_acc_1, avg_acc_2

# for epoch in range(num_epochs):
#     train_loss, train_acc_1, train_acc_2 = process_batch(train_loader, model, criterion, optimizer, train_mode=True)
#     val_loss, val_acc_1, val_acc_2 = process_batch(val_loader, model, criterion, train_mode=False)

#     print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc_1: {train_acc_1:.4f}, Train Acc_2: {train_acc_2:.4f}, Val Acc_1: {val_acc_1:.4f}, Val Acc_2: {val_acc_2:.4f}")

# test_loss, test_acc_1, test_acc_2 = process_batch(test_loader, model, criterion, train_mode=False)

# print(f"Test Loss: {test_loss:.4f}, Test Acc_1: {test_acc_1:.4f}, Test Acc_2: {test_acc_2:.4f}")


In [26]:
from torch.utils.data import DataLoader, random_split
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Dataset split
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


input_channels = 4
seq_length = 41
num_classes = 4
learning_rate = 0.001

num_epochs = 10

criterion = nn.CrossEntropyLoss()
model = EEGTransformerClassifier(input_channels, seq_length, num_classes)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

def process_batch(loader, model, criterion, optimizer=None, train_mode=False):
    total_loss = 0
    total_acc = 0
    if train_mode:
        model.train()
    else:
        model.eval()

    with torch.set_grad_enabled(train_mode):
        for batch in loader:
            x, y = batch
            label = y['before_label']
            if np.isinf(x['fractal']).any():
                max_finite = torch.max(x['fractal'][~torch.isinf(x['fractal'])])
                x['fractal'][torch.isinf(x['fractal'])] = max_finite
            if np.isnan(x['fractal']).any():
                mean_finite = torch.mean(x['fractal'][~torch.isnan(x['fractal'])])
                x['fractal'][torch.isnan(x['fractal'])] = mean_finite

            inputs = x['fractal']
            outputs = model(inputs)
            loss = criterion(outputs[:, 0, :], label)

            if train_mode:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_loss += loss.item()
            preds = outputs[:, 0, :].argmax(dim=1)
            total_acc += (preds == label).float().mean().item()

    avg_loss = total_loss / len(loader)
    avg_acc = total_acc / len(loader)

    return avg_loss, avg_acc

for epoch in range(num_epochs):
    train_loss, train_acc = process_batch(train_loader, model, criterion, optimizer, train_mode=True)
    val_loss, val_acc = process_batch(val_loader, model, criterion, train_mode=False)

    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

test_loss, test_acc = process_batch(test_loader, model, criterion, train_mode=False)

print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")


Epoch 1/10, Train Loss: 0.9862, Val Loss: 0.9859, Train Acc: 0.6297, Val Acc: 0.6264
Epoch 2/10, Train Loss: 0.9733, Val Loss: 0.9722, Train Acc: 0.6317, Val Acc: 0.6264
Epoch 3/10, Train Loss: 0.9710, Val Loss: 0.9655, Train Acc: 0.6317, Val Acc: 0.6264
Epoch 4/10, Train Loss: 0.9703, Val Loss: 0.9695, Train Acc: 0.6317, Val Acc: 0.6264
Epoch 5/10, Train Loss: 0.9689, Val Loss: 0.9653, Train Acc: 0.6316, Val Acc: 0.6264
Epoch 6/10, Train Loss: 0.9698, Val Loss: 0.9647, Train Acc: 0.6315, Val Acc: 0.6264
Epoch 7/10, Train Loss: 0.9675, Val Loss: 0.9688, Train Acc: 0.6315, Val Acc: 0.6264
Epoch 8/10, Train Loss: 0.9663, Val Loss: 0.9681, Train Acc: 0.6318, Val Acc: 0.6264
Epoch 9/10, Train Loss: 0.9659, Val Loss: 0.9610, Train Acc: 0.6314, Val Acc: 0.6264
Epoch 10/10, Train Loss: 0.9641, Val Loss: 0.9665, Train Acc: 0.6319, Val Acc: 0.6264
Test Loss: 0.9515, Test Acc: 0.6404
