In [1]:
import numpy as np
import xarray as xr

In [2]:
## We'll create two sessions to toy around with:
## Let's assume we have extracted the neural data, in 
## format (cells x trials x frames), with slighty differing 
## size for each session. 
session_1 = np.random.randn(500, 230, 60)  # (cells x trials x frames)
time_array_1 = np.linspace(-2, 4, 60)  # a time array (relative to stimulus onset)
session_2 = np.random.randn(430, 250, 62)
time_array_2 = np.linspace(-2, 4, 62)

def create_data_array(sess, time_array):
    '''Function to create data array from session'''
    assert type(sess) == np.ndarray
    assert len(time_array) == sess.shape[2]
    
    da = xr.DataArray(sess, 
                      dims=('cell', 'trial', 'time'),  # names of dimensions of the session np.ndarray
                      coords={'cell': np.arange(sess.shape[0]),  # Here we can define the dimensions
                              'trial': np.arange(sess.shape[1]),
                              'time': time_array})
    da.attrs['frequency'] = np.mean(np.diff(time_array))  # perhaps some extra info
    return da

In [3]:
## Let's create the data arrays
da_1 = create_data_array(sess=session_1, time_array=time_array_1)
da_2 = create_data_array(sess=session_2, time_array=time_array_2)

In [4]:
## DataArray now contain meta data such as the time axis:
print(da_1.time)
## which can be different per data array:
assert da_1.frequency > da_2.frequency

<xarray.DataArray 'time' (time: 60)>
array([-2.      , -1.898305, -1.79661 , -1.694915, -1.59322 , -1.491525,
       -1.389831, -1.288136, -1.186441, -1.084746, -0.983051, -0.881356,
       -0.779661, -0.677966, -0.576271, -0.474576, -0.372881, -0.271186,
       -0.169492, -0.067797,  0.033898,  0.135593,  0.237288,  0.338983,
        0.440678,  0.542373,  0.644068,  0.745763,  0.847458,  0.949153,
        1.050847,  1.152542,  1.254237,  1.355932,  1.457627,  1.559322,
        1.661017,  1.762712,  1.864407,  1.966102,  2.067797,  2.169492,
        2.271186,  2.372881,  2.474576,  2.576271,  2.677966,  2.779661,
        2.881356,  2.983051,  3.084746,  3.186441,  3.288136,  3.389831,
        3.491525,  3.59322 ,  3.694915,  3.79661 ,  3.898305,  4.      ])
Coordinates:
  * time     (time) float64 -2.0 -1.898 -1.797 -1.695 ... 3.695 3.797 3.898 4.0


In [5]:
## The cool thing comes when we create a xr.Dataset
## A Dataset contains (at least 1) xr.DataArray, but can 
## contain additional coordinates, with extra indexed info such as 
## s1_bool for example

def create_data_set(da):
    assert type(da) == xr.DataArray
    
    # Let's create some (mock) Variables first:
    s1_bool = np.random.randint(low=0, high=2, size=len(da.cell), dtype='bool')  # random bools
    s2_bool = np.logical_not(s1_bool)
    frames = np.arange(len(da.time))  # array from 0:len(time)
    trial_type = np.zeros(len(da.trial), dtype='object')  # dtype object to allow strings
    trial_type[:int(len(da_1.trial) / 3)] = 'w'
    trial_type[int(len(da_1.trial) / 3):int(len(da_1.trial) * 3 / 4)] = 'rob'  # sorry got a bit carried away
    trial_type[int(len(da_1.trial) * 3 / 4):] = 'lees'  # noice
    
    ds = xr.Dataset({'data': da},  # first give the dataArray. You can add multiple in this dictionary form 
                    coords={'s1_bool': ('cell', s1_bool),  # here we add a new coordinate, and we bind it to the dimension 'cell'
                            's2_bool': ('cell', s2_bool),
                            'frame': ('time', frames),
                            'trial_type': ('trial', trial_type)})
    ds.attrs['frequency'] = da.frequency  # we have to redefine this I'm afraid (might as well have done it here in the first place)
    return ds

In [6]:
## Ooh yes
ds_1 = create_data_set(da_1)
ds_2 = create_data_set(da_2)

## Beautiful! Everything is in there:
print(ds_1)
## (see how only cell, trial and time are denoted with an * meaning they're the primary coordinates)

<xarray.Dataset>
Dimensions:     (cell: 500, time: 60, trial: 230)
Coordinates:
  * cell        (cell) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
  * trial       (trial) int64 0 1 2 3 4 5 6 7 ... 223 224 225 226 227 228 229
  * time        (time) float64 -2.0 -1.898 -1.797 -1.695 ... 3.797 3.898 4.0
    s1_bool     (cell) bool True False True False False ... True False True True
    s2_bool     (cell) bool False True False True ... False True False False
    frame       (time) int64 0 1 2 3 4 5 6 7 8 9 ... 51 52 53 54 55 56 57 58 59
    trial_type  (trial) object 'w' 'w' 'w' 'w' ... 'lees' 'lees' 'lees' 'lees'
Data variables:
    data        (cell, trial, time) float64 1.604 -1.035 1.016 ... 1.193 1.641
Attributes:
    frequency:  0.1016949152542373


# And now come the perks:

In [7]:
## index primary coordinates by name
print(ds_1.sel(trial=np.array([42, 43, 44])))

## Note that regular numpy indexing is only available for DataArrays,
## not Datasets:
# da_1[:, np.array([42, 43, 44]), :]

<xarray.Dataset>
Dimensions:     (cell: 500, time: 60, trial: 3)
Coordinates:
  * cell        (cell) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
  * trial       (trial) int64 42 43 44
  * time        (time) float64 -2.0 -1.898 -1.797 -1.695 ... 3.797 3.898 4.0
    s1_bool     (cell) bool True False True False False ... True False True True
    s2_bool     (cell) bool False True False True ... False True False False
    frame       (time) int64 0 1 2 3 4 5 6 7 8 9 ... 51 52 53 54 55 56 57 58 59
    trial_type  (trial) object 'w' 'w' 'w'
Data variables:
    data        (cell, trial, time) float64 0.4524 -1.434 ... -0.3506 -0.1327
Attributes:
    frequency:  0.1016949152542373


In [8]:
## Find stuff by labels:

## Note the 'Dimensions: ........' line at the top! It tells you
## how many elements have been returned

ds_1.where(ds_1.s1_bool == True, drop=True)  # drop=False retuns the entire dataarray, but fills the s1_bool==False cells with nans

<xarray.Dataset>
Dimensions:     (cell: 233, time: 60, trial: 230)
Coordinates:
  * cell        (cell) int64 0 2 5 6 7 8 9 11 ... 487 488 489 494 496 498 499
  * trial       (trial) int64 0 1 2 3 4 5 6 7 ... 223 224 225 226 227 228 229
  * time        (time) float64 -2.0 -1.898 -1.797 -1.695 ... 3.797 3.898 4.0
    s1_bool     (cell) bool True True True True True ... True True True True
    s2_bool     (cell) bool False False False False ... False False False False
    frame       (time) int64 0 1 2 3 4 5 6 7 8 9 ... 51 52 53 54 55 56 57 58 59
    trial_type  (trial) object 'w' 'w' 'w' 'w' ... 'lees' 'lees' 'lees' 'lees'
Data variables:
    data        (cell, trial, time) float64 1.604 -1.035 1.016 ... 1.193 1.641
Attributes:
    frequency:  0.1016949152542373

In [12]:
## Use np.logical_and() or np.logical_or() to index multiple things:
ds_1.where(np.logical_and(ds_1.s1_bool == True, 
                          ds_1.time > 0), drop=True)

## (again, note the Dimensions: ...  line. Both cells and times have been sliced)

<xarray.Dataset>
Dimensions:     (cell: 233, time: 40, trial: 230)
Coordinates:
  * cell        (cell) int64 0 2 5 6 7 8 9 11 ... 487 488 489 494 496 498 499
  * trial       (trial) int64 0 1 2 3 4 5 6 7 ... 223 224 225 226 227 228 229
  * time        (time) float64 0.0339 0.1356 0.2373 0.339 ... 3.797 3.898 4.0
    s1_bool     (cell) bool True True True True True ... True True True True
    s2_bool     (cell) bool False False False False ... False False False False
    frame       (time) int64 20 21 22 23 24 25 26 27 ... 52 53 54 55 56 57 58 59
    trial_type  (trial) object 'w' 'w' 'w' 'w' ... 'lees' 'lees' 'lees' 'lees'
Data variables:
    data        (cell, trial, time) float64 0.6521 -0.7643 ... 1.193 1.641
Attributes:
    frequency:  0.1016949152542373

In [14]:
# I could go on all night! 
ds_1.where(ds_1.trial_type != 'w', drop=True).trial_type

<xarray.DataArray 'trial_type' (trial: 154)>
array(['rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob',
       'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob',
       'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob',
       'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob',
       'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob',
       'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob',
       'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob',
       'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob',
       'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob',
       'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'rob',
       'rob', 'rob', 'rob', 'rob', 'rob', 'rob', 'lees', 'lees', 'lees',
       'lees', 'lees', 'lees', 'lees', 'lees', 'lees', 'lees', 'lees',
       'lees', 'lees', 'lees', 'lees', 'lees', 'lees', 'lees', 'lees',
       'lees', 'lees', 'lees', 'lees', '

In [11]:
## And there is so much more good stuff (that I don't know about)!!
## Reading the docs will give you some ideas, e.g. the indexing page:

# http://xarray.pydata.org/en/stable/indexing.html