In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


STESA-Net Architecture

In [2]:
import torch
import torch.nn.functional as F
import math

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

class SelfAttention(torch.nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.query = torch.nn.Linear(input_dim, input_dim)
        self.key = torch.nn.Linear(input_dim, input_dim)
        self.value = torch.nn.Linear(input_dim, input_dim)
        self.scale = torch.sqrt(torch.tensor(input_dim, dtype=torch.float32))

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        attention_weights = F.softmax(torch.bmm(Q, K.transpose(1, 2)) / self.scale, dim=-1)
        return torch.bmm(attention_weights, V)

class STESA_Net(torch.nn.Module):
    def __init__(self, classes=2, sampleChannel=30, sampleLength=384, num_filters=16, d=2, kernel_size=64, lstm_hidden=64, lstm_layers=5):
        super(STESA_Net, self).__init__()
        self.bottleneck = torch.nn.Conv2d(1, num_filters, (sampleChannel, 1))
        self.temporal = torch.nn.Conv2d(num_filters, d * num_filters, (1, kernel_size), groups=num_filters)
        self.activ = torch.nn.ELU(alpha=0.0001)
        self.batchnorm = torch.nn.BatchNorm2d(d * num_filters, track_running_stats=False)
        self.pool = torch.nn.AvgPool1d(kernel_size=4, stride=4)
        self.pos_encoding = PositionalEncoding(d_model=32, max_len=500)
        self.attention = SelfAttention(32)
        self.lstm = torch.nn.LSTM(
            input_size=32,
            hidden_size=lstm_hidden,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True
        )
        self.dropout = torch.nn.Dropout(0.3)
        self.fc = torch.nn.Linear(lstm_hidden * 2, classes)
        self.softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, inputdata):
        block1 = self.bottleneck(inputdata)
        block1 = self.temporal(block1)
        block1 = self.activ(block1)
        block1 = self.batchnorm(block1)
        block1 = block1.squeeze(2).permute(0, 2, 1)
        block1 = self.pool(block1.permute(0, 2, 1))
        block1 = block1.permute(0, 2, 1)
        block1 = self.pos_encoding(block1)
        block2 = self.attention(block1)
        block3, _ = self.lstm(block2)
        block3 = self.dropout(block3)
        block3 = block3[:, -1, :]
        output = self.fc(block3)
        output = self.softmax(output)
        return output


Leave-One-Out-validation (LOOV) Main Code

In [None]:
import torch
import scipy.io as sio
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import torch.optim as optim

torch.cuda.empty_cache()
torch.manual_seed(0)


def run():
    filename = r'/content/drive/MyDrive/Balanced Nature Dataset.mat'
    tmp = sio.loadmat(filename)
    xdata = np.array(tmp['EEGsample'])
    label = np.array(tmp['substate'])
    subIdx = np.array(tmp['subindex'])
    label = label.astype(int)
    subIdx = subIdx.astype(int)

    samplenum = label.shape[0]
    channelnum = 30
    subjnum = 11
    samplelength = 3
    sf = 128
    lr = 1e-3
    batch_size = 50
    n_epoch = 11

    ydata = np.zeros(samplenum, dtype=np.longlong)
    for i in range(samplenum):
        ydata[i] = label[i]

    results = np.zeros(subjnum)
    precision_results = np.zeros(subjnum)
    recall_results = np.zeros(subjnum)
    f1_results = np.zeros(subjnum)

    for i in range(1, subjnum + 1):
        trainindx = np.where(subIdx != i)[0]
        xtrain = xdata[trainindx]
        x_train = xtrain.reshape(xtrain.shape[0], 1, channelnum, samplelength * sf)
        y_train = ydata[trainindx]

        testindx = np.where(subIdx == i)[0]
        xtest = xdata[testindx]
        x_test = xtest.reshape(xtest.shape[0], 1, channelnum, samplelength * sf)
        y_test = ydata[testindx]

        train = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train))
        train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)

        my_net = STESA_Net().double().cuda()
        optimizer = optim.Adam(my_net.parameters(), lr=lr)
        loss_class = torch.nn.NLLLoss().cuda()

        for p in my_net.parameters():
            p.requires_grad = True

        for epoch in range(n_epoch):
            for j, data in enumerate(train_loader, 0):
                inputs, labels = data
                input_data = inputs.cuda()
                class_label = labels.cuda()
                my_net.zero_grad()
                my_net.train()
                class_output = my_net(input_data)
                err = loss_class(class_output, class_label)
                err.backward()
                optimizer.step()

        my_net.eval()
        with torch.no_grad():
            x_test_tensor = torch.DoubleTensor(x_test).cuda()
            answer = my_net(x_test_tensor)
            probs = answer.cpu().numpy()
            preds = probs.argmax(axis=-1)

            acc = accuracy_score(y_test, preds)
            precision = precision_score(y_test, preds, average='macro', zero_division=0)
            recall = recall_score(y_test, preds, average='macro', zero_division=0)
            f1 = f1_score(y_test, preds, average='macro', zero_division=0)

            print(f'Subject {i}:')
            print(f'  Accuracy     = {acc:.4f}')
            print(f'  Precision    = {precision:.4f}')
            print(f'  Recall       = {recall:.4f} (Sensitivity / Selectivity)')
            print(f'  F1 Score     = {f1:.4f}')
            print('---------------------------------------------')

            results[i - 1] = acc
            precision_results[i - 1] = precision
            recall_results[i - 1] = recall
            f1_results[i - 1] = f1

    print('===== Overall Averages =====')
    print('Mean Accuracy     :', np.mean(results))
    print('Mean Precision    :', np.mean(precision_results))
    print('Mean Recall       :', np.mean(recall_results))
    print('Mean F1 Score     :', np.mean(f1_results))


if __name__ == '__main__':
    run()


  ydata[i] = label[i]
