In [1]:
"""
Some data preprocessing codes for TUSZ are adapted from https://github.com/tsy935/eeg-gnn-ssl
Compatible with TUSZ v2.0.0
"""
from tqdm import tqdm
import numpy as np
import pandas as pd
import os
import pyedflib
import h5py
from scipy.signal import resample
from scipy.fftpack import fft

# Channels of interest
INCLUDED_CHANNELS = [
    'EEG FP1',
    'EEG FP2',
    'EEG F3',
    'EEG F4',
    'EEG C3',
    'EEG C4',
    'EEG P3',
    'EEG P4',
    'EEG O1',
    'EEG O2',
    'EEG F7',
    'EEG F8',
    'EEG T3',
    'EEG T4',
    'EEG T5',
    'EEG T6',
    'EEG FZ',
    'EEG CZ',
    'EEG PZ']

# All seizure labels available in TUH
ALL_LABEL_DICT = {'fnsz': 1, 'gnsz': 2, 'spsz': 3, 'cpsz': 4,
                  'absz': 5, 'tnsz': 6, 'tcsz': 7, 'mysz': 8}

# Dataset
FREQUENCY = 100 # Hz
clip_len = 10
time_step = 1
is_fft = False
# is_fft = True
split = "eval"
# anom_cls_inc = ['fnsz'] 
anom_cls_inc = []

if is_fft:
    fft_label='fft'
else:
    fft_label='nofft'
    
if anom_cls_inc == []:
    anom_cls_label = 'allcs'
    anom_cls_idx = [1,2,3,4,5,6,7,8]
else:
    anom_cls_label = 'inc'
    anom_cls_idx = []
    for lab in anom_cls_inc:
        anom_cls_label = lab[0] + anom_cls_label 
        anom_cls_idx.append(ALL_LABEL_DICT[lab])

# Data paths
raw_edf_dir = '../data/TUSZ/eval' # dir of unprecessed data tree 
save_dir = '../data/TUSZ/eval_resampled' # dir of resampled outputs
output_dir = '../data/TUSZ'  # dir of preprocessed outputs

In [2]:
# Statistics
edf_files = []
for path, subdirs, files in os.walk(raw_edf_dir):
    for name in files:
        if ".edf" in name:
            edf_files.append(os.path.join(path, name))
            
seizure_class_counts =  {'fnsz': 0, 'gnsz': 0, 'spsz': 0, 'cpsz': 0, 
                         'absz': 0, 'tnsz': 0, 'tcsz': 0, 'mysz': 0}

for cls in tqdm(seizure_class_counts): # for each seizure class
    for file_name in edf_files:
        csv_file = file_name.split(".edf")[0] + ".csv"
        with open(csv_file) as f:
            for line in f.readlines():
                if "#" in line:
                    continue
                if "channel,start_time,stop_time,label,confidence" in line:
                    continue
                if cls in line: 
                    seizure_class_counts[cls] += 1
                    break

print(seizure_class_counts)

100%|██████████| 8/8 [00:05<00:00,  1.56it/s]

{'fnsz': 102, 'gnsz': 58, 'spsz': 0, 'cpsz': 31, 'absz': 1, 'tnsz': 1, 'tcsz': 8, 'mysz': 0}





In [5]:
# Step 1 Resampling
def getEDFsignals(edf):
    """
    Get EEG signal in edf file
    Args:
        edf: edf object
    Returns:
        signals: shape (num_channels, num_data_points)
    """
    n = edf.signals_in_file
    samples = edf.getNSamples()[0]
    signals = np.zeros((n, samples))
    for i in range(n):
        try:
            signals[i, :] = edf.readSignal(i)
        except:
            pass
    return signals

def getOrderedChannels(file_name, verbose, labels_object, channel_names):
    labels = list(labels_object)
    for i in range(len(labels)):
        labels[i] = labels[i].split("-")[0]

    ordered_channels = []
    for ch in channel_names:
        try:
            ordered_channels.append(labels.index(ch))
        except:
            if verbose:
                print(file_name + " failed to get channel " + ch)
            raise Exception("channel not match")
    return ordered_channels

def resampleData(signals, to_freq, window_size):
    """
    Resample signals from its original sampling freq to another freq
    Args:
        signals: EEG signal slice, (num_channels, num_data_points)
        to_freq: Re-sampled frequency in Hz
        window_size: time window in seconds
    Returns:
        resampled: (num_channels, resampled_data_points)
    """
    num = int(to_freq * window_size)
    resampled = resample(signals, num=num, axis=1)
    return resampled

In [4]:
failed_files = []
for idx in tqdm(range(len(edf_files))):
    edf_fn = edf_files[idx]
    save_fn = os.path.join(save_dir, edf_fn.split("\\")[-1].split(".edf")[0] + ".h5")
    if os.path.exists(save_fn):
        continue
    else:
        try:
            f = pyedflib.EdfReader(edf_fn)
            orderedChannels = getOrderedChannels(
                edf_fn, False, f.getSignalLabels(), INCLUDED_CHANNELS
            )
            signals = getEDFsignals(f)
            signal_array = np.array(signals[orderedChannels, :])

            sample_freq = f.getSampleFrequency(0)
            if sample_freq != FREQUENCY:
                signal_array = resampleData(
                    signal_array,
                    to_freq=FREQUENCY,
                    window_size=int(signal_array.shape[1] / sample_freq),
                )

            with h5py.File(save_fn, "w") as hf:
                hf.create_dataset("resampled_signal", data=signal_array)
    #             hf.create_dataset("resample_freq", data=FREQUENCY)

        except BaseException as e:
            failed_files.append(edf_fn)

print("RESAMPLING DONE")
print("Failed files: ",failed_files)
num_removed = 0
for failed in failed_files:
    edf_files.remove(failed)
    num_removed += 1
print(f'{num_removed} files removed.')

100%|██████████| 881/881 [10:07<00:00,  1.45it/s]

RESAMPLING DONE
Failed files:  []
0 files removed.





In [6]:
# Step 2 Main Data Processing
# - Split into clips 
# - Fft
# - Seizure/class labels

def getAnomalousChannels(montage, channels):
    """
    Args:
        montage: string of bipolar montage
        channels: list of included channels
    Returns:
        anom_channels: one hot vector of anomalous channels
    """
    anom_channels = []
    for ch in INCLUDED_CHANNELS:
        if montage.split("-")[0] in ch:
            anom_channels.append(1)
        elif montage.split("-")[1] in ch:
            anom_channels.append(1)
        else: 
            anom_channels.append(0)
            
    return anom_channels

def getSeizureTimes(file_name, verbose=False):
    """
    Args:
        file_name: edf file name
    Returns:
        seizure_times: list of times of seizure onset in seconds
        seizure_classes: list of seizure classes corresponding to seizure onset
        seizure_channels: list of seizure channels corresponding to seizure onset
    """
    csv_file = file_name.split(".edf")[0] + ".csv"

    seizure_times = []
    seizure_classes = []
    seizure_channels = []
    with open(csv_file) as f:
        for line in f.readlines():
            if "#" in line:
                continue
            if "channel,start_time,stop_time,label,confidence" in line:
                continue
            seizure_time = [float(line.strip().split(",")[1]),float(line.strip().split(",")[2])]
            for cls in ALL_LABEL_DICT: # for each seizure class
                if cls in line: 
                    # seizure start and end time
                    seizure_times.append(seizure_time)
                    seizure_classes.append(int(ALL_LABEL_DICT[cls]))
                    seizure_channels.append(getAnomalousChannels(line.strip().split(",")[0],INCLUDED_CHANNELS))

    if verbose:
        print(csv_file)
        print(seizure_times)
        print(seizure_classes)
        print(seizure_channels)
        print(len(seizure_times))

    return seizure_times, seizure_classes, seizure_channels

def computeFFT(signals, n):
    """
    Args:
        signals: EEG signals, (number of channels, number of data points)
        n: length of positive frequency terms of fourier transform
    Returns:
        FT: log amplitude of FFT of signals, (number of channels, number of data points)
        P: phase spectrum of FFT of signals, (number of channels, number of data points)
    """
    # fourier transform
    fourier_signal = fft(signals, n=n, axis=-1)  # FFT on the last dimension

    # only take the positive freq part
    idx_pos = int(np.floor(n / 2))
    fourier_signal = fourier_signal[:, :idx_pos]
    amp = np.abs(fourier_signal)
    amp[amp == 0.0] = 1e-8  # avoid log of 0

    FT = np.log(amp)
    P = np.angle(fourier_signal)

    return FT, P

def computeSliceMatrix(
        h5_fn,
        seizure_times,
        seizure_classes,
        seizure_channels,
        clip_idx,
        clip_len=12,
        time_step_size=1,
        is_fft=False,
        n_fft=FREQUENCY,
        res_freq=FREQUENCY,
        inc_chan=INCLUDED_CHANNELS,
        verbose=False):
    """
    Convert entire EEG sequence into clips of length clip_len
    Args:
        h5_fn: file name of resampled signal h5 file (full path)
        clip_idx: index of current clip/sliding window
        time_step_size: length of each time_step_size, in seconds, int
        clip_len: sliding window size or EEG clip length, in seconds, int
        is_fft: whether to perform FFT on raw EEG data
        n_fft: number of frequency terms for FFT
    Returns:
        slices: list of EEG clips, each having shape (num_channels, n_fft/2)
        is_seizure: list of seizure labels for each clip, 1 for seizure, 0 for no seizure
        is_seizure_one_hot: list of one-hot matrices of each clip that contains anomalous sensors
        is_seizure_class: list of classes of each clip, 0 for normal, >0 for seizure
        
    """
    with h5py.File(h5_fn, 'r') as f:
        signal_array = f["resampled_signal"][()]
    
    # clip slicing
    physical_clip_len = int(res_freq * clip_len)
    physical_time_step_size = int(res_freq * time_step_size)
    
    start_window = clip_idx * physical_clip_len
    end_window = start_window + physical_clip_len
    window = pd.Interval(start_window, end_window)
    # (num_channels, physical_clip_len)
    curr_slc = signal_array[:, start_window:end_window]
    
    start_time_step = 0
    time_steps = []
    while start_time_step <= curr_slc.shape[1] - physical_time_step_size:
        end_time_step = start_time_step + physical_time_step_size
        # (num_channels, physical_time_step_size)
        curr_time_step = curr_slc[:, start_time_step:end_time_step]
        if is_fft:
            curr_time_step, _ = computeFFT(
                curr_time_step, n=physical_time_step_size)
        time_steps.append(curr_time_step)
        start_time_step = end_time_step
        
    eeg_clip = np.concatenate(time_steps, axis=1)
    
    # determine if there's seizure in current clip
    is_seizure = 0
    is_seizure_one_hot = np.array([0]*len(inc_chan))
    is_seizure_class = []
    
    for ivl, t in enumerate(seizure_times):
        start_t = int(t[0] * res_freq)
        end_t = int(t[1] * res_freq)
        seizure_ivl = pd.Interval(start_t,end_t)
                            
        if window.overlaps(seizure_ivl):
            is_seizure = 1
            is_seizure_one_hot = is_seizure_one_hot + np.array(seizure_channels[ivl])
            is_seizure_one_hot = is_seizure_one_hot > 0
            is_seizure_one_hot = is_seizure_one_hot.astype(int)
            if seizure_classes[ivl] not in is_seizure_class:
                is_seizure_class.append(seizure_classes[ivl])
        
        if is_seizure_class==[]:
            is_seizure_class.append(0)

    if verbose:
        print(clip_idx)
        print(window)
        print(is_seizure)
        print(is_seizure_one_hot)
        print(is_seizure_class)
        
    return eeg_clip, is_seizure, is_seizure_one_hot, is_seizure_class

def getNumClips(h5_fn_path,
        clip_len=12,
        res_freq=FREQUENCY):
    """
    Args:
        h5_fn_path: file name of resampled signal h5 file (full path)
        clip_len: length of each clips (in seconds)
        res_freq: sampling frequency
    Returns:
        num_clips: number of clips available in the file
    """
    physical_clip_len = int(res_freq * clip_len)
    
    # read file
    with h5py.File(h5_fn_path, "r") as f:
        signal_array = f["resampled_signal"][()]
        
    # calculate number of clips
    num_clips = int(signal_array.shape[1] / physical_clip_len)
    
    return num_clips

In [7]:
output_file = os.path.join(
    output_dir,
    split +
    "_" +
    anom_cls_label +
    "_" +
    fft_label +
    ".h5")

f = h5py.File(output_file, "w")
f.close()

ds_idx = 0
nm_idx = 0
am_idx = 0
normal_count = 0
anom_node_count = 0
seizure_clip_counts =  {'fnsz': 0, 'gnsz': 0, 'spsz': 0, 'cpsz': 0, 
                         'absz': 0, 'tnsz': 0, 'tcsz': 0, 'mysz': 0}
seizure_node_counts =  {'fnsz': 0, 'gnsz': 0, 'spsz': 0, 'cpsz': 0, 
                         'absz': 0, 'tnsz': 0, 'tcsz': 0, 'mysz': 0}
clips_excluded = 0

# for idx in tqdm(range(100)):
for idx in tqdm(range(len(edf_files))):
    h5_fn = edf_files[idx].split("\\")[-1].split('.edf')[0] + '.h5'
    h5_fn_path = os.path.join(save_dir,h5_fn)
    edf_fn = edf_files[idx]
    verbosity = False
    
    num_clips = getNumClips(h5_fn_path,clip_len=clip_len,res_freq=FREQUENCY)
    seizure_times, seizure_classes, seizure_channels = getSeizureTimes(edf_fn.split('.edf')[0], verbose=verbosity)
    
    for clp in range(num_clips):
        eeg_clip, is_seizure, is_seizure_one_hot, is_seizure_class = computeSliceMatrix(
            h5_fn=h5_fn_path,
            seizure_times=seizure_times,
            seizure_classes=seizure_classes,
            seizure_channels=seizure_channels,
            clip_idx=clp,
            clip_len=clip_len,
            time_step_size=time_step,
            is_fft=is_fft,
            n_fft=FREQUENCY,
            res_freq=FREQUENCY,
            inc_chan=INCLUDED_CHANNELS,
            verbose=verbosity)
        
        for cls in is_seizure_class:
            if cls>0 and cls not in anom_cls_idx:
                clips_excluded += 1
                continue
        
        with h5py.File(output_file,'r+') as hf: # normal + abnormal
            grp = hf.create_group("clip"+str(ds_idx))
            grp.create_dataset("data",data=eeg_clip)
            grp.create_dataset("anom_label",data=is_seizure)
            grp.create_dataset("anom_class",data=is_seizure_class)
            grp.create_dataset("anom_channels",data=is_seizure_one_hot)


        ds_idx += 1
        anom_node_count += np.sum(is_seizure_one_hot)

        if is_seizure==0:
            normal_count += 1
            
        for cls in seizure_clip_counts:
            cls_label = ALL_LABEL_DICT[cls]
            if cls_label in is_seizure_class:
                seizure_clip_counts[cls] += 1
                seizure_node_counts[cls] += sum(is_seizure_one_hot)

anom_clips = 0
for key in seizure_clip_counts:
    anom_clips += seizure_clip_counts[key]
seizure_clip_ratios = {key: seizure_clip_counts[key] / (anom_clips+normal_count) for key in seizure_clip_counts}
seizure_node_ratios = {key: seizure_node_counts[key] / (anom_node_count + normal_count*19) for key in seizure_node_counts}

print("DONE")
print("Abnormal ratio= {}".format(1-normal_count/(ds_idx+1)))
print("Abnormal node ratio= {}".format(anom_node_count/((ds_idx+1)*len(INCLUDED_CHANNELS))))
print("Seizure class counts: {}".format(seizure_clip_counts))
print("Seizure node class counts: {}".format(seizure_node_counts))
print("Seizure class ratios: {}".format(seizure_clip_ratios))
print("Seizure node ratios: {}".format(seizure_node_ratios))
print("Normal count: {}".format(normal_count))
print("Normal node count: {}".format(normal_count*19))
print("Clips excluded: {}".format(clips_excluded))

100%|██████████| 881/881 [10:47<00:00,  1.36it/s]

DONE
Abnormal ratio= 0.06785009015197041
Abnormal node ratio= 0.03927485279697053
Seizure class counts: {'fnsz': 1598, 'gnsz': 1024, 'spsz': 0, 'cpsz': 395, 'absz': 73, 'tnsz': 5, 'tcsz': 68, 'mysz': 0}
Seizure node class counts: {'fnsz': 13517, 'gnsz': 15163, 'spsz': 0, 'cpsz': 3933, 'absz': 1241, 'tnsz': 39, 'tcsz': 921, 'mysz': 0}
Seizure class ratios: {'fnsz': 0.0342992058381627, 'gnsz': 0.021978965443228162, 'spsz': 0.0, 'cpsz': 0.008478214209057738, 'absz': 0.0015668598411676326, 'tnsz': 0.0001073191672032625, 'tcsz': 0.00145954067396437, 'mysz': 0.0}
Seizure node ratios: {'fnsz': 0.01571967186042671, 'gnsz': 0.01763389690165349, 'spsz': 0.0, 'cpsz': 0.0045739046701974, 'absz': 0.001443227992808282, 'tnsz': 4.535527132918856e-05, 'tcsz': 0.0010710821767739144, 'mysz': 0.0}
Normal count: 43427
Normal node count: 825113
Clips excluded: 0





In [8]:
# Step 3 Calculate mean and std
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

output_file = os.path.join(
    output_dir,
    split +
    "_" +
    anom_cls_label +
    "_" +
    fft_label +
    ".h5")

with h5py.File(output_file, "r") as f:
    keys = list(f.keys())

    for i, key in tqdm(enumerate(keys)):
        data = f[key]['data'][()]
        data = np.reshape(data, (-1,1))
        scaler.partial_fit(data)

means = scaler.mean_[0]
stds = scaler.scale_[0]

print(means)
print(stds)

46587it [00:20, 2229.17it/s]

3.697534735305587
1.3957092773450142





In [9]:
# Step 3: Standardize data
output_file = os.path.join(
    output_dir,
    split +
    "_" +
    anom_cls_label +
    "_" +
    fft_label +
    ".h5")

output_file_norm = os.path.join(
    output_dir,
    split +
    "_" +
    anom_cls_label +
    "_" +
    fft_label +
    "_normalized" +
    ".h5")

f = h5py.File(output_file_norm, "w")
f.close()

with h5py.File(output_file, "r") as f:
    with h5py.File(output_file_norm,'r+') as fn: # normal + abnormal
        keys = list(f.keys())

        for i, key in tqdm(enumerate(keys)):
            x = f[key]['data'][()]
            x = np.reshape(x, (-1,1))
            x = scaler.transform(x)
            x = np.reshape(x, (19,-1))
            
            grp = fn.create_group(key)
            grp.create_dataset("data",data=x)
            grp.create_dataset("anom_label",data=f[key]['anom_label'][()])
            grp.create_dataset("anom_class",data=f[key]['anom_class'][()])
            grp.create_dataset("anom_channels",data=f[key]['anom_channels'][()])

46587it [01:09, 666.42it/s]


In [3]:
# Create labels
output_file = os.path.join(
    output_dir,
    split +
    "_" +
    anom_cls_label +
    "_" +
    fft_label +
    "_normalized" +
    ".h5")

with h5py.File(output_file, "r") as f:
    keys = range(len(f.keys()))
    labels = []
    cls_labels = []
    for i, key in tqdm(enumerate(keys)):
        key = "clip" + str(key)
        labels.append(f[key]['anom_label'][()])
        if len(f[key]['anom_class'][()])==1:
            cls_labels.append(f[key]['anom_class'][()][0])
        else:
            cls_labels.append(0)

labels = np.array(labels)
print(labels.shape)
cls_labels = np.array(cls_labels)
fnsz_labels = cls_labels==1
gnsz_labels = cls_labels==2
cpsz_labels = cls_labels==4
print(np.mean(labels))


46587it [00:29, 1575.11it/s]


(46587,)
0.06783008135316719


In [4]:
print(np.sum(fnsz_labels))
print(np.sum(gnsz_labels))
print(np.sum(cpsz_labels))
print(np.mean(fnsz_labels))
print(np.mean(gnsz_labels))
print(np.mean(cpsz_labels))
print(fnsz_labels.shape)
print(gnsz_labels.shape)
print(cpsz_labels.shape)
print(labels.shape)

815
471
208
0.01749415072874407
0.010110116556120806
0.0044647648485629035
(46587,)
(46587,)
(46587,)
(46587,)


In [20]:
from sklearn.model_selection import train_test_split

# Split into train and test
train_idx, test_idx, train_labels, test_labels = train_test_split(
    np.arange(len(labels)), labels, test_size=0.1, random_state=123, stratify=labels
)

print(train_idx)
print(train_idx.shape)
print(test_idx)
print(test_idx.shape)
print(np.mean(train_labels))
print(np.mean(test_labels))
print(np.sum(test_labels))


[33421 16042  4216 ... 25709 22472 14133]
(41928,)
[29744 15218 32651 ... 12666  9081  9430]
(4659,)
0.067830566685747
0.0678257136724619
316


In [6]:
sz_classes = ['fnsz','gnsz','cpsz']
shots_per_class = 30

fnsz_train_idx = train_idx[fnsz_labels[train_idx]]
gnsz_train_idx = train_idx[gnsz_labels[train_idx]]
cpsz_train_idx = train_idx[cpsz_labels[train_idx]]

iso_anom_idx = []
iso_anom_idx.extend(fnsz_train_idx[:shots_per_class])
iso_anom_idx.extend(gnsz_train_idx[:shots_per_class])
iso_anom_idx.extend(cpsz_train_idx[:shots_per_class])

iso_anom_idx = np.array(iso_anom_idx)
print(iso_anom_idx)
print(iso_anom_idx.shape)

[11477 44255 32456 23248 11131 28450 35694 31883 32130 23769 31894 32575
 32486 43285 32611 42076 32862 43127 44256 32098 39658 32060 32059 32049
 43053 32296 16444 35701 32339 32345 10194  1884  9673  1452 45548 45749
 45538  9247 46019 45578 22634 46002 46024 45561 29867 45967 17362 45752
 22642 22644 46026 10050 10095 45571 30032 45994  9207 10197 45978  9819
 19878 24308 25289  5125  5290 24451 19626 24519 24305 24988 23957 24599
 24456 23967  5827  5284 24452 20272 24590 24596 24528 19932 19751  5828
 25229 20185 25109 24598 19881  5128]
(90,)


In [10]:
output_file_norm = os.path.join(
    output_dir,
    split +
    "_" +
    anom_cls_label +
    "_" +
    fft_label +
    "_normalized" +
    ".h5")

# Get isolated anomaly data
iso_anom_data = []
iso_anom_labels = []
print(labels)
print(labels[iso_anom_idx])
print(iso_anom_idx)
group_keys = []
with h5py.File(output_file_norm, "r") as f:
    for idx in tqdm(iso_anom_idx):
        group_key = "clip" + str(idx)
        group_keys.append(group_key)
        iso_anom_data.append(f[group_key]['data'][()])
        iso_anom_labels.append(int(f[group_key]['anom_label'][()]))

iso_anom_data = np.stack(iso_anom_data,axis=0)
iso_anom_labels = np.array(iso_anom_labels)
print(iso_anom_data.shape)
print(iso_anom_labels.shape)
print(iso_anom_labels)
print(group_keys)

[0 0 0 ... 0 0 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[11477 44255 32456 23248 11131 28450 35694 31883 32130 23769 31894 32575
 32486 43285 32611 42076 32862 43127 44256 32098 39658 32060 32059 32049
 43053 32296 16444 35701 32339 32345 10194  1884  9673  1452 45548 45749
 45538  9247 46019 45578 22634 46002 46024 45561 29867 45967 17362 45752
 22642 22644 46026 10050 10095 45571 30032 45994  9207 10197 45978  9819
 19878 24308 25289  5125  5290 24451 19626 24519 24305 24988 23957 24599
 24456 23967  5827  5284 24452 20272 24590 24596 24528 19932 19751  5828
 25229 20185 25109 24598 19881  5128]


100%|██████████| 90/90 [00:00<00:00, 3102.07it/s]

(90, 19, 1000)
(90,)
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
['clip11477', 'clip44255', 'clip32456', 'clip23248', 'clip11131', 'clip28450', 'clip35694', 'clip31883', 'clip32130', 'clip23769', 'clip31894', 'clip32575', 'clip32486', 'clip43285', 'clip32611', 'clip42076', 'clip32862', 'clip43127', 'clip44256', 'clip32098', 'clip39658', 'clip32060', 'clip32059', 'clip32049', 'clip43053', 'clip32296', 'clip16444', 'clip35701', 'clip32339', 'clip32345', 'clip10194', 'clip1884', 'clip9673', 'clip1452', 'clip45548', 'clip45749', 'clip45538', 'clip9247', 'clip46019', 'clip45578', 'clip22634', 'clip46002', 'clip46024', 'clip45561', 'clip29867', 'clip45967', 'clip17362', 'clip45752', 'clip22642', 'clip22644', 'clip46026', 'clip10050', 'clip10095', 'clip45571', 'clip30032', 'clip45994', 'clip9207', 'clip10197', 'clip45978', 'clip9819', 'clip19878', 'clip24




In [15]:
output_file_norm = os.path.join(
    output_dir,
    split +
    "_" +
    anom_cls_label +
    "_" +
    fft_label +
    "_normalized" +
    ".h5")

# Create test dataset file
test_data = []
test_labels = []
with h5py.File(output_file_norm, "r") as f:
    for idx in tqdm(test_idx):
        group_key = "clip" + str(idx)
        test_data.append(f[group_key]['data'][()])
        test_labels.append(int(f[group_key]['anom_label'][()]))

test_data = np.stack(test_data,axis=0)
test_labels = np.array(test_labels)
print(test_data.shape)
print(test_labels.shape)

print(test_data.shape)
print(test_labels.shape)

test_file = os.path.join(output_dir, "tusz_test_{}.h5".format(fft_label))
with h5py.File(test_file, "w") as f:
    f.create_dataset("X", data=test_data)
    f.create_dataset("y", data=test_labels)


100%|██████████| 4659/4659 [00:01<00:00, 2976.98it/s]


(4659, 19, 500)
(4659,)
(4659, 19, 500)
(4659,)


In [31]:
# Create test dataset file - hard setting
output_file_norm = os.path.join(
    output_dir,
    split +
    "_" +
    anom_cls_label +
    "_" +
    fft_label +
    "_normalized" +
    ".h5")

# Create test dataset file
test_data = []
test_labels = []
test_class_labels = []
with h5py.File(output_file_norm, "r") as f:
    for idx in tqdm(test_idx):
        group_key = "clip" + str(idx)
        test_data.append(f[group_key]['data'][()])
        test_labels.append(int(f[group_key]['anom_label'][()]))
        cls_label = f[group_key]['anom_class'][()]
        # print(cls_label)
        if cls_label.size>0:
            if cls_label.size==1 and cls_label[0]==0:
                test_class_labels.append(0)
            else:
                cls_label = np.setdiff1d(cls_label,[0])
                test_class_labels.append(int(cls_label[0]))
        else:
            test_class_labels.append(0)
            
            

test_data = np.stack(test_data,axis=0)
test_labels = np.array(test_labels)
test_class_labels = np.array(test_class_labels)
print(test_data.shape)
print(test_labels.shape)
print(test_class_labels.shape)

100%|██████████| 4659/4659 [00:01<00:00, 2447.96it/s]


(4659, 19, 1000)
(4659,)
(4659,)


In [32]:
print(test_class_labels[:100])
print("Normal ratio: ", len(np.where(np.array(test_class_labels) == 0)[0])/len(test_class_labels))
print("fnsz ratio: ", len(np.where(np.array(test_class_labels) == 1)[0])/len(test_class_labels))
print("gnsz ratio: ", len(np.where(np.array(test_class_labels) == 2)[0])/len(test_class_labels))
print("spsz ratio: ", len(np.where(np.array(test_class_labels) == 3)[0])/len(test_class_labels))
print("cpsz ratio: ", len(np.where(np.array(test_class_labels) == 4)[0])/len(test_class_labels))
print("absz ratio: ", len(np.where(np.array(test_class_labels) == 5)[0])/len(test_class_labels))
print("tnsz ratio: ", len(np.where(np.array(test_class_labels) == 6)[0])/len(test_class_labels))
print("tcsz ratio: ", len(np.where(np.array(test_class_labels) == 7)[0])/len(test_class_labels))
print("mysz ratio: ", len(np.where(np.array(test_class_labels) == 8)[0])/len(test_class_labels))
print("Abnormal ratio: ", len(np.where(np.array(test_class_labels) != 0)[0])/len(test_class_labels))


[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 2 0 0 2 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 5 0 0 0 0 0 2 0 0
 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
Normal ratio:  0.9321742863275381
fnsz ratio:  0.0352006868426701
gnsz ratio:  0.0223223867782786
spsz ratio:  0.0
cpsz ratio:  0.0068684267010088
absz ratio:  0.00128783000643915
tnsz ratio:  0.000214638334406525
tcsz ratio:  0.0019317450096587251
mysz ratio:  0.0
Abnormal ratio:  0.0678257136724619


In [35]:
seen_class = "cpsz"

class_idx = {'fnsz': 1, 'gnsz': 2, 'cpsz': 4}

seen_unseen_labels = []
for lbl in test_class_labels:
    if lbl == 0:
        seen_unseen_labels.append(0)
    elif lbl == class_idx[seen_class]:
        seen_unseen_labels.append(1)
    else:
        seen_unseen_labels.append(2)

seen_unseen_labels = np.array(seen_unseen_labels)
print(seen_unseen_labels)
print(len(seen_unseen_labels))
print(len(np.where(np.array(seen_unseen_labels) == 0)[0]))
print(len(np.where(np.array(seen_unseen_labels) == 1)[0]))
print(len(np.where(np.array(seen_unseen_labels) == 2)[0]))
print("seen %: {}".format(len(np.where(np.array(seen_unseen_labels) == 1)[0])/len(seen_unseen_labels)))
print("unseen %: {}".format(len(np.where(np.array(seen_unseen_labels) == 2)[0])/len(seen_unseen_labels)))

test_file = os.path.join(output_dir, "tusz_test_{}.h5".format(seen_class))
with h5py.File(test_file, "w") as f:
    f.create_dataset("X", data=test_data)
    f.create_dataset("y", data=test_labels)
    f.create_dataset("seen_unseen_labels", data=seen_unseen_labels)

[0 0 0 ... 0 0 0]
4659
4343
32
284
seen %: 0.0068684267010088
unseen %: 0.060957286971453105


In [17]:
# Create train dataset file
train_data = []
train_labels = []
with h5py.File(output_file_norm, "r") as f:
    for idx in tqdm(train_idx):
        group_key = "clip" + str(idx)
        train_data.append(f[group_key]['data'][()])
        train_labels.append(int(f[group_key]['anom_label'][()]))

train_data = np.stack(train_data,axis=0)
train_labels = np.array(train_labels)
print(train_data.shape)
print(train_labels.shape)

# Get isolated anomaly data
iso_anom_data = []
iso_anom_labels = []
with h5py.File(output_file_norm, "r") as f:
    for idx in tqdm(iso_anom_idx):
        group_key = "clip" + str(idx)
        iso_anom_data.append(f[group_key]['data'][()])
        iso_anom_labels.append(int(f[group_key]['anom_label'][()]))

iso_anom_data = np.stack(iso_anom_data,axis=0)
iso_anom_labels = np.array(iso_anom_labels)
print(iso_anom_data.shape)
print(iso_anom_labels.shape)
print(iso_anom_labels)

train_file = os.path.join(output_dir, "tusz_train_{}.h5".format(fft_label))
with h5py.File(train_file, "w") as f:
    f.create_dataset("X", data=train_data)
    f.create_dataset("y", data=train_labels)
    f.create_dataset("X_anom", data=iso_anom_data)

100%|██████████| 41928/41928 [00:22<00:00, 1848.54it/s]


(41928, 19, 500)
(41928,)


100%|██████████| 90/90 [00:00<00:00, 742.65it/s]

(90, 19, 500)
(90,)
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]





In [9]:
# Create train dataset file - hard setting
train_data = []
train_labels = []
with h5py.File(output_file_norm, "r") as f:
    for idx in tqdm(train_idx):
        group_key = "clip" + str(idx)
        train_data.append(f[group_key]['data'][()])
        train_labels.append(int(f[group_key]['anom_label'][()]))

train_data = np.stack(train_data,axis=0)
train_labels = np.array(train_labels)
print(train_data.shape)
print(train_labels.shape)

100%|██████████| 41928/41928 [01:01<00:00, 683.88it/s]


(41928, 19, 1000)
(41928,)


In [13]:
seen_class = 'cpsz'
shots_per_class = 30

iso_anom_idx = []
if seen_class == 'fnsz':
    iso_anom_idx.extend(fnsz_train_idx[:shots_per_class])
elif seen_class == 'gnsz':
    iso_anom_idx.extend(gnsz_train_idx[:shots_per_class])
elif seen_class == 'cpsz':
    iso_anom_idx.extend(cpsz_train_idx[:shots_per_class])
else: 
    print('Invalid seen class')

iso_anom_idx = np.array(iso_anom_idx)

print(iso_anom_idx)
print(iso_anom_idx.shape)

# Get isolated anomaly data
iso_anom_data = []
iso_anom_labels = []
with h5py.File(output_file_norm, "r") as f:
    for idx in tqdm(iso_anom_idx):
        group_key = "clip" + str(idx)
        iso_anom_data.append(f[group_key]['data'][()])
        iso_anom_labels.append(int(f[group_key]['anom_label'][()]))

iso_anom_data = np.stack(iso_anom_data,axis=0)
iso_anom_labels = np.array(iso_anom_labels)
print(iso_anom_data.shape)
print(iso_anom_labels.shape)
print(iso_anom_labels)

train_file = os.path.join(output_dir, "tusz_train_{}.h5".format(seen_class))
with h5py.File(train_file, "w") as f:
    f.create_dataset("X", data=train_data)
    f.create_dataset("y", data=train_labels)
    f.create_dataset("X_anom", data=iso_anom_data)

[19878 24308 25289  5125  5290 24451 19626 24519 24305 24988 23957 24599
 24456 23967  5827  5284 24452 20272 24590 24596 24528 19932 19751  5828
 25229 20185 25109 24598 19881  5128]
(30,)


100%|██████████| 30/30 [00:00<00:00, 450.20it/s]

(30, 19, 1000)
(30,)
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]



