Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,47 +1,57 @@
import datetime
from pynwb import NWBFile, NWBHDF5IO
import pandas as pd
import allensdk.brain_observatory.nwb as nwb
import numpy as np
import SimpleITK as sitk
import pytz
import uuid
from pandas.util.testing import assert_frame_equal
import os
import math
import uuid
import warnings

import numpy as np
import xarray as xr
import pandas as pd
import pytz
import SimpleITK as sitk
import xarray as xr

from pandas.util.testing import assert_frame_equal
from pynwb import NWBHDF5IO, NWBFile

from allensdk.core.lazy_property import LazyProperty
import allensdk.brain_observatory.nwb as nwb
from allensdk.brain_observatory.behavior.metadata_processing import (
get_expt_description
)
from allensdk.brain_observatory.behavior.behavior_ophys_api import (
BehaviorOphysApiBase
)
from allensdk.brain_observatory.behavior.schemas import (
BehaviorTaskParametersSchema, OphysBehaviorMetadataSchema)
from allensdk.brain_observatory.behavior.trials_processing import (
TRIAL_COLUMN_DESCRIPTION_DICT
)
from allensdk.brain_observatory.nwb.metadata import load_pynwb_extension
from allensdk.brain_observatory.nwb.nwb_api import NwbApi
from allensdk.brain_observatory.nwb.nwb_utils import set_omitted_stop_time
from allensdk.brain_observatory.behavior.trials_processing import TRIAL_COLUMN_DESCRIPTION_DICT
from allensdk.brain_observatory.behavior.schemas import OphysBehaviorMetadataSchema, BehaviorTaskParametersSchema
from allensdk.brain_observatory.nwb.metadata import load_pynwb_extension
from allensdk.brain_observatory.behavior.behavior_ophys_api import BehaviorOphysApiBase

from allensdk.core.lazy_property import LazyProperty

load_pynwb_extension(OphysBehaviorMetadataSchema, 'ndx-aibs-behavior-ophys')
load_pynwb_extension(BehaviorTaskParametersSchema, 'ndx-aibs-behavior-ophys')


class BehaviorOphysNwbApi(NwbApi, BehaviorOphysApiBase):


def __init__(self, *args, **kwargs):
self.filter_invalid_rois = kwargs.pop("filter_invalid_rois", False)
super(BehaviorOphysNwbApi, self).__init__(*args, **kwargs)


def save(self, session_object):

session_type = str(session_object.metadata['session_type'])

nwbfile = NWBFile(
session_description=str(session_object.metadata['session_type']),
session_description=session_type,
identifier=str(session_object.ophys_experiment_id),
session_start_time=session_object.metadata['experiment_datetime'],
file_create_date=pytz.utc.localize(datetime.datetime.now())
file_create_date=pytz.utc.localize(datetime.datetime.now()),
institution="Allen Institute for Brain Science",
keywords=["2-photon", "calcium imaging", "visual cortex",
"behavior", "task"],
experiment_description=get_expt_description(session_type)
)

# Add stimulus_timestamps to NWB in-memory object:
Expand Down Expand Up @@ -93,7 +103,9 @@ def save(self, session_object):
nwb.add_task_parameters(nwbfile, session_object.task_parameters)

# Add roi metrics to NWB in-memory object:
nwb.add_cell_specimen_table(nwbfile, session_object.cell_specimen_table)
nwb.add_cell_specimen_table(nwbfile,
session_object.cell_specimen_table,
session_object.metadata)

# Add dff to NWB in-memory object:
nwb.add_dff_traces(nwbfile, session_object.dff_traces, session_object.ophys_timestamps)
Expand Down Expand Up @@ -122,19 +134,19 @@ def get_running_data_df(self, **kwargs):
running_data_df[key] = self.nwbfile.get_acquisition(key).data

for key in ['dx']:
if ('running' in self.nwbfile.modules) and (key in self.nwbfile.modules['running'].fields['data_interfaces']):
running_data_df[key] = self.nwbfile.modules['running'].get_data_interface(key).data
if ('running' in self.nwbfile.processing) and (key in self.nwbfile.processing['running'].fields['data_interfaces']):
running_data_df[key] = self.nwbfile.processing['running'].get_data_interface(key).data

return running_data_df[['speed', 'dx', 'v_sig', 'v_in']]

def get_stimulus_templates(self, **kwargs):
return {key: val.data[:] for key, val in self.nwbfile.stimulus_template.items()}

def get_ophys_timestamps(self) -> np.ndarray:
return self.nwbfile.modules['two_photon_imaging'].get_data_interface('dff').roi_response_series['traces'].timestamps[:]
return self.nwbfile.processing['ophys'].get_data_interface('dff').roi_response_series['traces'].timestamps[:]

def get_stimulus_timestamps(self) -> np.ndarray:
return self.nwbfile.modules['stimulus'].get_data_interface('timestamps').timestamps[:]
return self.nwbfile.processing['stimulus'].get_data_interface('timestamps').timestamps[:]

def get_trials(self) -> pd.DataFrame:
trials = self.nwbfile.trials.to_dataframe()
Expand All @@ -144,35 +156,36 @@ def get_trials(self) -> pd.DataFrame:
return trials

def get_licks(self) -> np.ndarray:
if 'licking' in self.nwbfile.modules:
return pd.DataFrame({'time': self.nwbfile.modules['licking'].get_data_interface('licks')['timestamps'].timestamps[:]})
if 'licking' in self.nwbfile.processing:
return pd.DataFrame({'time': self.nwbfile.processing['licking'].get_data_interface('licks')['timestamps'].timestamps[:]})
else:
return pd.DataFrame({'time': []})

def get_rewards(self) -> np.ndarray:
if 'rewards' in self.nwbfile.modules:
time = self.nwbfile.modules['rewards'].get_data_interface('autorewarded').timestamps[:]
autorewarded = self.nwbfile.modules['rewards'].get_data_interface('autorewarded').data[:]
volume = self.nwbfile.modules['rewards'].get_data_interface('volume').data[:]
if 'rewards' in self.nwbfile.processing:
time = self.nwbfile.processing['rewards'].get_data_interface('autorewarded').timestamps[:]
autorewarded = self.nwbfile.processing['rewards'].get_data_interface('autorewarded').data[:]
volume = self.nwbfile.processing['rewards'].get_data_interface('volume').data[:]
return pd.DataFrame({'volume': volume, 'timestamps': time, 'autorewarded': autorewarded}).set_index('timestamps')
else:
return pd.DataFrame({'volume': [], 'timestamps': [], 'autorewarded': []}).set_index('timestamps')

def get_max_projection(self, image_api=None) -> sitk.Image:
return self.get_image('max_projection', 'two_photon_imaging', image_api=image_api)
return self.get_image('max_projection', 'ophys', image_api=image_api)

def get_average_projection(self, image_api=None) -> sitk.Image:
return self.get_image('average_image', 'two_photon_imaging', image_api=image_api)
return self.get_image('average_image', 'ophys', image_api=image_api)

def get_segmentation_mask_image(self, image_api=None) -> sitk.Image:
return self.get_image('segmentation_mask_image', 'two_photon_imaging', image_api=image_api)
return self.get_image('segmentation_mask_image', 'ophys', image_api=image_api)

def get_metadata(self) -> dict:

metadata_nwb_obj = self.nwbfile.lab_meta_data['metadata']
data = OphysBehaviorMetadataSchema(exclude=['experiment_datetime']).dump(metadata_nwb_obj)
data = OphysBehaviorMetadataSchema(
exclude=['experiment_datetime']).dump(metadata_nwb_obj)

# Add subject related metadata to behavior ophys metadata
# Add pyNWB Subject metadata to behavior ophys session metadata
nwb_subject = self.nwbfile.subject
data['LabTracks_ID'] = int(nwb_subject.subject_id)
data['sex'] = nwb_subject.sex
Expand All @@ -181,9 +194,31 @@ def get_metadata(self) -> dict:
data['reporter_line'] = list(nwb_subject.reporter_line)
data['driver_line'] = list(nwb_subject.driver_line)

experiment_datetime = metadata_nwb_obj.experiment_datetime
data['experiment_datetime'] = OphysBehaviorMetadataSchema().load({'experiment_datetime': experiment_datetime}, partial=True)['experiment_datetime']
data['behavior_session_uuid'] = uuid.UUID(data['behavior_session_uuid'])
# Add pyNWB OpticalChannel and ImagingPlane metadata to behavior ophys
# session metadata
try:
ophys_module = self.nwbfile.processing['ophys']
except KeyError:
warnings.warn("Could not locate 'ophys' module in "
"NWB file. The following metadata fields will be "
"missing: 'ophys_frame_rate', 'indicator', "
"'targeted_structure', 'excitation_lambda', "
"'emission_lambda'")
else:
image_seg = ophys_module.data_interfaces['image_segmentation']
imaging_plane = image_seg.plane_segmentations['cell_specimen_table'].imaging_plane
optical_channel = imaging_plane.optical_channel[0]

data['ophys_frame_rate'] = imaging_plane.imaging_rate
data['indicator'] = imaging_plane.indicator
data['targeted_structure'] = imaging_plane.location
data['excitation_lambda'] = imaging_plane.excitation_lambda
data['emission_lambda'] = optical_channel.emission_lambda

# Add other metadata stored in nwb file to behavior ophys session meta
data['experiment_datetime'] = self.nwbfile.session_start_time
data['behavior_session_uuid'] = uuid.UUID(
data['behavior_session_uuid'])
return data

def get_task_parameters(self) -> dict:
Expand All @@ -193,10 +228,11 @@ def get_task_parameters(self) -> dict:
return data

def get_cell_specimen_table(self) -> pd.DataFrame:
df = self.nwbfile.modules['two_photon_imaging'].data_interfaces['image_segmentation'].plane_segmentations['cell_specimen_table'].to_dataframe()
# NOTE: ROI masks are stored in full frame width and height arrays
df = self.nwbfile.processing['ophys'].data_interfaces['image_segmentation'].plane_segmentations['cell_specimen_table'].to_dataframe()
df.index.rename('cell_roi_id', inplace=True)
df['cell_specimen_id'] = [None if csid == -1 else csid for csid in df['cell_specimen_id'].values]
df['image_mask'] = [mask.astype(bool) for mask in df['image_mask'].values]

df.reset_index(inplace=True)
df.set_index('cell_specimen_id', inplace=True)

Expand All @@ -206,31 +242,40 @@ def get_cell_specimen_table(self) -> pd.DataFrame:
return df

def get_dff_traces(self) -> pd.DataFrame:
dff_nwb = self.nwbfile.modules['two_photon_imaging'].data_interfaces['dff'].roi_response_series['traces']
dff_traces = dff_nwb.data[:]
dff_nwb = self.nwbfile.processing['ophys'].data_interfaces['dff'].roi_response_series['traces']
# dff traces stored as timepoints x rois in NWB
# We want rois x timepoints, hence the transpose
dff_traces = dff_nwb.data[:].T
number_of_cells, number_of_dff_frames = dff_traces.shape
num_of_timestamps = len(self.get_ophys_timestamps())
assert num_of_timestamps == number_of_dff_frames

df = pd.DataFrame({'dff': [x for x in dff_traces]}, index=pd.Index(data=dff_nwb.rois.table.id[:], name='cell_roi_id'))

df = pd.DataFrame({'dff': dff_traces.tolist()},
index=pd.Index(data=dff_nwb.rois.table.id[:],
name='cell_roi_id'))
cell_specimen_table = self.get_cell_specimen_table()
df = cell_specimen_table[['cell_roi_id']].join(df, on='cell_roi_id')
return df

def get_corrected_fluorescence_traces(self) -> pd.DataFrame:
corrected_fluorescence_nwb = self.nwbfile.modules['two_photon_imaging'].data_interfaces['corrected_fluorescence'].roi_response_series['traces']
df = pd.DataFrame({'corrected_fluorescence': [x for x in corrected_fluorescence_nwb.data[:]]},
index=pd.Index(data=corrected_fluorescence_nwb.rois.table.id[:], name='cell_roi_id'))
corrected_fluorescence_nwb = self.nwbfile.processing['ophys'].data_interfaces['corrected_fluorescence'].roi_response_series['traces']
# f traces stored as timepoints x rois in NWB
# We want rois x timepoints, hence the transpose
f_traces = corrected_fluorescence_nwb.data[:].T
df = pd.DataFrame({'corrected_fluorescence': f_traces.tolist()},
index=pd.Index(data=corrected_fluorescence_nwb.rois.table.id[:],
name='cell_roi_id'))

cell_specimen_table = self.get_cell_specimen_table()
df = cell_specimen_table[['cell_roi_id']].join(df, on='cell_roi_id')
return df

def get_motion_correction(self) -> pd.DataFrame:
ophys_module = self.nwbfile.processing['ophys']

motion_correction_data = {}
motion_correction_data['x'] = self.nwbfile.modules['motion_correction'].get_data_interface('x').data[:]
motion_correction_data['y'] = self.nwbfile.modules['motion_correction'].get_data_interface('y').data[:]
motion_correction_data['x'] = ophys_module.get_data_interface('ophys_motion_correction_x').data[:]
motion_correction_data['y'] = ophys_module.get_data_interface('ophys_motion_correction_y').data[:]

return pd.DataFrame(motion_correction_data)

Expand Down Expand Up @@ -296,4 +341,4 @@ def compare_fields(x1, x2, err_msg=""):
assert x1[key] == x2[key], key_err_msg

else:
assert x1 == x2, err_msg
assert x1 == x2, err_msg
76 changes: 75 additions & 1 deletion allensdk/brain_observatory/behavior/metadata_processing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,77 @@
OPHYS_1_3_DESCRIPTION = (
"2-photon calcium imaging in the visual cortex of the mouse "
"brain as the mouse performs a visual change detection task "
"with a set of natural scenes upon which it has previously been "
"trained."
)
OPHYS_2_DESCRIPTION = (
"2-photon calcium imaging in the visual cortex of the "
"mouse brain as the mouse is shown images from a "
"change detection task with a set of natural scenes "
"upon which it has previously been trained, but with "
"the lick-response sensor withdrawn (passive/open "
"loop mode)."
)
OPHYS_4_6_DESCRIPTION = (
"2-photon calcium imaging in the visual cortex of the mouse "
"brain as the mouse performs a visual change detection task "
"with a set of natural scenes that are unique from those on "
"which it was previously trained."
)
OPHYS_5_DESCRIPTION = (
"2-photon calcium imaging in the visual cortex of the "
"mouse brain as the mouse is shown images from a "
"change detection task with a set of natural scenes "
"that are unique from those on which it was "
"previously trained, but with the lick-response "
"sensor withdrawn (passive/open loop mode)."
)


def get_expt_description(session_type: str) -> str:
"""Determine a behavior ophys session's experiment description based on
session type.

Parameters
----------
session_type : str
A session description string (e.g. OPHYS_1_images_B )

Returns
-------
str
A description of the experiment based on the session_type.

Raises
------
RuntimeError
Behavior ophys sessions should only have 6 different session types.
Unknown session types (or malformed session_type strings) will raise
an error.
"""
# Experiment descriptions for different session types:
# OPHYS_1 -> OPHYS_6
ophys_1_3 = dict.fromkeys(["OPHYS_1", "OPHYS_3"], OPHYS_1_3_DESCRIPTION)
ophys_4_6 = dict.fromkeys(["OPHYS_4", "OPHYS_6"], OPHYS_4_6_DESCRIPTION)
ophys_2_5 = {"OPHYS_2": OPHYS_2_DESCRIPTION,
"OPHYS_5": OPHYS_5_DESCRIPTION}

expt_description_dict = {**ophys_1_3, **ophys_2_5, **ophys_4_6}

# Session type string will look something like: OPHYS_4_images_A
truncated_session_type = "_".join(session_type.split("_")[:2])

try:
return expt_description_dict[truncated_session_type]
except KeyError as e:
e_msg = (
f"Encountered an unknown session type "
f"({truncated_session_type}) when trying to determine "
f"experiment descriptions. Valid session types are: "
f"{expt_description_dict.keys()}")
raise RuntimeError(e_msg) from e


def get_task_parameters(data):

task_parameters = {}
Expand All @@ -15,4 +89,4 @@ def get_task_parameters(data):
n_stimulus_frames += sum(stim_table.get("draw_log", []))
task_parameters['n_stimulus_frames'] = n_stimulus_frames

return task_parameters
return task_parameters
Loading