In [1]:
import pandas as pd
import plotly.graph_objects as go
import os
import seaborn as sns

RESULT_DIR = '../results'
DATA_DIR = RESULT_DIR + '/aggregate'
DATASET_TAGS = ['tabula-muris-heart', 'tabula-muris-marrow_P7_3', 'peripheral-blood', 'zheng-4', 'zheng-8']

In [2]:

def plot_sankey(labels, source, target, value, title, cluster_colors, out_file):
    fig = go.Figure(data=[go.Sankey(
        node = dict(
        pad = 15,
        thickness = 20,
        line = dict(color = "black", width = 0.5),
        label = labels,
        color = cluster_colors + cluster_colors
        ),
        link = dict(
        source = source,
        target = target,
        value = value
        ))])

    fig.update_layout(title_text=title, font_size=10)
    fig.show()
    fig.write_image(out_file + '.pdf')
    fig.write_html(out_file + '.html')
    

In [3]:
for dataset in DATASET_TAGS:
    cur_path = DATA_DIR + '/' + dataset + '/'
    # check if labels.csv exists
    if os.path.exists(cur_path + "labels.csv"):
        labels = pd.read_csv(cur_path + "labels.csv")

        if labels.columns[-1] != 'true_labels':
            print("ERROR: true_label not found for dataset " + dataset)
        else:
            tool_tags = labels.columns[1:-1]
            cluster_num = labels['true_labels'].nunique()
            plot_label = []
            for i in range(cluster_num):
                plot_label.append(str(i))

            mapping_path = "./dataset/" + dataset + "-filtered/mapping.csv"
            mapping_df = pd.read_csv(mapping_path)
            mapping_df.sort_values(by=['id'], inplace=True)
            go_strings = mapping_df['go'].tolist()
            plot_label = plot_label + go_strings

            for tool_label in tool_tags:
                # generate confusion matrix between labels
                confusion_matrix = pd.crosstab(labels[tool_label], labels['true_labels'], colnames=['Predicted'], rownames=['True'], margins=True)
                cur_source = []
                cur_target = []
                cur_value = []
                tool = tool_label[0:-6]
                for i in range(cluster_num):
                    for j in range(cluster_num):
                        if confusion_matrix.iloc[i, j] != 0:
                            cur_source.append(i)
                            cur_target.append(cluster_num + j)
                            cur_value.append(confusion_matrix.iloc[i, j])
                out_dir = RESULT_DIR + '/' + dataset + '/' + tool
                if not os.path.exists(out_dir):
                    os.makedirs(out_dir)
                out_file = out_dir + '/sankey'
                plot_sankey(plot_label, 
                            cur_source, 
                            cur_target, 
                            cur_value, 
                            dataset.replace('-', ' ').title() + ' ' + tool,
                            sns.color_palette("husl", cluster_num).as_hex(),
                            out_file)
      