In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import networkx as nx

import pynwb
from pynwb import get_class
from pynwb import register_class
from pynwb.form.utils import docval, getargs, popargs, call_docval_func

from datetime import datetime
from dateutil import tz

import matplotlib.pyplot as plt

### Build the track in NetworkX

In [None]:
# Initialize a graph
G = nx.Graph(name='w-track')

# Add nodes to the graph
# ----------------------
# Segments
G.add_node('L', coords=[(1.6, 2.2), (1.6, 0.65)], intermediate_coords=None, kind='segment')
G.add_node('LC', coords=[(1.6, 0.65), (1.9, 0.65)], intermediate_coords=None, kind='segment')  
G.add_node('C', coords=[(1.9, 0.65), (1.9, 2.2)], intermediate_coords=None, kind='segment')  
G.add_node('RC', coords=[(1.9, 0.65), (2.2, 0.65)], intermediate_coords=None, kind='segment')  
G.add_node('R', coords=[(2.2, 0.65), (2.2, 2.2)], intermediate_coords=None, kind='segment')  
G.add_node('EH', coords=[(1.9, 0.4), (1.9, 0.65)], intermediate_coords=None, kind='segment')
# Points
G.add_node('LW', coords=[(1.6, 2.2)], kind='point')
G.add_node('CW', coords=[(1.9, 2.2)], kind='point')
G.add_node('RW', coords=[(2.2, 2.2)], kind='point')
G.add_node('CP', coords=[(1.9, 0.65)], kind='point')
G.add_node('SBD', coords=[(1.9, 0.4)], kind='point')
# Polygon areas
G.add_node('SB', coords=[(1.8, 0.4), (1.9, 0.4), (2.0, 0.4), (2.0, 0.0), (1.8, 0.0)], 
           interior_coords=None, kind='polygon')

# Add edges connecting the appropriate nodes
# -----------------------------------------
# Segment <-> Segment       (Segs must share a coord)
G.add_edge('L', 'LC')
G.add_edge('RC', 'R')
# Segment <-> Well (Point)  (Seg must share a coord with the well)
G.add_edge('L', 'LW')
G.add_edge('C', 'CW')
G.add_edge('R', 'RW')
# Choice Point (Point) <-> Segments  (Seg must share a coord with the choice point)
G.add_edge('CP', 'C')
G.add_edge('CP', 'LC')
G.add_edge('CP', 'RC')
G.add_edge('CP', 'EH')
# Door (Point) <-> Segment  (Segment must share a coord with the door)
G.add_edge('SBD', 'EH')
# Polygon <-> Door (Point)  (Area must share a coord with the door)
G.add_edge('SB', 'SBD')

### Plot the track topology from NetworkX

In [None]:
ax = plt.subplot(111)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
plt.title("W-track topology")
nx.draw_networkx(G, fontsize=6)
pass

### Plot geometry from NetworkX

In [None]:
ax = plt.subplot(111)
plt.title('W-track geometry')
for n, attrs in list(G.nodes.data()):
    if attrs['kind']=='point':
        coord = attrs['coords'][0]
        plt.scatter([coord[0]], [coord[1]], color='r')
    elif attrs['kind']=='segment':
        start, end = attrs['coords']
        plt.plot([start[0], end[0]], [start[1], end[1]], color='k')
    elif attrs['kind']=='polygon':
        poly = attrs['coords']
        xs = [e[0] for e in poly]
        ys = [e[1] for e in poly]
        plt.fill(xs, ys, color='grey', alpha=0.4)
    else:
        raise TypeError("Nodes must have 'kind' point, segment, or polygon.")
plt.xlabel('x pos (meters)')
plt.ylabel('y pos (meters)')
pass 

### Load extension and get Python classes for the FL_ApparatusNodes

Note that the get_class() methods turn all parameters into type list/tuple/dict/set. 
This works for Nodes, since they just have coords that we can view as a list of (x, y) 
tuples as a convention. It also works for Edges, which we can view as a list of two Nodes
as convention. However, we might want to write our own Python classes at some
point to enforce the behavior we want, instead of just using named lists for these parameters.

In [None]:
ns_path = "franklab.namespace.yaml"
pynwb.load_namespaces(ns_path)

In [None]:
from pynwb import register_class, load_namespaces
from pynwb.file import MultiContainerInterface, NWBContainer

@register_class('Node', 'franklab')
class Node(NWBContainer):

    __nwbfields__ = ('name', 'coords')
    
    __help = 'info about FL_ApparatusNode'
    
    @docval({'name': 'name', 'type': str, 'doc': 'name of this node'}, 
            {'name': 'coords', 'type': ('array_data', 'data'), 'doc': 'coords for this node'})
    def __init__(self, **kwargs):
        super(Node, self).__init__(name=kwargs['name'])
        self.coords = kwargs['coords']
        
@register_class('Edge', 'franklab')
class Edge(NWBContainer):

    __nwbfields__ = ('name', 'edge_nodes')

    @docval({'name': 'name', 'type': str, 'doc': 'name of this segement node'},
            {'name': 'edge_nodes', 'type': ('array_data', 'data'), 
             'doc': 'the names of the two nodes in this undirected edge'})
    def __init__(self, **kwargs):
        super(Edge, self).__init__(name=kwargs['name'])
        self.edge_nodes = kwargs['edge_nodes']
        
@register_class('FL_PointNode', 'franklab')
class FL_PointNode(Node):
    
    __nwbfields__ = ('name', 'coords')

    @docval({'name': 'name', 'type': str, 'doc': 'name of this point node'},
            {'name': 'coords', 'type': ('array_data', 'data'), 'doc': 'coords of this node'})
    def __init__(self, **kwargs):
        super(FL_PointNode, self).__init__(name=kwargs['name'], coords=kwargs['coords'])
        
@register_class('FL_SegmentNode', 'franklab')
class FL_SegmentNode(Node):
    
    __nwbfields__ = ('name', 'coords', 'intermediate_coords')

    @docval({'name': 'name', 'type': str, 'doc': 'name of this segement node'},
            {'name': 'coords', 'type': ('array_data', 'data'), 'doc': 'start/stoop coords of this segment'},
            {'name': 'intermediate_coords', 'type': ('array_data', 'data'), 
             'doc': 'intermediate coords between the start/stop coords of this node', 'default': None})
    def __init__(self, **kwargs):
        super(FL_SegmentNode, self).__init__(name=kwargs['name'], coords=kwargs['coords'])
        self.intermediate_coords = kwargs['intermediate_coords']
        
@register_class('FL_PolygonNode', 'franklab')
class FL_PolygonNode(Node):

    __nwbfields__ = ('name', 'coords', 'interior_coords')

    @docval({'name': 'name', 'type': str, 'doc': 'name of this segement node'},
            {'name': 'coords', 'type': ('array_data', 'data'), 'doc': 'vertices and exterior control points (i.e. doors) of this polygon'},
            {'name': 'interior_coords', 'type': ('array_data', 'data'), 
             'doc': 'coords inside this polygon area (i.e. wells, objects)', 'default': None})
    def __init__(self, **kwargs):
        super(FL_PolygonNode, self).__init__(name=kwargs['name'], coords=kwargs['coords'])
        self.interior_coords = kwargs['interior_coords']
 

@register_class('FL_ApparatusGraph', 'franklab')
class FL_ApparatusGraph(MultiContainerInterface):
    """
    Graph of FL_ApparatusNodes connected by undirected FL_ApparatusEdges.
    Represents behaviorally reachable components of an apparatus.
    """
    
    __nwbfields__ = ('name', 'edges', 'nodes')
    
    __clsconf__ = [
        {
        'attr': 'edges',
        'type': Edge,
        'add': 'add_edge',
        'get': 'get_edge'
        },
        {
        'attr': 'nodes',
        'type': Node,
        'add': 'add_node',
        'get': 'get_node'
        }
    ]
    __help = 'info about FL_ApparatusGraph'

### Function for converting from NX nodes to FL_ApparatusNodes

In [None]:
def nx_to_fl_node(node_name, attrs):
    if 'kind' not in attrs:
        raise TypeError("NX node attributes must contain a 'kind' field")
    if 'coords' not in attrs:
        raise TypeError("NX node attributes must contain a 'coords' field")
    if attrs['kind']=='segment':
        if 'intermediate_coords' not in attrs:
            raise TypeError("NX 'segment' nodes must contain a 'intermediate_coords' field. It can be set to None.")
        return FL_SegmentNode(name=node_name, coords=attrs['coords'], 
                              intermediate_coords=attrs['intermediate_coords'])
    elif attrs['kind']=='point':
        return FL_PointNode(name=n, coords=attrs['coords'])
    elif attrs['kind']=='polygon':
        if 'interior_coords' not in attrs:
            raise TypeError("NX 'polygon' nodes must contain a 'interior_coords' field. It can be set to None.")
        return FL_PolygonNode(name=n, coords=attrs['coords'], 
                              interior_coords=attrs['interior_coords'])
    else:
        raise TypeError('Nodes must be of type point, segment, or polygon.')

### Load NetworkX nodes into FL_ApparatusNode objects and add to the FL_ApparatusGraph

In [None]:
appar = FL_ApparatusGraph(name='W-track with sleep box')
for n, attrs in list(G.nodes.data()):
    fl_node = nx_to_fl_node(n, attrs)
    appar.add_node(fl_node)

### Load NetworkX edges into FL_ApparatusEdge objects

In [None]:
for (n1, n2) in G.edges:
    name_str = n1 + '-' + n2
    appar.add_edge(Edge(name=name_str, edge_nodes=(n1, n2)))

### Examine our Apparatus (before saving to file)

In [None]:
print(appar)

### Create a new NWBfile

In [None]:
anim = 'Bon' 
day = 4 # below we'll code date as 2006-Jan-'Day'
day_str = '%02d' % day
dataset_zero_time = datetime(2006, 1, day, 12, 0, 0, tzinfo=tz.gettz('US/Pacific'))
file_create_date = datetime.now(tz.tzlocal())
nwb_filename = 'apparatus_extension_test.nwb'

nwbf = pynwb.NWBFile(
           session_description='Example NWBFile with behavioral track data',
           identifier=anim+day_str,
           session_start_time=dataset_zero_time,
           file_create_date=file_create_date,
           lab='Frank Laboratory',
           experimenter='Mattias Karlsson',
           institution='UCSF',
           experiment_description='Recordings from awake behaving rat')

### Add the FL_ApparatusGraph into the NWBFile
Here we add it as a data interface of a processing module called 'Behavior', even though the docs state that data interfaces should be used for containers that do not include metadata. (https://pynwb.readthedocs.io/en/latest/building_api.html#nwbdatainterface)

In [None]:
behav_mod = nwbf.create_processing_module(name='Behavior', 
                                          description='Behavioral data and metadata')
behav_mod.add_container(appar)

### Write the NWBfile

In [None]:
# Write
with pynwb.NWBHDF5IO(nwb_filename, mode='w') as iow:
    iow.write(nwbf, cache_spec=True)
print('Wrote nwb file: ' + nwb_filename)

### Read the NWBFile

In [None]:
# Read
io = pynwb.NWBHDF5IO(nwb_filename, mode='r')
nwbf_read = io.read()
print('Finished reading nwb file: ' + nwb_filename)

### Make sure our FL_ApparatusGraph is still looking good

In [None]:
appar = nwbf_read.modules['Behavior']['W-track']
print(appar)

### Function to directly input FL_ApparatusNodes into an NX graph

In [None]:
def add_fl_node_to_nx_graph(fl_node, nx_graph):
    if isinstance(fl_node, FL_SegmentNode):
        nx_graph.add_node(fl_node.name, coords=fl_node.coords, intermediate_coords=fl_node.intermediate_coords)
    elif isinstance(fl_node, FL_PointNode):
        nx_graph.add_node(fl_node.name, coords=fl_node.coords)
    elif isinstance(fl_node, FL_PolygonNode):
        nx_graph.add_node(fl_node.name, coords=fl_node.coords, interior_coords=fl_node.interior_coords)
    else:
        raise TypeError("'fl_node' must be of type FL_SegmentNode, FL_PointNode, or FL_PolygonNode")

### Extract graph data back out into NetworkX

In [None]:
appar = nwbf_read.modules['Behavior']['W-track']
H = nx.Graph(name='w-track round trip')

for n in appar.nodes:
    add_fl_node_to_nx_graph(n, H)

for e in appar.edges:
    (n1, n2) = e.nodes
    H.add_edge(n1, n2)

### Plot track topology from NetworkX (roundtrip)

In [None]:
ax = plt.subplot(111)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
plt.title("W-track topology")
nx.draw_networkx(H, fontsize=6)
pass

### Plot track geometry from NetworkX (roundtrip)

In [None]:
ax = plt.subplot(111)
plt.title('W-track geometry')
for n, attrs in list(H.nodes.data()):
    if attrs['kind']=='point':
        coord = attrs['coords'][0]
        plt.scatter([coord[0]], [coord[1]], color='r')
    elif attrs['kind']=='segment':
        start, end = attrs['coords']
        plt.plot([start[0], end[0]], [start[1], end[1]], color='k')
    elif attrs['kind']=='polygon':
        poly = attrs['coords']
        xs = [e[0] for e in poly]
        ys = [e[1] for e in poly]
        plt.fill(xs, ys, color='grey', alpha=0.4)
    else:
        raise TypeError("Nodes must have 'kind' point, segment, or polygon.")
plt.xlabel('x pos (meters)')
plt.ylabel('y pos (meters)')
pass 

In [None]:
io.close()  # Close the reading file IO