In [1]:
################################################################################
#
# Imporatbles, just the basics :P

    # My module, cuz I'm cool
from helpers import mpl_plotting_helpers as mph
from helpers import stats_helpers as sh
from helpers import general_helpers as gh
from helpers import western_helpers as wh
from helpers.mph_modules.dotplots import get_data_info, add_errorbar

    # Standard packages
import matplotlib.pyplot as plt
import matplotlib.font_manager as mpl_fm
from math import floor, ceil, log2
import pandas as pd
import glob

#
#
################################################################################

Loading the module: helpers.mpl_plotting_helpers

Loading the module: helpers.general_helpers

Loading the module: helpers.argcheck_helpers

Loading the module: helpers.pandas_helpers

Loading the module: helpers.stats_helpers.py

numpy        1.22.4
scipy         1.8.1
pandas        1.4.2

pandas        1.4.2
numpy         1.22.4

matplotlib    3.5.2
numpy         1.22.4



In [2]:
###################################################################
#
# Making line plots and doing the statistics
#  #  #  # Add to mpl_plotting_helpers at some point

def _logical_ignore_comps(labelled_line_groups,
                          group_strs,
                          xgroup_strs):
    """
    Only want to compare along a line group (e.g. timecourse) or
    down an x-column (e.g. JE6 DMSO 0m vs JE6 U0126 0m), but not
    all the random other comparisons because statistically they're
    kind of useless
    
    So this function will find all of the pairs that are useless
    """
    groups_unpacked = []
    for group in labelled_line_groups:
        groups_unpacked += group
    # This will hold the ignored pairs
    ignore_me_senpai = []
    # First, get all pairs
    paired = gh.make_pairs(groups_unpacked,
                           dupes = False,
                           reverse = False)
    # Then iterate over and check the labels
    for p in paired:
        gs_check = 0
        xs_check = 0
        # Check all the group strings
        for gs in group_strs:
            if gs_check == 1:
                pass
            elif gs in p[0][0] and gs in p[1][0]:
                gs_check = 1
        # Check all the xgroup strings
        for xs in xgroup_strs:
            if xs_check == 1:
                pass
            elif xs in p[0][0] and xs in p[1][0]:
                xs_check = 1
        # If there isn't a match, in either, ignore
        if gs_check == 0 and xs_check == 0:
            ignore_me_senpai.append(p)
    # Return the ignored pairs at the end
    return ignore_me_senpai

def perform_line_statistics(labelled_line_groups,
                            ignore_comps,
                            comp_type,
                            statsfile):
    """
    labelled_line_groups -> data with labels
                            list of lists of [label, [d1,d2,...,dn]]
    ignore_comps -> list of pairs ("group 1", "group 2") to not be
                    compared
    comp_type -> statistics to use, currently only
                 ["HolmSidak", "TukeyHSD"] are supported
                 (both do an ANOVA first by default)
    statsfile -> a string to the output path and filename
                 for the statistics file output
    #####
    Returns None, just dumps the statsfile
    """
    assert comp_type in ["HolmSidak", "TukeyHSD"], f"Invalid comparison type: {comp_type}"
    groups_unpacked = []
    for group in labelled_line_groups:
        groups_unpacked += group
    if comp_type == "HolmSidak":
        comparison = sh.HolmSidak(*groups_unpacked,
                                  labels = True,
                                  override = True,
                                  alpha = 0.05,
                                  no_comp = ignore_comps)
    elif comp_type == "TukeyHSD":
        comparison = sh.TukeyHSD(*groups_unpacked,
                                  labels = True,
                                  override = True,
                                  alpha = 0.05,
                                  no_comp = ignore_comps)
    comparison.write_output(filename = statsfile,
                            file_type = "csv")
    return None

def find_centres(plotting_info):
    """
    plotting_info -> output from get_data_info, a list of
                     data info and the raw data
                     
    goal: grab the centres for xticks
    """
    centres = []
    for group in plotting_info:
        if len(centres) <= len(group[0]["centers"]):
            centres = group[0]["centers"]
    return centres

def line_plot(labelled_line_groups,
              show_points = False,
              show_legend = False,
              colours = ["grey" for _ in range(20)],
              group_labs = [f"Thing {i}" for i in range(20)],
              markers = ["s" for _ in range(20)],
              linestyles = ["solid" for _ in range(20)],
              xlabels = [f"Time {i}" for i in range(20)],
              ylabel = ["Fold change"],
              ylims = None,
              ignore_comps = [],
              statsfile = None,
              comp_type = "HolmSidak",
              figfile = None):
    """
    labelled_line_groups -> list of lists, where each sublist contains labelled groups
    """
    # First, get some basic plotting information
    plotting_info = [get_data_info(line) for line in labelled_line_groups]
    # Then manage the statistics
    if statsfile != None:
        perform_line_statistics(labelled_line_groups, 
                                ignore_comps, 
                                comp_type, 
                                statsfile)
    # Begin plotting c::
    if ylims == None:
        ylims = floor(min([item for item in gh.unpack_list(labelled_line_groups) if type(item) in [int, float]])), ceil(max([item for item in gh.unpack_list(labelled_line_groups) if type(item) in [int, float]]))
    # 
    fig, ax = plt.subplots(figsize = (6,6))
    # 
    for i in range(len(labelled_line_groups)):
        #
        ax.plot(plotting_info[i][0]["centers"],
                plotting_info[i][0]["means"],
                color = colours[i],
                label = group_labs[i],
                linestyle = linestyles[i])
        #
        for j in range(len(labelled_line_groups[i])):
            add_errorbar(ax, 
                         plotting_info[i][0]["centers"][j],
                         plotting_info[i][0]["means"][j],
                         plotting_info[i][0]["sems"][j],
                         color = colours[i])
            if show_points:
            #
                ax.scatter(plotting_info[i][0]["xs"][j],
                           plotting_info[i][1][j][1],
                           color = colours[i],
                           edgecolor = "black", alpha = 0.3,
                           marker = markers[i],
                           s = 10)
            else:
            #
                ax.scatter(plotting_info[i][0]["centers"],
                           plotting_info[i][0]["means"],
                           color = colours[i],
                           edgecolor = "black", alpha = 0.3,
                           marker = markers[i],
                           s = 30)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    xticks = find_centres(plotting_info)
    ax.set_xticks(xticks)
    ax.set_xticklabels(xlabels[:len(xticks)],
                       fontfamily = "sans-serif", 
                       font = "Arial", 
                       fontweight = "bold", 
                       fontsize = 12,
                       rotation = 45,
                       ha = "center")
    ax.set_ylim(*ylims)
    mph.update_ticks(ax, which = "y")
    ax.set_ylabel(ylabel, fontfamily = "sans-serif",
                  font = "Arial", fontweight = "bold",
                  fontsize = 14)
    if show_legend:
        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5),
                  prop = mpl_fm.FontProperties(family = "sans-serif",
                                               weight = "bold"))
    if figfile == None:
        plt.show()
    else:
        plt.savefig(figfile)
    plt.close()
    return None

def replace_neg(a_list, value = float("nan")):
    """
    replace any value <0 with 0
    """
    newlist = []
    for item in a_list:
        try:
            truth = item < 0
        except:
            newlist.append(item)
        else:
            if truth:
                newlist.append(value)
            else:
                newlist.append(item)
    return newlist

def safe_log2(number):
    try:
        log2(number)
    except:
        return float("nan")
    else:
        return log2(number)
    
#
#
###################################################################

In [9]:
###################################################################
#
# Variable definitions

# File architechture change because we can't nromalise multiple
# blot files at the same time (Ab switches :/)
#data_path = "./excel_sheets/20241128_" # and then the rest of the file name

colour_dict = {"JE6" : "grey",
               "J.TCPTP-" : "deeppink",
               "J.PTPN22-" : "deepskyblue",
               "J.SHP1-" : "blueviolet"}

linestyle_dict = {"0.125% DMSO" : "solid",
                  r"20 $\mu$M" : "dashdot"}

markers = ["o", "s", "D","^", "o", "s", "D","^"]

experiment = [700.0, '685Ex-720Em']
control = [800.0, '785Ex-820Em']

    # Titration variables

#titr_files = [f"{data_path}titration_4g10.xlsx",
#              f"{data_path}titration_erk.xlsx"]

titr_groups = ["JE6", "J.PTPN22-", "J.SHP1-", "J.TCPTP-",]

titr_xgroups = ["0.125% DMSO 0m", "0.125% DMSO 5m",
                r"0.2$\mu$M 5m", r"2 $\,mu$M 5m",
                r"20 $\mu$M 5m", r"50 $\mu$M 5m" ]

    # Time course variables

#time_files = [f"{data_path}4g10.xlsx",
#              f"{data_path}erk.xlsx",
#              f"{data_path}plcg1.xlsx",
#              f"{data_path}lck.xlsx"]

time_groups = ["JE6 0.05% DMSO", 
               "J.PTPN22- 0.05% DMSO", 
               "J.SHP1- 0.05% DMSO",
               "J.TCPTP- 0.05% DMSO",
               r"JE6 20 $\mu$M", 
               r"J.PTPN22- 20 $\mu$M", 
               r"J.SHP1- 20 $\mu$M",
               r"J.TCPTP- 20 $\mu$M",]

time_xgroups = [" 0m", "2m", "5m", "10m"]

#
#
###################################################################

In [4]:
###################################################################
#
#  Want to normalise two files individually, then merge with 
#  correct labels

    # Assume we have files, and write from there
    
def find_correct_signal(a_file_df,
                        signal_values,
                        gs_kwargs):
    """
    Goal: Try to grab the signal for all the given
          signal values, then return the ones that worked
          returns what values were found, if none then []
    """
    found = []
    for val in signal_values:
        found = wh.get_signal(a_file_df,
                              val,
                              **gs_kwargs)
        if len(found) > 0:
            return found
    return found
    

def get_all_signals(file_dfs,
                    expr_signals = [700, "685Ex-720Em"],
                    load_signals = [800, "785Ex-820Em"],
                    gs_kwargs = dict(signal_column = "Signal",
                                 channel_column = "Channel"),
                    df_meta = ["Group", "Condition", "Time"]):
    """
    Goal: Go through each file, grab the experimental/load
          data, plus any specified metadata
          return 3 lists: experimental signal, load_signal, 
                          metadata per row (only experimental)
    """
    expr_out = []
    load_out = []
    metadata = []
    for df in file_dfs:
        # Temp holders for this file
        expr_hold = find_correct_signal(df, expr_signals, gs_kwargs)
        load_hold = find_correct_signal(df, load_signals, gs_kwargs)
        try:
            meta_hold = [find_correct_signal(df, expr_signals, {"signal_column" : metacheck,
                                                                "channel_column" : "Channel"}) for metacheck in df_meta]
        except:
            meta_hold = []
        # Add them to the returner lists
        expr_out.append(expr_hold)
        load_out.append(load_hold)
        metadata.append(meta_hold)
    # If metadata isn't empty,we want to reformat it
    if metadata != []:
        metadata = [gh.transpose(*g) for g in metadata]
        metadata = [[gh.list_to_str(subg, 
                                    delimiter = " ", 
                                    newline = False) for subg in g]
                    for g in metadata]
    return expr_out, load_out, metadata

def read_norm_merge(file_dir,
                    expr_signal = [700, "685Ex-720Em"],
                    load_signal = [800, "785Ex-820Em"],
                    df_meta = ["Group", "Condition", "Time"],
                    norm_string = " 0m",
                    gs_kwargs = dict(signal_column = "Signal",
                                 channel_column = "Channel"),
                    log2_trans = True):
    """
    Goal: Grab the dataframe files, use the Western Helpers
          stuff to do some of the management/normalisation
          and finally merge the files
    """
    # First, get a list of all the files. We assume this whole
    # directory is merging into one final file
    files = glob.glob(f"{file_dir}/*")
    files = [pd.read_excel(f) for f in files]
    
    expr_sigs, load_sigs, metadata = get_all_signals(files,
                                                     expr_signals = expr_signal,
                                                     load_signals = load_signal,
                                                     gs_kwargs = gs_kwargs,
                                                     df_meta = df_meta)
    # HOLY FUCK that took forever
    # Alright, now that we have all the information we need, we
    # can start croonchin noombres
    corrected_sigs = [wh.licor_correction(expr_sigs[i],
                                          load_sigs[i]) for i in range(len(expr_sigs))]
    # With the corrected signal, we can normalise to the mean of the
    # norm_group. We'll use the metadata strings to find the indices
    # for that group
    norm_inds = [[j for j in range(len(metadata[i])) if norm_string in metadata[i][j]]
                 for i in range(len(metadata))]
    norm_means = [[sh.mean([corrected_sigs[i][j] for j in norm_inds[i]],
                          filter_nans = True,
                          threshold = 1)]
                  for i in range(len(corrected_sigs))]
    norm_signals = [replace_neg([corrected_sigs[i][j]/norm_means[i][0] for j in range(len(corrected_sigs[i]))],
                                value = float("nan"))
                    for i in range(len(corrected_sigs))]
    if log2_trans:
        norm_signals = [[safe_log2(item) for item in group] for group in norm_signals]
    # Now we need to get ready to merge...
    # I'm thinking do the 'ol bullshit hack and use bin_by_col,
    # then merge_dicts. So first add in headers and merge the metadata
    d_heads = ["Norm Signal", "Group Labels"]
    data = [[norm_signals[i], metadata[i]] for i in range(len(norm_signals))]
    data = [gh.transpose(*g) for g in data]
    data = [[d_heads] + g for g in data]
    data = [gh.bin_by_col(g, d_heads.index("Group Labels")) for g in data]
    data = [{key : value[1:] for key, value in g.items()} for g in data]
    newdict = {}
    for subdict in data:
        for key, value in subdict.items():
            if key not in newdict.keys():
                newdict[key] = gh.transpose(*value)[0]
            else:
                newdict[key] += gh.transpose(*value)[0]
    # Merge dicts is weird, might not use it. but otherwise we just need
    # to take the two sublists and actually merge them properly
    return [[key, value] for key, value in newdict.items()]

def grab_all_files(file_dirs, 
                   expr_signal = [700, "685Ex-720Em"],
                    load_signal = [800, "785Ex-820Em"],
                    df_meta = ["Group", "Condition", "Time"],
                    norm_string = " 0m",
                    gs_kwargs = dict(signal_column = "Signal",
                                 channel_column = "Channel"),
                    log2_trans = True):
    data = []
    for dr in file_dirs:
        merged_files = read_norm_merge(dr,
                                     expr_signal = expr_signal,
                                     load_signal = load_signal,
                                     df_meta = df_meta,
                                     norm_string = norm_string,
                                     gs_kwargs = gs_kwargs,
                                     log2_trans = log2_trans)
        data.append(merged_files)
    return data


#
#
###################################################################

In [5]:
## Make the plot for the 4G10 titration
titr_stats = ["./stats/titration_4g10_holm",
              "./stats/titration_erk_holm"]
titr_figs = ["./figures/titration_4g10.pdf",
             "./figures/titration_erk.pdf"]

titr_ylims = [[-4,4], [-4,8]]

titr_dirs = glob.glob("./excel_sheets/titration_*")
titr_glabs = [dr.split("_")[-1] for dr in titr_dirs]


titr_data = {titr_glabs[i] : grab_all_files(glob.glob(f"{titr_dirs[i]}/*")) for i in range(len(titr_glabs))}

counter = 0
for key, value in titr_data.items():
    ignore = _logical_ignore_comps(value,
                                   group_strs = time_groups,
                                   xgroup_strs = titr_xgroups)
    
    line_plot(value, 
              ylims = titr_ylims[counter],
              colours = [colour_dict["JE6"],
                         colour_dict["J.PTPN22-"],
                         colour_dict["J.SHP1-"],
                         colour_dict["J.TCPTP-"]],
              markers = markers,
              linestyles = ["dashdot" for _ in range(4)],
              xlabels = titr_xgroups,
              show_points = False,
              show_legend = True,
              group_labs = titr_groups,
             ignore_comps = ignore,
              statsfile = titr_stats[counter],
              figfile = titr_figs[counter],
              comp_type = "HolmSidak",
              ylabel = r"$\log_{2}$ Fold Change")
    
    counter += 1
    


In [11]:
time_stats = ["./stats/timecourse_4g10_holm",
              "./stats/timecourse_erk_holm",
              "./stats/timecourse_lck_holm",
              "./stats/timecourse_plc_holm"]
time_figs = ["./figures/timecourse_4g10.pdf",
              "./figures/timecourse_erk.pdf",
              "./figures/timecourse_lck.pdf",
              "./figures/timecourse_plc.pdf"]

time_ylims = [[-2,4],
              [-2,8],
              [-2,4],
              [-2,6]]

time_dirs = glob.glob("./excel_sheets/timecourse_*")
time_glabs = [dr.split("_")[-1] for dr in time_dirs]

time_data = {time_glabs[i] : grab_all_files(glob.glob(f"{time_dirs[i]}/*/*")) for i in range(len(time_glabs))}

counter = 0
for key, value in time_data.items():
    ignore = _logical_ignore_comps(value,
                                   group_strs = time_groups,
                                   xgroup_strs = time_xgroups)
    line_plot(value, 
              ylims = time_ylims[counter],
              colours = [colour_dict["JE6"],
                         colour_dict["J.PTPN22-"],
                         colour_dict["J.SHP1-"],
                         colour_dict["J.TCPTP-"],
                         colour_dict["JE6"],
                         colour_dict["J.PTPN22-"],
                         colour_dict["J.SHP1-"],
                         colour_dict["J.TCPTP-"],],
              markers = markers,
              linestyles = ["solid" for _ in range(4)] + ["dotted" for _ in range(4)],
              xlabels = time_xgroups,
              show_points = False,
              show_legend = True,
              group_labs = time_groups,
             ignore_comps = ignore,
              statsfile = time_stats[counter],
              figfile = time_figs[counter],
              comp_type = "HolmSidak",
              ylabel = r"$\log_{2}$ Fold Change")
    
    counter += 1
    

In [7]:
for pair in ignore:
    if "20 $\mu$M 0m" in pair[0][0] or "20 $\mu$M 0m" in pair[1][0]:
        print(pair)

[['JE6 0.05% DMSO 2m', [2.4928959300132565, 2.019582131133139, 1.1864720943902385, 1.200172415798116]], ['JE6 20 $\\mu$M 0m', [0.07737048681912326, -0.08175610661061157, -0.601867767168747, 0.4234164232344189]]]
[['JE6 0.05% DMSO 2m', [2.4928959300132565, 2.019582131133139, 1.1864720943902385, 1.200172415798116]], ['J.PTPN22- 20 $\\mu$M 0m', [-0.08155721577194895, 0.07719234345448227, -0.3442493357372907, 0.27772565063908045]]]
[['JE6 0.05% DMSO 2m', [2.4928959300132565, 2.019582131133139, 1.1864720943902385, 1.200172415798116]], ['J.SHP1- 20 $\\mu$M 0m', [-0.1840212576755952, 0.16318176401879197, -0.16390082310485446, 0.14716557351125276]]]
[['JE6 0.05% DMSO 2m', [2.4928959300132565, 2.019582131133139, 1.1864720943902385, 1.200172415798116]], ['J.TCPTP- 20 $\\mu$M 0m', [-0.08263888298311667, 0.07816061404591497, -0.25829097183997096, 0.21899468018479962]]]
[['JE6 0.05% DMSO 5m', [2.9012498089958845, 2.7315564372603625, 2.169851766163232, 2.02610274670599]], ['JE6 20 $\\mu$M 0m', [0.07