# Process FFT Datasets Into Images

Here, we go through the classed functions for formatting the fft datasets into their image representations. We can:

- map data using the xyz coordinates using a PCA method
- map data using the inverse gain matrix transposed to map onto the parcellated 84 regions in MRI xyz space

These methods can be compared to determine the accuracy of encoding different types of structural information into the models.

In [9]:
import numpy as np
import os
import sys

sys.path.append('../')
from dnn.processing.format import formatfft

import matplotlib
import seaborn as sns
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def main(fftdatadir, rawdatadir,  metadatadir, outputdatadir):
    formatter = formatfft.FormatFFT(fftdatadir, rawdatadir, metadatadir, outputdatadir)

    formatter.getdatafiles()
    formatter.formatdata()
    print('finished!')

In [3]:
# patid = 'id001_ac'

# the main data directories that there is raw, meta, and output data
metadatadir = '/Volumes/ADAM LI/pydata/metadata/'
rawdatadir = '/Volumes/ADAM LI/pydata/convertedtng/'
fftdatadir = '/Volumes/ADAM LI/pydata/output_fft/tngcenter/win500_step250/'
# fftdatadir = '/Volumes/ADAM LI/pydata/output_fft/tvbsim/full/'
outputdatadir = '/Volumes/ADAM LI/pydata/output_fft/asimages/regions/'

# confile = os.path.join(metadatadir, patid, 'connectivity.zip')
# gainmatfile = os.path.join(metadatadir, patid, 'gain_inv-square.txt')
# run conversion
# main(fftdatadir, rawdatadir, metadatadir, outputdatadir)

# Check Processed Data By Visualization

Data should be processed into region mappings using the inverted gain matrix obtained from the dipole-dipole model in TVB for each patient.

In [29]:
import matplotlib
import matplotlib.pyplot as plt

def multi_slice_viewer(volume):
    remove_keymap_conflicts({'j', 'k'})
    
    # initialize figure to draw on
    fig, ax = plt.subplots()
    ax.volume = volume
    
    # set index as the first axis
    ax.index = volume.shape[0] // 2
    ax.imshow(volume[ax.index], cmap='jet',origin='lower')

    fig.canvas.mpl_connect('key_press_event', process_key)

def process_key(event):
    fig = event.canvas.figure
    ax = fig.axes[0]
    if event.key == 'j': # go to previous slice
        previous_slice(ax)
    elif event.key == 'k': # go to next slice
        next_slice(ax)
    fig.canvas.draw()

def previous_slice(ax):
    volume = ax.volume
    ax.index = (ax.index - 1) % volume.shape[0]  # wrap around using %
    ax.images[0].set_array(volume[ax.index])

def next_slice(ax):
    volume = ax.volume
    ax.index = (ax.index + 1) % volume.shape[0]
    ax.images[0].set_array(volume[ax.index])
    
def remove_keymap_conflicts(new_keys_set):
    '''
    a helper function to remove keys that we want 
    to use wherever they may appear in this dictionary.
    '''
    for prop in plt.rcParams:
        if prop.startswith('keymap.'):
            keys = plt.rcParams[prop]
            remove_list = set(keys) & new_keys_set
            for key in remove_list:
                keys.remove(key)

In [39]:
datafiles = os.listdir(os.path.join(outputdatadir, 'real'))

for datafile in datafiles:
    print(datafile)

    data = np.load(os.path.join(outputdatadir, 'real', datafile))
    
    image_tensor = data['image_tensor']
    metadata = data['metadata'].item()
    chanlabels = metadata['chanlabels']
    seeg_xyz = metadata['seeg_xyz']
    reg_xyz = metadata['reg_xyz']
    ylabels = metadata['ylabels']
    
    print(data.keys())
    print(metadata.keys())
    print(image_tensor.shape)
    
    plt.figure()
    plt.plot(ylabels)
    plt.show()
    break

id001_ac_absence_fftmodel.npz
['image_tensor', 'metadata']
dict_keys(['chanlabels', 'seeg_xyz', 'reg_xyz', 'ylabels', 'samplerate', 'timepoints'])
(84, 5, 1799)


<IPython.core.display.Javascript object>

In [54]:
image = image_tensor.reshape(12,7,5,1799)
image = image.swapaxes(0,3)
image = image.swapaxes(1,2)
print(image.shape)

plt.figure()
sns.heatmap(image[80,3,...], cmap='jet')
plt.show()

(1799, 5, 7, 12)


<IPython.core.display.Javascript object>

In [32]:
image = image_tensor.reshape(12,7,5,1799)
image = image.swapaxes(0,3)
image = image.swapaxes(1,2)
print(image.shape)
# fig, ax = plt.subplots()
# ax.imshow(image_tensor[0,0,:,:,:])
# fig.canvas.mpl_connect('key_press_event', process_key)
%matplotlib notebook
multi_slice_viewer(image[:,0,:,:])

(1799, 5, 7, 12)


<IPython.core.display.Javascript object>