### For working with any Machine learning models in python, it will be easier to work with numpy arrays. We can save each subject's numpy arrays inside a structured pickle object.

In [None]:
!pip install --verbose --no-cache-dir torch-scatter
!pip install --verbose --no-cache-dir torch-sparse
!pip install --verbose --no-cache-dir torch-cluster
!pip install --verbose --no-cache-dir torch-spline-conv (optional)
!pip install torch-geometric
!pip install eeg-positions

In [None]:
from scipy.io import loadmat
import numpy as np
import pickle
import os
import torch
from torch import optim, linalg
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch_geometric.nn import SGConv, global_add_pool
from torch_scatter import scatter_add
#from torch.utils.data import Dataset,DataLoader
from torch_geometric.data import Data, DataLoader
from tqdm import tqdm
from pprint import pprint
from eeg_positions import (
    get_alias_mapping,
    get_available_elec_names,
    get_elec_coords,
    plot_coords,
)
import time
import copy
from sklearn.model_selection import train_test_split

The sampling frequency is 128 Hz. So, suppose on a particular day, if say the subject has recorded EEG data for 40 minutes. Then total number of rows in the array will 128x60x40. As mentioned by @inancigdem, for each subject, and during every day, first 10 minutes of data corresponds to 'focussed', next 10 minutes to 'unfocussed', and the remaining to 'drowsed' state. So, I have sliced each individual array going by that information (i.e. slicing till row number 128x10x60 for 'focussed', from 128x10x60 to 128x20x60 for 'unfocussed', and from 128x20x60 till last row for 'drowsed).

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
fs = 128 
n_subjects = 5
mkpt1 = int(fs*10*60)
mkpt2 = int(fs*20*60)

Below cell is not really important. I didn't want to manually write the relevant file names for each subject. So, I did the following. It based on what the data contributor has mentioned in an comment. There are 5 subjects, and they have recorded data on 7 separate days (except for subject-5 who has only done that for 6 days). I have segregated each day's data as a trial. Out of these 7, first 2 were used for getting the subject to familiarise with the process. That is why I have only considered the last 5 trials of each subject as training data.

In [None]:
subject_map = {}
for s in range(1, n_subjects+1):
    a =  int(7*(s-1)) + 3
    if s!=5:
        b = a + 5
    else:
        b = a + 4
    subject_map[s] = [i for i in range(a, b)]
print(subject_map)

In [None]:
channels = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
useful_channels = ['F7','F3','P7','O1','O2','P8','AF4']
use_channel_inds = []
for c in useful_channels:
    if c in channels:
        use_channel_inds.append(channels.index(c))

In [None]:
inp_dir = '../input/eeg-data-for-mental-attention-state-detection/EEG Data/' 

### Saving the pickle object for each subject

In [None]:
print(mkpt1)
print(mkpt2)
mkpt3 = 214540
interval = mkpt3 - mkpt2
for s in range(1, n_subjects+1):
    data = {}
    data['channels'] = useful_channels
    data['fs'] = fs
    for i, t in enumerate(subject_map[s]):
        trial = {}
        trial_data = loadmat(inp_dir + f'eeg_record{t}.mat')
        eeg = trial_data['o']['data'][0][0][:, 3:17]
        eeg = eeg[:, use_channel_inds]
        
        trial['focussed'] = eeg[:interval]
        trial['unfocussed'] = eeg[mkpt1:mkpt1+interval]
        trial['drowsed'] = eeg[mkpt2:mkpt2+interval]
        data[f'trial_{i+1}'] = trial
    with open(f'subject_{s}.pkl', 'wb') as f: 
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
        

### Loading the pickle objects

Suppose I just want the data related to 'focussed' state from trial_1 of subject_1

In [None]:
with open('subject_1.pkl', 'rb') as f: 
    data = pickle.load(f)

In [None]:
data

In [None]:
with open('subject_1.pkl', 'rb') as f: 
    data1 = pickle.load(f)
with open('subject_2.pkl', 'rb') as f: 
    data2 = pickle.load(f)
with open('subject_3.pkl', 'rb') as f: 
    data3 = pickle.load(f)
with open('subject_4.pkl', 'rb') as f: 
    data4 = pickle.load(f)
with open('subject_5.pkl', 'rb') as f: 
    data5 = pickle.load(f)

In [None]:
state_num = {'focussed': 0, 'unfocussed': 1,'drowsed': 2}

In [None]:
def maybe_num_nodes(index, num_nodes=None):
    return index.max().item() + 1 if num_nodes is None else num_nodes


def add_remaining_self_loops(edge_index,
                             edge_weight=None,
                             fill_value=1,
                             num_nodes=None):
    #A' = A + I
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    row, col = edge_index
    mask = row != col
    #maskは0か1
    inv_mask = ~mask
    #[1,1,1,1,...] shape(62,)
    loop_weight = torch.full(
        (num_nodes, ),
        fill_value,
        dtype=None if edge_weight is None else edge_weight.dtype,
        device=edge_index.device)

    if edge_weight is not None:
        #62x62x8 = edge_index.size(1)?
        assert edge_weight.numel() == edge_index.size(1)
        #inv_mask = 0 or 1
        remaining_edge_weight = edge_weight[inv_mask]
        if remaining_edge_weight.numel() > 0:
            loop_weight[row[inv_mask]] = remaining_edge_weight
        edge_weight = torch.cat([edge_weight[mask], loop_weight], dim=0)

    loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)
    #loop indexを二段階で形成[0,1,2,...,61] shape : (2,62)
    edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1)
    #loop indexをtorch.cat()

    return edge_index, edge_weight


class NewSGConv(SGConv):
    def __init__(self, num_features, num_classes, K=1, cached=False,
                 bias=True):
        super(NewSGConv, self).__init__(num_features, num_classes, K=K, cached=cached, bias=bias)

    # allow negative edge weights
    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
        if edge_weight is None:
            #edge_weight.shape(62x62x8,)
            #edge_index.shape(2,62x62x8)
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)

        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)
        row, col = edge_index
        #次数行列Dを作成 サイズ 62 x 62
        #縦に62 edge_weight.shape:[62,62] row.shape:[1,62]
        #edge_weight:62x62x8, row:62x62x8
        deg = scatter_add(torch.abs(edge_weight), row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        if not self.cached or self.cached_result is None:
            edge_index, norm = NewSGConv.norm(
                edge_index, x.size(0), edge_weight, dtype=x.dtype)

            #Wは自動で計算？ Z=SXW
            for k in range(self.K):
                #print(f'x.shape : {x.size()}')
                #print(f'edge_index.shape : {edge_index.size()}')
                #print(f'norm.shape : {norm.size()}')
                #print(f'norm.max : {norm.max()}')
                #print(f'norm.min : {norm.min()}')
                x = self.propagate(edge_index, x=x, norm=norm)
            self.cached_result = x

        return self.lin(self.cached_result)

    def message(self, x_j, norm):
        # x_j: (batch_size*num_nodes*num_nodes, num_features)
        # norm: (batch_size*num_nodes*num_nodes, )
        return norm.view(-1, 1) * x_j

class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None


class SymSimGCNNet(torch.nn.Module):
    def __init__(self, num_nodes, learn_edge_weight, edge_weight, num_features, num_hiddens, num_classes, K, dropout=0.7, domain_adaptation=""):
        """
            num_nodes: number of nodes in the graph
            learn_edge_weight: if True, the edge_weight is learnable
            edge_weight: initial edge matrix
            num_features: feature dim for each node/channel
            num_hiddens: a tuple of hidden dimensions
            num_classes: number of emotion classes
            K: number of layers
            dropout: dropout rate in final linear layer
            domain_adaptation: RevGrad
        """
        super(SymSimGCNNet, self).__init__()
        self.domain_adaptation = domain_adaptation
        self.num_nodes = num_nodes
        #下の三角形の座標取得
        self.xs, self.ys = torch.tril_indices(self.num_nodes, self.num_nodes, offset=0)
        edge_weight = edge_weight.reshape(self.num_nodes, self.num_nodes)[self.xs, self.ys] # strict lower triangular values
        self.edge_weight = nn.Parameter(edge_weight, requires_grad=learn_edge_weight)
        self.dropout = dropout
        self.conv1 = NewSGConv(num_features=num_features, num_classes=num_hiddens[0], K=K)
        self.fc = nn.Linear(num_hiddens[0], num_classes)
        if self.domain_adaptation in ["RevGrad"]:
            self.domain_classifier = nn.Linear(num_hiddens[0], 2)

    def forward(self, data, alpha=0):
        batch_size = len(data.y)
        #edge_indexは1batchの何か 8x?
        x, edge_index = data.x, data.edge_index
        #print(f'edge_index shape: {edge_index.size()}')
        edge_weight = torch.zeros((self.num_nodes, self.num_nodes), device=edge_index.device)
        #下の三角形の値だけ入れてる、上の三角形は0
        edge_weight[self.xs.to(edge_weight.device), self.ys.to(edge_weight.device)] = self.edge_weight
        edge_weight = edge_weight + edge_weight.transpose(1,0) - torch.diag(edge_weight.diagonal()) # copy values from lower tri to upper tri
        #下半分の計算で上の計算もしたことになる
        #1batchの1直線のテンソル
        edge_weight = edge_weight.reshape(-1).repeat(batch_size)
        #print(f"長さ:{edge_weight.shape}")
        #print(f'edge_weight max : {edge_weight.max()}, edge_weight min : {edge_weight.min()}')
        x = F.relu(self.conv1(x, edge_index, edge_weight))
        
        # domain classification
        domain_output = None
        if self.domain_adaptation in ["RevGrad"]:
            reverse_x = ReverseLayerF.apply(x, alpha)
            domain_output = self.domain_classifier(reverse_x)
        x = global_add_pool(x, data.batch, size=batch_size)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc(x)
        return x, domain_output

In [None]:
#(62,62)の行列を初期化する
edge_weight = torch.zeros(7,7)

coords = get_elec_coords(
    elec_names = useful_channels,
    drop_landmarks = False,
    dim = "3d",
)

electrodes_data = {}
for ind, row in coords.iterrows():
    electrodes_data[str(row.label)] = (row.x,row.y,row.z)

for i in range(7):
    for j in range(i+1,7):
        elec_from = useful_channels[i]
        elec_to = useful_channels[j]

        from_data = electrodes_data[elec_from]
        to_data = electrodes_data[elec_to]
        vec_from = np.array(from_data)
        vec_to = np.array(to_data)
        dist = np.linalg.norm(vec_from - vec_to)
        edge_weight[i, j] = dist
        edge_weight[j, i] = dist

edge_weight = np.where(edge_weight > 0, 0.065 / edge_weight, edge_weight)

edge_weight = torch.tensor(edge_weight).float()
print(type(edge_weight))

In [None]:
def get_dataloader():
    #まずはdataを作成
    dataset = []    
    data_list = [data1, data2, data3, data4, data5]
    for i in range(5):
        now_data = data_list[i]
        for j in range(4):
            now_ind = j + 1
            trial_name = 'trial_' + str(now_ind)
            now_trial = now_data[trial_name]
            for state, val in now_trial.items():
                y = torch.tensor([state_num[state]]).long()
                x = torch.tensor(val).float()
                x = x.permute(1,0)
                #print(x.size())
                data = Data(x=x, edge_index=edge_index, y=y)
                dataset.append(data)
    
    #print(len(dataset))
    #dataset作成完了、次はtrainとvalでsplitしてdataloaderを作成する
    train, val = train_test_split(dataset, train_size=0.6)
    train_iterator = DataLoader(train, shuffle=True, batch_size=4, num_workers=2)
    val_iterator = DataLoader(val, shuffle=False, batch_size=4, num_workers=2)    
    
    data = next(iter(train_iterator))
    
    return train_iterator, val_iterator

In [None]:
row = []
col = []
for i in range(7):
    for j in range(7):
        row.append(i)
        col.append(j)

row = np.array(row)
col = np.array(col)

row = torch.tensor(row).unsqueeze(0)
col = torch.tensor(col).unsqueeze(0)

edge_index = torch.cat([row, col], dim = 0)

print(edge_index.size())

In [None]:
def labelize(label):
    label = label.float()
    #print(f'label : {label}')
    label.requires_grad = True
    tensor = None
    for i in range(4):
        if i == 0:
            pos = label[i].detach().long()
            _tensor = torch.zeros(1,3)
            _tensor[0, pos] = 1
            tensor = _tensor
        else:
            pos = label[i].detach().long()
            _tensor = torch.zeros(1,3)
            _tensor[0, pos] = 1
            tensor = torch.cat([tensor,_tensor], dim=0)

    tensor.requires_grad = True
    return tensor

In [None]:
models = []

def fit_model(edge_weight, edge_index):
        train_iterator, valid_iterator = get_dataloader()

        len_dataloader = min(len(train_iterator), len(valid_iterator))

        #print("dataloader まで作ったお")
        #こっからモデル定義
        """
          num_nodes: number of nodes in the graph
          learn_edge_weight: if True, the edge_weight is learnable
          edge_weight: initial edge matrix
          num_features: feature dim for each node/channel 
          num_hiddens: a tuple of hidden dimensions 
          num_classes: number of emotion classes
          K: number of layers
          dropout: dropout rate in final linear layer
          domain_adaptation: RevGrad
        """
        """
          train_iterator : batch_size-> 4  
        """

        model = SymSimGCNNet(num_nodes=7, learn_edge_weight=True, edge_weight=edge_weight, num_features=60940, num_hiddens=(10,20), num_classes=3, K=2, dropout=0.7, domain_adaptation="RevGrad")
        #print(f'model : {model}')
        #print(f'domain_classifier weight : {model.domain_classifier.weight}')

        for p in model.parameters():
            p.requires_grad = True

        loss_criterion1 = nn.KLDivLoss(reduction="sum")
        loss_criterion2 = nn.BCELoss()
        opt = optim.Adam(model.parameters(), lr=0.001)
        scheduler = optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
        softmax = nn.Softmax(dim=1)

        targets = iter(valid_iterator)

        #print("targets!!!!")
        all_loss = 0

        for j, data in enumerate(train_iterator):

            if j == len_dataloader:
                break

            #to device
            data = data.to(device)


            opt.zero_grad()
            y_pred, domain_output_tr = model(data)
            y_pred = y_pred / linalg.norm(y_pred)
          
            y_pred = softmax(y_pred)

            #data.yを(8,4)に変える
            y = labelize(data.y)
            
            print(y_pred.log().size())
            print(y.size())
            loss_tr = F.kl_div(y_pred.log(), y, None, None, 'sum')
        
            def calc_domain_loss(domain_output, mode):
                #domain_outputをsigmoidに通す
                nlf = nn.Sigmoid()
                domain_output = nlf(domain_output)

                weight = model.domain_classifier.weight
                weight = weight.to(device)

                #domain_classifierのparameterと掛け合わせる
                p = softmax(torch.mm(domain_output, weight))


                res = None

                #bceを計算
                if mode == 'zero': 
                    t = torch.zeros(p.size(0), p.size(1))
                    t = t.to(device)
                    res = loss_criterion2(p, t).float()
                if mode == 'one':
                    t = torch.ones(p.size(0), p.size(1))
                    t = t.to(device)
                    res = loss_criterion2(p, t).float()

                return res

            loss_tr_domain = calc_domain_loss(domain_output_tr, 'zero')

            data_eval = next(targets)

            #to device
            data_eval = data_eval.to(device)

            _, domain_output_te = model(data_eval)

            loss_te_domain = calc_domain_loss(domain_output_te, 'one')

            #domain classifierからのlossを計算して最適化
            total_loss = loss_tr + loss_tr_domain + loss_te_domain
            all_loss += total_loss
            total_loss.backward()
            opt.step()

            #opt.step()
            print(f'Epoch: {j+1} / Lr: {scheduler.get_lr()[0]} / Loss: {total_loss}')
            #print(opt.param_groups[0])
            scheduler.step()

        end_time = time.time()
        models.append(model)
        #epoch_mins, epoch_secs = (end_time - start_time)//60, round((end_time - start_time)%60)

In [None]:
get_dataloader()

In [None]:
fit_model(edge_weight, edge_index)