In [1]:
# get imports
import os
import json
import copy
import numpy as np
import pandas as pd
import plotly.graph_objects as go

In [2]:
"""
'1_compiled_data', 
'2_compiled_data_no_outliers', 
'3_compiled_data_interpolation', 
'5_compiled_data_warped',
'6_cell_coordinates_by_timepoint',
'7_cell_coordinates_by_timepoint_smoothed'
"""

step = 7
model_name = '2020_08_20-OD1599_NU+NerveRing'
folderpath = 'Y:/RyanC/model_building_code/workspace/{}/'.format(model_name)

In [3]:
def plot_2d_model(pos_data, og_timepoints=None, interval=1, model="test", axis='zx', cell_info=None, show_plot=False):

    if axis == 'zy':
        width = 1000
        height = 500
        x_range = [0, 1500]
        y_range = [0, 300]
    elif axis == 'xy':
        width = 1000
        height = 1000  
        x_range = [0, 300]
        y_range = [0, 300] 
    else: # zx
        width = 1000
        height = 500
        x_range = [0, 1500]
        y_range = [0, 300]
                   
    layout = {
        'template': "plotly_dark",
        'scene': {
            'aspectmode': 'data',
            'xaxis': {
                'showgrid': False,
                "title": "x"
            },
            'yaxis': {
                'showgrid': False,
                "title": "y",
            },
            "camera": {
                "up": {
                    "x": 0,
                    "y": 0,
                    "z": 1
                },
                "center": {
                    "x": 0,
                    "y": 0,
                    "z": 0
                },
                "eye": {
                    "x": 2,
                    "y": 0,
                    "z": 0.5
                },
                "projection": {
                    "type": "perspective"
                }
            }
        },
        'showlegend': True,
        'height': height,
        'width': width,
        'xaxis': {
            "range": x_range,
        },
        'yaxis': {
            "range": y_range,
            'scaleanchor': 'x',
            'scaleratio': 1
        },
    }

    # make figure
    fig_dict = {
        "data": [],
        "layout": layout,
        "frames": []
    }

    # fill in most of layout
    fig_dict["layout"]["hovermode"] = "closest"
    fig_dict["layout"]["updatemenus"] = [
        {
            "buttons": [],
            "direction": "left",
            "pad": {"r": 10, "t": 87},
            "showactive": False,
            "type": "buttons",
            "x": 0,
            "xanchor": "right",
            "y": 0,
            "yanchor": "top"
        }
    ]

    sliders_dict = {
        "active": 0,
        "yanchor": "top",
        "xanchor": "left",
        "currentvalue": {
            "font": {"size": 20},
            "prefix": "Time:",
            "visible": True,
            "xanchor": "right"
        },
        "transition": {
            "duration": 10, 
            # "easing": "cubic-in-out"
        },
        "pad": {"b": 10, "t": 50},
        "len": 1,
        "x": 0,
        "y": 0,
        "steps": []
    }


    # get cooresponding volume numbers
    timepoints = np.arange(0, 420, interval)

    if og_timepoints:
        og_scaled_timepoints = (timepoints-np.min(timepoints)) 
        og_scaled_timepoints = og_scaled_timepoints/np.max(og_scaled_timepoints) * \
            np.max(og_timepoints-np.min(og_timepoints)) + np.min(og_timepoints)
        vol_point = []
        for og_scaled_timepoint in og_scaled_timepoints:
            closest_idx = np.argmin(abs(og_timepoints - og_scaled_timepoint))
            vol_point.append(og_timepoints[closest_idx])

    # make frames        
    for timepoint_idx, timepoint in enumerate(timepoints):
        frame = {
            "data": [], 
            "name": str(timepoint)
        }

        timepoint_df = []
        for cell_type in ['seam_cells', 'annotations']: # for each cell, instead of continent
            for cell_name in pos_data[cell_type].keys():
                try:
                    coords = np.array(pos_data[cell_type][cell_name]['coordinates'])
                except: # step 6
                    coords = np.array(pos_data[cell_type][cell_name])

                timepoint_cell = {}
                timepoint_cell['cell_name'] = cell_name
                timepoint_cell['cell_type'] = cell_type
                timepoint_cell['x'] = coords[timepoint, 0] # by timepoint and by cell
                timepoint_cell['y'] = coords[timepoint, 1]
                timepoint_cell['z'] = coords[timepoint, 2]
                timepoint_df.append(timepoint_cell)
        timepoint_df = pd.DataFrame(timepoint_df)

        # OTHER STRUCTURES; if there's another structure, define it here (e.g. nerve ring)
        # NERVE RING (GOES FIRST SO BEHIND) -------------------------------------------------------
        nerve_ring = {
            'color': '#FFF',
            'data': {'cell_name': [], 'x': [], 'y': []}
        }
        for cell_idx, cell_data in timepoint_df.iterrows():
            if 'nr' in cell_data['cell_name']:
                # plot certain axis
                if axis == 'zy':
                    x = cell_data["z"]
                    y = cell_data["y"]
                elif axis == 'xy':
                    x = cell_data["x"]
                    y = cell_data["y"]            
                else: # zx
                    x = cell_data["z"]
                    y = cell_data["x"]
                nerve_ring['data']['cell_name'].append(cell_data['cell_name'])
                nerve_ring['data']['x'].append(x)
                nerve_ring['data']['y'].append(y)
                
        # compile all other structures so that it's coherent
        cell_names_og_order = copy.deepcopy(nerve_ring['data']['cell_name'])
        nerve_ring['data']['cell_name'] = [x for _,x in sorted(zip(cell_names_og_order, nerve_ring['data']['cell_name']))]
        nerve_ring['data']['x'] = [x for _,x in sorted(zip(cell_names_og_order, nerve_ring['data']['x']))]
        nerve_ring['data']['y'] = [x for _,x in sorted(zip(cell_names_og_order, nerve_ring['data']['y']))]
        nerve_ring_dict = {
            "x": nerve_ring['data']['x'], "y": nerve_ring['data']['y'],
            "mode": "lines",
            "text": nerve_ring['data']["cell_name"],
            "line": {
                "color": nerve_ring['color'],
                'width': 10
            },
            "name": 'nerve_ring'
        }
        if timepoint == 0: # initialize data in plot
            fig_dict['data'].append(nerve_ring_dict)
        frame["data"].append(nerve_ring_dict)
        
        # MAIN DRAWING LOOP -------------------------------------------------------- 
        for cell_idx, cell_data in timepoint_df.iterrows():
            if cell_data['cell_name'].lower() in cell_info.keys():
                all_colors = cell_info[cell_data['cell_name'].lower()]['colors']
                all_colors_keys = list(all_colors.keys())
                color_list = all_colors[all_colors_keys[0]]
                color = 'rgb({},{},{})'.format(color_list[0], color_list[1], color_list[2])
            elif cell_data['cell_type'] == 'seam_cells':
                color='#AAA'
            else:
                color=None

            # plot certain axis
            if axis == 'zy':
                x = [cell_data["z"]]
                y = [cell_data["y"]]
            elif axis == 'xy':
                x = [cell_data["x"]]
                y = [cell_data["y"]]            
            else: # zx
                x = [cell_data["z"]]
                y = [cell_data["x"]]
                
            data_dict = {
                "x": x, "y": y,
                "mode": "markers+text",
                "text": cell_data["cell_name"],
                "marker": {
                    "sizemode": "area",
                    "size": 20,
                    "color": color
                },
                "name": cell_data["cell_name"]
            }
            
            if timepoint == 0: # initialize data in plot
                fig_dict['data'].append(data_dict)
            frame["data"].append(data_dict)
            
        # append frame
        fig_dict["frames"].append(frame)

        # deal with sliders
        if og_timepoints:
            label_name = '{} ({})'.format(str(timepoint), vol_point[timepoint_idx])
        else:
            label_name = '{}'.format(str(timepoint))
        slider_step = {"args": 
            [
                [str(timepoint)],{
                    "frame": {
                        "duration": 10, 
                        "redraw": False
                    },
                 "mode": "immediate",
                 "transition": {"duration": 10}
                }
            ],
            "label": label_name,
            "method": "animate"}
        sliders_dict["steps"].append(slider_step)

    fig_dict["layout"]["sliders"] = [sliders_dict]

    fig = go.Figure(fig_dict)
    if show_plot:
        fig.show()
    output_filepath = os.path.join(output_folder, "{}_{}.html".format(model, axis))

    fig.write_html(output_filepath)

In [4]:
def plot_3d_model(pos_data, og_timepoints=None, interval=1, model="test", cell_info=None, show_plot=False):
    layout = {
        'template': "plotly_dark",
        'scene': {
            'aspectmode': 'manual',
            'aspectratio':{'x':5, 'y':1, 'z':1},
            'xaxis': {
                'showgrid': False,
                "title": "z",
                "range": [0, 1500],
            },
            'yaxis': {
                'showgrid': False,
                "title": "y",
                "range": [0, 300],
                "visible": False
            },
            'zaxis': {
                'showgrid': False,
                "title": "x",
                "range": [0, 300],
                "visible": False
            },
        },
        'showlegend': True,
        'height': 600,
        'width': 1000
    }


    # make figure
    fig_dict = {
        "data": [],
        "layout": layout,
        "frames": []
    }

    # fill in most of layout
    fig_dict["layout"]["hovermode"] = "closest"

    sliders_dict = {
        "active": 0,
        "yanchor": "top",
        "xanchor": "left",
        "currentvalue": {
            "font": {"size": 20},
            "prefix": "Time:",
            "visible": True,
            "xanchor": "right"
        },
        "transition": {
            "duration": 10, 
            "easing": "cubic-in-out"
        },
        "len": 1,
        "x": 0,
        "y": 0,
        "steps": []
    }

    # get cooresponding volume numbers
    timepoints = np.arange(0, 420, interval)
    
    if og_timepoints:
        og_scaled_timepoints = (timepoints-np.min(timepoints)) 
        og_scaled_timepoints = og_scaled_timepoints/np.max(og_scaled_timepoints) * \
            np.max(og_timepoints-np.min(og_timepoints)) + np.min(og_timepoints)
        vol_point = []
        for og_scaled_timepoint in og_scaled_timepoints:
            closest_idx = np.argmin(abs(og_timepoints - og_scaled_timepoint))
            vol_point.append(og_timepoints[closest_idx])

    # make frames
    for timepoint_idx, timepoint in enumerate(timepoints):
        # create frame object
        frame = {}
        frame["data"] = []
        frame["name"] = str(timepoint_idx)

        # get cells at timepoint into single DataFrame
        timepoint_df = []
        for cell_type in ['seam_cells', 'annotations']: # for each cell, instead of continent
            for cell_name in pos_data[cell_type].keys():
                try:
                    coords = np.array(pos_data[cell_type][cell_name]['coordinates'])
                except:
                    coords = np.array(pos_data[cell_type][cell_name])

                timepoint_cell = {}
                timepoint_cell['cell_name'] = cell_name
                timepoint_cell['cell_type'] = cell_type
                timepoint_cell['x'] = coords[timepoint, 0] # by timepoint and by cell
                timepoint_cell['y'] = coords[timepoint, 1]
                timepoint_cell['z'] = coords[timepoint, 2]
                timepoint_df.append(timepoint_cell)
        timepoint_df = pd.DataFrame(timepoint_df)

        # OTHER STRUCTURES; if there's another structure, define it here (e.g. nerve ring)
        # NERVE RING (GOES FIRST SO BEHIND) -------------------------------------------------------
        nerve_ring = {
            'color': '#FFF',
            'data': {'cell_name': [], 'x': [], 'y': [], 'z': []}
        }
        for cell_idx, cell_data in timepoint_df.iterrows():
            if 'nr' in cell_data['cell_name']:
                nerve_ring['data']['cell_name'].append(cell_data['cell_name'])
                nerve_ring['data']['x'].append(cell_data["z"])
                nerve_ring['data']['y'].append(cell_data["y"])
                nerve_ring['data']['z'].append(cell_data["x"])
                
        cell_names_og_order = copy.deepcopy(nerve_ring['data']['cell_name'])
        nerve_ring['data']['cell_name'] = [x for _,x in sorted(zip(cell_names_og_order, nerve_ring['data']['cell_name']))]
        nerve_ring['data']['x'] = [x for _,x in sorted(zip(cell_names_og_order, nerve_ring['data']['x']))]
        nerve_ring['data']['y'] = [x for _,x in sorted(zip(cell_names_og_order, nerve_ring['data']['y']))]
        nerve_ring['data']['z'] = [x for _,x in sorted(zip(cell_names_og_order, nerve_ring['data']['z']))]
        nerve_ring_dict = go.Scatter3d(
            x = nerve_ring['data']['x'], 
            y = nerve_ring['data']['y'],
            z = nerve_ring['data']['z'],
            mode = "lines",
            text = nerve_ring['data']["cell_name"],
            line = {
                "color": nerve_ring['color'],
                'width': 10
            },
            name = 'nerve_ring'
        )
        if timepoint == 0: # initialize data in plot
            fig_dict['data'].append(nerve_ring_dict)
        frame["data"].append(nerve_ring_dict)
        
        # MAIN DRAWING LOOP --------------------------------------------------------                 
        # plot each cell individually
        all_frame_data = []
        for cell_idx, cell_data in timepoint_df.iterrows():
            if cell_data['cell_name'].lower() in cell_info.keys():
                all_colors = cell_info[cell_data['cell_name'].lower()]['colors']
                all_colors_keys = list(all_colors.keys())
                color_list = all_colors[all_colors_keys[0]]
                color = 'rgb({},{},{})'.format(color_list[0], color_list[1], color_list[2])
            elif cell_data['cell_type'] == 'seam_cells':
                color='#AAA'
            else:
                color=None

            single_data_point = go.Scatter3d(
                x=[cell_data['z']], 
                y=[cell_data['y']], 
                z=[cell_data['x']],
                marker={
                    "sizemode": "area",
                    "size": 10, # list(dataset_by_year_and_cont["pop"])
                    "color": color
                },
                mode="markers+text",
                text=cell_data["cell_name"],
                name=cell_data["cell_name"]
            )

            all_frame_data.append(single_data_point)

        if timepoint == 0:
            # initialize data with the full plot
            fig_dict['data'].extend(all_frame_data)
        frame["data"].extend(all_frame_data)


        # put in compiled data into over all animation
        fig_dict["frames"].append(frame)

        # handle UI stuff
        if og_timepoints:
            label_name = '{} ({})'.format(str(timepoint), vol_point[timepoint_idx])
        else:
            label_name = '{}'.format(str(timepoint))
        slider_step = {"args": 
            [
                [str(timepoint_idx)],{
                    "frame": {
                        "duration": 500,
                        "redraw": True 
                    },
                 "mode": "immediate",
                 "transition": {"duration": 500}
                }
            ],
            "label": label_name,
            "method": "animate"}
        sliders_dict["steps"].append(slider_step)

    fig_dict["layout"]["sliders"] = [sliders_dict]
    

    fig = go.Figure(fig_dict)
    
    output_filepath = os.path.join(output_folder, "{}_3D.html".format(model))
    fig.write_html(output_filepath)
    
    if show_plot:
        fig.show()

In [5]:
# load in data
step_names = [
    '1_compiled_data', 
    '2_compiled_data_no_outliers', 
    '3_compiled_data_interpolation', 
    None, 
    '5_compiled_data_warped',
    '6_cell_coordinates_by_timepoint',
    '7_cell_coordinates_by_timepoint_smoothed'
]
with open('config.json') as f:
    config = json.load(f)
    cell_info_filepath = config['settings']['mipav_output']['cell_info']

# get colors
with open(cell_info_filepath) as f:
    cell_info = json.load(f)
    
filename = step_names[step-1]
filepath = folderpath + filename + '.json'
filepath_raw = folderpath +  step_names[0] + '.json'
output_folder = os.path.join('workspace', model_name, 'visualizations')
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

with open(filepath) as f:
    data = json.load(f)
with open(filepath_raw) as f:
    step_1_data = json.load(f)

if step < 6:
    for strain_name in data.keys():
        pos_names = data[strain_name].keys()
        for pos_name in pos_names:

            # find the volume numbers
            raw_data = step_1_data[strain_name][pos_name]
            og_timepoints = None
            timepoints_max = 0
            for cell_name in raw_data['annotations'].keys():
                cell_data = raw_data['annotations'][cell_name]

                if len(cell_data['timepoints']) >= timepoints_max:
                    og_timepoints = cell_data['timepoints']
                    timepoints_max = len(cell_data['timepoints'])

            # plot position data
            plot_name = '{}_{}_step_{}'.format(strain_name, pos_name, step)
            pos_data = data[strain_name][pos_name]

            # plot each axis
            for axis in ['zx', 'zy', 'xy']:
                plot_2d_model(pos_data, og_timepoints=og_timepoints, 
                              model=plot_name, show_plot=False, axis=axis, cell_info=cell_info)
elif step >= 6:
    plot_name = 'combined_model_step_{}'.format(step)
    for axis in ['zx', 'zy', 'xy']:
            plot_2d_model(data, og_timepoints=None, cell_info=cell_info,
                          model=plot_name, show_plot=False, axis=axis)

In [6]:
if step >= 6:
    plot_name = 'combined_model_step_{}'.format(step)
    plot_3d_model(data, model=plot_name, show_plot=False, cell_info=cell_info)