In [3]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as colors
from scipy import stats

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times New Roman"]
plt.rcParams["font.size"] = 16

In [None]:
def plot_TDI_rw(TDI, filename, variables, seq, th):
    n_nodes = TDI.shape[0]

    fig, ax = plt.subplots()

    im = ax.imshow(TDI, cmap='Blues', norm=colors.LogNorm())
    ax.set_xticks(ticks=range(n_nodes),labels=variables[seq])
    ax.set_yticks(ticks=range(n_nodes),labels=variables[seq])
    ax.set_xlabel('Cause')
    ax.set_ylabel('Effect')
    plt.colorbar(im, ax=ax, shrink=0.9)

    for i in range(TDI.shape[0]):
        for j in range(TDI.shape[0]):
            if TDI[i,j] > th:
                rect = patches.Rectangle((j-0.5, i-0.5), 1, 1, linewidth=3,
                              edgecolor='orange', facecolor='none',
                              linestyle='--')
                ax.add_patch(rect)


    pathname = f"./results/{filename}.png"
    plt.savefig(pathname, bbox_inches = 'tight')

    plt.show()

In [None]:
def plot_TDI_l(A, TDI, strength, variables, seq):
    n_nodes = A.shape[0]
    
    fig = plt.figure(constrained_layout=True, figsize=(13, 4))
    axs = fig.subplots(1, 3)
    
    plot_TDI_onax(A, TDI, axs, fig, strength, variables, seq)
    
    filename = f"TDI-N={n_nodes}"
    pathname = f"./results/{filename}.png"
    plt.savefig(pathname)
    plt.show()
    
def plot_TDI_violin(A, TDIs, ax, fig, strength, show_label = True, show_hline = True):
    # connection weightplot_TDI_rw descending order
    labels = [f'{s}' for s in strength]
    
    alldata = []
    max_wo_link = -1
    min_with_link = np.iinfo(np.int32).max
    for s in strength:
        l = np.where(A==s)
        TDI_s = np.array([np.array(TDIs[:,l[0][i], l[1][i]]) for i in range(len(l[0])) if l[0][i] != l[1][i]])
        alldata.append(TDI_s.flatten())
        if s > 0:
            min_with_link = min(min_with_link, np.min(alldata[-1]))
        else:
            max_wo_link = max(max_wo_link, np.max(alldata[-1]))
            
    print(min_with_link, max_wo_link)
    
    pos = ax.violinplot(alldata, points=100, widths=0.7, showmeans=True,
                         showextrema=True, showmedians=False, bw_method=0.5)
    ax.set_xticks([x + 1 for x in range(len(strength))], labels=labels, fontsize='x-large')
    
    if show_label:
        ax.set_xlabel('Connection Weight', fontsize='x-large')
        ax.set_ylabel('TDI', fontsize='x-large')
        
    ax.set_yscale('log')
    ax.tick_params(labelsize='x-large')
    
    if show_hline:
        ax.axhline(y=max_wo_link, color='r', linestyle='--')
        ax.axhline(y=min_with_link, color='r', linestyle='--')
    
    return pos, alldata

def plot_TDI_onax(A, TDI, axs, fig, strength, variables, seq):
    n_nodes = A.shape[0]
    
    cmap = plt.cm.Blues
    (ax1, ax2, ax3) = axs
    
    plot_TDI_violin(A, np.array([TDI]), ax1, fig, strength)
    
    pos2 = ax2.imshow(TDI, cmap=cmap, norm=colors.LogNorm())
    fig.colorbar(pos2, ax=ax2, shrink=0.9)
    ax2.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax2.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax2.set_xlabel('Cause')
    ax2.set_ylabel('Effect')
    ax2.set_title('IRC')

    pos3 = ax3.imshow(A, cmap=cmap, norm=colors.BoundaryNorm([0, 0.33, 0.67, 1], ncolors=cmap.N))
    fig.colorbar(pos3, extend='neither', shrink=0.9, ax=ax3, ticks=[0,0.5,1])
    ax3.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax3.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax3.set_xlabel('Cause')
    ax3.set_ylabel('Effect')
    ax3.set_title('Ground Truth')

def plot_multi_TDI_onax(A, TDIs, axs, fig, titles, variables, seq):
    cmap = plt.cm.Blues
    poss = []
    for i in range(len(TDIs)):
        TDI = TDIs[i]
        ax = axs[i]
        title = titles[i]
        
        pos = ax.imshow(TDI, cmap=cmap, norm=colors.LogNorm())
        fig.colorbar(pos, ax=ax, shrink=0.9)
        ax.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
        ax.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
        ax.set_xlabel('Cause')
        ax.set_ylabel('Effect')
        ax.set_title(title)
        
        poss.append(pos)
    
    ax = axs[-1]
    pos = ax.imshow(A, cmap=cmap, norm=colors.BoundaryNorm([0, 0.33, 0.67, 1], ncolors=cmap.N))
    fig.colorbar(pos, extend='neither', shrink=0.9, ax=ax, ticks=[0,0.5,1])
    ax.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax.set_xlabel('Cause')
    ax.set_ylabel('Effect')
    ax.set_title('Ground Truth')
        
    poss.append(pos)
    
    return poss

In [None]:
def plot_im_only(TDI, variables, colors = None, textsize = 'xx-large', label = None):
    fig, ax = plt.subplots()
    im = ax.imshow(TDI, cmap='Blues')
    
    ax.set_xticks(ticks=np.arange(len(variables)),labels=variables, fontsize='x-large')
    ax.set_yticks(ticks=np.arange(len(variables)),labels=variables, fontsize='x-large')
    
    if not label is None:
        ax.set_xlabel(label[0], fontsize='x-large')
        ax.set_ylabel(label[1], fontsize='x-large')
    
    if colors:
        for i in range(TDI.shape[0]):
            for j in range(TDI.shape[1]):
                if i == j:
                    continue
                text = ax.text(j, i, f'{TDI[i, j] :.2e}',
                               ha="center", va="center", color=colors[i][j], fontsize=textsize)
    
    filename = f"TDI_im_only"
    pathname = f"./results/{filename}.png"
    plt.title('$\mathbf{TDI}$', fontsize='xx-large')
    plt.savefig(pathname, bbox_inches = 'tight')
    plt.show()

In [None]:
def plot_TCL_2x5(A_x, TDI_x, strength):
    fig = plt.figure(constrained_layout=True, figsize=(20, 8))
    axs = fig.subplots(2, 5)
    
    alldata_x = None
    axi, axj = 0, 0
    for i in range(len(A_x)):
        A = A_x[i]
        TDI = TDI_x[i]
        ax = axs[axi, axj]
        pos, alldata = plot_TDI_violin(A, np.array([TDI]), ax, fig, strength, show_label = False)
        
        axj += 1
        if axj >= 5:
            axj = 0
            axi += 1
            
        if not alldata_x:
            alldata_x = alldata
        else:
            for i in range(len(alldata_x)):
                alldata_x[i] = np.append(alldata_x[i], alldata[i])

    fig.supxlabel('Connection Strength')
    fig.supylabel('$\mathbf{TDI}$')
    
    plt.suptitle(f"$N={len(A_x[0])}$")
    filename = f"rep N={len(A_x[0])}"
    pathname = f"./results/{filename}.png"
    plt.savefig(pathname, bbox_inches = 'tight')
    plt.show()
    
    return alldata_x

def plot_TDI_rep(A_all, TDI_all, strength): 
    labels = [f'{s}' for s in strength]
    alldata_all = []
    for x in range(len(A_all)):
        A_x = A_all[x]
        TDI_x = TDI_all[x]
        
        alldata_x = plot_TCL_2x5(A_x, TDI_x, strength)
        
        alldata_all.append(alldata_x)
                           
    fig = plt.figure(constrained_layout=True, figsize=(13, 4))
    axs = fig.subplots(1, 3)
    
    for i in range(len(A_all)):
        ax = axs[i]
        for s in range(len(strength)):
            print(f"strength {strength[s]} data total count: {alldata_all[i][s].shape[0]}")
        max_wo_link = np.max(alldata_all[i][-1])
        min_with_link = min([np.min(data) for data in alldata_all[i][:-1]])
        pos = ax.violinplot(alldata_all[i], points=100, widths=0.7, showmeans=True,
                             showextrema=True, showmedians=False, bw_method=0.5)
        ax.set_xticks([x + 1 for x in range(len(strength))], labels=labels)

        ax.set_yscale('log')
        ax.axhline(y=max_wo_link, color='r', linestyle='--')
        ax.axhline(y=min_with_link, color='r', linestyle='--')
        ax.set_title(f"$N={len(A_all[i][0])}$")
        
    fig.supxlabel('Connection Strength')
    fig.supylabel('TDI')
    
    filename = f"rep merged"
    pathname = f"./results/{filename}.png"
    plt.savefig(pathname, bbox_inches = 'tight')
    plt.show()

In [None]:
def plot_TDI_l_multi(A, TDI1, TDI2, title1, title2, strength, variables, seq):
    n_nodes = A.shape[0]
    
    fig = plt.figure(constrained_layout=True, figsize=(22, 5))
    (ax1, ax2, ax4, ax3) = fig.subplots(1, 4)
    
    cmap = plt.cm.Blues
    
    plot_TDI_violin(A, np.array([TDI1, TDI2]), ax1, fig, strength)
    
    pos2 = ax2.imshow(TDI1, cmap=cmap, norm=colors.LogNorm())
    fig.colorbar(pos2, ax=ax2, shrink=0.9)
    ax2.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax2.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax2.set_xlabel('Cause')
    ax2.set_ylabel('Effect')
    ax2.set_title(title1)
    
    pos4 = ax4.imshow(TDI2, cmap=cmap, norm=colors.LogNorm())
    fig.colorbar(pos4, ax=ax4, shrink=0.9)
    ax4.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax4.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax4.set_xlabel('Cause')
    ax4.set_ylabel('Effect')
    ax4.set_title(title2)

    pos3 = ax3.imshow(A, cmap=cmap, norm=colors.BoundaryNorm([0, 0.33, 0.67, 1], ncolors=cmap.N))
    fig.colorbar(pos3, extend='neither', shrink=0.9, ax=ax3, ticks=[0,0.5,1])
    ax3.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax3.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax3.set_xlabel('Cause')
    ax3.set_ylabel('Effect')
    ax3.set_title('Ground Truth')
    
    filename = f"N={n_nodes} multi compair"
    pathname = f"./results/{filename}.png"
    plt.savefig(pathname, bbox_inches = 'tight')
    plt.show()
    
def plot_multi_TDI_onax(A, TDIs, axs, fig, titles, variables, seq):
    cmap = plt.cm.Blues
    poss = []
    for i in range(len(TDIs)):
        TDI = TDIs[i]
        ax = axs[i]
        title = titles[i]
        
        pos = ax.imshow(TDI, cmap=cmap, norm=colors.LogNorm())
        fig.colorbar(pos, ax=ax, shrink=0.9)
        ax.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
        ax.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
        ax.set_xlabel('Cause')
        ax.set_ylabel('Effect')
        ax.set_title(title)
        
        poss.append(pos)
    
    ax = axs[-1]
    pos = ax.imshow(A, cmap=cmap, norm=colors.BoundaryNorm([0, 0.33, 0.67, 1], ncolors=cmap.N))
    fig.colorbar(pos, extend='neither', shrink=0.9, ax=ax, ticks=[0,0.5,1])
    ax.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax.set_xlabel('Cause')
    ax.set_ylabel('Effect')
    ax.set_title('Ground Truth')
        
    poss.append(pos)
    
    return poss

In [1]:
def plot_TDI_ROC_multi(Arrays, labels, strength, N_xs, y_lim = None):
    fig = plt.figure(constrained_layout=True, figsize=(5*len(N_xs)+1, 5))
    axs = fig.subplots(1, len(N_xs))
    for i in range(len(N_xs)):
        N_x = N_xs[i]
        variables = np.arange(1, N_x+1, 1)
        seq = np.arange(0, N_x, math.ceil(N_x / 10))
        
        plot_TDI_ROC(Arrays[0][i], [Arrays[1][i], Arrays[2][i], Arrays[3][i], Arrays[4][i]], labels, strength, variables, seq, axs[i], fig)
        if not y_lim is None:
            axs[i].set_ylim(y_lim[0], y_lim[1])
        
    fig.supxlabel('Connection Strength')
    fig.supylabel('TDI')
    
    filename = f"ROC"
    pathname = f"./results/{filename}.png"
    plt.savefig(pathname, bbox_inches = 'tight')

    plt.show()

def plot_TDI_ROC(A_gc, Arrays, labels, strength, variables, seq, ax_main, fig):    
    N = A_gc.shape[0]
    TDI = Arrays[0]

    recons = []
    truth = []
    for A in Arrays:
        recon = []
        for i in range(N):
            for j in range(N):
                if i == j:
                    continue
                recon.append(A[i,j])
        recons.append(recon)

    for i in range(N):
        for j in range(N):
            if i == j:
                continue
            truth.append(A_gc[i,j])

    recons = np.array(recons)
    truth = np.array(truth)

    pos = np.sum(truth > 0)
    neg = np.sum(truth == 0)

    recon_sorts = np.zeros((len(Arrays),recons.shape[1]))
    truth_sorts = np.zeros((len(Arrays),recons.shape[1]))
    print(f'pos:{pos}, neg:{neg}')

    for a in range(len(Arrays)):
        recon_sorts[a] = np.sort(recons[a])[::-1]  #从大到小
        index = np.argsort(recons[a])[::-1]  #从大到小

        truth_sorts[a] = truth[index]


    tprs = []
    fprs = []
    for a in range(len(Arrays)):
        tpr = []
        fpr = []
        for i in range(recon_sorts[a].shape[0]+1):
            tpr.append(np.sum(truth_sorts[a,:i] > 0) / pos)
            fpr.append(np.sum(truth_sorts[a,:i] == 0) / neg)
        tprs.append(tpr)
        fprs.append(fpr)
    
    ax_main.set_title(f'$N$ = {N}', fontsize='20')
    plot_TDI_violin(A_gc, np.array([TDI]), ax_main, fig, strength, show_label= False, show_hline = False)
        
    axroc = ax_main.inset_axes([.03, .03, .6, .5])
    axroc.get_xaxis().set_visible(False)
    axroc.get_yaxis().set_visible(False)
    
    for a in range(len(Arrays)):
        A = Arrays[a]
        fpr = fprs[a]
        tpr = tprs[a]
        
        auroc = 0
        
        for index in range(len(fpr)):
            dx = fpr[index+1] - fpr[index] if index < len(fpr) - 1 else 1-fpr[index]
            y = tpr[index]
            
            auroc += dx * y
        
        if a==0:
            axroc.plot(fpr, tpr, '--', label=f'{labels[a]} ({auroc:.2f})', linewidth=2.5)
        else:
            axroc.plot(fpr, tpr, '--', label=f'{labels[a]} ({auroc:.2f})')

    axroc.plot([(0,0),(1,1)], '--', color='grey')

    axroc.set_xlim([-0.02,1.02])
    axroc.set_ylim([-0.02,1.02])

    axroc.legend(loc='lower right', fontsize='15')
    
def plot_array_2x2(Arrays, titles, variables, seq):
    fig = plt.figure(constrained_layout=True, figsize=(10, 8))
    axs = fig.subplots(2, 2)
    axlist = [axs[0,0], axs[1,0], axs[0,1], axs[1,1]]
    for i in range(len(Arrays)):
        A = Arrays[i]
        ax = axlist[i]
        title = titles[i]
        
        pos = ax.imshow(A / np.max(A), cmap='Blues')
        fig.colorbar(pos, ax=ax, shrink=0.9)
        ax.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
        ax.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
        ax.set_xlabel('Cause')
        ax.set_ylabel('Effect')
        ax.set_title(title)
    
    filename = f"ROC-2x2-N={Arrays[0].shape[0]}"
    pathname = f"./results/{filename}.png"
    plt.savefig(pathname, bbox_inches = 'tight')

    plt.show()
    
def plot_array_1x5(Arrays, titles, variables, seq):
    cmap = plt.cm.Blues
    
    fig = plt.figure(constrained_layout=True, figsize=(24, 4))
    axs = fig.subplots(1, 5)
    A0 = Arrays[0]
    for i in range(len(Arrays)):
        A = Arrays[i]
        ax = axs[i]
        title = titles[i]
        if i == 0:
            pos = ax.imshow(A, cmap=cmap)
            fig.colorbar(pos, ax=ax, shrink=0.9, extend='neither', ticks=[0,0.5,1], norm=colors.BoundaryNorm([0, 0.33, 0.67, 1], ncolors=cmap.N))
        else:
            pos = ax.imshow(A / np.max(A), cmap=cmap)
            fig.colorbar(pos, ax=ax, shrink=0.9)
        
        ax.set_xticks(ticks=seq,labels=variables[seq])
        ax.set_yticks(ticks=seq,labels=variables[seq])
        ax.set_xlabel('Cause')
        ax.set_ylabel('Effect')
        ax.set_title(title)
    
    filename = f"ROC-1x5-N={Arrays[0].shape[0]}"
    pathname = f"./results/{filename}.png"
    plt.savefig(pathname, bbox_inches = 'tight')

    plt.show()

In [None]:
def plot_missing_info(A_gc, TDI, TDI_mi, miss_from, miss_to, variables, seq):
    cmap = plt.cm.Blues

    fig = plt.figure(constrained_layout=True, figsize=(14, 4))

    (ax1,ax2,ax3) = fig.subplots(1, 3)

    pos1 = ax1.imshow(A_gc, cmap=cmap, norm=colors.BoundaryNorm([0, 0.33, 0.67, 1], ncolors=cmap.N))
    fig.colorbar(pos1, extend='neither', shrink=0.9, ax=ax1, ticks=[0,0.5,1])
    ax1.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax1.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax1.set_xlabel('Cause')
    ax1.set_ylabel('Effect')
    ax1.set_title('Ground Truth')

    pos2 = ax2.imshow(TDI / np.max(TDI), cmap=cmap)
    fig.colorbar(pos2, ax=ax2)
    ax2.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax2.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax2.set_xlabel('Cause')
    ax2.set_ylabel('Effect')
    ax2.set_title('Reconstruction')

    cmap = plt.cm.get_cmap("Blues").copy()
    cmap.set_over('black')

    pos3 = ax3.imshow(TDI_mi / np.max(TDI_mi), cmap=cmap)
    fig.colorbar(pos3, ax=ax3)
    ax3.set_xticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax3.set_yticks(ticks=np.arange(len(seq)),labels=variables[seq])
    ax3.set_xlabel('Cause')
    ax3.set_ylabel('Effect')
    ax3.set_title('Reconstruction without Node 10')

    for ax in [ax1,ax2,ax3]:
        for i in miss_from:
            for j in miss_from:
                if i == j:
                    continue
                rect = patches.Rectangle((j-0.5, i-0.5), 1, 1, linewidth=3,
                              edgecolor='grey', facecolor='none',
                              linestyle='--')
                ax.add_patch(rect)
        for i in miss_from:
            for j in miss_to:
                if i == j:
                    continue
                rect = patches.Rectangle((j-0.5, i-0.5), 1, 1, linewidth=3,
                              edgecolor='tab:blue', facecolor='none',
                              linestyle='--')
                ax.add_patch(rect)


    filename = f"missinginfo"
    pathname = f"./results/{filename}.png"
    plt.savefig(pathname, bbox_inches = 'tight')
    plt.show()

In [1]:
def plot_various_N_evo(various_list, various_TDI, A, strength_list, variables, 
                       title = None, yscale_log = False, xscale_log = False, legend_loc = 'upper left'):
    dots_colors = ['midnightblue', 'tab:brown']
    TDI_categorized = [] #TDI_categorized[strength_index][n_evo_index] = [TDI...]
    for strength_index in range(len(strength_list)):
        list_strength_level=[]
        for various_index in range(len(various_list)):
            list_various_level = []
            i,j=np.where(A==strength_list[strength_index])
            for item in zip(i,j):
                if item[0] == item[1]:
                    continue
                nparray = np.array(various_TDI)
                list_various_level.append(nparray[various_index,item[0], item[1]])
            list_strength_level.append(list_various_level)
        TDI_categorized.append(list_strength_level)
        
    plt_means = []
    plt_stds = []
    for TDI in TDI_categorized:
        plt_means.append(np.mean(TDI, axis=1))
        plt_stds.append(np.std(TDI, axis=1))
        
    plt_means = np.array(plt_means)
    plt_stds = np.array(plt_stds)
        
    fig = plt.figure(constrained_layout=True, figsize=(18, 6))
    (axA, axplot) = fig.subplots(1, 2)
    axA.set_xticks(ticks=np.arange(len(variables)),labels=variables, fontsize='x-large')
    axA.set_yticks(ticks=np.arange(len(variables)),labels=variables, fontsize='x-large')
    
    cmap = plt.cm.Blues
    pos = axA.imshow(A, cmap=cmap, norm=colors.BoundaryNorm([0, 0.5, 1], ncolors=cmap.N))
    fig.colorbar(pos, extend='neither', shrink=0.9, ax=axA, ticks=[0,1])
    
    axA.set_xlabel('Cause')
    axA.set_ylabel('Effect')
    axA.set_title('A', fontsize='x-large')
    
    for strength_index in range(len(strength_list)):
        axplot.plot(various_list, plt_means[strength_index], label=f'{"with" if strength_list[strength_index] > 0 else "without"} causation')
        axplot.fill_between(various_list, plt_means[strength_index] - plt_stds[strength_index], plt_means[strength_index] + plt_stds[strength_index], alpha=0.2)
        for var_index in range(len(various_list)):
            var = various_list[var_index]
            axplot.plot(np.ones(len(TDI_categorized[strength_index][var_index])) * var, TDI_categorized[strength_index][var_index], 'o', color=dots_colors[strength_index])
    if yscale_log:
        axplot.set_yscale('log')
    if xscale_log:
        axplot.set_xscale('log')
    axplot.legend(loc=legend_loc, fontsize='large')
    
    plt.suptitle(title, fontsize='xx-large')
    
    if title is None:
        filename = f'various_N_evo'
    else:
        filename = title
    pathname = f"./results/{filename}.png"
    plt.savefig(pathname, bbox_inches = 'tight')
    plt.show()

In [None]:
def plot_various(N_scenarrio, various_list, various_TDI_categorized, strength_list, filename, 
                 title = None, xlabel = None, ylim = None,
                 dots_colors = ['midnightblue', 'tab:brown'], legend_loc = 'upper left',
                 xscale_log = False, yscale_log = False, 
                 x_ticks = None, zoom_inset_settings = None, inset_x_limit = None):
    
    N_strength = len(strength_list)
    plt_means = np.zeros((N_scenarrio, N_strength, len(various_list)))
    plt_stds = np.zeros((N_scenarrio, N_strength, len(various_list)))
    for strength_index in range(N_strength):
        for sc_index in range(N_scenarrio):
            plt_means[sc_index, strength_index, :] = np.mean(np.array(various_TDI_categorized[strength_index][sc_index]), axis=1)
            plt_stds [sc_index, strength_index, :] = np.std (np.array(various_TDI_categorized[strength_index][sc_index]), axis=1)
    
    fig = plt.figure(figsize=(10, 6 * N_scenarrio))
    axs = fig.subplots(N_scenarrio, 1, sharex=True)
    fig.subplots_adjust(hspace=0)
    
    if not zoom_inset_settings is None:
        axins_pos, xscale_log_inset, yscale_log_inset, inset_show_list = zoom_inset_settings
    
    for s in range(N_scenarrio):
        ax = axs[s]
        if not zoom_inset_settings is None:
            if not inset_x_limit is None:
                axins = ax.inset_axes(axins_pos, xticklabels=[], xlim=(inset_x_limit[0],inset_x_limit[1]))
            else:
                axins = ax.inset_axes(axins_pos, xticklabels=[])
            
            if xscale_log_inset:
                axins.set_xscale('log')
            if yscale_log_inset:
                axins.set_yscale('log')
                
        for strength_index in range(N_strength):
            mean = plt_means[s, strength_index, :]
            upper = plt_means[s, strength_index, :] + plt_stds[s, strength_index, :]
            lower = plt_means[s, strength_index, :] - plt_stds[s, strength_index, :]
            ax.plot(various_list, mean, label=f'{"with" if strength_list[strength_index] > 0 else "without"} causation')
 
            ax.fill_between(various_list, lower, upper, alpha=0.2)
            
            for var_index in range(len(various_list)):
                var = various_list[var_index]
                ax.plot(np.ones(len(various_TDI_categorized[strength_index][s][var_index])) * var, various_TDI_categorized[strength_index][s][var_index], 'o', color=dots_colors[strength_index])

            if not zoom_inset_settings is None:
                axins.plot(various_list[inset_show_list], mean[inset_show_list])
                axins.fill_between(various_list[inset_show_list], lower[inset_show_list], upper[inset_show_list], alpha=0.2)
                for var_index in inset_show_list:
                    var = various_list[var_index]
                    axins.plot(np.ones(len(various_TDI_categorized[strength_index][s][var_index])) * var, various_TDI_categorized[strength_index][s][var_index], 'o', color=dots_colors[strength_index])
        
        if not zoom_inset_settings is None:
            rectangle_patch, connector_lines = ax.indicate_inset_zoom(axins, edgecolor="black")
            connector_lines[0].set(visible = False)
            connector_lines[1].set(visible = True)
            connector_lines[2].set(visible = True)
            connector_lines[3].set(visible = False)
            
            
        if yscale_log:
            ax.set_yscale('log')
        if xscale_log:
            ax.set_xscale('log')
            
        if not ylim is None:
            ax.set_ylim(ylim[s][0], ylim[s][1])
            
        ax.set_xlabel(xlabel, fontsize = 'x-large')
        ax.set_ylabel('$\mathbf{TDI}$ for ' + f'scenarrio {s+1}', fontsize = 'x-large')
            
    axs[0].legend(loc=legend_loc, fontsize='large')
    axs[0].set_title(title, fontsize='xx-large')
    if not x_ticks is None:
        axs[-1].set_xticks(ticks = x_ticks)
    
    pathname = f"./results/{filename}.png"
    plt.savefig(pathname, bbox_inches = 'tight')
    plt.show()

In [None]:
def plot_dot(A_S, All_TDI_categorized, N_data, strength_list, 
             filename = None, dots_colors = ['midnightblue', 'tab:brown'], cmap = plt.cm.Blues, 
             suptitle = None):
    cmap = plt.cm.Blues
    N_scenarrio = len(A_S)
    N_strength = len(strength_list)
    fig = plt.figure(figsize=(10, 6 * N_scenarrio))
    axs = fig.add_gridspec(2,1,right=0.75).subplots()
    
    for A_index in range(len(A_S)):
        A = A_S[A_index]
        N_x = len(A)
        ax = axs[A_index]

        TDI_categorized = All_TDI_categorized[A_index]

        t_stat, p_value = stats.ttest_ind(TDI_categorized[0].flatten(), TDI_categorized[1].flatten())

        result = [np.array(TDI_categorized[0]), np.array(TDI_categorized[1])]
        correct_matrix = np.min(result[0],axis=1) > np.max(result[1],axis=1) if strength_list[0] > strength_list[1] else np.max(result[0],axis=1) < np.min(result[1],axis=1)
        correct = np.sum(np.array(correct_matrix, dtype=int))
        
        title = f'$p$-value: {p_value:.1e}, correct: {correct} / {N_data}'
        ax.set_title(title, fontsize='xx-large')

        ax_histy = ax.inset_axes([1.05, 0, 0.25, 1], sharey=ax)
        ax_histy.tick_params(axis="y", labelleft=False)            
        
        for strength_index in range(N_strength): 
            x_axis = ((np.arange(len(TDI_categorized[strength_index]))+1).reshape(len(TDI_categorized[strength_index]),1) @ np.ones((1,edge_counts[strength_index]))).flatten()
            y_axis = TDI_categorized[strength_index].flatten()
            ax.scatter(x_axis, y_axis, color=dots_colors[strength_index], label=f'{"with" if strength_list[strength_index] > 0 else "without"} causation')
            
            bins = np.logspace(np.log10(np.min(y_axis) / 1.1), np.log10(np.max(y_axis) * 1.1), 30)
            ax_histy.hist(y_axis, bins=bins, orientation='horizontal', color=dots_colors[strength_index])
            
        ax.set_yscale('log')
        ax.set_ylabel('$\mathbf{TDI}$', fontsize = 'x-large')
        
        ax_inside = ax.inset_axes([0.02, 0.08, 0.3, 0.3])
        ax_inside.imshow(A, cmap=cmap)
        ax_inside.set_xticks(ticks=np.arange(N_x),labels=np.arange(1,N_x+1,1))
        ax_inside.set_yticks(ticks=np.arange(N_x),labels=np.arange(1,N_x+1,1))
        
    axs[0].legend(loc='upper left', fontsize='large')
    
    if not suptitle is None:
        plt.suptitle(suptitle, fontsize='xx-large')
        
    if not filename is None:
        pathname = f"./results/{filename}.png"
        plt.savefig(pathname, bbox_inches = 'tight')

    plt.show()

In [None]:
def get_TDI_categorized(TDIs, A, strength_list):
    TDI_categorized = []
    for _ in strength_list:
        TDI_categorized.append([])

    for strength_index in range(len(strength_list)):
        xi,xj=np.where(A==strength_list[strength_index])
        for item in zip(xi,xj):
            if item[0] == item[1]:
                continue
            for data_index in range(len(TDIs)):
                TDI_categorized[strength_index].append(TDIs[data_index, item[0], item[1]])
    return TDI_categorized