In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, ResGatedGraphConv, TAGConv, ARMAConv, MFConv, global_mean_pool, BatchNorm
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, accuracy_score

In [None]:
X = np.load(".../data.npy")

In [None]:
labels = np.load(".../labels.npy")

In [None]:
device = torch.device("cuda")
X = torch.tensor(X).float()
labels = torch.tensor(labels)

In [None]:
# LSTM network for feature extraction
class Featurizer(nn.Module):
    def __init__(self, embedding_dim):
        super(Featurizer, self).__init__()
        self.lstm = nn.LSTM(1, embedding_dim, batch_first=True)
        self.project = nn.Linear(embedding_dim, 3)
        self.project.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if type(x) in [nn.Linear, nn.LSTM] else None)
    def forward(self, x):
        _, (h, _) = self.lstm(x)
        return h, F.softmax(self.project(h), dim=-1).squeeze(0)
    
# Basic architecture of our GCNN

class EEGraph(nn.Module):
    def __init__(self, embedding_dim, first_conv, n_layers, conv_layer):
        super(EEGraph, self).__init__()
        self.n_layers = n_layers
        self.convs = []
        self.bns = []
        d_in = embedding_dim
        d_out = first_conv
        for i in range(n_layers):
            self.convs.append(conv_layer(d_in, d_out))
            self.bns.append(BatchNorm(d_out, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True))
            if i < n_layers - 1:
                d_in, d_out = d_out, 2*d_out
        
        self.convs = torch.nn.ModuleList(self.convs)
        self.bns = torch.nn.ModuleList(self.bns)
        self.project = nn.Linear(d_out, 3) 
    
        self.project.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if type(x) == nn.Linear else None)

    def forward(self, x, edge_index):
        for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
            x = conv(x, edge_index).permute(0, 2, 1)
            x = bn(x)
            x = F.dropout(F.leaky_relu(x, negative_slope=0.01), p=0.5, training=self.training).permute(0, 2, 1)
        out = x.mean(dim=1).squeeze(dim=-1)
        out = self.project(out)
        return F.softmax(out, dim=-1)

In [None]:
# Here we extract the embeddings features of the LSTM network for each electrode
def get_embeddings(X, labels, channel, embedding_dim, n_epochs=10, lr=0.1):
    m = Featurizer(embedding_dim)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(m.parameters(), lr=lr)
    for epoch in range(n_epochs):
        indices = torch.randperm(15*3*15)
        acc_loss = 0.
        for j, batch in enumerate(indices.view(-1, 15)):
            optimizer.zero_grad()
            embeddings, outputs = m(X[batch, channel:channel+1, :].permute(0, 2, 1))
            acc = (torch.argmax(outputs, dim=-1) == labels[batch]).float().sum() / len(outputs)
            loss = criterion(outputs, labels[batch])
            loss.backward()
            acc_loss += loss.item()
            optimizer.step()
        if epoch % 5 == 4:
            print("Dim", embedding_dim, "Channel:", channel, "Epoch:", epoch, "Loss:", loss.item(), "Accuracy", acc_loss / 15)
    return m(X[:, channel:channel+1, :].permute(0,2, 1))[0]\
            .detach().squeeze(0).unsqueeze(1).cpu().numpy()

In [None]:
#We extract features with dimensions 16,32 and 64
for D in [16, 32, 64]:
    embeddings = [get_embeddings(X, labels, i, embedding_dim=D, n_epochs=50) for i in range(X.shape[1])]
    np.save(f".../graph_{D}.npy", np.concatenate(embeddings, axis=1))
    del embeddings

In [None]:
# The order of the channels based on dataset

channel_order = """
FP1
FPZ
FP2
AF3
AF4
F7
F5
F3
F1
FZ
F2
F4
F6
F8
FT7
FC5
FC3
FC1
FCZ
FC2
FC4
FC6
FT8
T7
C5
C3
C1
CZ
C2
C4
C6
T8
TP7
CP5
CP3
CP1
CPZ
CP2
CP4
CP6
TP8
P7
P5
P3
P1
PZ
P2
P4
P6
P8
PO7
PO5
PO3
POZ
PO4
PO6
PO8
CB1
O1
OZ
O2
CB2
""".split('\n')[1:]
channel_order[:3]

In [None]:
# Define the edges between the electrodes

edges = [['O2', 'CB2'],
 ['O2', 'OZ'],
 ['O1', 'OZ'],
 ['O1', 'CB1'],
 ['PO8', 'P8'],
 ['PO8', 'CB2'],
 ['PO6', 'PO4'],
 ['PO6', 'PO8'],
 ['PO6', 'P6'],
 ['PO6', 'CB2'],
 ['PO4', 'P2'],
 ['PO4', 'O2'],
 ['POZ', 'PZ'],
 ['POZ', 'PO3'],
 ['POZ', 'PO4'],
 ['POZ', 'OZ'],
 ['PO3', 'P1'],
 ['PO3', 'PO5'],
 ['PO3', 'O1'],
 ['PO5', 'CB1'],
 ['PO7', 'P7'],
 ['PO7', 'PO5'],
 ['PO7', 'CB1'],
 ['P4', 'CP4'],
 ['P4', 'P2'],
 ['P4', 'P6'],
 ['P2', 'CP2'],
 ['PZ', 'P1'],
 ['PZ', 'P2'],
 ['P5', 'P7'],
 ['P5', 'P3'],
 ['P5', 'PO5'],
 ['P1', 'P3'],
 ['P6', 'CP6'],
 ['P6', 'P8'],
 ['CP4', 'CP6'],
 ['CP4', 'CP2'],
 ['CPZ', 'CP2'],
 ['CPZ', 'PZ'],
 ['CPZ', 'CZ'],
 ['TP7', 'T7'],
 ['TP7', 'P7'],
 ['CP5', 'CP3'],
 ['CP5', 'TP7'],
 ['CP5', 'P5'],
 ['CP5', 'C5'],
 ['CP1', 'CPZ'],
 ['CP1', 'CP3'],
 ['CP1', 'P1'],
 ['CP1', 'C1'],
 ['CP3', 'P3'],
 ['CP3', 'C3'],
 ['TP8', 'CP6'],
 ['TP8', 'P8'],
 ['TP8', 'T8'],
 ['C4', 'CP4'],
 ['C2', 'C4'],
 ['C2', 'CZ'],
 ['C2', 'CP2'],
 ['C2', 'FC2'],
 ['CZ', 'C1'],
 ['C5', 'T7'],
 ['C3', 'C1'],
 ['C3', 'C5'],
 ['C3', 'FC3'],
 ['C6', 'C4'],
 ['C6', 'T8'],
 ['C6', 'CP6'],
 ['C6', 'FC6'],
 ['FC4', 'FC2'],
 ['FC4', 'F4'],
 ['FC4', 'C4'],
 ['FCZ', 'FC2'],
 ['FCZ', 'FC1'],
 ['FCZ', 'CZ'],
 ['FT7', 'F7'],
 ['FT7', 'T7'],
 ['FC5', 'FC3'],
 ['FC5', 'FT7'],
 ['FC5', 'C5'],
 ['FC5', 'F5'],
 ['FC1', 'FC3'],
 ['FC1', 'C1'],
 ['FT8', 'T8'],
 ['FC6', 'FC4'],
 ['FC6', 'FT8'],
 ['FC6', 'F6'],
 ['F5', 'F3'],
 ['F5', 'F7'],
 ['F5', 'AF3'],
 ['F8', 'FT8'],
 ['F8', 'F6'],
 ['F6', 'AF4'],
 ['F4', 'F6'],
 ['F4', 'AF4'],
 ['F2', 'FC2'],
 ['F2', 'F4'],
 ['F2', 'AF4'],
 ['FZ', 'FCZ'],
 ['FZ', 'F2'],
 ['FZ', 'F1'],
 ['F1', 'FC1'],
 ['F1', 'F3'],
 ['F1', 'AF3'],
 ['F3', 'FC3'],
 ['AF4', 'FP2'],
 ['AF3', 'F3'],
 ['AF3', 'FP1'],
 ['FPZ', 'FP1'],
 ['FPZ', 'FP2'],
 ['CB2', 'O2'],
 ['OZ', 'O2'],
 ['OZ', 'O1'],
 ['CB1', 'O1'],
 ['P8', 'PO8'],
 ['CB2', 'PO8'],
 ['PO4', 'PO6'],
 ['PO8', 'PO6'],
 ['P6', 'PO6'],
 ['CB2', 'PO6'],
 ['P2', 'PO4'],
 ['O2', 'PO4'],
 ['PZ', 'POZ'],
 ['PO3', 'POZ'],
 ['PO4', 'POZ'],
 ['OZ', 'POZ'],
 ['P1', 'PO3'],
 ['PO5', 'PO3'],
 ['O1', 'PO3'],
 ['CB1', 'PO5'],
 ['P7', 'PO7'],
 ['PO5', 'PO7'],
 ['CB1', 'PO7'],
 ['CP4', 'P4'],
 ['P2', 'P4'],
 ['P6', 'P4'],
 ['CP2', 'P2'],
 ['P1', 'PZ'],
 ['P2', 'PZ'],
 ['P7', 'P5'],
 ['P3', 'P5'],
 ['PO5', 'P5'],
 ['P3', 'P1'],
 ['CP6', 'P6'],
 ['P8', 'P6'],
 ['CP6', 'CP4'],
 ['CP2', 'CP4'],
 ['CP2', 'CPZ'],
 ['PZ', 'CPZ'],
 ['CZ', 'CPZ'],
 ['T7', 'TP7'],
 ['P7', 'TP7'],
 ['CP3', 'CP5'],
 ['TP7', 'CP5'],
 ['P5', 'CP5'],
 ['C5', 'CP5'],
 ['CPZ', 'CP1'],
 ['CP3', 'CP1'],
 ['P1', 'CP1'],
 ['C1', 'CP1'],
 ['P3', 'CP3'],
 ['C3', 'CP3'],
 ['CP6', 'TP8'],
 ['P8', 'TP8'],
 ['T8', 'TP8'],
 ['CP4', 'C4'],
 ['C4', 'C2'],
 ['CZ', 'C2'],
 ['CP2', 'C2'],
 ['FC2', 'C2'],
 ['C1', 'CZ'],
 ['T7', 'C5'],
 ['C1', 'C3'],
 ['C5', 'C3'],
 ['FC3', 'C3'],
 ['C4', 'C6'],
 ['T8', 'C6'],
 ['CP6', 'C6'],
 ['FC6', 'C6'],
 ['FC2', 'FC4'],
 ['F4', 'FC4'],
 ['C4', 'FC4'],
 ['FC2', 'FCZ'],
 ['FC1', 'FCZ'],
 ['CZ', 'FCZ'],
 ['F7', 'FT7'],
 ['T7', 'FT7'],
 ['FC3', 'FC5'],
 ['FT7', 'FC5'],
 ['C5', 'FC5'],
 ['F5', 'FC5'],
 ['FC3', 'FC1'],
 ['C1', 'FC1'],
 ['T8', 'FT8'],
 ['FC4', 'FC6'],
 ['FT8', 'FC6'],
 ['F6', 'FC6'],
 ['F3', 'F5'],
 ['F7', 'F5'],
 ['AF3', 'F5'],
 ['FT8', 'F8'],
 ['F6', 'F8'],
 ['AF4', 'F6'],
 ['F6', 'F4'],
 ['AF4', 'F4'],
 ['FC2', 'F2'],
 ['F4', 'F2'],
 ['AF4', 'F2'],
 ['FCZ', 'FZ'],
 ['F2', 'FZ'],
 ['F1', 'FZ'],
 ['FC1', 'F1'],
 ['F3', 'F1'],
 ['AF3', 'F1'],
 ['FC3', 'F3'],
 ['FP2', 'AF4'],
 ['F3', 'AF3'],
 ['FP1', 'AF3'],
 ['FP1', 'FPZ'],
 ['FP2', 'FPZ']]

In [None]:
device = torch.device("cuda")
connectivity = [[channel_order.index(e[0]), channel_order.index(e[1])] for e in edges]
connectivity = torch.tensor(connectivity).t().contiguous()

In [None]:
best_f1_score = -1
best_trial_name = None
n_epochs = 500
lr = 1e-3
weight_decay = 1e-5
batch_size = 54
criterion = nn.CrossEntropyLoss()
for node_dim in [16, 32, 64]:
    node_features = np.load(f".../graph_{node_dim}.npy")
    A, Xte, yA, yte = train_test_split(node_features, labels, test_size=0.2, shuffle=True, stratify=labels, random_state=0)
    Xtr, Xtr_valid, ytr, ytr_valid = train_test_split(A, yA, test_size=0.2, shuffle=True, stratify=yA, random_state=0)
    Xtr = torch.tensor(Xtr).float()
    Xtr_valid = torch.tensor(Xtr_valid).float()
    Xte = torch.tensor(Xte).float()
    ytr = torch.tensor(ytr)
    ytr_valid = torch.tensor(ytr_valid)
    yte = torch.tensor(yte)
    for conv_fn in [TAGConv, ARMAConv, MFConv, GCNConv, SAGEConv, ResGatedGraphConv]:
        for n_layers in range(1, 4):
            for conv_dim in [32, 64, 128, 256]:
                trial_name = f"node_dim_{node_dim}-conv_fn_{conv_fn.__name__}-conv_layers_{n_layers}-conv_dim_{conv_dim}"
                print(f"@: {trial_name}")
                model = EEGraph(embedding_dim=Xtr.shape[-1], 
                        first_conv=conv_dim,
                        n_layers=n_layers,
                        conv_layer=conv_fn)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
                for epoch in range(n_epochs):
                    model.train()
                    indices = torch.randperm(len(Xtr))
                    for j, batch in enumerate(indices.view(-1, 54)):
                        optimizer.zero_grad()
                        batch_input = Xtr[batch]
                        outputs = model(batch_input, connectivity)
                        loss = criterion(outputs, ytr[batch])
                        loss.backward()
                        optimizer.step()
                with torch.no_grad() :
                    model.eval()
                    outputs = model(Xtr_valid, connectivity)
                    output_classes = torch.argmax(outputs, dim=-1).cpu().numpy()
                    f1 = f1_score(ytr_valid, output_classes, average="macro")
                    if f1 > best_f1_score:
                        best_trial_name = trial_name
                        best_f1_score = f1
                        print("-"*100)
                        print(f"Best model so far: {best_trial_name}")
                        print(f"Best F1 Score: %{100*best_f1_score:.2f}")
                        test_outputs = model(Xte, connectivity)
                        test_output_classes = torch.argmax(test_outputs, dim=-1).cpu().numpy()
                        print(classification_report(yte, test_output_classes, target_names=["Negative", "Neutral", "Positive"]))
                        print("-"*100)
                        print()

In [None]:
node_dim = ... # The best parameter that you get from the previous cell
conv_fn = ... # The best parameter that you get from the previous cell
n_layers = ... # The best parameter that you get from the previous cell
conv_dim = ... # The best parameter that you get from the previous cell
lr = 1e-3
weight_decay = 1e-5
n_epochs = 500
node_features = np.load(f".../graph_{node_dim}.npy")
Xtr, Xte, ytr, yte = train_test_split(node_features, labels, test_size=0.2, shuffle=True, stratify=labels, random_state=0)
model = EEGraph(embedding_dim=Xtr.shape[-1], 
                        first_conv=conv_dim,
                        n_layers=n_layers,
                        conv_layer=conv_fn)
Xtr = torch.tensor(Xtr).float()
Xte = torch.tensor(Xte).float()
ytr = torch.tensor(ytr)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
best_f1_score = -1
best_acc = -1
for epoch in range(n_epochs):
    model.train()
    indices = torch.randperm(len(Xtr))
    for j, batch in enumerate(indices.view(-1, 54)):
        optimizer.zero_grad()
        batch_input = Xtr[batch]
        outputs = model(batch_input, connectivity)
        loss = criterion(outputs, ytr[batch])
        loss.backward()
        optimizer.step()     
    with torch.no_grad():
        model.eval()
        outputs = model(Xte, connectivity)
        output_classes = torch.argmax(outputs, dim=-1).cpu().numpy()
        f1 = f1_score(yte, output_classes, average="macro") 
        if f1 > best_f1_score:
            acc = accuracy_score(yte, output_classes)
            best_f1_score = f1
            print("-"*100)
            print(f"Best F1 Score: %{100*best_f1_score:.2f}")
            print(classification_report(yte, output_classes, target_names=["Negative", "Neutral", "Positive"]))
            print("-"*100)
            print()