In [1]:
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import random
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay, f1_score

In [2]:
def over_connected(graph, upper, is_cov, revert) :

    G = graph.flatten()
    cross_thr, full_thr = 90, 90
    # No such over-connected graphs in covariance matrices and not same thresholds with Laplacian (revert==True)
    if is_cov or revert : 
        return False
    # If on full symmetric matrix, the threshold count of pixels has to be doubled
    if not upper :
        cross_thr = 2*cross_thr
        full_thr = 2*full_thr
    if (G > 0.6).sum() >= cross_thr :
        return True
    elif (G > 0.4).sum() >= full_thr : 
        return True
    else : 
        return False

def load_graphs(input_dir, class_dict, is_cov, upper, revert, over_conn) :

    data, data_labels = [], [] # data contains the graphs as tensors and data_labels the associated seizure type labels
    i = 0

    for szr_type in class_dict.keys() :

        szr_label = class_dict[szr_type]
        for _, _, files in os.walk(os.path.join(input_dir,szr_type)) :
            
            for npy_file in files :
                A = np.load(os.path.join(input_dir,szr_type,npy_file))
                # Normalise A (already normalised depending on the input)
                A = A/np.amax(A.flatten())

                if not is_cov and revert : 
                    L = np.diag(np.sum(A,axis=1)) - A
                else : 
                    L = A
                
                # Only keep upper triangle as matrix is symmetric
                if upper : L = np.triu(L, 0)

                if over_conn : is_over_conn = over_connected(L, upper=upper, is_cov=is_cov, revert=revert)
                else : is_over_conn = False

                if not is_over_conn :

                    # Change to tensor and reshape for dataloader
                    L = torch.tensor(L).view(1,20,20)
                    
                    data.append(L)
                    data_labels.append(szr_label)

    return np.array(data, dtype=object), np.array(data_labels)

def train_test_data(input_dir, class_dict, is_cov, upper, revert, over_conn) :

    train, train_labels = load_graphs(os.path.join(input_dir,'train'), class_dict, is_cov, upper, revert, over_conn)
    test, test_labels = load_graphs(os.path.join(input_dir,'dev'), class_dict, is_cov, upper, revert, over_conn)

    return train, test, train_labels, test_labels


In [60]:
def load_data(input_dir, class_dict, is_cov, upper, revert, over_conn) :

    data, data_labels = [], [] # data contains the graphs as tensors and data_labels the associated seizure type labels

    train_ids, test_ids = [], []

    tot = 0

    for set_ in ['train','dev'] :

        for szr_type in class_dict.keys() :

            szr_label = class_dict[szr_type]
            for _, _, files in os.walk(os.path.join(input_dir,set_,szr_type)) :
                
                for npy_file in files :
                    A = np.load(os.path.join(input_dir,set_,szr_type,npy_file))
                    # Normalise A (already normalised depending on the input)
                    A = A/np.amax(A.flatten())

                    if not is_cov and revert : 
                        L = np.diag(np.sum(A,axis=1)) - A
                    else : 
                        L = A
                    
                    # Only keep upper triangle as matrix is symmetric
                    if upper : L = np.triu(L, 0)

                    if over_conn : is_over_conn = over_connected(L, upper=upper, is_cov=is_cov, revert=revert)
                    else : is_over_conn = False

                    if not is_over_conn and (((set_ == 'dev') and (int(npy_file.split('_')[3]) not in [1027, 6546])) or set_=='train') : 

                        # Change to tensor and reshape for dataloader
                        L = torch.tensor(L).view(1,20,20)
                        
                        data.append(L)
                        data_labels.append(szr_label)

                        p_id = npy_file.split('_')[3]

                        tot += 1

                        if set_ == 'train' : train_ids.append(int(p_id))
                        else : test_ids.append(int(p_id))
    
    print('Total : ',tot)

    return np.array(data, dtype=object), np.array(data_labels), train_ids, test_ids

In [61]:
input_dir = '../data/v1.5.2/graph_lapl_low_50'
is_cov = False
upper = True
revert = False
over_conn = False

classes = ['FNSZ','GNSZ']

class_dict = {}
for i, szr_type in enumerate(classes) :
    class_dict[szr_type] = i

# Load all graphs :
data, data_labels, train_ids, test_ids = load_data(input_dir, class_dict, is_cov, upper, revert, over_conn)

Total :  2344


  return np.array(data, dtype=object), np.array(data_labels), train_ids, test_ids


In [62]:
print(len(test_ids))

399


In [56]:
# Remove duplicates
#train_ids = list(dict.fromkeys(train_ids))
#test_ids = list(dict.fromkeys(test_ids))

# Sort the lists
train_ids.sort()
test_ids.sort()

In [51]:
print(len(train_ids))
print(len(test_ids))

1945
75


In [None]:
print('Train :\n',train_ids,'\nTest :\n',test_ids)

In [None]:
for id in train_ids :
    if id in test_ids :
        print('In both : ',id)