In [7]:
import pandas as pd
import plotly.graph_objects as go
import os
import seaborn as sns

RESULT_DIR = './results'
DATA_DIR = RESULT_DIR + '/aggregate'
DATASET_TAGS = ['tabula-muris-heart', 'tabula-muris-marrow_P7_3', 'peripheal-blood', 'zheng-4', 'zheng-8']

In [8]:

def plot_sankey(labels, source, target, value, title, cluster_colors, out_file):
    fig = go.Figure(data=[go.Sankey(
        node = dict(
        pad = 15,
        thickness = 20,
        line = dict(color = "black", width = 0.5),
        label = labels,
        color = cluster_colors + cluster_colors
        ),
        link = dict(
        source = source,
        target = target,
        value = value
        ))])

    fig.update_layout(title_text=title, font_size=10)
    fig.show()
    fig.write_image(out_file + '.pdf')
    fig.write_html(out_file + '.html')
    

In [9]:
for dataset in DATASET_TAGS:
    cur_path = DATA_DIR + '/' + dataset + '/'
    # check if labels.csv exists
    if os.path.exists(cur_path + "labels.csv"):
        labels = pd.read_csv(cur_path + "labels.csv")

        if labels.columns[-1] != 'true_labels':
            print("ERROR: true_label not found for dataset " + dataset)
        else:
            tool_tags = labels.columns[1:-1]
            cluster_num = labels['true_labels'].nunique()
            plot_label = []
            for i in range(cluster_num):
                plot_label.append(str(i))

            mapping_path = "./dataset/" + dataset + "-filtered/mapping.csv"
            mapping_df = pd.read_csv(mapping_path)
            mapping_df.sort_values(by=['id'], inplace=True)
            go_strings = mapping_df['go'].tolist()
            plot_label = plot_label + go_strings

            for tool_label in tool_tags:
                # generate confusion matrix between labels
                confusion_matrix = pd.crosstab(labels[tool_label], labels['true_labels'], colnames=['Predicted'], rownames=['True'], margins=True)
                cur_source = []
                cur_target = []
                cur_value = []
                tool = tool_label[0:-6]
                for i in range(cluster_num):
                    for j in range(cluster_num):
                        if confusion_matrix.iloc[i, j] != 0:
                            cur_source.append(i)
                            cur_target.append(cluster_num + j)
                            cur_value.append(confusion_matrix.iloc[i, j])
                out_dir = RESULT_DIR + '/' + dataset + '/' + tool
                if not os.path.exists(out_dir):
                    os.makedirs(out_dir)
                out_file = out_dir + '/sankey'
                plot_sankey(plot_label, 
                            cur_source, 
                            cur_target, 
                            cur_value, 
                            dataset.replace('-', ' ').title() + ' ' + tool,
                            # sns.color_palette(None, cluster_num).as_hex(), 
                            sns.color_palette("husl", 9),
                            out_file)
      

ValueError: 
    Invalid element(s) received for the 'color' property of sankey.node
        Invalid elements include: [0.9677975592919913, 0.44127456009157356, 0.5358103155058701, 0.8369430560927636, 0.5495828952802333, 0.1952683223448124, 0.6430915736746491, 0.6271955086583126, 0.19381135329796756, 0.3126890019504329]

    The 'color' property is a color and may be specified as:
      - A hex string (e.g. '#ff0000')
      - An rgb/rgba string (e.g. 'rgb(255,0,0)')
      - An hsl/hsla string (e.g. 'hsl(0,100%,50%)')
      - An hsv/hsva string (e.g. 'hsv(0,100%,100%)')
      - A named CSS color:
            aliceblue, antiquewhite, aqua, aquamarine, azure,
            beige, bisque, black, blanchedalmond, blue,
            blueviolet, brown, burlywood, cadetblue,
            chartreuse, chocolate, coral, cornflowerblue,
            cornsilk, crimson, cyan, darkblue, darkcyan,
            darkgoldenrod, darkgray, darkgrey, darkgreen,
            darkkhaki, darkmagenta, darkolivegreen, darkorange,
            darkorchid, darkred, darksalmon, darkseagreen,
            darkslateblue, darkslategray, darkslategrey,
            darkturquoise, darkviolet, deeppink, deepskyblue,
            dimgray, dimgrey, dodgerblue, firebrick,
            floralwhite, forestgreen, fuchsia, gainsboro,
            ghostwhite, gold, goldenrod, gray, grey, green,
            greenyellow, honeydew, hotpink, indianred, indigo,
            ivory, khaki, lavender, lavenderblush, lawngreen,
            lemonchiffon, lightblue, lightcoral, lightcyan,
            lightgoldenrodyellow, lightgray, lightgrey,
            lightgreen, lightpink, lightsalmon, lightseagreen,
            lightskyblue, lightslategray, lightslategrey,
            lightsteelblue, lightyellow, lime, limegreen,
            linen, magenta, maroon, mediumaquamarine,
            mediumblue, mediumorchid, mediumpurple,
            mediumseagreen, mediumslateblue, mediumspringgreen,
            mediumturquoise, mediumvioletred, midnightblue,
            mintcream, mistyrose, moccasin, navajowhite, navy,
            oldlace, olive, olivedrab, orange, orangered,
            orchid, palegoldenrod, palegreen, paleturquoise,
            palevioletred, papayawhip, peachpuff, peru, pink,
            plum, powderblue, purple, red, rosybrown,
            royalblue, rebeccapurple, saddlebrown, salmon,
            sandybrown, seagreen, seashell, sienna, silver,
            skyblue, slateblue, slategray, slategrey, snow,
            springgreen, steelblue, tan, teal, thistle, tomato,
            turquoise, violet, wheat, white, whitesmoke,
            yellow, yellowgreen
      - A list or array of any of the above