# Preprocessing Guide
This notebooks serves as a guide on converting the initial graph representations created by [Díaz-Montiel & Lankarany (2023)](https://www.biorxiv.org/content/10.1101/2023.06.02.543277v1.abstract) from the OpenNeuro ds003029 dataset into a format that can be used by [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/). This is fully automated using the `patch` function in `src/patch.py`. The processed data we are using can be found in the Graham cluster directory:

`/User/projects/def-milad777/gr_research/brain-greg/data/ds003029-processed/graph_representation_elements`

which contains folders for each patient and their runs.

In [45]:
# hyper-parameters
freq = 256
ws = int(1*freq)
step = int(0.0625*256)
dim = 23

In [46]:
base_data_path = '/Users/dentira/anomaly-detection/epilepsy-dataset/chb01/'

In [47]:
# data loading 
import pandas as pd
import numpy as np
import mne
import random
# use the labels.csv to load the data
path_labels = "/Users/dentira/anomaly-detection/epilepsy-dataset/labels_chb01.csv"
df = pd.read_csv(path_labels)
preitcal_list = []
ictal_list = []
for index, row in df.iterrows():
    file_name = row['File_names']
    label = row['Labels']
    start = row['Start_time']
    end = row['End_time']
    file_path = f"{base_data_path}/{file_name}"
    raw = mne.io.read_raw_edf(file_path)
    data, times = raw[:]
    if label == 0: 
        preitcal_list.append(data)
    else:
        preitcal_list.append(data[:, 0:start])
        if end < len(data)-1:
           preitcal_list.append(data[:, end+1:len(data)])
        ictal_list.append(data[:, start:end])
preictal_mat = np.hstack(preitcal_list)
ictal_mat = np.hstack(ictal_list)
print(preictal_mat.shape)
print(ictal_mat.shape)
# print(len(preitcal_list))
# print(len(ictal_list))


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_02.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)
  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_05.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)
  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_06.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_07.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_08.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_09.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_10.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_11.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_12.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_13.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_14.edf...
EDF file detected


  raw = mne.io.read_raw_edf(file_path)


Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_15.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_16.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_17.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_18.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_19.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_20.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_21.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)
  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_22.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_23.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_24.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_25.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_26.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_27.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_29.edf...
EDF file detected


  raw = mne.io.read_raw_edf(file_path)
  raw = mne.io.read_raw_edf(file_path)
  raw = mne.io.read_raw_edf(file_path)


Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_30.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_31.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_32.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_33.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_34.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_36.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_37.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_38.edf...
EDF file detected


  raw = mne.io.read_raw_edf(file_path)


Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_39.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)


Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_40.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_41.edf...


  raw = mne.io.read_raw_edf(file_path)


EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_42.edf...
EDF file detected


  raw = mne.io.read_raw_edf(file_path)


Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_43.edf...
EDF file detected


  raw = mne.io.read_raw_edf(file_path)


Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/dentira/anomaly-detection/epilepsy-dataset/chb01/chb01_46.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf(file_path)
  raw = mne.io.read_raw_edf(file_path)


(23, 31259247)
(23, 442)


In [48]:
import mne
print(mne.__version__)


1.9.0


In [49]:
from mne_connectivity import spectral_connectivity_epochs
from scipy.signal import welch,detrend
def get_band_energies(data, dim):
    freq_bands = {
        "Delta": (1, 4),
        "Theta": (4, 8),
        "Alpha": (8, 13),
        "Beta": (13, 30),
        "Gamma1": (30, 70),
        "Gamma2": (70, 100),
        "HighGamma1": (100, 250),
        "HighGamma2": (250, 500),
    }
    data = np.nan_to_num(data)  # Replace NaN/Inf with finite numbers
    if np.var(data) == 0:
        raise ValueError("Input data has zero variance. PSD cannot be computed.")

    data = detrend(data)  # Remove trends or biases
    nperseg = min(len(data), 1024)  # Ensure nperseg <= signal length
    freqs, psd = welch(data, fs=freq, nperseg=nperseg, axis=-1)
    n_bands = len(freq_bands)
    band_energy_matrix = np.zeros((dim, n_bands))

    for i,(band_name, (fmin, fmax)) in enumerate(freq_bands.items()):
        band_mask = (freqs >= fmin) & (freqs <= fmax)
        band_energy_matrix[:, i] = np.sum(psd[:, band_mask], axis=1)

    return band_energy_matrix
def generate_coh_array(segments):
    n_channels = segments.shape[0]
    n_times = segments.shape[1]
    segments = np.array(segments)

    # Create artificial epochs (e.g., splitting into 4 sub-epochs)
    n_sub_epochs = 4
    sub_epoch_length = n_times // n_sub_epochs
    segments_split = segments.reshape(n_sub_epochs, n_channels, sub_epoch_length)
    con = spectral_connectivity_epochs(
        segments_split,
        method='coh',
        mode='fourier',            # Use FFT for spectral estimation
        sfreq=freq,
        tmin=0,                  # Start coherence computation from 1 second
        tmax=None,                 # Use until the end of the data
        faverage=True,            # Do not average frequencies
        verbose=False           # Print detailed logs
    )
    coherence_matrix = con.get_data().reshape(dim, dim)
    coherence_matrix = np.nan_to_num(coherence_matrix)
    return coherence_matrix

def generate_plv_array(segments):
    n_channels = segments.shape[0]
    n_times = segments.shape[1]
    segments = np.array(segments)
    # Create artificial epochs (e.g., splitting into 4 sub-epochs)
    n_sub_epochs = 4
    sub_epoch_length = n_times // n_sub_epochs
    segments_split = segments.reshape(n_sub_epochs, n_channels, sub_epoch_length)
    con = spectral_connectivity_epochs(
        segments_split, method='plv', mode='multitaper', sfreq=freq,
        faverage=True, tmin=0, tmax=None, verbose=False
    )
    plv_matrix = con.get_data().reshape(dim, dim)
    plv_matrix = np.nan_to_num(plv_matrix)
    return plv_matrix
def generate_fcns(data, dim):
    fcns = []
    i = 0
    while i < len(data):
        fcn = []
        # all ones
        fcn.append(np.ones((dim, dim)))
        # correlation
        fcn.append(np.corrcoef(data[i]))
        # coherence
        fcn.append(generate_coh_array(data[i]))       
        # PLV
        fcn.append(generate_plv_array(data[i]))

        fcns.append(fcn)
        i+=1
    return fcns

def generate_node_features(data, dim):
    features = []
    # all ones
    features.append(np.ones((dim, dim)))
    # avg energy 
    features.append(np.mean(data**2, axis=1, keepdims=True))
    # band energies
    features.append(get_band_energies(data,dim)) 
    return features

In [50]:
def generate_segements(data):
     # print(data.shape)
    dim, num = data.shape
    i = 0
    segments = []
    while i < num:
        if num - i > ws:
            curr_win = data[:, i:i+256]
        segments.append(curr_win)
        i += step
    
    return segments

preictal_segments = generate_segements(preictal_mat)
ictal_segments = generate_segements(ictal_mat)
preictal_segments = random.sample(preictal_segments, 10000)

    

In [51]:
print(len(preictal_segments))
print(len(ictal_segments))

10000
28


In [52]:

def generate_graphs(data, fcns):
    i = 0
    graphs = []
    while i < len(data):
        NF = generate_node_features(data[i], dim)
        graphs.append([fcns[i], NF, np.expand_dims(fcns[i], axis=-1)])
        i += 1
    return graphs

In [53]:

ictal_fcns = generate_fcns(ictal_segments, dim)
precital_fcns = generate_fcns(preictal_segments, dim)



# preictal_graphs = generate_graphs(preictal_segments, precital_fcns)
# ictal_graphs = generate_graphs(ictal_segments, ictal_fcns)

  self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy)
  self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy)
  self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy)
  self.con_scores[con_idx] = np.abs(csd_mean) / np.sqrt(psd_xx * psd_yy)


In [54]:
print(len(ictal_fcns))
print(len(precital_fcns))

28
10000


In [55]:
# ictal_fcns[0][2].shape
precital_fcns[0][2].shape

(23, 23)

In [56]:
preictal_graphs = generate_graphs(preictal_segments, precital_fcns)
ictal_graphs = generate_graphs(ictal_segments, ictal_fcns)

In [57]:
import pickle
base_path = "/Users/dentira/anomaly-detection/ssl-based-model/ssl-seizure-detection/ssl_seizure_detection/dummy_data/supervised/"
with open(f"{base_path}ictal_data.pkl",'wb') as file:
    pickle.dump(ictal_graphs, file)

with open(f"{base_path}preictal_data.pkl",'wb') as file:
    pickle.dump(preictal_graphs, file)

In [58]:
import torch
import pickle
import numpy as np
import sys
sys.path.append('/Users/dentira/anomaly-detection/ssl-based-model/ssl-seizure-detection/ssl_seizure_detection/src/data')


In [59]:
from preprocess import new_grs, create_tensordata_new, convert_to_Data, pseudo_data, convert_to_PairData, convert_to_TripletData


### Step 1: Extracting Graph Representations
For each patient and each run, there are three files: preictal (before seizure), ictal (seizure occurring), and postictal (after seizure). Each file is a list with entries of the form the form `graph = [A, NF, EF]`. Where `A`, `NF`, and `EF` are lists of length 4, 3, and 4 respectively defined below.

`A = [A0, A1, A2, A3]`, where :
-   `A0` = Ones, shape `(107,107)`.
-   `A1` = Correlation, shape `(107,107)`.  
-   `A2` = Coherence, shape `(107,107)`.
-   `A3` = Phase, shape `(107,107)`.

`NF = [NF0, NF1, NF2]`, where:

-  `NF0` = Ones, shape `(107,1)`.
-  `NF1` = Average Energy, shape `(107,1)`.
-  `NF2` = Band Energy, shape `(107,8)`.


`EF = [EF0, EF1, EF2, EF3]`, where:

-  `EF0` = Ones, shape `(107,107,1)`.
-  `EF1` = Correlation, shape `(107,107,1)`.
-  `EF2` = Coherence, shape `(107,107,1)`.
-  `EF3` = Phase, shape `(107,107,1)`.

All the information above has been (experimentally) confirmed by the above and Alan's documentation of `get_nf`, `get_adj`, and `get_ef` helper functions in his `load_data()` function, but should talk to Alan about confirming these details for absolute certainty.

We'll first load the preictal, ictal, and postictal files for a single patient and run. In this case, the patient folder is `jh101` and we are using run $1$.

In [60]:
# Mac
path_ictal = f"/Users/dentira/anomaly-detection/ssl-based-model/ssl-seizure-detection/ssl_seizure_detection/dummy_data/supervised/ictal_data.pkl"
path_preictal = f"/Users/dentira/anomaly-detection/ssl-based-model/ssl-seizure-detection/ssl_seizure_detection/dummy_data/supervised/preictal_data.pkl"
# path_postictal = f"/Users/dentira/anomaly-detection/ssl-based-model/ssl-seizure-detection/ssl_seizure_detection/dummy_data/supervised/ictal_data.pkl"


with open(path_preictal, 'rb') as f:
    data_preictal = pickle.load(f)
with open(path_ictal, 'rb') as f:
    data_ictal = pickle.load(f)
# with open(path_postictal, 'rb') as f:
    # data_postictal = pickle.load(f)

### Step 2: Selecting Graph Representations
For simplicity we're going to select the most extensive graph representation:
-  `A` = None
-  `NF` = Average Energy and Band Energy, shape `(107,9)`.
-  `EF` = Correlation, Coherence, Phase, shape `(107, 107, 3)`.

Note that because most PyG layers do not use a separate adjacency matrix with weights, we will not use it, and instead we'll use all the possible edge features. This is facilitated by the `new_grs` functions which gives us the data with a label of `Y = [Y_1, Y_2]` where `Y_1` is the binary label for ictal (1) or nonictal (0), and `Y_2` is the multiclass label for preictal (0), ictal (1), or postictal (2).

##### Binary Classification

In [61]:
# Select the graph representation for each class
new_data_preictal = new_grs(data_preictal, type="preictal")
new_data_ictal = new_grs(data_ictal, type="ictal")
# new_data_postictal = new_grs(data_postictal, type="postictal")


After selecting the GRs for each class, we concatenate them temporally into a single list `[preictal, ictal, postictal]`.

In [62]:
new_data = new_data_preictal + new_data_ictal

In [63]:
num_electrodes = new_data[0][0][0].shape[0]
print(f"Number of electrodes: {num_electrodes}")

Number of electrodes: 23


### Step 3: Standard GRs $\rightarrow$ PyG GRs
The function `create_tensordata_new` converts the pickle file list of standard graph representations, a list with entries of the form $[ [NF, EF] , Y]$, where $NF$ are the node features, $EF$ are the edge features, and $Y$ is the graph label. The function first inserts an `edge_index` for a **complete graph** in the PyG format, which is a tensor of shape `[2, num_edges]` where each column $[i \ \ j]^T$ indicates the directed edge $i \to j$; this is built using the helper function `build_K_n` found in `preprocess.py`. The node features $NF$ are untouched, but converted to float32 a tensor, notated by `x` in PyG. The edge features are converted to `edge_attr` which is a float32 tensor of shape `[num_edges, num_edge_features]` which follows the `edge_index` accordingly, i.e. the 4th column in `edge_index` (4th edge) will correspond to the edge feature `edge_attr[3,:]`, and so on. The label $Y$ is converted to a long torch tensor. The output is a list with entries of the form `[[edge_index, x, edge_attr], y]`.

In [64]:
pyg_grs = create_tensordata_new(num_nodes=num_electrodes, data_list=new_data, complete=True, save=False, logdir=None)
print(len(pyg_grs))
zero = 0
one = 0
two = 0
for g in pyg_grs:
    # print(g[1][0][0])
    if g[1][0][1] == 0: zero+=1
    elif g[1][0][1] == 1: one+=1
    else: two+=1
print(zero)
print(one)
print(two)


10028
10000
28
0


In [65]:
# Look inside of pyg_grs
print(len(pyg_grs))
print(type(pyg_grs[0][0][0]))
print("Edge features shape:", pyg_grs[0][0][2].shape)
print("Edge features stored in edge_attr:", pyg_grs[0][0][2])

10028
<class 'torch.Tensor'>
Edge features shape: torch.Size([506, 3])
Edge features stored in edge_attr: tensor([[ 0.3355,  0.0000,  0.0000],
        [-0.1031,  0.0000,  0.0000],
        [-0.1673,  0.0000,  0.0000],
        ...,
        [-0.4085,  0.3932,  0.3770],
        [ 0.1278,  0.3777,  0.2925],
        [ 0.3122,  0.7194,  0.4664]])


### Step 4: PyG GRs $\rightarrow$ PyG Data
<u>**Stop after this step**</u> if you only need PyG Data for <u>**supervised learning**</u>. 

Here we take the PyG graph representations, and apply the `convert_to_Data` function to create a new list where each entry is now a `torch_geometric.data.Data` object. This is the main object uses to hold graphs in PyG, so we need to use it, especially for batching (for more details see my tutorial `tutorial.ipynb`, or click [here](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html) for the official tutorial from PyG).

In [66]:
# Convert the PyG GRs to the PyG Data format
pyg_Data_path = "/Users/dentira/anomaly-detection/ssl-based-model/ssl-seizure-detection/ssl_seizure_detection/patient_gr/jh101_pyg_Data.pt"
Data_list = convert_to_Data(pyg_grs, save=True, logdir=pyg_Data_path)

### Step 5: Relative Positioning
In this step we take the output of Step 3 (`pyg_grs`) and create the pseudolabeled dataset of graph pairs for the relative positioning self-supervised method.  Given our list `pyg_grs` and hyperparameters $\tau_+$ and $\tau_-$. The function `pseudo_data` below returns a list of graph pairs where each entry is of the form `[[edge_index1, x1, edge_attr1], [edge_index2, x2, edge_attr2], y]`, where `y` is a pseudolabel (not the old label). Since the total size of the pseudolabeled dataset can be quite large, we use the `sample_ratio` argument to randomly sample a certain portion of it (e.g., `sample_ratio = 0.2` will give us 20% of the total pseudolabeled dataset). Also note that the function will return an equal number of positive and negative samples, as `pseudo_data` automatically balances out the correspondingly classes.

In [91]:
pdata = pseudo_data(pyg_grs, tau_pos=12 // 0.12, tau_neg=60 // 0.12, stats=True, save=False, patientid="", 
                            logdir=None, model="relative_positioning", sample_ratio=0.10)

Number of examples: 0
Series([], Name: count, dtype: int64)


In [92]:
# Look inside of pdata
print(len(pdata))
example = pdata[0]
graph1, graph2, label = example
edge_index1, x1, edge_attr1 = graph1
edge_index2, x2, edge_attr2 = graph2
print("Edge features shape:", edge_attr1.shape)
print("Edge features stored in edge_attr:", edge_attr1)

0


IndexError: list index out of range

Now instead of converting each graph pair to `torch_geometric.data.Data` object, we instead create a new class called `PairData` that inherits from the `torch_geometric.data.Data` class, allowing us to batch *pairs* of graphs. We use the `convert_to_PairData` function to convert the list of graph pairs to a list of `PairData` objects (see [here](https://pytorch-geometric.readthedocs.io/en/latest/advanced/batching.html) for more details).

In [16]:
Pair_Data = convert_to_PairData(pdata, save=False, logdir=None)

### Step 6: Temporal Shuffling
This step is nearly identical to Step 5, we take the `pyg_grs` and use them to create a pseudolabeled dataset for the temporal shuffling self-supervised method. However, in this method we generate *graph triplets* of the form `[[edge_index1, x1, edge_attr1], [edge_index2, x2, edge_attr2], [edge_index3, x3, edge_attr3], y]` where `y` is the pseudolabel. The size of the pseudolabeled dataset for temporal shuffling can be extremely large, therefore it is <u>**strongly encouraged**</u> to use the `sample_ratio` argument to scale down the dataset.

In [19]:
pdata = pseudo_data(pyg_grs, tau_pos=12 // 0.12, tau_neg=60 // 0.12, stats=True, save=False, patientid="patient", logdir=None, 
                    model="temporal_shuffling", sample_ratio=0.3)

In [None]:
print(pdata[0][0][2])

Similar to Step 5, we create a new class called `TripletData` that inherits from the `torch_geometric.data.Data` class for batching graph triplets in PyG.

In [None]:
Triplet_Data = convert_to_TripletData(pdata, save=False, logdir=None)

### Step 7: Automatic Conversion
The `patch` function in `patch.py` does all of the above, converting the original preictal, ictal, and postictal files from a single patient run. Please see documentation in `patch.py`

In [1]:
import sys
import torch
from ssl_seizure_detection.src.patch import single_patient_patcher

# PC
patient_dir = r"C:\Users\xmoot\Desktop\Data\ssl-seizure-detection\patient_gr"
patient = "jh101"
logdir = r"C:\Users\xmoot\Desktop\Data\ssl-seizure-detection\patient_pyg"

# Patch the data
data = single_patient_patcher(user="xmootoo", patient_dir=patient_dir, patient=patient, logdir=logdir, model="VICRegT1", stats=True, save=True,
                              sigma=5, tau=0.68)

jh101_data = torch.load(r"C:\Users\xmoot\Desktop\Data\ssl-seizure-detection\patient_pyg\jh101\VICRegT1\jh101_combined.pt")

print(len(jh101_data))