# Data Preprocessing Steps

1. Running FFT analysis on SEEG data - log seizure times using Z
2. Running FFT projection into 2D tensors, or 3D tensors - log seizure times using preivous step
2D = [numsamps, numfreqs, W, H]
3D = [numsamps, numfreqs, W, H, D]

To project onto 2D or 3D, we will define a mesh grid of the image in brain MRI space. So the electrodes need to have defined xyz coordinates in the MRI (e.g. say freesurfer, or flirt space) when doing reconstructions.

Then each step will be saved with the corresponding data samples as a .npz file.

In [2]:
import sys
sys.path.append('../')
# from fragility.signalprocessing import frequencyanalysis
# from datainterface import readmat

# sys.path.append('/home/adamli/tng_tvb/')
# from tvbsim import visualize

import os
import time

import numpy as np
import pandas as pd
import scipy
import scipy.io


import processing.util as util
import processing.frequencytransform as ft
import peakdetect
import processing.preprocessfft as preprocess

# sys.path.append('/Users/adam2392/Documents/tvb/')
# sys.path.append('/Users/adam2392/Documents/tvb/_tvbdata/')
# sys.path.append('/Users/adam2392/Documents/tvb/_tvblibrary/')
# from tvb.simulator.lab import *
# import tvbsim.util

from natsort import natsorted
import ntpath
from scipy.signal import butter, lfilter

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
def path_leaf(path):
    head, tail = ntpath.split(path)
    return tail or ntpath.basename(head)

def _gettimepoints(numsignals, numwinsamps, numstepsamps):
    # create array of indices of window start times
    timestarts = np.arange(0, numsignals-numwinsamps+1, numstepsamps)
    # create array of indices of window end times
    timeends = np.arange(numwinsamps-1, numsignals, numstepsamps)
    # create the timepoints array for entire data array
    timepoints = np.append(timestarts.reshape(len(timestarts), 1), timeends.reshape(len(timestarts), 1), axis=1)
    return timepoints

def getseiztimes(onsettimes, offsettimes):
    minsize = np.min((len(onsettimes),len(offsettimes)))
    seizonsets = []
    seizoffsets = []
    
    # perform some checks
    if minsize == 0:
        print("no full onset/offset available!")
        return 0
    
    idx = 0
    # to store the ones we are checking rn
    _onset = onsettimes[idx]
    _offset = offsettimes[idx]
    seizonsets.append(_onset)
    
    # start loop after the first onset/offset pair
    for i in range(1,minsize):        
        # to store the previoius values
        _nextonset = onsettimes[i]
        _nextoffset = offsettimes[i]
        
        # check this range and add the offset if it was a full seizure
        # before the next seizure
        if _nextonset < _offset:
            _offset = _nextoffset
        else:
            seizoffsets.append(_offset)
            idx = i
            # to store the ones we are checking rn
            _onset = onsettimes[idx]
            _offset = offsettimes[idx]
            seizonsets.append(_onset)
    if len(seizonsets) != len(seizoffsets):
        seizonsets = seizonsets[0:len(seizoffsets)]
    return seizonsets, seizoffsets
def findonsetoffset(zts, delta=0.2/8):
    maxpeaks, minpeaks = peakdetect.peakdetect(zts, delta=delta)

    # get every other peaks
    onsettime, _ = zip(*minpeaks)
    offsettime, _ = zip(*maxpeaks)

    return onsettime, offsettime
def getonsetsoffsets(zts, ezindices, pzindices, delta=0.2/8):
    # create lambda function for checking the indices
    check = lambda indices: isinstance(indices,np.ndarray) and len(indices)>=1

    onsettimes=np.array([])
    offsettimes=np.array([])
    if check(ezindices):
        for ezindex in ezindices:
            _onsettimes, _offsettimes = findonsetoffset(zts[ezindex, :].squeeze(), 
                                                                    delta=delta)
            onsettimes = np.append(onsettimes, np.asarray(_onsettimes))
            offsettimes = np.append(offsettimes, np.asarray(_offsettimes))

    if check(pzindices):
        for pzindex in pzindices:
            _onsettimes, _offsettimes = findonsetoffset(zts[pzindex, :].squeeze(), 
                                                                    delta=delta)
            onsettimes = np.append(onsettimes, np.asarray(_onsettimes))
            offsettimes = np.append(offsettimes, np.asarray(_offsettimes))

    # first sort onsettimes and offsettimes
    onsettimes.sort()
    offsettimes.sort()
    
    return onsettimes, offsettimes

In [13]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band', analog=False)
    return b, a
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = scipy.signal.filtfilt(b, a, data)
    return y
def butter_highpass(lowcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    b, a = butter(order, low, btype='highpass', analog=False)
    return b, a
def butter_highpass_filter(data, lowcut, fs, order=5):
    b, a = butter_highpass(lowcut, fs, order=order)
    y = scipy.signal.filtfilt(b, a, data)
    return y

# 1. Run FFT Analysis on All SEEG Simulated Data

SEEG Simulated data is assumed to be bandpass filtered when simulations were generated.
Run a check on the onset/offset times.

First, create a dataset from the old files with bandpass filtering ran on them (although these seizures are not as "realistic"), so keep this dataset separate until we loop them into the training procedure. In essence, if there were any that was bandpassed, then this would just essentially be adding another order to the butterworth filter (another convolution/multiplication).

Then create another dataset that is more robust starting from moved_v2 that has bandpass filtering already applied, and then store that dataset and keep appending, as I generate more and more data (_v3, v4, etc....)

*Need to add the index to start from the largest in the list in the directory so far, else = 0*

In [51]:
datadir = '/Volumes/ADAM LI/pydata/tvbforwardsim/varydistance/'
metadatadir = '/Volumes/ADAM LI/pydata/metadata/'    
traindir = '/Volumes/ADAM LI/pydata/traindata/fft/'

# Get ALL datafiles from all downstream files
datafiles = []
for root, dirs, files in os.walk(datadir):
    if 'allregions_sim' not in root:
        for file in files:
            if file.endswith(".npz"):
                 datafiles.append(os.path.join(root, file))
print(len(datafiles))
# print(datafiles[50:])

25


In [27]:
# establish frequency bands
freqbands = {
        'dalpha':[0,15],
        'beta':[15,30],
        'gamma':[30,90],
        'high':[90,200],
    }
postprocessfft = preprocess.PreProcess(freqbands=freqbands)

In [28]:
# FFT Parameters
fs = 1000
winsize = 1000 # winsize in milliseconds
stepsize = 500 # stepsize in milliseconds
typetransform = 'fourier'
mtbandwidth = 4
mtfreqs = []

mtaper = ft.MultiTaperFFT(winsize, stepsize, fs, mtbandwidth, mtfreqs)

Default method of tapering is eigen


In [52]:
for idx, datafile in enumerate(datafiles):
    filename = path_leaf(datafile)
    
    data = np.load(datafile, encoding='bytes')
    metadata = data['metadata'].item()
    zts = data['zts']
    seegts = data['seegts']
    
    # Extract location coordinates
    locations = metadata[b'seeg_xyz']
    try:
        patient_id = metadata[b'patient'].decode("utf-8") 
    except:
        patient = '_'.join(filename.split('_')[0:2])

    ezindices = metadata[b'ezindices']
    pzindices = metadata[b'pzindices']
    x0 = metadata[b'x0ez']
    seeg_contacts = metadata[b'seeg_contacts']

    # get onset/offset times correctly
    onsettimes, offsettimes = getonsetsoffsets(zts, np.array(ezindices).ravel(), np.array(pzindices).ravel())

    # get the actual seizure times and offsets
    seizonsets, seizoffsets = getseiztimes(onsettimes, offsettimes)
    seizonsets = np.asarray(seizonsets)
    seizoffsets = np.asarray(seizoffsets)
    seiztimes = np.concatenate((seizonsets[:,np.newaxis], seizoffsets[:,np.newaxis]), axis=1)
    
    # filter the data in case it was not done already
    lowcut = 0.1
    highcut = 499.
    fs = 1000.
    newseegts = butter_bandpass_filter(seegts, lowcut, highcut, fs, order=4)

    mtaper.loadrawdata(newseegts)
    power, freqs, timepoints, _ = mtaper.mtwelch()
    power = postprocessfft.binFrequencyValues(power, freqs)
    
    filename = os.path.join(traindir, 
                            patient_id + '_nez' + str(len(ezindices)) + '_npz' + str(len(pzindices)) + '_'+str(idx) +'_varydist.npz')
    
    np.savez_compressed(filename, 
                        power=power, 
                        timepoints=timepoints,
                        seiztimes=seiztimes,
                        locs=locations,
                        seeg_contacts=seeg_contacts,
                        x0ez=x0)
    print(datafile)
    print(power.shape)
    print(freqs.shape)
    print(timepoints.shape)
    print(locations.shape)
    print(seeg_contacts.shape)
    print('\n\n')

Loaded raw data in MultiTaperFFT!


  power_binned[:,idx,:] = np.mean(power[:,indices[0]:indices[1]+1,:], axis=1) #[np.newaxis,:,:]


/Volumes/ADAM LI/pydata/tvbforwardsim/varydistance/nez1_npz1_dist7/id001_ac_sim_nez1_npz1.npz
(70, 4, 439)
(501,)
(439, 2)
(70, 3)
(70,)



Loaded raw data in MultiTaperFFT!
/Volumes/ADAM LI/pydata/tvbforwardsim/varydistance/nez1_npz1_dist7/id002_cj_sim_nez1_npz1.npz
(162, 4, 439)
(501,)
(439, 2)
(162, 3)
(162,)



Loaded raw data in MultiTaperFFT!
/Volumes/ADAM LI/pydata/tvbforwardsim/varydistance/nez1_npz1_dist7/id014_rb_sim_nez1_npz1.npz
(165, 4, 439)
(501,)
(439, 2)
(165, 3)
(165,)



Loaded raw data in MultiTaperFFT!
/Volumes/ADAM LI/pydata/tvbforwardsim/varydistance/nez1_npz1_dist1/id001_ac_sim_nez1_npz1.npz
(70, 4, 235)
(501,)
(235, 2)
(70, 3)
(70,)



Loaded raw data in MultiTaperFFT!
/Volumes/ADAM LI/pydata/tvbforwardsim/varydistance/nez1_npz1_dist1/id002_cj_sim_nez1_npz1.npz
(162, 4, 229)
(501,)
(229, 2)
(162, 3)
(162,)



Loaded raw data in MultiTaperFFT!
/Volumes/ADAM LI/pydata/tvbforwardsim/varydistance/nez1_npz1_dist1/id014_rb_sim_nez1_npz1.npz
(165, 4, 249)
(501,)
(249, 

# 2. Run Image Transformation in 2D onto a Mesh grid

In [66]:
'''
This code segment should only be ran once, to extract metadata for each patient, 
so it doesn't need to be done in the loop for FFT compression.

EXTRACTS LOCATION DATA FOR EACH PATIENT
'''

patients = []
# get all patients in the dataset
for datafile in datafiles:
    filename = path_leaf(datafile)
    patient = '_'.join(filename.split('_')[0:2])
    if patient not in patients:
        patients.append(patient)
print(patients)

patient_dict = {}

# get all metadata for each patient
for patient in patients:
    project_dir = os.path.join(metadatadir, patient)
    confile = os.path.join(project_dir, "connectivity.zip")

    # get the regions, and region_centers from connectivity
    reader = util.ZipReader(confile)
    region_centers = reader.read_array_from_file("centres", use_cols=(1, 2, 3))
    regions = reader.read_array_from_file("centres", dtype=np.str, use_cols=(0,))

['id001_ac', 'id002_cj', 'id014_rb']


In [71]:
datadir = '/Volumes/ADAM LI/pydata/traindata/fft/after_v2/'
metadatadir = '/Volumes/ADAM LI/pydata/metadata/'    

# Get ALL datafiles from all downstream files
datafiles = []
for root, dirs, files in os.walk(datadir):
    for file in files:
        if file.endswith(".npz"):
             datafiles.append(os.path.join(root, file))
# print(datafiles)
print(len(datafiles))

43


In [72]:
from sklearn.decomposition import PCA
# define the data handler 
datahandler = util.DataHandler()
pca = PCA(n_components=2)

AZIM=0
trainimagedir = '/Volumes/ADAM LI/pydata/traindata/image_2d/after_v2/'
metadir = '/Volumes/ADAM LI/pydata/traindata/image_2d/after_v2/meta/'
if not os.path.exists(trainimagedir):
    os.makedirs(trainimagedir)
if not os.path.exists(metadir):
    os.makedirs(metadir)
# loop through each data file and get grid
for idx,datafile in enumerate(datafiles):
    # load data
    data = np.load(datafile, encoding='bytes')
    power = data['power']
    print(power.shape)
    print(data.keys())
    
    # load xyz data for this particular dataset
    xyz_data = data['locs']
    seeg_contacts = data['seeg_contacts']
    x0ez = data['x0ez']
    seiztimes = data['seiztimes']
    
    # project xyz data
    if AZIM==1:
        print("using azim projection to grid image")
        new_locs = []
        for ichan in range(0,xyz_data.shape[0]):
            new_locs.append(datahandler.azim_proj(xyz_data[ichan,:]))
        new_locs = np.asarray(new_locs)
    if AZIM==0:
        print("using pca to grid image")
        new_locs = pca.fit_transform(xyz_data)
    
    ylabels = datahandler.computelabels(seiztimes, timepoints)
    
    # Tensor of size [samples, freqbands, W, H] containing generated images.
    image_tensor = datahandler.gen_images(new_locs, power, 
                                n_gridpoints=32, normalize=True, augment=True, 
                                pca=False, std_mult=0.1, edgeless=False)

    # set saving file paths for image and corresponding meta data
    filename = path_leaf(datafile)
    imagefilename = os.path.join(trainimagedir, filename)
    metafilename = os.path.join(metadir, filename)
    
    # instantiate metadata hash table
    metadata = dict()
    metadata['x0ez'] = x0ez
    metadata['seeg_contacts'] = seeg_contacts
    metadata['new_locs'] = new_locs
    metadata['ylabels'] = ylabels
    
    # save image and meta data
    np.save(imagefilename, image_tensor)
    np.savez_compressed(metafilename, metadata=metadata)
    
    print(new_locs.shape)
#     break

(70, 4, 359)
['power', 'timepoints', 'seiztimes', 'locs', 'seeg_contacts', 'x0ez']
using pca to grid image
(359, 70)
Interpolating 359/359nterpolating 77/359Interpolating 124/359Interpolating 175/359Interpolating 224/359Interpolating 271/359Interpolating 317/359(70, 2)
(70, 4, 359)
['power', 'timepoints', 'seiztimes', 'locs', 'seeg_contacts', 'x0ez']
using pca to grid image
(359, 70)
(70, 2)lating 359/359terpolating 2/359Interpolating 53/359Interpolating 105/359Interpolating 156/359Interpolating 205/359Interpolating 255/359Interpolating 305/359Interpolating 354/359
(70, 4, 359)
['power', 'timepoints', 'seiztimes', 'locs', 'seeg_contacts', 'x0ez']
using pca to grid image
(359, 70)
(70, 2)lating 359/359nterpolating 45/359Interpolating 95/359Interpolating 146/359Interpolating 198/359Interpolating 247/359Interpolating 297/359Interpolating 348/359
(70, 4, 359)
['power', 'timepoints', 'seiztimes', 'locs', 'seeg_contacts', 'x0ez']
using pca to grid image
(359, 70)
(70, 2)lating 359/359nterpol

Interpolating 439/439nterpolating 27/439Interpolating 92/439Interpolating 125/439Interpolating 158/439Interpolating 191/439Interpolating 224/439Interpolating 257/439Interpolating 291/439Interpolating 324/439Interpolating 357/439Interpolating 390/439Interpolating 422/439(162, 2)
(70, 4, 249)
['power', 'timepoints', 'seiztimes', 'locs', 'seeg_contacts', 'x0ez']
using pca to grid image
no seizure times in <computelabels>!
(249, 70)
(70, 2)lating 249/249nterpolating 16/249Interpolating 67/249Interpolating 118/249Interpolating 169/249Interpolating 220/249
(162, 4, 239)
['power', 'timepoints', 'seiztimes', 'locs', 'seeg_contacts', 'x0ez']
using pca to grid image
no seizure times in <computelabels>!
(239, 162)
Interpolating 239/239nterpolating 26/239Interpolating 55/239Interpolating 85/239Interpolating 118/239Interpolating 150/239Interpolating 182/239Interpolating 214/239(162, 2)
(162, 4, 439)
['power', 'timepoints', 'seiztimes', 'locs', 'seeg_contacts', 'x0ez']
using pca to grid image
(439, 

# 3. Process FFT Analysis into a 3D Mesh Grid

Here, I need to just redefine how to generate images with the data handler class. I need to reate a mesh grid on the 3D space now instead of just on the 2D space.

I will name this new function gen_images3d

In [73]:
datadir = '/Volumes/ADAM LI/pydata/traindata/fft/before_v2/'
metadatadir = '/Volumes/ADAM LI/pydata/metadata/'    

# Get ALL datafiles from all downstream files
datafiles = []
for root, dirs, files in os.walk(datadir):
    for file in files:
        if file.endswith(".npz"):
             datafiles.append(os.path.join(root, file))
# print(datafiles)
print(len(datafiles))

27


In [74]:
# define the data handler 
datahandler = util.DataHandler()

trainimagedir = '/Volumes/ADAM LI/pydata/traindata/image_3d/before_v2/'
metadir = '/Volumes/ADAM LI/pydata/traindata/image_3d/before_v2/meta/'
if not os.path.exists(trainimagedir):
    os.makedirs(trainimagedir)
if not os.path.exists(metadir):
    os.makedirs(metadir)
    
# loop through each data file and get grid
for idx,datafile in enumerate(datafiles):
    # load data
    data = np.load(datafile, encoding='bytes')
    power = data['power']
    print(power.shape)
    print(data.keys())
    
    # load xyz data for this particular dataset
    xyz_data = data['locs']
    seeg_contacts = data['seeg_contacts']
    x0ez = data['x0ez']
    seiztimes = data['seiztimes']

    ylabels = datahandler.computelabels(seiztimes, timepoints)
    
    # Tensor of size [samples, freqbands, W, H] containing generated images.
    image_tensor = datahandler.gen_images3d(xyz_data, power, 
                                n_gridpoints=32, normalize=True, augment=False, 
                                std_mult=0.1, edgeless=False)

    # set saving file paths for image and corresponding meta data
    filename = path_leaf(datafile)
    imagefilename = os.path.join(trainimagedir, filename)
    metafilename = os.path.join(metadir, filename)
    
    # instantiate metadata hash table
    metadata = dict()
    metadata['x0ez'] = x0ez
    metadata['seeg_contacts'] = seeg_contacts
    metadata['seeg_xyz'] = xyz_data
    metadata['ylabels'] = ylabels
    
    # save image and meta data
    np.save(imagefilename, image_tensor)
    np.savez_compressed(metafilename, metadata=metadata)

(70, 5, 479)
['power', 'timepoints', 'seiztimes', 'locs', 'seeg_contacts', 'x0ez']
(479, 70)
Interpolating 479/479nterpolating 19/479Interpolating 29/479Interpolating 39/479Interpolating 47/479Interpolating 55/479Interpolating 65/479Interpolating 75/479Interpolating 85/479Interpolating 95/479Interpolating 105/479Interpolating 114/479Interpolating 124/479Interpolating 134/479Interpolating 144/479Interpolating 154/479Interpolating 164/479Interpolating 174/479Interpolating 184/479Interpolating 194/479Interpolating 205/479Interpolating 216/479Interpolating 226/479Interpolating 236/479Interpolating 246/479Interpolating 256/479Interpolating 266/479Interpolating 276/479Interpolating 286/479Interpolating 296/479Interpolating 306/479Interpolating 316/479Interpolating 326/479Interpolating 336/479Interpolating 346/479Interpolating 356/479Interpolating 366/479Interpolating 376/479Interpolating 386/479Interpolating 396/479Interpolating 406/479Interpolating 416/479Interpolating 426/479Interpolating 

Interpolating 389/389terpolating 10/389Interpolating 21/389Interpolating 31/389Interpolating 42/389Interpolating 53/389Interpolating 64/389Interpolating 75/389Interpolating 86/389Interpolating 96/389Interpolating 106/389Interpolating 116/389Interpolating 127/389Interpolating 137/389Interpolating 148/389Interpolating 159/389Interpolating 170/389Interpolating 181/389Interpolating 192/389Interpolating 203/389Interpolating 214/389Interpolating 225/389Interpolating 247/389Interpolating 258/389Interpolating 269/389Interpolating 280/389Interpolating 291/389Interpolating 302/389Interpolating 313/389Interpolating 324/389Interpolating 335/389Interpolating 346/389Interpolating 357/389Interpolating 368/389Interpolating 379/389(162, 5, 479)
['power', 'timepoints', 'seiztimes', 'locs', 'seeg_contacts', 'x0ez']
(479, 162)
Interpolating 478/479terpolating 6/479Interpolating 12/479Interpolating 18/479Interpolating 24/479Interpolating 30/479Interpolating 36/479Interpolating 42/479Interpolating 48/479Int

Interpolating 479/479terpolating 3/479Interpolating 9/479Interpolating 15/479Interpolating 21/479Interpolating 27/479Interpolating 33/479Interpolating 39/479Interpolating 45/479Interpolating 51/479Interpolating 57/479Interpolating 63/479Interpolating 69/479Interpolating 75/479Interpolating 81/479Interpolating 87/479Interpolating 93/479Interpolating 99/479Interpolating 105/479Interpolating 111/479Interpolating 117/479Interpolating 123/479Interpolating 129/479Interpolating 135/479Interpolating 141/479Interpolating 147/479Interpolating 154/479Interpolating 160/479Interpolating 166/479Interpolating 172/479Interpolating 178/479Interpolating 184/479Interpolating 190/479Interpolating 196/479Interpolating 202/479Interpolating 214/479Interpolating 220/479Interpolating 226/479Interpolating 232/479Interpolating 238/479Interpolating 244/479Interpolating 250/479Interpolating 256/479Interpolating 262/479Interpolating 268/479Interpolating 274/479Interpolating 280/479Interpolating 286/479Interpolating

Interpolating 479/479terpolating 3/479Interpolating 15/479Interpolating 21/479Interpolating 26/479Interpolating 32/479Interpolating 38/479Interpolating 44/479Interpolating 50/479Interpolating 56/479Interpolating 62/479Interpolating 68/479Interpolating 74/479Interpolating 80/479Interpolating 86/479Interpolating 92/479Interpolating 98/479Interpolating 104/479Interpolating 110/479Interpolating 116/479Interpolating 122/479Interpolating 128/479Interpolating 134/479Interpolating 140/479Interpolating 146/479Interpolating 152/479Interpolating 158/479Interpolating 164/479Interpolating 170/479Interpolating 176/479Interpolating 182/479Interpolating 188/479Interpolating 194/479Interpolating 200/479Interpolating 206/479Interpolating 212/479Interpolating 218/479Interpolating 224/479Interpolating 230/479Interpolating 236/479Interpolating 242/479Interpolating 248/479Interpolating 254/479Interpolating 260/479Interpolating 266/479Interpolating 272/479Interpolating 278/479Interpolating 284/479Interpolati

Interpolating 479/479terpolating 4/479Interpolating 11/479Interpolating 18/479Interpolating 25/479Interpolating 32/479Interpolating 39/479Interpolating 46/479Interpolating 53/479Interpolating 60/479Interpolating 67/479Interpolating 74/479Interpolating 81/479Interpolating 88/479Interpolating 95/479Interpolating 102/479Interpolating 109/479Interpolating 116/479Interpolating 123/479Interpolating 130/479Interpolating 137/479Interpolating 144/479Interpolating 151/479Interpolating 158/479Interpolating 165/479Interpolating 172/479Interpolating 179/479Interpolating 186/479Interpolating 193/479Interpolating 200/479Interpolating 207/479Interpolating 214/479Interpolating 221/479Interpolating 228/479Interpolating 235/479Interpolating 242/479Interpolating 249/479Interpolating 256/479Interpolating 263/479Interpolating 270/479Interpolating 277/479Interpolating 284/479Interpolating 291/479Interpolating 298/479Interpolating 305/479Interpolating 312/479Interpolating 319/479Interpolating 326/479Interpola

In [58]:
import matplotlib
import matplotlib.pyplot as plt

def multi_slice_viewer(volume):
    remove_keymap_conflicts({'j', 'k'})
    
    # initialize figure to draw on
    fig, ax = plt.subplots()
    ax.volume = volume
    
    # set index as the first axis
    ax.index = volume.shape[2] // 2
    ax.imshow(volume[ax.index])
    fig.canvas.mpl_connect('key_press_event', process_key)

def process_key(event):
    fig = event.canvas.figure
    ax = fig.axes[0]
    if event.key == 'j': # go to previous slice
        previous_slice(ax)
    elif event.key == 'k': # go to next slice
        next_slice(ax)
    fig.canvas.draw()

def previous_slice(ax):
    volume = ax.volume
    ax.index = (ax.index - 1) % volume.shape[0]  # wrap around using %
    ax.images[0].set_array(volume[ax.index])

def next_slice(ax):
    volume = ax.volume
    ax.index = (ax.index + 1) % volume.shape[0]
    ax.images[0].set_array(volume[ax.index])
    
def remove_keymap_conflicts(new_keys_set):
    '''
    a helper function to remove keys that we want 
    to use wherever they may appear in this dictionary.
    '''
    for prop in plt.rcParams:
        if prop.startswith('keymap.'):
            keys = plt.rcParams[prop]
            remove_list = set(keys) & new_keys_set
            for key in remove_list:
                keys.remove(key)
    

In [59]:
# fig, ax = plt.subplots()
# ax.imshow(image_tensor[0,0,:,:,:])
# fig.canvas.mpl_connect('key_press_event', process_key)
%matplotlib notebook
multi_slice_viewer(image_tensor[0,0,:,:,:])

<IPython.core.display.Javascript object>

# 4. Create YLabels Corresponding To The Image Dataset

Handle creating from scratch and if appending onto the image dataset with new images.

In [None]:
trainimagedir = './traindata/images/'
metadir = './traindata/meta/'
if not os.path.exists(trainimagedir):
    os.makedirs(trainimagedir)
if not os.path.exists(metadir):
    os.makedirs(metadir)
trainlabeldir = './traindata/labels/'
if not os.path.exists(trainlabeldir):
    os.makedirs(trainlabeldir)

images = []
# loop through each data file and get grid
for idx,datafile in enumerate(datafiles):
    filename = path_leaf(datafile) + '.npy'
    metafile = path_leaf(datafile)
    imagefile = os.path.join(trainimagedir, filename)
    metafile = os.path.join(metadir, metafile)
    labelsfile = os.path.join(trainlabeldir, filename)
    
    metadata = np.load(metafile)
    metadata = metadata['metadata'].item()
#     print(metadata.keys())
    image = np.load(imagefile)
    label = np.load(labelsfile)
    
    if idx == 0:
        images = image
        ylabels = label
    else:
        images = np.concatenate((images, image), axis=0)
        ylabels = np.concatenate((ylabels, label), axis=0)
        
#     print(imagefile)
#     print(labelsfile)
#     print(image.shape)
#     print(label.shape)
    
#     break
print(images.shape)
print(ylabels.shape)