In [5]:
# handle imports
import os
import json
import numpy as np
import pandas as pd

In [6]:
step = 1
display_type = "seam+annotations" # seam+annotations
model_name = '2020_09_09-Nerve_Ring'
strain_target = None
pos_target = None # strain name substring

In [7]:
# get data
folderpath = os.path.join(os.getcwd(), 'workspace', model_name)
if step == 1:
    filename = '1_compiled_data'
elif step == 2:
    filename = '2_compiled_data_no_outliers'
elif step == 3:
    filename = '3_compiled_data_interpolation'
elif step == 5:
    filename = '5_compiled_data_warped'
elif step == 6:
    filename = '6_cell_coord_stats_by_timepoint'
elif step == 7:
    filename = '7_cell_coord_stats_by_timepoint_smoothed'
filepath = os.path.join(folderpath, filename + '.json')
with open(filepath) as f:
    data = json.load(f)

In [8]:
# plot figure
from plotly.subplots import make_subplots
import plotly.graph_objects as go

axes = ['x', 'y', 'z']


for strain in data.keys():
    if strain_target and strain_target not in strain:
        continue

    for pos_idx, pos_name in enumerate(data[strain].keys()):
        if pos_target and pos_target not in pos_name:
            continue

        # make the plot for each position
        fig = make_subplots(rows=3, cols=1)
        for axis_idx, axis in enumerate(axes):      
            for cell_type in ['seam_cells', 'annotations']:
                if "seam" not in display_type and cell_type == "seam_cells":
                    continue
                elif "anno" not in display_type and cell_type == "annotations":
                    continue      

                for cell_name in data[strain][pos_name][cell_type].keys():
                    try:
                        timepoints = np.array(data[strain][pos_name][cell_type][cell_name]['timepoints'])
                        coordinates = np.array(data[strain][pos_name][cell_type][cell_name]['coordinates'])
                    except:
                        coordinates = np.array(data[strain][pos_name][cell_type][cell_name])                   
                        timepoints = np.arange(coordinates.shape[0])

                    show_legend = False
                    if axis_idx == 0:
                        show_legend = True
                        
                    fig.append_trace(go.Scatter(
                        x=timepoints,
                        y=coordinates[:, axis_idx],
                        mode='lines',
                        name=cell_name,
                        legendgroup=cell_name,
                        showlegend = show_legend
                    ), row=axis_idx+1, col=1)

            fig.update_xaxes(title_text="Timepoints", row=axis_idx+1, col=1)
            fig.update_yaxes(title_text=axis, row=axis_idx+1, col=1)

        id_text = "_{}_{}".format(strain, pos_name)

        fig.update_layout(height=1920, width=1080, 
                          title_text="{}{} Step {} ({})".format(model_name.upper(), id_text, step, display_type))

        # fig.show()
        output_filename = "{}{}_step_{}_1D.html".format(model_name.upper(), id_text, step)
        output_filepath = os.path.join(folderpath, 'visualizations', output_filename)
        fig.write_html(output_filepath)
