In [1]:
import os
import scipy.io as sio
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import torch.nn as nn
import torch.optim as optim

import os
import torch
# import scipy.io as sio
from torch.utils.data import Dataset
from collections import defaultdict
import numpy as np
class MatFileDataset(Dataset):
    def __init__(self, directory):
        self.features = []  # To store 4-channel fractal features
        self.signals = []   # To store 4-channel raw signals
        self.labels = []    # To store labels
        self.fractal_feature_length = None
        self._load_data(directory)
        self._validate_labels()

    def _load_data(self, directory):
        # Temporary storage for grouping channels
        grouped_data = defaultdict(lambda: defaultdict(list))

        # Step 1: Read all files and group by participant and window
        for filename in os.listdir(directory):
            if filename.endswith(".mat"):
                filepath = os.path.join(directory, filename)
                mat_data = sio.loadmat(filepath, struct_as_record=False, squeeze_me=True)

                all_window_features = mat_data.get("all_window_features")
                if all_window_features is None:
                    continue

                # Extract the channel name from the file name
                channel_name = filename.split('_')[-2]  # Assumes "TP10" or similar is always at the second last position

                for participant_index, participant_data in enumerate(all_window_features):
                    if participant_data is None:
                        continue

                    for window_index, window in enumerate(participant_data):
                        # Extract labels first for filtering
                        before_label = getattr(window, "before_label", None)
                        after_label = getattr(window, "after_label", None)
                        
                        # Filter condition: both labels must exist and be <=3
                        if (before_label is None or 
                            after_label is None or 
                            not (0 <= before_label <= 3) or 
                            not (0 <= after_label <= 3)):
                            continue

                        # Extract features and raw signals
                        fractal_features = []
                        raw_signals = []

                        if hasattr(window, "raw_window_signal") and window.raw_window_signal is not None:
                            raw_signals = window.raw_window_signal.flatten()
                        if hasattr(window, "Dq") and window.Dq is not None:
                            fractal_features = window.Dq.flatten()
                        
                        # Store data temporarily, grouped by participant and window
                        grouped_data[(participant_index, window_index)]["channels"].append(channel_name)
                        grouped_data[(participant_index, window_index)]["fractal"].append(torch.tensor(fractal_features, dtype=torch.float32))
                        grouped_data[(participant_index, window_index)]["signal"].append(torch.tensor(raw_signals, dtype=torch.float32))
                        grouped_data[(participant_index, window_index)]["label"] = (int(before_label), int(after_label))
                        
                        # Set fractal feature length from the first valid feature
                        if self.fractal_feature_length is None and len(fractal_features) > 0:
                            self.fractal_feature_length = len(fractal_features)

        # Step 2: Combine data for all channels
        for (participant_index, window_index), data in grouped_data.items():
            channels = data["channels"]
            if len(channels) != 4:  # Ensure all 4 channels are present
                continue
            
            # Stack fractal features and signals along a new dimension (channel dimension)
            fractal_features = torch.stack(data["fractal"], dim=0)  # Shape: (4, fractal_feature_length)
            raw_signals = torch.stack(data["signal"], dim=0)        # Shape: (4, signal_length)

            self.features.append(fractal_features)
            self.signals.append(raw_signals)
            self.labels.append(data["label"])

    def _validate_labels(self):
        """Ensure all labels are valid integers 0-3"""
        valid_before = all(0 <= lbl[0] <= 3 for lbl in self.labels)
        valid_after = all(0 <= lbl[1] <= 3 for lbl in self.labels)
        if not (valid_before and valid_after):
            invalid = [
                (i, lbl) for i, lbl in enumerate(self.labels)
                if not (0 <= lbl[0] <= 3 and 0 <= lbl[1] <= 3)
            ]
            raise ValueError(f"Invalid labels found at indices: {invalid[:10]}")

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        x = {
            'fractal': self.features[idx],  # Shape: (4, fractal_feature_length)
            'signal': self.signals[idx],    # Shape: (4, signal_length)
        }
        y = {
            "before_label": torch.tensor(self.labels[idx][0], dtype=torch.long),
            "after_label": torch.tensor(self.labels[idx][1], dtype=torch.long),
        }
        return x, y


In [5]:
dataset = MatFileDataset("/Users/athenasaghi/VSProjects/CognitiveFatigueDetection/Prediction/")
print(f"Loaded {len(dataset)} valid samples")

Loaded 28908 valid samples


In [24]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Dataset split
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Helper function to preprocess batches
def preprocess_data(loader):
    X, y = [], []
    for batch in loader:
        x, labels = batch
        label = labels['before_label']

        # Handle inf/nan values in 'fractal'
        if np.isinf(x['fractal']).any():
            max_finite = torch.max(x['fractal'][~torch.isinf(x['fractal'])])
            x['fractal'][torch.isinf(x['fractal'])] = max_finite
        if np.isnan(x['fractal']).any():
            mean_finite = torch.mean(x['fractal'][~torch.isnan(x['fractal'])])
            x['fractal'][torch.isnan(x['fractal'])] = mean_finite

        inputs = x['fractal']
        X.append(inputs.numpy().reshape(inputs.shape[0], -1))
        y.append(label.numpy())

    return np.vstack(X), np.hstack(y)

# Preprocess data for non-PyTorch models
X_train, y_train = preprocess_data(train_loader)
X_val, y_val = preprocess_data(val_loader)
X_test, y_test = preprocess_data(test_loader)

# Multichannel LSTM model definition
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        _, (h_n, _) = self.lstm(x)  # h_n: hidden state of the last LSTM cell
        out = self.fc(h_n[-1])     # Fully connected layer for classification
        return out

input_size = X_train.shape[1]
hidden_size = 64
num_classes = len(np.unique(y_train))
learning_rate = 0.001
num_epochs = 10

lstm_model = LSTMClassifier(41, hidden_size, num_classes)
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(lstm_model.parameters(), lr=learning_rate)
optimizer = torch.optim.SGD(lstm_model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
    lstm_model.train()
    total_loss = 0
    total_acc = 0

    for batch in train_loader:
        x, labels = batch
        label = labels['before_label']

        # Handle inf/nan values in 'fractal'
        if np.isinf(x['fractal']).any():
            max_finite = torch.max(x['fractal'][~torch.isinf(x['fractal'])])
            x['fractal'][torch.isinf(x['fractal'])] = max_finite
        if np.isnan(x['fractal']).any():
            mean_finite = torch.mean(x['fractal'][~torch.isnan(x['fractal'])])
            x['fractal'][torch.isnan(x['fractal'])] = mean_finite

        # Reshape inputs to (batch_size, seq_length, num_channels)
        inputs = x['fractal'].float()  # Ensure inputs are float
        labels = label.long()  # Ensure labels are of type long

        optimizer.zero_grad()
        outputs = lstm_model(inputs)  # Forward pass

        # Calculate loss
        loss = criterion(outputs, labels)
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        total_acc += (preds == labels).float().mean().item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}, Acc: {total_acc / len(train_loader):.4f}")

# Evaluate LSTM
lstm_model.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for batch in test_loader:
        x, labels = batch
        label = labels['before_label']

        # Handle inf/nan values in 'fractal'
        if np.isinf(x['fractal']).any():
            max_finite = torch.max(x['fractal'][~torch.isinf(x['fractal'])])
            x['fractal'][torch.isinf(x['fractal'])] = max_finite
        if np.isnan(x['fractal']).any():
            mean_finite = torch.mean(x['fractal'][~torch.isnan(x['fractal'])])
            x['fractal'][torch.isnan(x['fractal'])] = mean_finite

        # Reshape inputs to (batch_size, seq_length, num_channels)
        inputs = x['fractal'].float()
        labels = label.long()

        # Forward pass
        outputs = lstm_model(inputs)
        preds = outputs.argmax(dim=1)

        # Append to true and predicted lists
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

# Calculate metrics
lstm_acc = accuracy_score(y_true, y_pred)
lstm_precision = precision_score(y_true, y_pred, average='weighted')
lstm_recall = recall_score(y_true, y_pred, average='weighted')
lstm_f1 = f1_score(y_true, y_pred, average='weighted')

print(f"LSTM Test Accuracy: {lstm_acc:.4f}")
print(f"LSTM Test Precision: {lstm_precision:.4f}")
print(f"LSTM Test Recall: {lstm_recall:.4f}")
print(f"LSTM Test F1 Score: {lstm_f1:.4f}")

# Train and evaluate XGBoost
xgb_model = XGBClassifier()
xgb_model.fit(X_train, y_train)
xgb_preds = xgb_model.predict(X_test)
xgb_acc = accuracy_score(y_test, xgb_preds)
xgb_precision = precision_score(y_test, xgb_preds, average='weighted')
xgb_recall = recall_score(y_test, xgb_preds, average='weighted')
xgb_f1 = f1_score(y_test, xgb_preds, average='weighted')

print(f"XGBoost Test Accuracy: {xgb_acc:.4f}")
print(f"XGBoost Test Precision: {xgb_precision:.4f}")
print(f"XGBoost Test Recall: {xgb_recall:.4f}")
print(f"XGBoost Test F1 Score: {xgb_f1:.4f}")

# Train and evaluate Random Forest
rf_model = RandomForestClassifier()
rf_model.fit(X_train, y_train)
rf_preds = rf_model.predict(X_test)
rf_acc = accuracy_score(y_test, rf_preds)
rf_precision = precision_score(y_test, rf_preds, average='weighted')
rf_recall = recall_score(y_test, rf_preds, average='weighted')
rf_f1 = f1_score(y_test, rf_preds, average='weighted')

print(f"Random Forest Test Accuracy: {rf_acc:.4f}")
print(f"Random Forest Test Precision: {rf_precision:.4f}")
print(f"Random Forest Test Recall: {rf_recall:.4f}")
print(f"Random Forest Test F1 Score: {rf_f1:.4f}")

# Train and evaluate SVM
svm_model = SVC()
svm_model.fit(X_train, y_train)
svm_preds = svm_model.predict(X_test)
svm_acc = accuracy_score(y_test, svm_preds)
svm_precision = precision_score(y_test, svm_preds, average='weighted')
svm_recall = recall_score(y_test, svm_preds, average='weighted')
svm_f1 = f1_score(y_test, svm_preds, average='weighted')

print(f"SVM Test Accuracy: {svm_acc:.4f}")
print(f"SVM Test Precision: {svm_precision:.4f}")
print(f"SVM Test Recall: {svm_recall:.4f}")
print(f"SVM Test F1 Score: {svm_f1:.4f}")

# Summary of results
print("\nModel Comparison:")
print(f"LSTM Accuracy: {lstm_acc:.4f}, Precision: {lstm_precision:.4f}, Recall: {lstm_recall:.4f}, F1 Score: {lstm_f1:.4f}")
print(f"XGBoost Accuracy: {xgb_acc:.4f}, Precision: {xgb_precision:.4f}, Recall: {xgb_recall:.4f}, F1 Score: {xgb_f1:.4f}")
print(f"Random Forest Accuracy: {rf_acc:.4f}, Precision: {rf_precision:.4f}, Recall: {rf_recall:.4f}, F1 Score: {rf_f1:.4f}")
print(f"SVM Accuracy: {svm_acc:.4f}, Precision: {svm_precision:.4f}, Recall: {svm_recall:.4f}, F1 Score: {svm_f1:.4f}")


Epoch [1/10], Loss: 1.2503, Acc: 0.5640
Epoch [2/10], Loss: 1.1136, Acc: 0.6339
Epoch [3/10], Loss: 1.0456, Acc: 0.6337
Epoch [4/10], Loss: 1.0126, Acc: 0.6339
Epoch [5/10], Loss: 0.9957, Acc: 0.6341
Epoch [6/10], Loss: 0.9870, Acc: 0.6338
Epoch [7/10], Loss: 0.9815, Acc: 0.6338
Epoch [8/10], Loss: 0.9778, Acc: 0.6338
Epoch [9/10], Loss: 0.9750, Acc: 0.6339
Epoch [10/10], Loss: 0.9730, Acc: 0.6340
LSTM Test Accuracy: 0.6279
LSTM Test Precision: 0.3942
LSTM Test Recall: 0.6279
LSTM Test F1 Score: 0.4843


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


XGBoost Test Accuracy: 0.6129
XGBoost Test Precision: 0.4693
XGBoost Test Recall: 0.6129
XGBoost Test F1 Score: 0.4927


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Random Forest Test Accuracy: 0.6262
Random Forest Test Precision: 0.4388
Random Forest Test Recall: 0.6262
Random Forest Test F1 Score: 0.4857
SVM Test Accuracy: 0.6279
SVM Test Precision: 0.3942
SVM Test Recall: 0.6279
SVM Test F1 Score: 0.4843

Model Comparison:
LSTM Accuracy: 0.6279, Precision: 0.3942, Recall: 0.6279, F1 Score: 0.4843
XGBoost Accuracy: 0.6129, Precision: 0.4693, Recall: 0.6129, F1 Score: 0.4927
Random Forest Accuracy: 0.6262, Precision: 0.4388, Recall: 0.6262, F1 Score: 0.4857
SVM Accuracy: 0.6279, Precision: 0.3942, Recall: 0.6279, F1 Score: 0.4843


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
