In [1]:
import os
import random
import torch
from torch import nn
from torch.utils.data import TensorDataset, Dataset

import numpy as np 
import pickle
import matplotlib.pyplot as plt
import time
import copy

from sklearn.model_selection import train_test_split
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

random_seed = 1
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/gnn')

# Load data and normalize

In [2]:
X = pickle.load(open('../dataset/train/cross_subject_data_5_subjects.pickle', 'rb'))
y = X['train_y']

X = X['train_x'].astype(np.float32)

label_map = {'imagine_both_feet': 0, 'imagine_both_fist': 1, 'imagine_left_fist': 2, 'imagine_right_fist': 3}
y = np.vectorize(label_map.__getitem__)(y)

In [3]:
# Normalize
mean, std = X.mean(), X.std()
X = (X - mean) / std

# Import Adj Matrix

In [4]:
from convert_to_graphs import n_graph, d_graph, s_graph, normalize_adj

seq_len = 100
n_channels = 64
batch_size = 32

A = n_graph()
A = np.array(A, dtype=np.float32)
A = normalize_adj(A)
A = A + np.eye(A.shape[0], dtype=np.float32)


A_big = np.zeros((n_channels*batch_size, n_channels*batch_size), dtype=np.float32)
print(A_big.shape)
for i in range(0, A_big.shape[0], n_channels):
    A_big[i:i+n_channels, i:i+n_channels] = A

# A_bigger = []
# for i in range(batch_size):
#     A_bigger.append(A_big)

# A_bigger = np.array(A_bigger, dtype=np.float32)

A = torch.Tensor(A_big).to(device)

print('Adjacency Matrix A:')
print(A.shape)



(2048, 2048)
Adjacency Matrix A:
torch.Size([2048, 2048])


In [5]:
# import mne 
# import pandas as pd
# import numpy as np 

# ten_twenty_montage = mne.channels.make_standard_montage("standard_1020")
# ch_names = pd.read_csv("../dataset/physionet.org_csv/S001/S001R01.csv")
# ch_names = ch_names.columns[2:]

# ch_pos_1020 = ten_twenty_montage.get_positions()["ch_pos"]

# ch_pos_1010 = {}
# for ch_name_orig in ch_names:
#     ch_name = ch_name_orig.upper().rstrip(".")
#     if "Z" in ch_name:
#         ch_name = ch_name.replace("Z", "z")
#     if "P" in ch_name and len(ch_name) > 2:
#         ch_name = ch_name.replace("P", "p")
#     if "Cp" in ch_name:
#         ch_name = ch_name.replace("Cp", "CP")
#     if "Tp" in ch_name:
#         ch_name = ch_name.replace("Tp", "TP")
#     if "pO" in ch_name:
#         ch_name = ch_name.replace("pO", "PO")
#     ch_pos_1010[ch_name_orig] = ch_pos_1020[ch_name]
# print(len(ch_pos_1010))

# ch_pos_1010_names = []
# ch_pos_1010_dist = []
# for name, value in ch_pos_1010.items():
#     ch_pos_1010_names.append(name)
#     ch_pos_1010_dist.append(value)
# ch_pos_1010_dist = np.array(ch_pos_1010_dist)

# A = d_graph(n_channels, ch_pos_1010_dist)
# A = np.array(A, dtype=np.float32)

# A = normalize_adj(A)
# A = A + np.eye(A.shape[0], dtype=np.float32)


# A_big = np.zeros((n_channels*batch_size, n_channels*batch_size), dtype=np.float32)
# print(A_big.shape)
# for i in range(0, A_big.shape[0], n_channels):
#     A_big[i:i+n_channels, i:i+n_channels] = A

# # A_bigger = []
# # for i in range(batch_size):
# #     A_bigger.append(A_big)

# # A_bigger = np.array(A_bigger, dtype=np.float32)

# A = torch.Tensor(A_big).to(device)

# print('Adjacency Matrix A:')
# print(A.shape)

# Convert data to [n_samples, n_channels] -> [n_samples, seq_len, n_channels]

In [6]:
# def reshape_data_gnn(X, y, seq_len):
#     print('X original shape:', X.shape)
#     print('y original shape:', y.shape)
#     print('Seq len:', seq_len)
#     len_tail = X.shape[0] % seq_len
#     if len_tail == 0:
#         X = X.reshape(-1, seq_len*n_channels, 1)
#         y = y.reshape(-1, seq_len)
#     else:
#         X = X[:-len_tail].reshape(-1, seq_len*n_channels, 1)
#         y = y[:-len_tail].reshape(-1, seq_len)
#     y = y[:, -1]
#     print('X conversion shape:', X.shape)
#     print('y conversion shape:', y.shape)
#     return X, y

def reshape_data_gnn_2(X, y, seq_len):
    print('X original shape:', X.shape)
    print('y original shape:', y.shape)
    print('Seq len:', seq_len)
    len_tail = X.shape[0] % seq_len
    if len_tail == 0:
        X = X.reshape(-1, seq_len, n_channels)
        X = np.moveaxis(X, 1, -1)
        y = y.reshape(-1, seq_len)
    else:
        X = X[:-len_tail].reshape(-1, seq_len, n_channels)
        X = np.moveaxis(X, 1, -1)
        y = y[:-len_tail].reshape(-1, seq_len)
    y = y[:, -1]
    print('X conversion shape:', X.shape)
    print('y conversion shape:', y.shape)
    return X, y

X, y = reshape_data_gnn_2(X, y, seq_len)

X original shape: (295008, 64)
y original shape: (295008,)
Seq len: 100
X conversion shape: (2950, 64, 100)
y conversion shape: (2950,)


In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=random_seed, stratify=y)
def print_class_dist(y):
    dist = {}

    labels = np.unique(y)
    for label in labels:
        dist[str(label)] = len(y[y == label]) / len(y)
    print(dist)
print_class_dist(y)
print_class_dist(y_train)
print_class_dist(y_test)

{'0': 0.24745762711864408, '1': 0.25254237288135595, '2': 0.25864406779661014, '3': 0.24135593220338983}
{'0': 0.24745762711864408, '1': 0.25254237288135595, '2': 0.2584745762711864, '3': 0.24152542372881355}
{'0': 0.24745762711864408, '1': 0.25254237288135595, '2': 0.2593220338983051, '3': 0.24067796610169492}


In [8]:
X_train, y_train = torch.tensor(X_train).to(device), torch.tensor(y_train).to(device)
X_test, y_test = torch.tensor(X_test).to(device), torch.tensor(y_test).to(device)

train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, drop_last=True)

dataset_sizes = {'train': len(train_dataset), 'val': len(test_dataset)}
dataloaders = {'train': train_loader, 'val': test_loader}
class_names = list(label_map.keys())
print(class_names)

['imagine_both_feet', 'imagine_both_fist', 'imagine_left_fist', 'imagine_right_fist']


In [9]:
def train_model(model, criterion, optimizer, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs-1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            writer.add_scalar(f'{phase} loss', epoch_loss, epoch)
            writer.add_scalar(f'{phase} accuracy', epoch_acc, epoch)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed%60:.0f}s')
    print(f'Best val acc: {best_acc:.4f}')

    model.load_state_dict(best_model_wts)
    return model

# Test computation

In [10]:
# import time
# import torch
# from layers_original import GraphConvolution

# torch.manual_seed(0)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# x = torch.randn(16, 64, 32).to(device)
# x = x.view(-1, 32)
# # A = torch.randn(16*64, 16*64).to(device)
# # weight = torch.randn(32, 512).to(device)
# # A = A.to_sparse()
# print(A.is_sparse)
# gcn = GraphConvolution(32, 128).to(device)
# now = time.time()

# out = gcn(x, A)
# print(out.shape)
# out = out.view(16, 64, 128)
# # support = torch.matmul(x, weight)
# # output = torch.bmm(A, support)
# print(out.shape)
# print('elapsed time:', time.time() - now)

In [11]:
in_features = seq_len
hidden_size_1 = 256
hidden_size_2 = 512
hidden_size_3 = 256
hidden_size_4 = 4
num_classes = 4
num_epochs = 100

# from layers_batchwise import BatchwiseGraphConvolution
from layers_batchwise_2 import GraphConvolution
import torch.nn.functional as F

class AVWGCN(nn.Module):
    def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
        super(AVWGCN, self).__init__()
        self.cheb_k = cheb_k
        self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
        self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
    def forward(self, x, node_embeddings):
        #x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N]
        #output shape [B, N, C]
        node_num = node_embeddings.shape[0]
        supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1)
        support_set = [torch.eye(node_num).to(supports.device), supports]
        #default cheb_k = 3
        for k in range(2, self.cheb_k):
            support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2])
        supports = torch.stack(support_set, dim=0)
        weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool)  #N, cheb_k, dim_in, dim_out
        bias = torch.matmul(node_embeddings, self.bias_pool)                       #N, dim_out
        x_g = torch.einsum("knm,bmc->bknc", supports, x)      #B, cheb_k, N, dim_in
        x_g = x_g.permute(0, 2, 1, 3)  # B, N, cheb_k, dim_in
        x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias     #b, N, dim_out
        return x_gconv

class GCN(nn.Module):
    def __init__(self, in_features, n_nodes, embed_dim_nodes, embed_dim_adj, num_classes):
        super(GCN, self).__init__()
        # self.gc1 = GraphConvolution(in_features, hidden_size_1, batch_size, n_channels)
        # self.gc2 = GraphConvolution(hidden_size_1, hidden_size_2, batch_size, n_channels)
        # self.gc3 = GraphConvolution(hidden_size_2, hidden_size_3, batch_size, n_channels)
        # self.gc4 = GraphConvolution(hidden_size_3, hidden_size_4, batch_size, n_channels)
        # self.flatten = nn.Flatten()
        # self.linear = nn.Linear(hidden_size_4*n_channels, num_classes)
        # self.node_embeddings = nn.Parameter(torch.randn(n_nodes, embed_dim), requires_grad=True)

        self.gc1 = AVWGCN(in_features, hidden_size_1, cheb_k=2, embed_dim=embed_dim_nodes)
        self.gc2 = AVWGCN(hidden_size_1, hidden_size_2, cheb_k=2, embed_dim=embed_dim_nodes)
        # self.gc3 = AVWGCN(hidden_size_2, hidden_size_3, cheb_k=2, embed_dim=embed_dim)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(hidden_size_2*n_channels, num_classes)

        self.node_embeddings = nn.Parameter(torch.randn(n_nodes, embed_dim_adj), requires_grad=True)
    def forward(self, x):
        # print(x.shape)
        out = F.relu(self.gc1(x, self.node_embeddings))
        # print(out.shape)
        out = F.relu(self.gc2(out, self.node_embeddings))
        # print(out.shape)
        # out = F.relu(self.gc3(out, self.node_embeddings))
        # print(out.shape)
        # out = F.relu(self.gc4(out, A))
        # print(out.shape)
        out = self.flatten(out)
        # print(out.shape)
        out = self.linear(out)
        # print(out.shape)
        return out


model = GCN(in_features=in_features, n_nodes=n_channels, embed_dim_nodes=1, embed_dim_adj=10, num_classes=num_classes).to(device)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
    else:
        nn.init.uniform_(p)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())


# writer.add_graph(model, X_train[:batch_size])

In [12]:
model = train_model(model, criterion, optimizer, num_epochs=num_epochs)

Epoch 0/99
----------


RuntimeError: Function MmBackward returned an invalid gradient at index 0 - got [64, 1] but expected shape compatible with [64, 10]

# Computation using for loop

In [None]:
y_preds = []
y_true = []
for inputs, labels in test_loader:
    _, y_pred = torch.max(model(inputs), 1)
    y_preds.append(y_pred)
    y_true.append(labels)
y_preds = torch.cat(y_preds)
y_true = torch.cat(y_true)

In [None]:
print(y_preds.shape)
print(y_true.shape)

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
cr = classification_report(y_true.cpu().numpy(), y_preds.cpu().numpy())
print(cr)

cm = confusion_matrix(y_true.cpu().numpy(), y_preds.cpu().numpy())
print(cm)

y_pred_ohe = np.zeros((y_preds.size(0), num_classes))
for i, j in enumerate(y_pred):
    y_pred_ohe[i, j] = 1

y_true_ohe = np.zeros((y_true.size(0), num_classes))
for i, j in enumerate(y_true):
    y_true_ohe[i, j] = 1
auroc = roc_auc_score(y_true_ohe, y_pred_ohe, multi_class='ovo')
writer.add_scalar('AUROC OvO', auroc)
print('AUROC ovo:', auroc)
auroc = roc_auc_score(y_true_ohe, y_pred_ohe, multi_class='ovr')
writer.add_scalar('AUROC OvR', auroc)
print('AUROC ovr:', auroc)

In [None]:
import seaborn as sns
import pandas as pd
import io

figure = plt.figure(figsize=(7, 5))
cm_df = pd.DataFrame(cm, columns=class_names, index=class_names)
sns.heatmap(cm_df, annot=True, fmt='g')
plt.ylabel('True')
plt.xlabel('Pred')
plt.tight_layout()
plt.savefig('runs/gnn/cm.png')
plt.show()

In [None]:
print('Number of trainable parameters')
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
[p.numel() for p in model.parameters() if p.requires_grad]