In [None]:
def _to_hex(rgb):
    
    """
    Convert RGB color into hexadecimal.
    
    :param rgb: RGB color
    :return: Hexadecimal color
    """
    
    if isinstance(rgb, list):
        return [matplotlib.colors.to_hex(c) for c in arr]
    else:
        return '%02x%02x%02x' % rgb

def get_dict_legend(l):
    
    """
    return dictionary that contains a distinct color for each distinct value from a list l.
    
    :param l: List of categorical values 
    :return: Dictionary
    """
    
    dict_legend = {}
    for e in np.unique(l):
        dict_legend[e] = "#" + _to_hex((random.randint(0,255), random.randint(0,255), random.randint(0,255)))
    return dict_legend

def get_colors(l, dict_legend):
    
    """
    return list of colors for each categorical value from the dictionary.
    
    :param l: List of categorical values
    :param dict_legend: Dictionary of distinct colors for each categorical value
    :return: List of colors from the dictionary
    """
    
    colors = []
    for e in l:
        colors.append(dict_legend[e])
    return colors

def get_categorical_values(tr_id_list, df_cat):
    
    """
    Get list of colors for each categorical value from the dictionary.
    
    :param tr_id_list: List of ID patients
    :param df_cat: Dataframe of categorical values for each ID patient
    :return: List of all categorical values
    """
    
    global embed
    cat_values = []
    
    for i, tr_id in enumerate(tr_id_list):
        df_tmp = df_cat[(df_cat['ID'] == tr_id)].drop('ID', axis=1)
        cat_values.append(df_tmp.drop_duplicates().values[0][0])
            
    return cat_values

In [1]:
def get_greater_trajectory_duration(df_values, df_timestamps):
    
    """
    Get largest follow-up duration from trajectories.
    
    :param df_values: Dataframe that contains values from trajectories
    :param df_timestamps: Dataframe that contains timestamps from trajectories
    :return: Largest duration
    """
    
    # Find x axis limit
    max_duration = float(-np.inf)
    # For each trajectory id
    for i, index in enumerate(df_values.index.values):
        
        # Get values and timestamps from the trajectory
        values_filter, timestamps_filter = get_filtered_trajectory_values_timestamps(index, df_values, df_timestamps)
        
        # compare duration
        if timestamps_filter[-1] > max_duration:
            max_duration = timestamps_filter[-1]
            
    return max_duration

In [None]:
def _get_embedding(umap_object):
    if hasattr(umap_object, "embedding_"):
        return umap_object.embedding_
    elif hasattr(umap_object, "embedding"):
        return umap_object.embedding
    else:
        raise ValueError("Could not find embedding attribute of umap_object")

def interactive(
    umap_object,
    labels=None,
    values=None,
    hover_data=None,
    theme=None,
    cmap="Blues",
    color_key=None,
    color_key_cmap="Spectral",
    background="white",
    width=800,
    height=800,
    point_size=None,
    subset_points=None,):

    if theme is not None:
        cmap = _themes[theme]["cmap"]
        color_key_cmap = _themes[theme]["color_key_cmap"]
        background = _themes[theme]["background"]

    if labels is not None and values is not None:
        raise ValueError(
            "Conflicting options; only one of labels or values should be set"
        )

    points = _get_embedding(umap_object)
    if subset_points is not None:
        if len(subset_points) != points.shape[0]:
            raise ValueError(
                "Size of subset points ({}) does not match number of input points ({})".format(
                    len(subset_points), points.shape[0]
                )
            )
        points = points[subset_points]

    if points.shape[1] != 2:
        raise ValueError("Plotting is currently only implemented for 2D embeddings")

    if point_size is None:
        point_size = 100.0 / np.sqrt(points.shape[0])

    data = pd.DataFrame(_get_embedding(umap_object), columns=("x", "y"))

    if labels is not None:
        data["label"] = labels

        if color_key is None:
            unique_labels = np.unique(labels)
            num_labels = unique_labels.shape[0]
            color_key = rgb_to_hex(
                plt.get_cmap(color_key_cmap)(np.linspace(0, 1, num_labels))
            )

        if isinstance(color_key, dict):
            data["color"] = pd.Series(labels).map(color_key)
        else:
            unique_labels = np.unique(labels)
            if len(color_key) < unique_labels.shape[0]:
                raise ValueError(
                    "Color key must have enough colors for the number of labels"
                )

            new_color_key = {k: color_key[i] for i, k in enumerate(unique_labels)}
            data["color"] = pd.Series(labels).map(new_color_key)

        colors = "color"

    elif values is not None:
        data["value"] = values
        palette = _to_hex(plt.get_cmap(cmap)(np.linspace(0, 1, 256)))
        colors = btr.linear_cmap(
            "value", palette, low=np.min(values), high=np.max(values)
        )

    else:
        colors = matplotlib.colors.rgb2hex(plt.get_cmap(cmap)(0.5))

    if subset_points is not None:
        data = data[subset_points]
        if hover_data is not None:
            hover_data = hover_data[subset_points]

    if points.shape[0] <= width * height // 10:

        if hover_data is not None:
            tooltip_dict = {}
            for col_name in hover_data:
                data[col_name] = hover_data[col_name]
                tooltip_dict[col_name] = "@" + col_name
            tooltips = list(tooltip_dict.items())
        else:
            tooltips = None

        data["alpha"] = 1

        bpl.output_notebook(hide_banner=True)
        data_source = bpl.ColumnDataSource(data)

        plot = bpl.figure(
            width=width,
            height=height,
            tooltips=tooltips,
            background_fill_color=background,
        )
        plot.circle(
            x="x",
            y="y",
            source=data_source,
            color=colors,
            size=point_size,
            alpha="alpha",
        )

        plot.grid.visible = False
        plot.axis.visible = False

        
    else:
        if hover_data is not None:
            warn(
                "Too many points for hover data -- tooltips will not"
                "be displayed. Sorry; try subssampling your data."
            )
        if interactive_text_search:
            warn(
                "Too many points for text search." "Sorry; try subssampling your data."
            )
        hv.extension("bokeh")
        hv.output(size=300)
        hv.opts('RGB [bgcolor="{}", xaxis=None, yaxis=None]'.format(background))
        if labels is not None:
            point_plot = hv.Points(data, kdims=["x", "y"])
            plot = hd.datashade(
                point_plot,
                aggregator=ds.count_cat("color"),
                color_key=color_key,
                cmap=plt.get_cmap(cmap),
                width=width,
                height=height,
            )
        elif values is not None:
            min_val = data.values.min()
            val_range = data.values.max() - min_val
            data["val_cat"] = pd.Categorical(
                (data.values - min_val) // (val_range // 256)
            )
            point_plot = hv.Points(data, kdims=["x", "y"], vdims=["val_cat"])
            plot = hd.datashade(
                point_plot,
                aggregator=ds.count_cat("val_cat"),
                cmap=plt.get_cmap(cmap),
                width=width,
                height=height,
            )
        else:
            point_plot = hv.Points(data, kdims=["x", "y"])
            plot = hd.datashade(
                point_plot,
                aggregator=ds.count(),
                cmap=plt.get_cmap(cmap),
                width=width,
                height=height,
            )

    return plot