In [1]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.patches import Circle
from helpers import general_helpers as gh
from helpers import mpl_plotting_helpers as mph

database_decoding = {
    "TARGET_GENES" : "Transcription Factor Targets",
    "PID_" : "Pathway Interaction Database",
    "KEGG_MEDICUS_ENV_FACTOR_" : "KEGG Environmental Factor Signatures",
    "KEGG_MEDICUS_PATHOGEN_" : "KEGG Pathogen Signatures",
    "KEGG_MEDICUS_REFERENCE_" : "KEGG Reference Signatures",
    "KEGG_MEDICUS_VARIANT_" : "KEGG Gene Variant Signatures",
    "KEGG_" : "Kyoto Encyclopedia of Genes and Genomes,\nLegacy Signatures",
    "BIOCARTA_" : "BIOCARTA Pathway Signatures",
    "REACTOME_" : "Reactome Pathway Signatures",
    "WP_" : "WikiPath Pathway Signatures",
    "HALLMARK_" : "GSEA Hallmark Pathway Signatures"
}

def safe_float(a_number):
    try:
        return float(a_number)
    except:
        if a_number == "NA":
            return float("nan")
    else:
        return a_number

def str_replace_value(string, delim = "_", new_delim = " "):
    newstr = ""
    for char in string:
        if char != delim:
            newstr += char
        else:
            newstr += new_delim
    return newstr

def bin_by_db_id(a_gsea_row, decoder = {},
                 id_col = 0, delim = "_"):
    keys = [key for key, value in decoder.items() if key in a_gsea_row[id_col]]
    # This will happen for KEGG/KEGG_MEDICUS
    if len(keys) > 1:
        keys = sorted(keys, key= lambda x: -len(x))
    elif len(keys) == 1:
        pass
    else:
        return None, None
    # Then, we need to remove the substring with the decoder
    # key and replace "_" with " "
    new_idcol = a_gsea_row[0].split(keys[0])
    new_idcol = gh.list_to_str(new_idcol, delimiter = " ", newline = False)
    new_idcol = str_replace_value(new_idcol, delim = "_", new_delim = " ")
    new_gsea_row = []
    for i in range(len(a_gsea_row)):
        if i == id_col:
            new_gsea_row.append(new_idcol)
        else:
            new_gsea_row.append(safe_float(a_gsea_row[i]))
    #if a_gsea_row[0] == "BIOCARTA_TCR_PATHWAY":
    #    print(new_gsea_row)
    return keys[0], new_gsea_row

def split_by_db(sea_file, 
                decoder = database_decoding, # Set to the global var
                delim = "", id_col = 0):
    # Make a blank dict with heads to hold the parsed values
    parsed = {key : [sea_file[0]] for key, value in decoder.items()}
    # Loop over the file
    for row in sea_file[1:]:
        # Parse the row
        cur_key, cur_row = bin_by_db_id(row, decoder = decoder,
                                       id_col = id_col, delim = delim)
        if cur_key != None:
            parsed[cur_key].append(cur_row)
    return parsed

def find_radius(sig, sig_mapper = {1 : 0.1,
                                   0.1 : 0.2,
                                   0.05 : 0.3,
                                   0.005 : 0.49}):
    if float(sig) != float(sig):
        return 0
    holder = 0.1
    for thresh in list(sig_mapper.keys()):
        if sig <= thresh:
            holder = sig_mapper[thresh]
    return holder

def score_to_colour(score, low = 0, high = 0.99999999, max_score = 1, 
                    cmap = mph.trans):
    if float(score) != float(score):
        return "grey"
    mid = (low+high)/2
    moved_low = low - mid
    moved_high = high - mid
    scalar = moved_high/max_score
    color = (score*scalar) + mid
    return cmap(color)

def all_scores_to_colours(score_matrix, **stc_kwargs):
    # This should be a n (rows) x m (cols) matrix of scores (numbers)
    return [[score_to_colour(score, **stc_kwargs) for score in row] for row in score_matrix]

def find_all_radii(sig_matrix, sig_mapper = {1 : 0.05,
                                             0.1 : 0.1,
                             0.05 : 0.2,
                             0.005 : 0.35,
                             0.001 : 0.49}):
    return [[find_radius(sig, sig_mapper = sig_mapper) for sig in row] for row in sig_matrix]

def sea_arr_poss(w,h, colours = None, radii = None, labels = None):
    if colours == None:
        colours = [["white" for j in range(w)] for i in range(h)]
    if radii == None:
        radii = [[0.1 for j in range(w)] for i in range(h)]
    arr = [[Circle((i,j), radii[i][j],ec="black", lw=0.25, color=colours[i][j]) 
            for j in range(w)] for i in range(h)]
    return arr

def legend_points(axis, sig_mapper, leg_scale, 
                  fontdict = dict(fontfamily="sans-serif",
                                      font = "Arial",
                                      fontweight= "bold",
                                      fontsize = 2)
                 ):
    # Make circles that are 0.5 apart with text next to them
    circs = []
    index = 0
    for key, value in sig_mapper.items():
        circs.append(Circle((1,index), radius = value, color = "white", ec = 'black',lw=0.25 ) )
        index += 1
    keys = list(sig_mapper.keys())
    for i in range(len(circs)):
        axis.add_patch(circs[i])
        axis.text(1.5, i, f"$q < {keys[i]}$",**fontdict)
    return None

def add_legend(ax1,ax2, cmap, vmin = -10, vmax = 10, 
               leg_scale = 69,
               sig_mapper = {1 : 0.1,
                                             0.1 : 0.2,
                             0.05 : 0.3,
                             0.005 : 0.49},
               fontdict = dict(fontfamily="sans-serif",
                                      font = "Arial",
                                      fontweight= "bold",
                                      fontsize = 2),
               **stc_kwargs):
    # First, make fake circles that are white
    sig_matr = [[1,0.05,0.005,0.0005,0.00005]]
    rads= find_all_radii(sig_matr, sig_mapper)
    #points = sea_arr_poss(len(rads[0]), len(rads), labels = [[str(item) for item in row] for row in sig_matr])
    img = plt.imshow([[-10,10]], cmap = cmap, vmin = vmin, vmax = vmax)
    img.set_visible(False)
    cb = plt.colorbar(ax = ax1, fraction = 0.046, pad = 0.04)
    cb.set_ticks([item for item in cb.get_ticks() if item == int(item)])
    cb.set_ticklabels([int(item) for item in cb.get_ticks()], **fontdict)
    #cb.set_label(label, fontfamily = "sans-serif",
    #                  font = "Arial", fontweight= "bold", loc = "top")
    legend_points(ax2, sig_mapper, leg_scale, fontdict= fontdict)
    return ax1,ax2

Loading the module: helpers.general_helpers

Loading the module: helpers.mpl_plotting_helpers

Loading the module: helpers.argcheck_helpers

Loading the module: helpers.pandas_helpers

Loading the module: helpers.stats_helpers.py

numpy        1.22.4
scipy         1.8.1
pandas        1.4.2

pandas        1.4.2
numpy         1.22.4

matplotlib    3.5.2
numpy         1.22.4



In [21]:
gsea_file = "./figs/gary/mike/output_combined_heatmap.txt"

outdirs = "./figs/gary/mike/"

def enrich_bubbleplot(enriched_dict,
                      savefile, # Should equal len(keys), be a directory to put files
                      filetype = "pdf",
                      group_heads = ["Large dong" for _ in range(20)],
                      bubblenum = 20, 
                      colourmap = mph.trans,
                      max_score = 10,
                      fontdict = dict(fontfamily="sans-serif",
                                      font = "Arial",
                                      fontweight= "bold",
                                      fontsize = 6),
                      sig_mapper = {1 : 0.05,
                                    0.05 : 0.1,
                                    0.005 : 0.2,
                                    0.0005 : 0.3,
                                    0.0001 : 0.45},
                      vmin = -10, vmax = 10,
                      ):
    # Make an enrichment bubbleplot for every grouping, containing only
    # bubblenum groups.
    # Any filtering should be done preemptively.
    cmap = cm.get_cmap(colourmap)
    rgba = cmap(0.4999999995)
    # Loop over the keys and values of the dictionary
    saver = 0
    for key, value in enriched_dict.items():
        # The 0th row is headers, so ignore them.We will use group_heads for this
        # Grab the q-values and the enrichment scores. the number of
        # q-cols and nes-cols = len(group_heads)
        # Also, the 0th column is the labels
        qs = [row[1:len(group_heads)+1] for row in value[1:]]
        #qs = gh.transpose(*qs)
        qs = [qs[bubblenum*i:bubblenum*(i+1)] for i in range(len(qs)//bubblenum + 1)]
        #qs = [gh.transpose(*c) for c in qs]
        nes = [row[len(group_heads)+1:] for row in value[1:]]
        #nes = gh.transpose(*nes)
        nes = [nes[bubblenum*i:bubblenum*(i+1)] for i in range(len(nes)//bubblenum + 1)]
        #nes = [gh.transpose(*c) for c in nes]
        labels = [row[0] for row in value[1:]]
        labels = [labels[bubblenum*i:bubblenum*(i+1)] for i in range(len(labels)//bubblenum + 1)]
        # The sea_arr_poss fails if there are no entries
        points = []
        try:
            for i in range(len(qs)):
                points.append(sea_arr_poss(len(qs[i]),    # Rows
                                  len(qs[i][0]), # Cols
                                  colours = all_scores_to_colours(gh.transpose(*nes[i]),
                                                                  max_score=max_score,
                                                                  cmap=cmap),
                                  radii= find_all_radii(gh.transpose(*qs[i]),
                                                                     sig_mapper = sig_mapper)) )
        except:
            print(f"Skipping category {database_decoding[key]}: No groups were enriched.\n")
        else:
            index = 0
            # Begin plotting each cluster
            for cluster in points:
                fig, ax = plt.subplots(1,2, figsize = (6,6))
                for row in cluster:
                    for p in row:
                        ax[0].add_artist(p)
                # Set some of the parameters
                ax[0].set_xticks(list(range(len(group_heads))))
                ax[0].set_xticklabels(group_heads, rotation = 90, **fontdict)
                ax[0].set_yticks(list(range(len(labels[index]))))
                ax[0].set_yticklabels(labels[index], **fontdict)
                ax[0].set_xlim(-1, bubblenum+1)
                ax[0].set_ylim(-1, bubblenum+1)
                ax[1].set_xlim(-1, bubblenum+1)
                ax[1].set_ylim(-1, bubblenum+1)
                ax[0].set_aspect("equal")
                ax[1].set_aspect("equal")
                ax[1].axis('off')
                ax[0].set_title(database_decoding[key],**fontdict)
                add_legend(ax[0], ax[1], cmap, leg_scale = bubblenum,
                          vmin = vmin, vmax = vmax,
                           sig_mapper = sig_mapper,
                           fontdict = fontdict)
                ax[0].spines[:].set_visible(False)
                plt.tight_layout()
                plt.savefig(f"{savefile}/{key}_{index}.{filetype}")
                plt.close()
                index +=1
        saver += 1
    return None



In [22]:
def enrich_bubbleplot_list(ptmsea_outfiles, 
                           savefiles,
                           sig_exception = ["filenames"],
                           significance = 1, # 0 < significance < 1
                           filetype = "pdf",
                           group_heads = [["Large dong" for _ in range(20)]
                                          for i in range(20)],
                           bubblenum = 15, 
                           colourmap = mph.trans,
                           max_score = 10,
                           fontdict = dict(fontfamily="sans-serif",
                                      font = "Arial",
                                      fontweight= "bold",
                                      fontsize = 6),
                           vmin = -10, vmax = 10,
                           sig_mapper = {1 : 0.1,
                                    0.05 : 0.2,
                                    0.005 : 0.3,
                                    0.0005 : 0.4,
                                    0.0001 : 0.49},
                          decoder = database_decoding):
    # Wraps bubbleplot to do a list of them and send them where they need to go
    for i in range(len(ptmsea_outfiles)):
        print(ptmsea_outfiles[i], len(group_heads))
        file = gh.read_file(ptmsea_outfiles[i])
        file = split_by_db(file, decoder = decoder)
        if ptmsea_outfiles[i] in sig_exception:
            print("exception")
            file = {key : [value[0]] + [row for row in value[1:] if any([item < 1 for item in row[1:len(group_heads[i])+1]])] 
               for key, value in file.items()}
        else:
            file = {key : [value[0]] + [row for row in value[1:] if any([item < significance for item in row[1:len(group_heads[i])+1]])] 
               for key, value in file.items()}
        # Remove any empties
        file = {key : value for key, value in file.items() if len(value) >1}
        #for row in file["BIOCARTA_"]:
        #    if row[0] == '  TCR PATHWAY':
        #        print(row)
        #    else:
        #        print("Not found")
        enrich_bubbleplot(file, savefiles[i],
                         filetype = filetype,
                         group_heads = group_heads[i],
                         bubblenum = bubblenum,
                         colourmap = colourmap,
                          max_score = max_score,
                          fontdict = fontdict,
                          vmin = vmin, vmax = vmax,
                          sig_mapper = sig_mapper)
        #break
    return None

In [23]:
gsea_file = "./figs/gary/mike/output_combined_heatmap.txt"
outdirs = "./figs/gary/mike/"
bubble_heads = ["J.TCPTP- (T2 1C10)", "J.PTPN22- (P1 1C1)", "J.SHP1- (S1 2C4)"]

enrich_bubbleplot_list([gsea_file], [outdirs],
                       sig_exception = ["ween"],
                       significance = 0.05,
                       group_heads = [bubble_heads],
                       fontdict = dict(fontfamily="sans-serif",
                                      font = "Arial",
                                      fontweight= "bold",
                                      fontsize = 2),
                       bubblenum = 45,
                       max_score = 10,
                       vmin = -6, vmax = 6, 
                       sig_mapper = {1 : 0.1,
                                    0.05 : 0.2,
                                    0.005 : 0.3,
                                    0.0005 : 0.4,
                                    0.0001 : 0.49})

./figs/gary/mike/output_combined_heatmap.txt 1
[[[0.843140214279774, 0.0467764031980683, 0.656927134744756], [0.84286976351609, 0.0313540526282949, 0.662320408345299], [0.395720011325218, 0.0463219026548673, 0.673067625043914], [0.20323863271981, 0.00564030612244898, 0.26851801829081], [0.404982285539569, 0.00564030612244898, 0.668348706126076], [0.689192103900046, 0.0358771705877517, 0.221065669801866], [0.941118518608714, 0.0394860914662895, 0.899692897813839], [0.910246406412068, 0.043261455921155, 0.390520801815431], [0.960619574149416, 0.00875011476857905, 0.853074149059446], [0.35937477028027, 0.0498117696357726, 0.110025207969577], [0.456919944798109, 0.00564030612244898, 0.887684880659297], [0.0157627278145488, 0.556913825699184, 0.0217532467532468], [0.823545840217959, 0.0612206149355042, 0.00875011476857905], [0.352259926003069, 0.00564030612244898, 0.799872524639413], [0.00564030612244898, 0.00564030612244898, 0.992935913666913], [0.811438490441535, 0.00564030612244898, 0.96