In [1]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.patches import Circle
from helpers import general_helpers as gh
from helpers import mpl_plotting_helpers as mph

database_decoding = {
    "PERT-PSP" : "PhosphoSitePlus Perturbation\nSignatures",
    "DISEASE-PSP" : "PhosphoSitePlus Disease\nSignatures",
    "KINASE-PSP" : "PhosphoSitePlus Kinase\nSignatures",
    "KINASE-iKiP" : "In Vitro Kinase\nSubstrate Signatures",
    "PATH-BI" : "BI Pathway Signatures",
    "PATH-NP" : "NetPath Pathway\nSignatures",
    "PATH-WP" : "WikiPathway Signatures",
    "PERT-P100-DIA2" : "P100-DIA2\nPertubation Signatures",
}

def find_ids(str_list, delim = "_", position = 0, obj = []):
    new_ids = {}
    for s in str_list:
        newid = s.split(delim)[position]
        if newid not in list(new_ids.keys()):
            new_ids[newid] = obj
    return new_ids

def split_by_dbtype(sea_hm_file, delim = "_", id_col = 0, position = 0):
    # read the file, bin by id_col split on delim, return dict of sublists
    parsed = find_ids(gh.transpose(*sea_hm_file)[0][1:], delim = delim, position = position)
    parsed = {key : [sea_hm_file[0]] for key, value in parsed.items()}
    for row in sea_hm_file[1:]:
        newrow = [gh.list_to_str(row[0].split(delim)[1:], delimiter = " ", newline = False)] + [float(item) for item in row[1:]]
        parsed[row[id_col].split(delim)[position]].append(newrow)
    return parsed

def find_radius(sig, sig_mapper = {1 : 0.1,
                                             0.1 : 0.2,
                             0.05 : 0.3,
                             0.005 : 0.49}):
    if float(sig) != float(sig):
        return 0
    holder = 0.1
    for thresh in list(sig_mapper.keys()):
        if sig <= thresh:
            holder = sig_mapper[thresh]
    return holder

def score_to_colour(score, low = 0, high = 0.99999999, max_score = 1, 
                    cmap = mph.trans):
    if float(score) != float(score):
        return "grey"
    mid = (low+high)/2
    moved_low = low - mid
    moved_high = high - mid
    scalar = moved_high/max_score
    color = (score*scalar) + mid
    return cmap(color)

def all_scores_to_colours(score_matrix, **stc_kwargs):
    # This should be a n (rows) x m (cols) matrix of scores (numbers)
    return [[score_to_colour(score, **stc_kwargs) for score in row] for row in score_matrix]

def find_all_radii(sig_matrix, sig_mapper = {1 : 0.05,
                                             0.1 : 0.1,
                             0.05 : 0.2,
                             0.005 : 0.35,
                             0.001 : 0.49}):
    return [[find_radius(sig, sig_mapper = sig_mapper) for sig in row] for row in sig_matrix]

def sea_arr_poss(w,h, colours = None, radii = None, labels = None):
    if colours == None:
        colours = [["white" for j in range(w)] for i in range(h)]
    if radii == None:
        radii = [[0.1 for j in range(w)] for i in range(h)]
    arr = [[Circle((i,j), radii[i][j],ec="black", lw=0.25, color=colours[i][j]) 
            for j in range(w)] for i in range(h)]
    return arr

def legend_points(axis, sig_mapper, leg_scale, 
                  fontdict = dict(fontfamily="sans-serif",
                                      font = "Arial",
                                      fontweight= "bold",
                                      fontsize = 2)
                 ):
    # Make circles that are 0.5 apart with text next to them
    circs = []
    index = 0
    for key, value in sig_mapper.items():
        circs.append(Circle((1,index), radius = value, color = "white", ec = 'black',lw=0.25 ) )
        index += 1
    keys = list(sig_mapper.keys())
    for i in range(len(circs)):
        axis.add_patch(circs[i])
        axis.text(1.5, i, f"$q < {keys[i]}$",**fontdict)
    return None

def add_legend(ax1,ax2, cmap, vmin = -10, vmax = 10, 
               leg_scale = 69,
               sig_mapper = {1 : 0.1,
                                             0.1 : 0.2,
                             0.05 : 0.3,
                             0.005 : 0.49},
               fontdict = dict(fontfamily="sans-serif",
                                      font = "Arial",
                                      fontweight= "bold",
                                      fontsize = 2),
               **stc_kwargs):
    # First, make fake circles that are white
    sig_matr = [[1,0.05,0.005,0.0005,0.00005]]
    rads= find_all_radii(sig_matr, sig_mapper)
    #points = sea_arr_poss(len(rads[0]), len(rads), labels = [[str(item) for item in row] for row in sig_matr])
    img = plt.imshow([[-10,10]], cmap = cmap, vmin = vmin, vmax = vmax)
    img.set_visible(False)
    cb = plt.colorbar(ax = ax1, fraction = 0.046, pad = 0.04)
    cb.set_ticks([item for item in cb.get_ticks() if item == int(item)])
    cb.set_ticklabels([int(item) for item in cb.get_ticks()], **fontdict)
    #cb.set_label(label, fontfamily = "sans-serif",
    #                  font = "Arial", fontweight= "bold", loc = "top")
    legend_points(ax2, sig_mapper, leg_scale, fontdict= fontdict)
    return ax1,ax2

Loading the module: helpers.general_helpers

Loading the module: helpers.mpl_plotting_helpers

Loading the module: helpers.argcheck_helpers

Loading the module: helpers.pandas_helpers

Loading the module: helpers.stats_helpers.py

numpy        1.22.4
scipy         1.8.1
pandas        1.4.2

pandas        1.4.2
numpy         1.22.4

matplotlib    3.5.2
numpy         1.22.4



In [2]:
from py_scripts import managing_ssgsea_outputs as mso

Loading the module: helpers.general_helpers

Loading the module: helpers.argcheck_helpers



In [4]:
ptm_files = gh.get_file_list("./figs/ptmsea_output",
                             "output_combined_heatmap.txt")


outdirs = [gh.list_to_str(f.split("/")[:-1],
                          delimiter = "/",
                          newline = False) for f in ptm_files]

def enrich_bubbleplot(enriched_dict,
                      savefile, # Should equal len(keys), be a directory to put files
                      filetype = "pdf",
                      group_heads = [" stuff n thangs " for _ in range(20)],
                      bubblenum = 20, 
                      colourmap = mph.trans,
                      max_score = 10,
                      fontdict = dict(fontfamily="sans-serif",
                                      font = "Arial",
                                      fontweight= "bold",
                                      fontsize = 6),
                      sig_mapper = {1 : 0.05,
                                    0.05 : 0.1,
                                    0.005 : 0.2,
                                    0.0005 : 0.3,
                                    0.0001 : 0.45},
                      vmin = -10, vmax = 10,
                      ):
    # Make an enrichment bubbleplot for every grouping, containing only
    # bubblenum groups.
    # Any filtering should be done preemptively.
    cmap = cm.get_cmap(colourmap)
    rgba = cmap(0.4999999995)
    # Loop over the keys and values of the dictionary
    saver = 0
    for key, value in enriched_dict.items():
        # The 0th row is headers, so ignore them.We will use group_heads for this
        # Grab the q-values and the enrichment scores. the number of
        # q-cols and nes-cols = len(group_heads)
        # Also, the 0th column is the labels
        qs = [row[1:len(group_heads)+1] for row in value[1:]]
        #qs = gh.transpose(*qs)
        qs = [qs[bubblenum*i:bubblenum*(i+1)] for i in range(len(qs)//bubblenum + 1)]
        #qs = [gh.transpose(*c) for c in qs]
        nes = [row[len(group_heads)+1:] for row in value[1:]]
        #nes = gh.transpose(*nes)
        nes = [nes[bubblenum*i:bubblenum*(i+1)] for i in range(len(nes)//bubblenum + 1)]
        #nes = [gh.transpose(*c) for c in nes]
        labels = [row[0] for row in value[1:]]
        labels = [labels[bubblenum*i:bubblenum*(i+1)] for i in range(len(labels)//bubblenum + 1)]
        # The sea_arr_poss fails if there are no entries
        points = []
        try:
            for i in range(len(qs)):
                points.append(sea_arr_poss(len(qs[i]),    # Rows
                                  len(qs[i][0]), # Cols
                                  colours = all_scores_to_colours(gh.transpose(*nes[i]),
                                                                  max_score=max_score,
                                                                  cmap=cmap),
                                  radii= find_all_radii(gh.transpose(*qs[i]),
                                                                     sig_mapper = sig_mapper)) )
        except:
            print(f"Skipping category {database_decoding[key]}: No groups were enriched.\n")
        else:
            index = 0
            # Begin plotting each cluster
            for cluster in points:
                fig, ax = plt.subplots(1,2, figsize = (6,6))
                for row in cluster:
                    for p in row:
                        ax[0].add_artist(p)
                # Set some of the parameters
                ax[0].set_xticks(list(range(len(group_heads))))
                ax[0].set_xticklabels(group_heads, rotation = 90, **fontdict)
                ax[0].set_yticks(list(range(len(labels[index]))))
                ax[0].set_yticklabels(labels[index], **fontdict)
                ax[0].set_xlim(-1, bubblenum+1)
                ax[0].set_ylim(-1, bubblenum+1)
                ax[1].set_xlim(-1, bubblenum+1)
                ax[1].set_ylim(-1, bubblenum+1)
                ax[0].set_aspect("equal")
                ax[1].set_aspect("equal")
                ax[1].axis('off')
                ax[0].set_title(database_decoding[key],**fontdict)
                add_legend(ax[0], ax[1], cmap, leg_scale = bubblenum,
                          vmin = vmin, vmax = vmax,
                           sig_mapper = sig_mapper,
                           fontdict = fontdict)
                ax[0].spines[:].set_visible(False)
                plt.tight_layout()
                plt.savefig(f"{savefile}/{key}_{index}.{filetype}")
                plt.close()
                index +=1
        saver += 1
    return None



In [5]:
outdirs

['/mnt/c/Users/redas/Desktop/jupyter_directory/arts_lab/phosphatase_ko/figs/ptmsea_output/ko_wt_dmso',
 '/mnt/c/Users/redas/Desktop/jupyter_directory/arts_lab/phosphatase_ko/figs/ptmsea_output/ko_wt_u0126',
 '/mnt/c/Users/redas/Desktop/jupyter_directory/arts_lab/phosphatase_ko/figs/ptmsea_output/pm_u0126',
 '/mnt/c/Users/redas/Desktop/jupyter_directory/arts_lab/phosphatase_ko/figs/ptmsea_output/timecourse_wo_u0126',
 '/mnt/c/Users/redas/Desktop/jupyter_directory/arts_lab/phosphatase_ko/figs/ptmsea_output/timecourse_w_u0126']

In [6]:
def enrich_bubbleplot_list(ptmsea_outfiles, 
                           savefiles,
                           sig_exception = ["filenames"],
                           significance = 1, # 0 < significance < 1
                           filetype = "pdf",
                           group_heads = [["stuff n thangs" for _ in range(20)]
                                          for i in range(20)],
                           bubblenum = 15, 
                           colourmap = mph.trans,
                           max_score = 10,
                           fontdict = dict(fontfamily="sans-serif",
                                      font = "Arial",
                                      fontweight= "bold",
                                      fontsize = 6),
                           vmin = -10, vmax = 10,
                           sig_mapper = {1 : 0.1,
                                    0.05 : 0.2,
                                    0.005 : 0.3,
                                    0.0005 : 0.4,
                                    0.0001 : 0.49},):
    # Wraps bubbleplot to do a list of them and send them where they need to go
    for i in range(len(ptmsea_outfiles)):
        print(ptmsea_outfiles[i])
        file = gh.read_file(ptmsea_outfiles[i])
        file = split_by_dbtype(file)
        if ptmsea_outfiles[i] in sig_exception:
            print("exception")
            file = {key : [value[0]] + [row for row in value[1:] if any([item < 1 for item in row[1:len(group_heads)+1]])] 
               for key, value in file.items()}
        else:
            file = {key : [value[0]] + [row for row in value[1:] if any([item < significance for item in row[1:len(group_heads)+1]])] 
               for key, value in file.items()}
        # Remove any empties
        file = {key : value for key, value in file.items() if len(value) >1}
        enrich_bubbleplot(file, savefiles[i],
                         filetype = filetype,
                         group_heads = group_heads[i],
                         bubblenum = bubblenum,
                         colourmap = colourmap,
                          max_score = max_score,
                          fontdict = fontdict,
                          vmin = vmin, vmax = vmax,
                          sig_mapper = sig_mapper)
        #break
    return None

In [7]:
g_heads = [["J.TCPTP- vs Jurkat 0 min", "J.PTPN22- vs Jurkat 0 min", "J.SHP1- vs Jurkat 0 min",
            "J.TCPTP- vs Jurkat 5 min", "J.PTPN22- vs Jurkat 5 min", "J.SHP1- vs Jurkat 5 min" ],
           ["J.TCPTP- vs Jurkat 0 min", "J.PTPN22- vs Jurkat 0 min", "J.SHP1- vs Jurkat 0 min",
            "J.TCPTP- vs Jurkat 5 min", "J.PTPN22- vs Jurkat 5 min", "J.SHP1- vs Jurkat 5 min" ],
           ["Jurkat 0 min", "J.TCPTP- 0 min", "J.PTPN22- 0 min", "J.SHP1- 0 min",
            "Jurkat 5 min", "J.TCPTP- 5 min", "J.PTPN22- 5 min", "J.SHP1- 5 min" ],
           ["Jurkat", "J.TCPTP-", "J.PTPN22-", "J.SHP1-"],
           ["Jurkat", "J.TCPTP-", "J.PTPN22-", "J.SHP1-"]]

enrich_bubbleplot_list(ptm_files, outdirs,
                       sig_exception = ["ween"],
                       significance = 0.05,
                       group_heads = g_heads,
                       fontdict = dict(fontfamily="sans-serif",
                                      font = "Arial",
                                      fontweight= "bold",
                                      fontsize = 3),
                       bubblenum = 45,
                       max_score = 6,
                       vmin = -6, vmax = 6)

/mnt/c/Users/redas/Desktop/jupyter_directory/arts_lab/phosphatase_ko/figs/ptmsea_output/ko_wt_dmso/output_combined_heatmap.txt
/mnt/c/Users/redas/Desktop/jupyter_directory/arts_lab/phosphatase_ko/figs/ptmsea_output/ko_wt_u0126/output_combined_heatmap.txt
/mnt/c/Users/redas/Desktop/jupyter_directory/arts_lab/phosphatase_ko/figs/ptmsea_output/pm_u0126/output_combined_heatmap.txt
/mnt/c/Users/redas/Desktop/jupyter_directory/arts_lab/phosphatase_ko/figs/ptmsea_output/timecourse_wo_u0126/output_combined_heatmap.txt
Skipping category In Vitro Kinase
Substrate Signatures: No groups were enriched.

/mnt/c/Users/redas/Desktop/jupyter_directory/arts_lab/phosphatase_ko/figs/ptmsea_output/timecourse_w_u0126/output_combined_heatmap.txt


In [None]:
for f in ptm_files:
    print(f)
