In [1]:
import mne
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch import load, save

# Define the neural network model
class NeuralNetwork(nn.Module):
    def __init__(self, n_channels, n_classes=3):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, (3, 3), 1, 1)
        self.conv2 = nn.Conv2d(16, 32, (3, 3), 1, 1)
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=0)
        self.lstm = nn.LSTM(input_size=40, hidden_size=64, num_layers=1, batch_first=True)
        self.fc = nn.Linear(64, n_classes)
    
    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1, x.size(-1))
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.fc(x)
        return x

# Custom Dataset class
class EEGDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long) if labels is not None else None
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.labels is not None:
            label = self.labels[idx]
            return sample, label
        return sample

# Function to preprocess EEG data
def preprocess_eeg_data(file_list):
    all_segments = []
    all_labels = []

    for file in file_list:
        raw_data = mne.io.read_raw_edf(file, preload=True)
        raw_data.pick(['Fp1.', 'Fp2.', 'F1..', 'F2..', 'Fz..', 'Pz..'])
        raw_data.load_data()
        raw_data.filter(1.0, 40.0, fir_design='firwin')

        # Extract events and labels
        events, event_id = mne.events_from_annotations(raw_data)
        epochs = mne.Epochs(raw_data, events, event_id=event_id, tmin=-0.2, tmax=0.8, baseline=(None, 0), preload=True)
        epochs.drop_bad()  # Drop bad epochs

        # Convert epochs to numpy array
        data = epochs.get_data()  # shape: (n_epochs, n_channels, n_times)
        labels = epochs.events[:, -1]  # Extract labels from events

        all_segments.extend(data)
        all_labels.extend(labels - 1)
    
    all_segments = np.array(all_segments, dtype=np.float32)
    all_labels = np.array(all_labels, dtype=np.long)
    return all_segments, all_labels

# List of .edf files and event IDs
file_list = ['S001R03.edf', 'S001R04.edf', 'S001R05.edf', 'S001R06.edf', 'S001R07.edf', 'S001R08.edf']

# Preprocess data
segments, labels = preprocess_eeg_data(file_list)

# Create Dataset and DataLoader
dataset = EEGDataset(segments, labels)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize and train the model
model = NeuralNetwork(n_channels=64, n_classes=3)
loss_fn = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(), lr=0.0001)

epochs = 10
for epoch in range(epochs):
    running_loss = 0
    for batch, labels in data_loader:
        batch = batch.unsqueeze(1)  # Add channel dimension if needed
        outputs = model(batch)
        loss = loss_fn(outputs, labels)
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        running_loss += loss.item()
    print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(data_loader):.4f}')


with open('model_state.pt', 'wb') as file :
    save(model.state_dict(), file)

Extracting EDF parameters from c:\Users\aryas\OneDrive\Desktop\Brain Computing\S001R03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 529 samples (3.306 s)

Used Annotations descriptions: [np.str_('T0'), np.str_('T1'), np.str_('T2')]
Not setting metadata
30 matching events found
Setting baseline interval to [-0.2, 0.0] s
Applying baseline correction (mode: mean

In [2]:
with open('model_state.pt', 'rb') as file :
    model.load_state_dict(load(file))

raw = mne.io.read_raw_edf('S001R09.edf')
#raw.pick(['Fp1.', 'Fp2.', 'F1..', 'F2..', 'Fz..', 'Pz..'])
print(raw)
events, events_id = mne.events_from_annotations(raw)
raw.load_data()
raw.filter(1., 40., fir_design='firwin')
epochs = mne.Epochs(raw, events, event_id=events_id, tmin=-0.2, tmax=0.8, baseline=(None, 0), preload=True)
epochs.drop_bad()
print(epochs)
print(epochs.events)

Extracting EDF parameters from c:\Users\aryas\OneDrive\Desktop\Brain Computing\S001R09.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
<RawEDF | S001R09.edf, 64 x 20000 (125.0 s), ~54 kB, data not loaded>
Used Annotations descriptions: [np.str_('T0'), np.str_('T1'), np.str_('T2')]
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 529 samples (3.306 s)

Not setting metadata
30 matching events found
Setting basel

  model.load_state_dict(load(file))
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


In [3]:
eeg_data = epochs.get_data()
eeg_events = epochs.events[:, -1] - 1

dataset = EEGDataset(eeg_data, eeg_events)
test_loader = DataLoader(dataset, batch_size=32, shuffle=True)

print(data_loader)

<torch.utils.data.dataloader.DataLoader object at 0x000001CD541BD010>


In [4]:
with torch.no_grad() :
    model.eval()
    for batch, label in test_loader :
        batch = batch.unsqueeze(1)
        print(batch)
        outputs = model(batch)
        print(outputs)
        _, outputs = torch.max(outputs, axis=1)
        print(len(label[label != outputs]))
        print(len(label[label == outputs]))

tensor([[[[-2.5573e-05, -1.4067e-05,  5.0598e-06,  ...,  6.7236e-05,
            5.1097e-05,  3.2812e-05],
          [-5.2299e-06,  5.6054e-06,  1.5336e-05,  ...,  7.8555e-05,
            5.0577e-05,  2.1367e-05],
          [ 1.3329e-05,  1.1819e-05,  1.4857e-05,  ...,  7.2106e-05,
            4.1827e-05,  3.8314e-06],
          ...,
          [ 5.8294e-05,  5.2768e-05,  4.3529e-05,  ...,  5.0372e-05,
            3.7464e-05,  3.5254e-05],
          [ 7.5289e-05,  5.9287e-05,  4.5766e-05,  ...,  4.3345e-05,
            3.4538e-05,  4.9353e-05],
          [ 4.1209e-05,  3.8801e-05,  3.3858e-05,  ...,  4.8743e-05,
            3.8644e-05,  4.1415e-05]]],


        [[[ 2.7591e-05,  2.5452e-05,  2.7267e-05,  ..., -2.7679e-05,
           -2.4886e-05, -3.3617e-05],
          [ 2.4454e-05,  2.9940e-05,  3.9300e-05,  ..., -4.9242e-05,
           -4.3626e-05, -4.4428e-05],
          [ 2.7507e-05,  3.4361e-05,  4.5837e-05,  ..., -5.6553e-05,
           -4.7847e-05, -4.6401e-05],
          ...,
   