diff --git a/.github/workflows/cdci.yml b/.github/workflows/cdci.yml index cc00de9..f7b6a3c 100644 --- a/.github/workflows/cdci.yml +++ b/.github/workflows/cdci.yml @@ -9,33 +9,32 @@ on: types: [published] jobs: - # format: - # runs-on: ubuntu-latest - # steps: - # - uses: actions/checkout@v4 - # - uses: psf/black@stable - # lint: - # name: Lint with ruff - # runs-on: ubuntu-latest - # steps: - # - uses: actions/checkout@v4 + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable + lint: + name: Lint with ruff + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 - # - uses: actions/setup-python@v5 - # with: - # python-version: "3.11" - # - name: Install ruff - # run: | - # pip install ruff - # - name: Lint with ruff - # run: | - # # stop the build if there are Python syntax errors or undefined names - # ruff check . + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install ruff + run: | + pip install ruff + - name: Lint with ruff + run: | + ruff check . test: name: Test runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -51,7 +50,6 @@ jobs: pip install -e . # - name: Run tests # run: python -m pytest tests - publish: diff --git a/pyproject.toml b/pyproject.toml index 2edb025..16e5363 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dynamic = ["version"] description = "A Python package for plotting related to multimodal molecular data. Works with acore." license = { text = "GNU General Public License v3" } readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.9,<3.13" classifiers = [ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", @@ -27,12 +27,16 @@ dependencies = [ "beautifulsoup4", "requests", "dash", # from dash import html - # "networkx", + "networkx", "matplotlib", - # "cy-jupyterlab", - # "nltk", - # "webweb" - # "acore", + "kaleido", + "pyvis", + "wordcloud", + "cyjupyter", + "nltk", + "webweb", + "acore", + "dash-cytoscape", ] [project.optional-dependencies] @@ -64,5 +68,7 @@ Documentation = "https://analytics-core.readthedocs.io/" [build-system] -requires = ["setuptools"] +requires = ["setuptools", "setuptools_scm>=8"] build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] diff --git a/src/vuecore/Dendrogram.py b/src/vuecore/Dendrogram.py index 6515c0e..a0e292a 100644 --- a/src/vuecore/Dendrogram.py +++ b/src/vuecore/Dendrogram.py @@ -1,12 +1,24 @@ +from collections import OrderedDict + import numpy as np import scipy as scp -from collections import OrderedDict -import plotly.graph_objs as go -def plot_dendrogram(Z_dendrogram, cutoff_line=True, value=15, orientation='bottom', hang=30, hide_labels=False, labels=None, - colorscale=None, hovertext=None, color_threshold=None): + +def plot_dendrogram( + Z_dendrogram, + cutoff_line=True, + value=15, + orientation="bottom", + hang=30, + hide_labels=False, + labels=None, + colorscale=None, + hovertext=None, + color_threshold=None, +): """ - Modified version of Plotly _dendrogram.py that returns a dendrogram Plotly figure object with cutoff line. + Modified version of Plotly _dendrogram.py that + returns a dendrogram Plotly figure object with cutoff line. :param Z_dendrogram: Matrix of observations as array of arrays :type Z_dendrogram: ndarray @@ -36,25 +48,59 @@ def plot_dendrogram(Z_dendrogram, cutoff_line=True, value=15, orientation='botto figure = plot_dendrogram(dendro_tree, hang=0.9, cutoff_line=False) """ - dendrogram = Dendrogram(Z_dendrogram, orientation, hang, hide_labels, labels, colorscale, hovertext=hovertext, color_threshold=color_threshold) - - if cutoff_line == True: - dendrogram.layout.update({'shapes':[{'type':'line', - 'xref':'paper', - 'yref':'y', - 'x0':0, 'y0':value, - 'x1':1, 'y1':value, - 'line':{'color':'red'}}]}) + dendrogram = Dendrogram( + Z_dendrogram, + orientation, + hang, + hide_labels, + labels, + colorscale, + hovertext=hovertext, + color_threshold=color_threshold, + ) + + if cutoff_line: + dendrogram.layout.update( + { + "shapes": [ + { + "type": "line", + "xref": "paper", + "yref": "y", + "x0": 0, + "y0": value, + "x1": 1, + "y1": value, + "line": {"color": "red"}, + } + ] + } + ) figure = dict(data=dendrogram.data, layout=dendrogram.layout) - figure['layout']['template'] = 'plotly_white' + figure["layout"]["template"] = "plotly_white" return figure + class Dendrogram(object): """Refer to plot_dendrogram() for docstring.""" - def __init__(self, Z_dendrogram, orientation='bottom', hang=1, hide_labels=False, labels=None, colorscale=None, hovertext=None, - color_threshold=None, width=np.inf, height=np.inf, xaxis='xaxis', yaxis='yaxis'): + + def __init__( + self, + Z_dendrogram, + orientation="bottom", + hang=1, + hide_labels=False, + labels=None, + colorscale=None, + hovertext=None, + color_threshold=None, + width=np.inf, + height=np.inf, + xaxis="xaxis", + yaxis="yaxis", + ): self.orientation = orientation self.labels = labels self.xaxis = xaxis @@ -64,20 +110,19 @@ def __init__(self, Z_dendrogram, orientation='bottom', hang=1, hide_labels=False self.sign = {self.xaxis: 1, self.yaxis: 1} self.layout = {self.xaxis: {}, self.yaxis: {}} - if self.orientation in ['left', 'bottom']: + if self.orientation in ["left", "bottom"]: self.sign[self.xaxis] = 1 else: self.sign[self.xaxis] = -1 - if self.orientation in ['right', 'bottom']: + if self.orientation in ["right", "bottom"]: self.sign[self.yaxis] = 1 else: self.sign[self.yaxis] = -1 - (dd_traces, xvals, yvals, - ordered_labels, leaves) = self.get_dendrogram_traces(Z_dendrogram, hang, colorscale, - hovertext, - color_threshold) + (dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces( + Z_dendrogram, hang, colorscale, hovertext, color_threshold + ) self.labels = ordered_labels self.leaves = leaves @@ -93,9 +138,9 @@ def __init__(self, Z_dendrogram, orientation='bottom', hang=1, hide_labels=False if len(self.zero_vals) > len(yvals) + 1: l_border = int(min(self.zero_vals)) r_border = int(max(self.zero_vals)) - correct_leaves_pos = range(l_border, - r_border + 1, - int((r_border - l_border) / len(yvals))) + correct_leaves_pos = range( + l_border, r_border + 1, int((r_border - l_border) / len(yvals)) + ) self.zero_vals = [v for v in correct_leaves_pos] self.zero_vals.sort() @@ -113,26 +158,29 @@ def get_color_dict(self, colorscale): # These are the color codes returned for dendrograms # We're replacing them with nicer colors - d = {'r': 'red', - 'g': 'green', - 'b': 'blue', - 'c': 'cyan', - 'm': 'magenta', - 'y': 'yellow', - 'k': 'black', - 'w': 'white'} + d = { + "r": "red", + "g": "green", + "b": "blue", + "c": "cyan", + "m": "magenta", + "y": "yellow", + "k": "black", + "w": "white", + } default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0])) if colorscale is None: colorscale = [ - 'rgb(0,116,217)', # blue - 'rgb(35,205,205)', # cyan - 'rgb(61,153,112)', # green - 'rgb(40,35,35)', # black - 'rgb(133,20,75)', # magenta - 'rgb(255,65,54)', # red - 'rgb(255,255,255)', # white - 'rgb(255,220,0)'] # yellow + "rgb(0,116,217)", # blue + "rgb(35,205,205)", # cyan + "rgb(61,153,112)", # green + "rgb(40,35,35)", # black + "rgb(133,20,75)", # magenta + "rgb(255,65,54)", # red + "rgb(255,255,255)", # white + "rgb(255,220,0)", + ] # yellow for i in range(len(default_colors.keys())): k = list(default_colors.keys())[i] # PY3 won't index keys @@ -150,32 +198,34 @@ def set_axis_layout(self, axis_key, hide_labels): :return (dict): An axis_key dictionary with set parameters. """ axis_defaults = { - 'type': 'linear', - 'ticks': 'outside', - 'mirror': 'allticks', - 'rangemode': 'tozero', - 'showticklabels': True, - 'zeroline': False, - 'showgrid': False, - 'showline': True, - } + "type": "linear", + "ticks": "outside", + "mirror": "allticks", + "rangemode": "tozero", + "showticklabels": True, + "zeroline": False, + "showgrid": False, + "showline": True, + } if len(self.labels) != 0: axis_key_labels = self.xaxis - if self.orientation in ['left', 'right']: + if self.orientation in ["left", "right"]: axis_key_labels = self.yaxis if axis_key_labels not in self.layout: self.layout[axis_key_labels] = {} - self.layout[axis_key_labels]['tickvals'] = \ - [zv*self.sign[axis_key] for zv in self.zero_vals] - self.layout[axis_key_labels]['ticktext'] = self.labels - self.layout[axis_key_labels]['tickmode'] = 'array' + self.layout[axis_key_labels]["tickvals"] = [ + zv * self.sign[axis_key] for zv in self.zero_vals + ] + self.layout[axis_key_labels]["ticktext"] = self.labels + self.layout[axis_key_labels]["tickmode"] = "array" self.layout[axis_key].update(axis_defaults) - if hide_labels == True: - self.layout[axis_key].update({'showticklabels': False}) - else: pass + if hide_labels: + self.layout[axis_key].update({"showticklabels": False}) + else: + pass return self.layout[axis_key] @@ -191,24 +241,27 @@ def set_figure_layout(self, width, height, hide_labels): :type hide_labels: boolean :return: Plotly layout """ - self.layout.update({ - 'showlegend': False, - 'autosize': False, - 'hovermode': 'closest', - 'width': width, - 'height': height - }) + self.layout.update( + { + "showlegend": False, + "autosize": False, + "hovermode": "closest", + "width": width, + "height": height, + } + ) self.set_axis_layout(self.xaxis, hide_labels=hide_labels) self.set_axis_layout(self.yaxis, hide_labels=False) return self.layout - - def get_dendrogram_traces(self, Z_dendrogram, hang, colorscale, hovertext, color_threshold): + def get_dendrogram_traces( + self, Z_dendrogram, hang, colorscale, hovertext, color_threshold + ): """ Calculates all the elements needed for plotting a dendrogram. - + :param Z_dendrogram: Matrix of observations as array of arrays :type Z_dendrogram: ndarray :param hang: dendrogram distance of leaf lines @@ -218,32 +271,32 @@ def get_dendrogram_traces(self, Z_dendrogram, hang, colorscale, hovertext, color :param hovertext: List of hovertext for constituent traces of dendrogram :type hovertext: list :return (tuple): Contains all the traces in the following order: - + a. trace_list: List of Plotly trace objects for dendrogram tree b. icoord: All X points of the dendrogram tree as array of arrays with length 4 c. dcoord: All Y points of the dendrogram tree as array of arrays with length 4 d. ordered_labels: leaf labels in the order they are going to appear on the plot e. Z_dendrogram['leaves']: left-to-right traversal of the leaves """ - icoord = scp.array(Z_dendrogram['icoord']) - dcoord = scp.array(Z_dendrogram['dcoord']) - ordered_labels = scp.array(Z_dendrogram['ivl']) - color_list = scp.array(Z_dendrogram['color_list']) - colors = self.get_color_dict(colorscale) + icoord = scp.array(Z_dendrogram["icoord"]) + dcoord = scp.array(Z_dendrogram["dcoord"]) + ordered_labels = scp.array(Z_dendrogram["ivl"]) + # color_list = scp.array(Z_dendrogram["color_list"]) + # colors = self.get_color_dict(colorscale) trace_list = [] for i in range(len(icoord)): - if self.orientation in ['top', 'bottom']: + if self.orientation in ["top", "bottom"]: xs = icoord[i] else: xs = dcoord[i] - if self.orientation in ['top', 'bottom']: + if self.orientation in ["top", "bottom"]: ys = dcoord[i] else: ys = icoord[i] - color_key = color_list[i] + # color_key = color_list[i] # not used hovertext_label = None if hovertext: hovertext_label = hovertext[i] @@ -264,28 +317,31 @@ def get_dendrogram_traces(self, Z_dendrogram, hang, colorscale, hovertext, color y_coord.append(y) trace = dict( - type='scattergl', + type="scattergl", x=np.multiply(self.sign[self.xaxis], x_coord), y=np.multiply(self.sign[self.yaxis], y_coord), - mode='lines', - marker=dict(color='rgb(40,35,35)'), - line=dict(color='rgb(40,35,35)', width=1), #dict(color=colors[color_key]), + mode="lines", + marker=dict(color="rgb(40,35,35)"), + line=dict( + color="rgb(40,35,35)", width=1 + ), # dict(color=colors[color_key]), text=hovertext_label, - hoverinfo='text') + hoverinfo="text", + ) try: x_index = int(self.xaxis[-1]) except ValueError: - x_index = '' + x_index = "" try: y_index = int(self.yaxis[-1]) except ValueError: - y_index = '' + y_index = "" - trace['xaxis'] = 'x' + x_index - trace['yaxis'] = 'y' + y_index + trace["xaxis"] = "x" + x_index + trace["yaxis"] = "y" + y_index trace_list.append(trace) - return trace_list, icoord, dcoord, ordered_labels, Z_dendrogram['leaves'] + return trace_list, icoord, dcoord, ordered_labels, Z_dendrogram["leaves"] diff --git a/src/vuecore/color_list.py b/src/vuecore/color_list.py index 3f1498f..9ed0bb9 100644 --- a/src/vuecore/color_list.py +++ b/src/vuecore/color_list.py @@ -52,25 +52,26 @@ 190 190 190 grey """ + def make_color_dict(colors=COLORS): """Returns a dictionary that maps color names to RGB strings. The format of RGB strings is '#RRGGBB'. """ # regular expressions to match numbers and color names - number = r'(\d+)' - space = r'[ \t]*' - name = r'([ \w]+)' + number = r"(\d+)" + space = r"[ \t]*" + name = r"([ \w]+)" pattern = space + (number + space) * 3 + name prog = re.compile(pattern) # read the file d = dict() - for line in colors.split('\n'): + for line in colors.split("\n"): ro = prog.match(line) if ro: r, g, b, name = ro.groups() - rgb = '#%02x%02x%02x' % (int(r), int(g), int(b)) + rgb = "#%02x%02x%02x" % (int(r), int(g), int(b)) d[name] = rgb return d @@ -108,7 +109,7 @@ def invert_dict(d): return inv -if __name__ == '__main__': +if __name__ == "__main__": color_dict = make_color_dict() for name, rgb in color_dict.items(): print(name, rgb) diff --git a/src/vuecore/linkers.py b/src/vuecore/linkers.py new file mode 100644 index 0000000..cefef38 --- /dev/null +++ b/src/vuecore/linkers.py @@ -0,0 +1,20 @@ +from io import StringIO + +import requests + + +def get_clustergrammer_link(net, filename=None): + clustergrammer_url = "http://amp.pharm.mssm.edu/clustergrammer/matrix_upload/" + if filename is None: + file_string = net.write_matrix_to_tsv() + file_obj = StringIO(file_string) + if "filename" not in net.dat or net.dat["filename"] is None: + fake_filename = "Network.txt" + else: + fake_filename = net.dat["filename"] + r = requests.post(clustergrammer_url, files={"file": (fake_filename, file_obj)}) + else: + file_obj = open(filename, "r") + r = requests.post(clustergrammer_url, files={"file": file_obj}) + link = r.text + return link diff --git a/src/vuecore/translate.py b/src/vuecore/translate.py new file mode 100644 index 0000000..22d7142 --- /dev/null +++ b/src/vuecore/translate.py @@ -0,0 +1,20 @@ +import base64 +import io + +from dash import html + + +def hex2rgb(color): + _hex = color.lstrip("#") + rgb = tuple(int(_hex[i : i + 2], 16) for i in (0, 2, 4)) + rgba = rgb + (0.6,) + return rgba + + +def mpl_to_html_image(plot, width=800): + buf = io.BytesIO() + plot.savefig(buf, format="png") + data = base64.b64encode(buf.getbuffer()).decode("utf8") + figure = html.Img(src="data:image/png;base64,{}".format(data), width=f"{width}") + + return figure diff --git a/src/vuecore/utils.py b/src/vuecore/utils.py index 8b89e8b..d714eb7 100644 --- a/src/vuecore/utils.py +++ b/src/vuecore/utils.py @@ -1,17 +1,17 @@ import random -from Bio import Entrez, Medline from collections import defaultdict -import pandas as pd -import io -import base64 +from urllib import error + import bs4 as bs -import dash_html_components as html -import requests import networkx as nx +import pandas as pd +import requests +from Bio import Entrez, Medline +from dash import html from networkx.readwrite import json_graph -from urllib import error -Entrez.email = 'alberto.santos@cpr.ku.dk' # TODO: This should probably be changed to the email of the person installing ckg? +# TODO: This should probably be changed to the email of the person installing ckg? +Entrez.email = "alberto.santos@cpr.ku.dk" def check_columns(df, cols): @@ -21,15 +21,6 @@ def check_columns(df, cols): return True -def mpl_to_html_image(plot, width=800): - buf = io.BytesIO() - plot.savefig(buf, format="png") - data = base64.b64encode(buf.getbuffer()).decode("utf8") - figure = html.Img(src="data:image/png;base64,{}".format(data), width="800") - - return figure - - def generate_html(network): """ This method gets the data structures supporting the nodes, edges, @@ -53,8 +44,19 @@ def generate_html(network): template = network.template nodes, edges, height, width, options = network.get_network_data() - network.html = template.render(height=height, width=width, nodes=nodes, edges=edges, options=options, use_DOT=network.use_DOT, dot_lang=network.dot_lang, - widget=network.widget, bgcolor=network.bgcolor, conf=network.conf, tooltip_link=use_link_template) + network.html = template.render( + height=height, + width=width, + nodes=nodes, + edges=edges, + options=options, + use_DOT=network.use_DOT, + dot_lang=network.dot_lang, + widget=network.widget, + bgcolor=network.bgcolor, + conf=network.conf, + tooltip_link=use_link_template, + ) def append_to_list(mylist, myappend): @@ -64,7 +66,7 @@ def append_to_list(mylist, myappend): mylist.append(myappend) -def neo4j_path_to_networkx(paths, key='path'): +def neo4j_path_to_networkx(paths, key="path"): nodes = set() rels = set() for path in paths: @@ -72,10 +74,10 @@ def neo4j_path_to_networkx(paths, key='path'): relationships = path[key] if len(relationships) == 3: node1, rel, node2 = relationships - if 'name' in node1: - source = node1['name'] - if 'name' in node2: - target = node2['name'] + if "name" in node1: + source = node1["name"] + if "name" in node2: + target = node2["name"] nodes.update([source, target]) rels.add((source, target, rel)) @@ -90,20 +92,20 @@ def neo4j_path_to_networkx(paths, key='path'): def neo4j_schema_to_networkx(schema): nodes = set() rels = set() - if 'relationships' in schema[0]: - relationships = schema[0]['relationships'] + if "relationships" in schema[0]: + relationships = schema[0]["relationships"] for node1, rel, node2 in relationships: - if 'name' in node1: - source = node1['name'] - if 'name' in node2: - target = node2['name'] + if "name" in node1: + source = node1["name"] + if "name" in node2: + target = node2["name"] nodes.update([source, target]) rels.add((source, target, rel)) G = nx.Graph() G.add_nodes_from(nodes) colors = dict(zip(nodes, get_hex_colors(len(nodes)))) - nx.set_node_attributes(G, colors, 'color') + nx.set_node_attributes(G, colors, "color") for s, t, label in rels: G.add_edge(s, t, label=label) @@ -112,8 +114,8 @@ def neo4j_schema_to_networkx(schema): def networkx_to_cytoscape(graph): cy_graph = json_graph.cytoscape_data(graph) - cy_nodes = cy_graph['elements']['nodes'] - cy_edges = cy_graph['elements']['edges'] + cy_nodes = cy_graph["elements"]["nodes"] + cy_edges = cy_graph["elements"]["edges"] cy_elements = cy_nodes cy_elements.extend(cy_edges) mouseover_node = dict(graph.nodes(data=True)) @@ -130,17 +132,17 @@ def networkx_to_neo4j_document(graph): seen_rels = set() for n, attr in graph.nodes(data=True): rels = defaultdict(list) - attr.update({'id': n}) + attr.update({"id": n}) for r in graph[n]: edge = graph[n][r] - edge.update({'id': r}) - if 'type' in edge: - rel_type = edge['type'] - if 'type' in graph.nodes()[r]: - edge['type'] = graph.nodes()[r]['type'] - if not (n, r, edge['type']) in seen_rels: + edge.update({"id": r}) + if "type" in edge: + rel_type = edge["type"] + if "type" in graph.nodes()[r]: + edge["type"] = graph.nodes()[r]["type"] + if (n, r, edge["type"]) not in seen_rels: rels[rel_type].append(edge) - seen_rels.update({(n, r, edge['type']), (r, n, edge['type'])}) + seen_rels.update({(n, r, edge["type"]), (r, n, edge["type"])}) attr.update(rels) graph_json.append(attr) @@ -149,7 +151,7 @@ def networkx_to_neo4j_document(graph): def json_network_to_gml(graph_json, path): graph = json_network_to_networkx(graph_json) - with open(path, 'wb') as out: + with open(path, "wb") as out: nx.write_gml(graph, out) @@ -159,7 +161,7 @@ def networkx_to_graphml(graph, path): def json_network_to_graphml(graph_json, path): graph = json_network_to_networkx(graph_json) - with open(path, 'wb') as out: + with open(path, "wb") as out: nx.write_graphml(graph, out) @@ -174,18 +176,18 @@ def get_clustergrammer_link(net, filename=None): from StringIO import StringIO except ImportError: from io import StringIO - clustergrammer_url = 'http://amp.pharm.mssm.edu/clustergrammer/matrix_upload/' + clustergrammer_url = "http://amp.pharm.mssm.edu/clustergrammer/matrix_upload/" if filename is None: file_string = net.write_matrix_to_tsv() file_obj = StringIO(file_string) - if 'filename' not in net.dat or net.dat['filename'] is None: - fake_filename = 'Network.txt' + if "filename" not in net.dat or net.dat["filename"] is None: + fake_filename = "Network.txt" else: - fake_filename = net.dat['filename'] - r = requests.post(clustergrammer_url, files={'file': (fake_filename, file_obj)}) + fake_filename = net.dat["filename"] + r = requests.post(clustergrammer_url, files={"file": (fake_filename, file_obj)}) else: - file_obj = open(filename, 'r') - r = requests.post(clustergrammer_url, files={'file': file_obj}) + file_obj = open(filename, "r") + r = requests.post(clustergrammer_url, files={"file": file_obj}) link = r.text return link @@ -199,25 +201,65 @@ def generator_to_dict(genvar): def parse_html(html_snippet): - html_parsed = bs.BeautifulSoup(html_snippet, 'html.parser') + html_parsed = bs.BeautifulSoup(html_snippet, "html.parser") return html_parsed def convert_html_to_dash(el, style=None): - ALLOWED_CST = {'div', 'span', 'a', 'hr', 'br', 'p', 'b', 'i', 'u', 's', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'ol', 'ul', 'li', - 'em', 'strong', 'cite', 'tt', 'pre', 'small', 'big', 'center', 'blockquote', 'address', 'font', 'img', - 'table', 'tr', 'td', 'caption', 'th', 'textarea', 'option'} + ALLOWED_CST = { + "div", + "span", + "a", + "hr", + "br", + "p", + "b", + "i", + "u", + "s", + "h1", + "h2", + "h3", + "h4", + "h5", + "h6", + "ol", + "ul", + "li", + "em", + "strong", + "cite", + "tt", + "pre", + "small", + "big", + "center", + "blockquote", + "address", + "font", + "img", + "table", + "tr", + "td", + "caption", + "th", + "textarea", + "option", + } def __extract_style(el): if not el.attrs.get("style"): return None - return {k.strip(): v.strip() for k, v in [x.split(": ") for x in el.attrs["style"].split(";") if x != '']} + return { + k.strip(): v.strip() + for k, v in [x.split(": ") for x in el.attrs["style"].split(";") if x != ""] + } - if type(el) is str: + if isinstance(el, str): return convert_html_to_dash(parse_html(el)) - if type(el) == bs.element.NavigableString: + if isinstance(el, bs.element.NavigableString): return str(el) else: name = el.name @@ -228,13 +270,6 @@ def __extract_style(el): return getattr(html, name.title())(contents, style=style) -def hex2rgb(color): - hex = color.lstrip('#') - rgb = tuple(int(hex[i:i+2], 16) for i in (0, 2, 4)) - rgba = rgb + (0.6,) - return rgba - - def get_rgb_colors(n): colors = [] r = int(random.random() * 256) @@ -264,11 +299,21 @@ def get_hex_colors(n): def getMedlineAbstracts(idList): - fields = {"TI": "title", "AU": "authors", "JT": "journal", "DP": "date", "MH": "keywords", "AB": "abstract", "PMID": "PMID"} + fields = { + "TI": "title", + "AU": "authors", + "JT": "journal", + "DP": "date", + "MH": "keywords", + "AB": "abstract", + "PMID": "PMID", + } pubmedUrl = "https://www.ncbi.nlm.nih.gov/pubmed/" abstracts = pd.DataFrame() try: - handle = Entrez.efetch(db="pubmed", id=idList, rettype="medline", retmode="json") + handle = Entrez.efetch( + db="pubmed", id=idList, rettype="medline", retmode="json" + ) records = Medline.parse(handle) results = [] for record in records: diff --git a/src/vuecore/viz.py b/src/vuecore/viz.py index 58511f7..2be89bf 100644 --- a/src/vuecore/viz.py +++ b/src/vuecore/viz.py @@ -1,81 +1,107 @@ -import os -import numpy as np -import pandas as pd +"""To be broken down into modules.""" + import ast +import json +import math +import os from collections import defaultdict -import dash_core_components as dcc -import dash_html_components as html -import matplotlib + +import dash_cytoscape as cyto import matplotlib.pyplot as plt +import networkx as nx +import nltk +import numpy as np +import pandas as pd import plotly -import plotly.tools as tls -import plotly.graph_objs as go -import plotly.figure_factory as FF import plotly.express as px -import math -import dash_table -import plotly.subplots as tools +import plotly.figure_factory as FF +import plotly.graph_objs as go import plotly.io as pio -from scipy.spatial.distance import pdist, squareform -from scipy.stats import zscore -import networkx as nx +import plotly.subplots as tools +import plotly.tools as tls +from acore import network_analysis, wgcna_analysis from cyjupyter import Cytoscape -from pyvis.network import Network as visnet -from webweb import Web +from dash import dash_table, dcc, html from networkx.readwrite import json_graph -import json -from ckg.analytics_core import utils -from ckg.analytics_core.analytics import analytics -from wordcloud import WordCloud, STOPWORDS from nltk.corpus import stopwords -import nltk +from pyvis.network import Network as visnet +from scipy import stats +from scipy.spatial.distance import pdist, squareform +from scipy.stats import zscore +from webweb import Web +from wordcloud import STOPWORDS, WordCloud +from . import dendrogram, utils, wgcna +from .linkers import get_clustergrammer_link +from .translate import hex2rgb, mpl_to_html_image -from ckg.analytics_core.analytics import wgcnaAnalysis -from ckg.analytics_core.viz import wgcnaFigures, Dendrogram -import dash_cytoscape as cyto +# matplotlib.use("Agg") -matplotlib.use("Agg") - -def getPlotTraces(data, key='full', type='lines', div_factor=float(10^10000), horizontal=False): +def getPlotTraces( + data, key="full", type="lines", div_factor=float(10 ^ 10000), horizontal=False +): """ This function returns traces for different kinds of plots. - :param data: Pandas DataFrame with one variable as data.index (i.e. 'x') and all others as columns (i.e. 'y'). + :param data: Pandas DataFrame with one variable as data.index (i.e. 'x') + and all others as columns (i.e. 'y'). :param str type: 'lines', 'scaled markers', 'bars'. :param float div_factor: relative size of the markers. :param bool horizontal: bar orientation. :return: list of traces. - Exmaple 1:: + Example 1:: result = getPlotTraces(data, key='full', type = 'lines', horizontal=False) Example 2:: - result = getPlotTraces(data, key='full', type = 'scaled markers', div_factor=float(10^3000), horizontal=True) + result = getPlotTraces(data, key='full', type = 'scaled markers', + div_factor=float(10^3000), horizontal=True) """ - if type == 'lines': - traces = [go.Scattergl(x=data.index, y=data[col], name = col+' '+key, mode='markers+lines') for col in data.columns] - - elif type == 'scaled markers': - traces = [go.Scattergl(x = data.index, y = data[col], name = col+' '+key, mode = 'markers', marker = dict(size = data[col].values/div_factor, sizemode = 'area')) for col in data.columns] - - elif type == 'bars': - traces = [go.Bar(x = data.index, y = data[col], orientation = 'v', name = col+' '+key) for col in data.columns] - if horizontal == True: - traces = [go.Bar(x = data[col], y = data.index, orientation = 'h', name = col+' '+key) for col in data.columns] + if type == "lines": + traces = [ + go.Scattergl( + x=data.index, y=data[col], name=col + " " + key, mode="markers+lines" + ) + for col in data.columns + ] + + elif type == "scaled markers": + traces = [ + go.Scattergl( + x=data.index, + y=data[col], + name=col + " " + key, + mode="markers", + marker=dict(size=data[col].values / div_factor, sizemode="area"), + ) + for col in data.columns + ] + + elif type == "bars": + traces = [ + go.Bar(x=data.index, y=data[col], orientation="v", name=col + " " + key) + for col in data.columns + ] + if horizontal: + traces = [ + go.Bar(x=data[col], y=data.index, orientation="h", name=col + " " + key) + for col in data.columns + ] else: - return 'Option not found' + return "Option not found" return traces def get_markdown(text, args={}): """ - Converts a given text into a Dash Markdown component. It includes a syntax for things like bold text and italics, links, inline code snippets, lists, quotes, and more. + Converts a given text into a Dash Markdown component. It includes a syntax for things + like bold text and italics, links, inline code snippets, lists, quotes, and more. + For more information visit https://dash.plot.ly/dash-core-components/markdown. :param str text: markdown string (or array of strings) that adhreres to the CommonMark spec. @@ -96,18 +122,28 @@ def get_pieplot(data, identifier, args): :param dict args: see below. :Arguments: * **valueCol** (str) -- name of the column with the values to be plotted. - * **textCol** (str) -- name of the column containing information for the hoverinfo parameter. + * **textCol** (str) -- name of the column \ + containing information for the hoverinfo parameter. * **height** (str) -- height of the plot. * **width** (str) -- width of the plot. :return: Pieplot figure within the
. """ figure = {} - figure['data'] = [] - figure['data'].append(go.Pie(labels=data.index, values=data[args['valueCol']], hovertext=data[args['textCol']], hoverinfo='label+text+percent')) - figure["layout"] = go.Layout(height=args['height'], - width=args['width'], - annotations=[dict(xref='paper', yref='paper', showarrow=False, text='')], - template='plotly_white') + figure["data"] = [] + figure["data"].append( + go.Pie( + labels=data.index, + values=data[args["valueCol"]], + hovertext=data[args["textCol"]], + hoverinfo="label+text+percent", + ) + ) + figure["layout"] = go.Layout( + height=args["height"], + width=args["width"], + annotations=[dict(xref="paper", yref="paper", showarrow=False, text="")], + template="plotly_white", + ) return dcc.Graph(id=identifier, figure=figure) @@ -125,7 +161,7 @@ def get_distplot(data, identifier, args): df = data.copy() graphs = [] - df = df.set_index(args['group']) + df = df.set_index(args["group"]) df = df.transpose() df = df.dropna() @@ -135,9 +171,17 @@ def get_distplot(data, identifier, args): hist_data.append(df.loc[i, c].values.tolist()) group_labels = df.columns.unique().tolist() # Create distplot with custom bin_size - fig = FF.create_distplot(hist_data, group_labels, bin_size=.5, curve_type='normal') - fig['layout'].update(height=600, width=1000, title='Distribution plot ' + i, annotations=[dict(xref='paper', yref='paper', showarrow=False, text='')], template='plotly_white') - graphs.append(dcc.Graph(id=identifier+"_"+i, figure=fig)) + fig = FF.create_distplot( + hist_data, group_labels, bin_size=0.5, curve_type="normal" + ) + fig["layout"].update( + height=600, + width=1000, + title="Distribution plot " + i, + annotations=[dict(xref="paper", yref="paper", showarrow=False, text="")], + template="plotly_white", + ) + graphs.append(dcc.Graph(id=identifier + "_" + i, figure=fig)) return graphs @@ -146,7 +190,8 @@ def get_boxplot_grid(data, identifier, args): """ This function plots a boxplot in a grid based on column values. - :param data: pandas DataFrame with columns: 'x' values and 'y' values to plot, 'color' and 'facet' (color and facet can be the same). + :param data: pandas DataFrame with columns: 'x' values and 'y' values to plot, + 'color' and 'facet' (color and facet can be the same). :param str identifier: id used to identify the div where the figure will be generated. :param dict args: see below. :Arguments: @@ -161,31 +206,60 @@ def get_boxplot_grid(data, identifier, args): Example:: - result = get_boxplot_grid(data, identifier='Boxplot', args:{"title":"Boxplot", 'x':'sample', 'y':'identifier', 'color':'group', 'facet':'qc_class', 'axis':'cols'}) + result = get_boxplot_grid(data, + identifier='Boxplot', + args:{"title":"Boxplot", + 'x':'sample', + 'y':'identifier', + 'color':'group', + 'facet':'qc_class', + 'axis':'cols'} + ) """ fig = {} - if 'x' in args and 'y' in args and 'color' in args: - if 'axis' not in args: - args['axis'] = 'cols' - if 'facet' not in args: + if "x" in args and "y" in args and "color" in args: + if "axis" not in args: + args["axis"] = "cols" + if "facet" not in args: args["facet"] = None - if 'width' not in args: - args['width'] = 2500 - if 'title' not in args: - args['title'] = 'Boxplot' - if 'colors' in args: - color_map = args['colors'] + if "width" not in args: + args["width"] = 2500 + if "title" not in args: + args["title"] = "Boxplot" + if "colors" in args: + color_map = args["colors"] else: color_map = {} - - if args['axis'] == 'rows': - fig = px.box(data, x=args["x"], y=args["y"], color=args['color'], color_discrete_map=color_map, points="all", facet_row=args["facet"], width=args['width']) + + if args["axis"] == "rows": + fig = px.box( + data, + x=args["x"], + y=args["y"], + color=args["color"], + color_discrete_map=color_map, + points="all", + facet_row=args["facet"], + width=args["width"], + ) else: - fig = px.box(data, x=args["x"], y=args["y"], color=args['color'], color_discrete_map=color_map, points="all", facet_col=args["facet"], width=args['width']) - fig.update_xaxes(type='category') - fig.update_layout(annotations=[dict(xref='paper', yref='paper', showarrow=False, text='')], template='plotly_white') + fig = px.box( + data, + x=args["x"], + y=args["y"], + color=args["color"], + color_discrete_map=color_map, + points="all", + facet_col=args["facet"], + width=args["width"], + ) + fig.update_xaxes(type="category") + fig.update_layout( + annotations=[dict(xref="paper", yref="paper", showarrow=False, text="")], + template="plotly_white", + ) else: - fig = get_markdown(text='Missing arguments. Please, provide: x, y, color') + fig = get_markdown(text="Missing arguments. Please, provide: x, y, color") return dcc.Graph(id=identifier, figure=fig) @@ -194,7 +268,8 @@ def get_barplot(data, identifier, args): """ This function plots a simple barplot. - :param data: pandas DataFrame with three columns: 'name' of the bars, 'x' values and 'y' values to plot. + :param data: pandas DataFrame with three columns: + 'name' of the bars, 'x' values and 'y' values to plot. :param str identifier: id used to identify the div where the figure will be generated. :param dict args: see below. :Arguments: @@ -211,69 +286,71 @@ def get_barplot(data, identifier, args): """ figure = {} figure["data"] = [] - if 'title' not in args: - args['title'] = 'Barplot {} - {}'.format(args['x'], args['y']) - if 'x_title' not in args: - args['x_title'] = args['x'] - if 'y_title' not in args: - args['y_title'] = args['y'] - if 'height' not in args: - args['height'] = 600 - if 'width' not in args: - args['width'] = 600 + if "title" not in args: + args["title"] = "Barplot {} - {}".format(args["x"], args["y"]) + if "x_title" not in args: + args["x_title"] = args["x"] + if "y_title" not in args: + args["y_title"] = args["y"] + if "height" not in args: + args["height"] = 600 + if "width" not in args: + args["width"] = 600 if "group" in args: for g in data[args["group"]].unique(): color = None - if 'colors' in args: - if g in args['colors']: - color = args['colors'][g] + if "colors" in args: + if g in args["colors"]: + color = args["colors"][g] errors = [] - if 'errors' in args: - errors = data.loc[data[args["group"]] == g, args['errors']] + if "errors" in args: + errors = data.loc[data[args["group"]] == g, args["errors"]] - if 'orientation' in args: + if "orientation" in args: trace = go.Bar( - x=data.loc[data[args["group"]] == g, args['x']], - y=data.loc[data[args["group"]] == g, args['y']], - error_y=dict(type='data', array=errors), - name=g, - marker=dict(color=color), - orientation=args['orientation'] - ) + x=data.loc[data[args["group"]] == g, args["x"]], + y=data.loc[data[args["group"]] == g, args["y"]], + error_y=dict(type="data", array=errors), + name=g, + marker=dict(color=color), + orientation=args["orientation"], + ) else: trace = go.Bar( - x=data.loc[data[args["group"]] == g, args['x']], # assign x as the dataframe column 'x' - y=data.loc[data[args["group"]] == g, args['y']], - error_y=dict(type='data', array=errors), - name=g, - marker=dict(color=color), - ) + x=data.loc[ + data[args["group"]] == g, args["x"] + ], # assign x as the dataframe column 'x' + y=data.loc[data[args["group"]] == g, args["y"]], + error_y=dict(type="data", array=errors), + name=g, + marker=dict(color=color), + ) figure["data"].append(trace) else: - if 'orientation' in args: + if "orientation" in args: figure["data"].append( - go.Bar( - x=data[args['x']], - y=data[args['y']], - orientation=args['orientation'] - ) - ) + go.Bar( + x=data[args["x"]], + y=data[args["y"]], + orientation=args["orientation"], + ) + ) else: figure["data"].append( - go.Bar( - x=data[args['x']], - y=data[args['y']], - ) - ) + go.Bar( + x=data[args["x"]], + y=data[args["y"]], + ) + ) figure["layout"] = go.Layout( - title=args['title'], - xaxis={"title": args["x_title"], "type": "category"}, - yaxis={"title": args["y_title"]}, - height=args['height'], - width=args['width'], - annotations=[dict(xref='paper', yref='paper', showarrow=False, text='')], - template='plotly_white' - ) + title=args["title"], + xaxis={"title": args["x_title"], "type": "category"}, + yaxis={"title": args["y_title"]}, + height=args["height"], + width=args["width"], + annotations=[dict(xref="paper", yref="paper", showarrow=False, text="")], + template="plotly_white", + ) return dcc.Graph(id=identifier, figure=figure) @@ -284,7 +361,7 @@ def get_histogram(data, identifier, args): :param data: pandas dataframe with at least values to be plotted. :param str identifier: id used to identify the div where the figure will be generated. - :param ditc args: see below. + :param dict args: see below. :Arguments: * **x** (str) -- name of the column containing values to plot in the x axis. * **y** (str) -- name of the column containing values to plot in the y axis (if used). @@ -298,72 +375,106 @@ def get_histogram(data, identifier, args): Example:: - result = get_histogram(data, identifier='histogram', args={'x':'a', 'color':'group', 'facet_row':'sample', 'title':'Facet Grid Plot'}) + result = get_histogram(data, + identifier='histogram', + args={'x':'a', + 'color':'group', + 'facet_row':'sample', + 'title':'Facet Grid Plot'} + ) """ figure = None - if 'x' in args and args['x'] in data: - if 'y' not in args: - args['y'] = None - elif args['y'] not in data: - args['y'] = None - if 'color' not in args: - args['color'] = None - elif args['color'] not in data: - args['color'] = None - if 'facet_row' not in args: - args['facet_row'] = None - elif args['facet_row'] not in data: - args['facet_row'] = None - if 'facet_col' not in args: - args['facet_col'] = None - elif args['facet_col'] not in data: - args['facet_col'] = None - if 'height' not in args: - args['height'] = 800 - if 'width' not in args: - args['width'] = None - if 'title' not in args: - args['title'] = None - - figure = px.histogram(data, x=args['x'], y=args['y'], color=args['color'], facet_row=args['facet_row'], facet_col=args['facet_col'], height=args['height'], width=args['width']) + if "x" in args and args["x"] in data: + if "y" not in args: + args["y"] = None + elif args["y"] not in data: + args["y"] = None + if "color" not in args: + args["color"] = None + elif args["color"] not in data: + args["color"] = None + if "facet_row" not in args: + args["facet_row"] = None + elif args["facet_row"] not in data: + args["facet_row"] = None + if "facet_col" not in args: + args["facet_col"] = None + elif args["facet_col"] not in data: + args["facet_col"] = None + if "height" not in args: + args["height"] = 800 + if "width" not in args: + args["width"] = None + if "title" not in args: + args["title"] = None + + figure = px.histogram( + data, + x=args["x"], + y=args["y"], + color=args["color"], + facet_row=args["facet_row"], + facet_col=args["facet_col"], + height=args["height"], + width=args["width"], + ) return dcc.Graph(id=identifier, figure=figure) -##ToDo + +# ToDo def get_facet_grid_plot(data, identifier, args): """ - This function plots a scatterplot matrix where we can plot one variable against another to form a regular scatter plot, and we can pick a third faceting variable - to form panels along the columns to segment the data even further, forming a bunch of vertical panels. For more information visit https://plot.ly/python/facet-trellis/. + This function plots a scatterplot matrix where we can plot one variable against another + to form a regular scatter plot, and we can pick a third faceting variable + to form panels along the columns to segment the data even further, + forming a bunch of vertical panels. - :param data: pandas dataframe with format: 'group', 'name', 'type', and 'x' and 'y' values to be plotted. + For more information visit https://plot.ly/python/facet-trellis/. + + :param data: pandas dataframe with format: + 'group', 'name', 'type', and 'x' and 'y' values to be plotted. :param str identifier: id used to identify the div where the figure will be generated. - :param ditc args: see below. + :param dict args: see below. :Arguments: * **x** (str) -- name of the column containing values to plot in the x axis. * **y** (str) -- name of the column containing values to plot in the y axis. * **group** (str) -- name of the column containing the group. * **class** (str) -- name of the column to be used as 'facet' column. - * **plot_type** (str) -- decides the type of plot to appear in the facet grid. The options are 'scatter', 'scattergl', 'histogram', 'bar', and 'box'. + * **plot_type** (str) -- decides the type of plot to appear in the facet grid. \ + The options are 'scatter', 'scattergl', 'histogram', 'bar', and 'box'. * **title** (str) -- plot title. :return: facet grid figure within the
. Example:: - result = get_facet_grid_plot(data, identifier='facet_grid', args={'x':'a', 'y':'b', 'group':'group', 'class':'type', 'plot_type':'bar', 'title':'Facet Grid Plot'}) - """ - figure = FF.create_facet_grid(data, - x=args['x'], - y=args['y'], - marker={'opacity': 1.}, - facet_col=args['class'], - color_name=args['group'], - color_is_cat=True, - trace_type=args['plot_type']) - figure['layout'] = dict(title=args['title'].title(), - paper_bgcolor=None, - legend=None, - annotations=[dict(xref='paper', yref='paper', showarrow=False, text='')], - template='plotly_white') + result = get_facet_grid_plot(data, + identifier='facet_grid', + args={'x':'a', + 'y':'b', + 'group':'group', + 'class':'type', + 'plot_type':'bar', + 'title':'Facet Grid Plot'} + ) + """ + figure = FF.create_facet_grid( + data, + x=args["x"], + y=args["y"], + marker={"opacity": 1.0}, + facet_col=args["class"], + color_name=args["group"], + color_is_cat=True, + trace_type=args["plot_type"], + ) + figure["layout"] = dict( + title=args["title"].title(), + paper_bgcolor=None, + legend=None, + annotations=[dict(xref="paper", yref="paper", showarrow=False, text="")], + template="plotly_white", + ) return dcc.Graph(id=identifier, figure=figure) @@ -372,24 +483,39 @@ def get_ranking_plot(data, identifier, args): """ Creates abundance multiplots (one per sample group). - :param data: long-format pandas dataframe with group as index, 'name' (protein identifiers) and 'y' (LFQ intensities) as columns. + :param data: long-format pandas dataframe with group as index, + 'name' (protein identifiers) and 'y' (LFQ intensities) as columns. :param str identifier: id used to identify the div where the figure will be generated. :param dict args: see below :Arguments: * **group** (str) -- name of the column containing the group. - * **index** (bool) -- set to True when multi samples per group. Calculates the mean intensity for each protein in each group. + * **index** (bool) -- set to True when multi samples per group. \ + Calculates the mean intensity for each protein in each group. * **x_title** (str) -- title of plot x axis. * **y_title** (str) -- title of plot y axis. * **title** (str) -- plot title. * **width** (int) -- plot width. * **height** (int) -- plot height. - * **annotations** (dict, optional) -- dictionary where data points names are the keys and descriptions are the values. + * **annotations** (dict, optional) -- dictionary where data points names are \ + the keys and descriptions are the values. :return: multi abundance plot figure within the
. Example:: - result = get_ranking_plot(data, identifier='ranking', args={'group':'group', 'index':'', 'x_title':'x_axis', 'y_title':'y_axis', \ - 'title':'Ranking Plot', 'width':100, 'height':150, 'annotations':{'GPT~P24298': 'liver disease', 'CP~P00450': 'Wilson disease'}}) + result = get_ranking_plot(data, + identifier='ranking', + args={'group':'group', + 'index':'', + 'x_title':'x_axis', + 'y_title':'y_axis', + 'title':'Ranking Plot', + 'width':100, + 'height':150, + 'annotations':{ + 'GPT~P24298': 'liver disease', + 'CP~P00450': 'Wilson disease'} + } + ) """ # data['y'] = data['y'].rpow(2) # data['y'] = np.log10(data['y']) @@ -397,42 +523,52 @@ def get_ranking_plot(data, identifier, args): num_cols = 3 fig = {} layouts = [] - if 'index' in args and args['index']: + if "index" in args and args["index"]: num_groups = len(data.index.unique()) - num_rows = math.ceil(num_groups/num_cols) - fig = tools.make_subplots(rows=num_rows, cols=num_cols, shared_yaxes=True, print_grid=False) + num_rows = math.ceil(num_groups / num_cols) + fig = tools.make_subplots( + rows=num_rows, cols=num_cols, shared_yaxes=True, print_grid=False + ) r = 1 c = 1 - range_y = [data['y'].min(), data['y'].max()+1] + range_y = [data["y"].min(), data["y"].max() + 1] i = 0 for index in data.index.unique(): - gdata = data.loc[index, :].dropna().groupby('name', as_index=False).mean().sort_values(by='y', ascending=False) + gdata = ( + data.loc[index, :] + .dropna() + .groupby("name", as_index=False) + .mean() + .sort_values(by="y", ascending=False) + ) gdata = gdata.reset_index().reset_index() - cols = ['x', 'group', 'name', 'y'] + cols = ["x", "group", "name", "y"] cols.extend(gdata.columns[4:]) gdata.columns = cols - if 'colors' in args: - gdata['colors'] = args['colors'][index] - - gfig = get_simple_scatterplot(gdata, identifier+'_'+str(index), args) - trace = gfig.figure['data'].pop() - glayout = gfig.figure['layout']['annotations'] - - for l in glayout: - nlayout = dict(x = l.x, - y = l.y, - xref = 'x'+str(i+1), - yref = 'y'+str(i+1), - text = l.text, - showarrow = True, - ax = l.ax, - ay = l.ay, - font = l.font, - align='center', - arrowhead=1, - arrowsize=1, - arrowwidth=1, - arrowcolor='#636363') + if "colors" in args: + gdata["colors"] = args["colors"][index] + + gfig = get_simple_scatterplot(gdata, identifier + "_" + str(index), args) + trace = gfig.figure["data"].pop() + glayout = gfig.figure["layout"]["annotations"] + + for _l in glayout: + nlayout = dict( + x=_l.x, + y=_l.y, + xref="x" + str(i + 1), + yref="y" + str(i + 1), + text=_l.text, + showarrow=True, + ax=_l.ax, + ay=_l.ay, + font=_l.font, + align="center", + arrowhead=1, + arrowsize=1, + arrowwidth=1, + arrowcolor="#636363", + ) layouts.append(nlayout) trace.name = index fig.append_trace(trace, r, c) @@ -443,24 +579,39 @@ def get_ranking_plot(data, identifier, args): else: c += 1 i += 1 - fig['layout'].update(dict(height = args['height'], - width=args['width'], - title=args['title'], - xaxis= {"title": args['x_title'], 'autorange':True}, - yaxis= {"title": args['y_title'], 'range':range_y}, - template='plotly_white')) - [fig['layout'][e].update(range=range_y) for e in fig['layout'] if e[0:5] == 'yaxis'] - fig['layout'].annotations = [dict(xref='paper', yref='paper', showarrow=False, text='')] + layouts + fig["layout"].update( + dict( + height=args["height"], + width=args["width"], + title=args["title"], + xaxis={"title": args["x_title"], "autorange": True}, + yaxis={"title": args["y_title"], "range": range_y}, + template="plotly_white", + ) + ) + [ + fig["layout"][e].update(range=range_y) + for e in fig["layout"] + if e[0:5] == "yaxis" + ] + fig["layout"].annotations = [ + dict(xref="paper", yref="paper", showarrow=False, text="") + ] + layouts else: - fig = get_simple_scatterplot(data, identifier+'_'+group, args).figure + if "group" in args: + identifier = identifier + f"_{args['group']}" + # ! get_simple_scatterplot does not use identifier... + fig = get_simple_scatterplot(data, identifier, args).figure return dcc.Graph(id=identifier, figure=fig) + def get_scatterplot_matrix(data, identifier, args): """ This function pltos a multi scatterplot (one for each unique element in args['group']). - :param data: pandas dataframe with four columns: 'name' of the data points, 'x' and 'y' values to plot, and 'group' they belong to. + :param data: pandas dataframe with four columns: 'name' of the data points, + 'x' and 'y' values to plot, and 'group' they belong to. :param str identifier: id used to identify the div where the figure will be generated. :param dict args: see below :Arguments: @@ -470,35 +621,52 @@ def get_scatterplot_matrix(data, identifier, args): * **y_title** (str) -- plot y axis title. * **height** (int) -- plot height. * **width** (int) -- plot width. - * **annotations** (dict, optional) -- dictionary where data points names are the keys and descriptions are the values. + * **annotations** (dict, optional) -- dictionary where data points names are \ + the keys and descriptions are the values. :return: multi scatterplot figure within the
. Example:: - result = get_scatterplot_matrix(data, identifier='scatter matrix', args={'group':'group', 'title':'Scatter Plot Matrix', 'x_title':'x_axis', \ - 'y_title':'y_axis', 'height':100, 'width':100, 'annotations':{'GPT~P24298': 'liver disease', 'CP~P00450': 'Wilson disease'}}) + result = get_scatterplot_matrix(data, + identifier='scatter matrix', + args={'group':'group', + 'title':'Scatter Plot Matrix', + 'x_title':'x_axis', + 'y_title':'y_axis', + 'height':100, + 'width':100, + 'annotations':{ + 'GPT~P24298': 'liver disease', + 'CP~P00450': 'Wilson disease'} + } + ) """ num_cols = 3 fig = {} - if 'group' in args and args['group'] in data.columns: - group = args['group'] + if "group" in args and args["group"] in data.columns: + group = args["group"] num_groups = len(data[group].unique()) - num_rows = math.ceil(num_groups/num_cols) - if 'colors' not in data.columns: - if 'colors' in args: - data['colors'] = [args['colors'][g] if g in args['colors'] else '#999999' for g in data[group]] + num_rows = math.ceil(num_groups / num_cols) + if "colors" not in data.columns: + if "colors" in args: + data["colors"] = [ + args["colors"][g] if g in args["colors"] else "#999999" + for g in data[group] + ] - fig = tools.make_subplots(rows=num_rows, cols=num_cols, shared_yaxes=True,print_grid=False) + fig = tools.make_subplots( + rows=num_rows, cols=num_cols, shared_yaxes=True, print_grid=False + ) r = 1 c = 1 range_y = None - if pd.api.types.is_numeric_dtype(data['y']): - range_y = [data['y'].min(), data['y'].max()+1] + if pd.api.types.is_numeric_dtype(data["y"]): + range_y = [data["y"].min(), data["y"].max() + 1] for g in data[group].unique(): gdata = data[data[group] == g].dropna() - gfig = get_simple_scatterplot(gdata, identifier+'_'+str(g), args) - trace = gfig.figure['data'].pop() + gfig = get_simple_scatterplot(gdata, identifier + "_" + str(g), args) + trace = gfig["data"].pop() trace.name = g fig.append_trace(trace, r, c) @@ -508,29 +676,36 @@ def get_scatterplot_matrix(data, identifier, args): else: c += 1 - fig['layout'].update(dict(height = args['height'], - width=args['width'], - title=args['title'], - xaxis= {"title": args['x_title'], 'autorange':True}, - yaxis= {"title": args['y_title'], 'range':range_y}, - template='plotly_white')) - - fig['layout'].annotations = [dict(xref='paper', yref='paper', showarrow=False, text='')] + fig["layout"].update( + dict( + height=args["height"], + width=args["width"], + title=args["title"], + xaxis={"title": args["x_title"], "autorange": True}, + yaxis={"title": args["y_title"], "range": range_y}, + template="plotly_white", + ) + ) - return dcc.Graph(id=identifier, figure=fig) + fig["layout"].annotations = [ + dict(xref="paper", yref="paper", showarrow=False, text="") + ] + return fig def get_simple_scatterplot(data, identifier, args): """ Plots a simple scatterplot with the possibility of including in-plot annotations of data points. - :param data: long-format pandas dataframe with columns: 'x' (ranking position), 'group' (original dataframe position), \ - 'name' (protein identifier), 'y' (LFQ intensity), 'symbol' (data point shape) and 'size' (data point size). + :param data: long-format pandas dataframe with columns: 'x' (ranking position), + 'group' (original dataframe position), 'name' (protein identifier), + 'y' (LFQ intensity), 'symbol' (data point shape) and 'size' (data point size). :param str identifier: id used to identify the div where the figure will be generated. :param dict args: see below. :Arguments: - * **annotations** (dict) -- dictionary where data points names are the keys and descriptions are the values. + * **annotations** (dict) -- dictionary where data points names are \ + the keys and descriptions are the values. * **title** (str) -- plot title. * **x_title** (str) -- plot x axis title. * **y_title** (str) -- plot y axis title. @@ -540,72 +715,92 @@ def get_simple_scatterplot(data, identifier, args): Example:: - result = get_scatterplot_matrix(data, identifier='scatter plot', args={'annotations':{'GPT~P24298': 'liver disease', 'CP~P00450': 'Wilson disease'}', \ - 'title':'Scatter Plot', 'x_title':'x_axis', 'y_title':'y_axis', 'height':100, 'width':100}) + result = get_simple_scatterplot(data, + identifier='scatter plot', + args={'annotations':{'GPT~P24298': 'liver disease', + 'CP~P00450': 'Wilson disease'}', + 'title':'Scatter Plot', + 'x_title':'x_axis', + 'y_title':'y_axis', + 'height':100, + 'width':100} + ) """ figure = {} - m = {'size': 15, 'line': {'width': 0.5, 'color': 'grey'}} + m = {"size": 15, "line": {"width": 0.5, "color": "grey"}} text = data.name - if 'colors' in data.columns: - m.update({'color':data['colors'].tolist()}) - elif 'colors' in args: - m.update({'color':args['colors']}) - if 'size' in data.columns: - m.update({'size':data['size'].tolist()}) - if 'symbol' in data.columns: - m.update({'symbol':data['symbol'].tolist()}) - - annots=[] - if 'annotations' in args: + if "colors" in data.columns: + m.update({"color": data["colors"].tolist()}) + elif "colors" in args: + m.update({"color": args["colors"]}) + if "size" in data.columns: + m.update({"size": data["size"].tolist()}) + if "symbol" in data.columns: + m.update({"symbol": data["symbol"].tolist()}) + + annots = [] + if "annotations" in args: for index, row in data.iterrows(): - name = str(row['name']).split(' ')[0] - if name in args['annotations']: - annots.append({'x': row['x'], - 'y': row['y'], - 'xref':'x', - 'yref': 'y', - 'text': name, - 'showarrow': False, - 'ax': 55, - 'ay': -1, - 'font': dict(size = 8)}) - figure['data'] = [go.Scattergl(x = data.x, - y = data.y, - text = text, - mode = 'markers', - opacity=0.7, - marker= m, - )] - - figure["layout"] = go.Layout(title = args['title'], - xaxis= {"title": args['x_title']}, - yaxis= {"title": args['y_title']}, - #margin={'l': 40, 'b': 40, 't': 30, 'r': 10}, - legend={'x': 0, 'y': 1}, - hovermode='closest', - height=args['height'], - width=args['width'], - annotations = annots + [dict(xref='paper', yref='paper', showarrow=False, text='')], - showlegend=False, - template='plotly_white' - ) + name = str(row["name"]).split(" ")[0] + if name in args["annotations"]: + annots.append( + { + "x": row["x"], + "y": row["y"], + "xref": "x", + "yref": "y", + "text": name, + "showarrow": False, + "ax": 55, + "ay": -1, + "font": dict(size=8), + } + ) + figure["data"] = [ + go.Scattergl( + x=data.x, + y=data.y, + text=text, + mode="markers", + opacity=0.7, + marker=m, + ) + ] - return dcc.Graph(id= identifier, figure = figure) + figure["layout"] = go.Layout( + title=args["title"], + xaxis={"title": args["x_title"]}, + yaxis={"title": args["y_title"]}, + # margin={'l': 40, 'b': 40, 't': 30, 'r': 10}, + legend={"x": 0, "y": 1}, + hovermode="closest", + height=args["height"], + width=args["width"], + annotations=annots + + [dict(xref="paper", yref="paper", showarrow=False, text="")], + showlegend=False, + template="plotly_white", + ) + + return figure def get_scatterplot(data, identifier, args): """ This function plots a simple Scatterplot. - :param data: is a Pandas DataFrame with four columns: "name", x values and y values (provided as variables) to plot. + :param data: is a Pandas DataFrame with four columns: "name", x values and y values + (provided as variables) to plot. :param str identifier: is the id used to identify the div where the figure will be generated. :param dict args: see below. :Arguments: * **title** (str) -- title of the figure. * **x** (str) -- column in dataframe with values for x * **y** (str) -- column in dataframe with values for y - * **group** (str) -- column in dataframe with the groups - translates into colors (default None) - * **hovering_cols** (list) -- list of columns in dataframe that will be shown when hovering over a dot + * **group** (str) -- column in dataframe with the groups - translates into colors \ + (default None) + * **hovering_cols** (list) -- list of columns in dataframe that will be shown when \ + hovering over a dot * **size** (str) -- column in dataframe that contains the size of the dots (default None) * **trendline** (bool) -- whether or not to draw a trendline * **text** (str) -- column in dataframe that contains the values shown for each dot @@ -618,81 +813,130 @@ def get_scatterplot(data, identifier, args): Example:: - result = get_scatteplot(data, identifier='scatter plot', 'title':'Scatter Plot', 'x_title':'x_axis', 'y_title':'y_axis', 'height':100, 'width':100})) + result = get_scatterplot(data, + identifier='scatter plot', + args={'title':'Scatter Plot', + 'x_title':'x_axis', + 'y_title':'y_axis', + 'height':100, + 'width':100} + ) """ - annotation =[] - title = 'Scatter plot' - x_title = 'x' - y_title = 'y' + annotation = [] + title = "Scatter plot" + x_title = "x" + y_title = "y" height = 800 width = 800 size = None symbol = None - x = 'x' - y = 'y' + x = "x" + y = "y" trendline = None group = None text = None - if 'x' in args: - x = args['x'] - if 'y' in args: - y = args['y'] - if 'group' in args: - group = args['group'] + if "x" in args: + x = args["x"] + if "y" in args: + y = args["y"] + if "group" in args: + group = args["group"] if "hovering_cols" in args: annotation = args["hovering_cols"] - if 'title' in args: - title = args['title'] - if 'x_title' in args: - x_title = args['x_title'] - if 'y_title' in args: - y_title = args['y_title'] - if 'height' in args: - height = args['height'] - if 'width' in args: - width = args['width'] - if 'size' in args: - size = args['size'] - if 'symbol' in args: - symbol = args['symbol'] - if 'trendline' in args: - trendline = args['trendline'] - if 'text' in args: - text = args['text'] - - if 'colors' in args and isinstance(args['colors'], dict): - figure = px.scatter(data, x=x, y=y, color=group, color_discrete_map=args['colors'], hover_data=annotation, size=size, symbol=symbol, trendline=trendline, text=text) + if "title" in args: + title = args["title"] + if "x_title" in args: + x_title = args["x_title"] + if "y_title" in args: + y_title = args["y_title"] + if "height" in args: + height = args["height"] + if "width" in args: + width = args["width"] + if "size" in args: + size = args["size"] + if "symbol" in args: + symbol = args["symbol"] + if "trendline" in args: + trendline = args["trendline"] + if "text" in args: + text = args["text"] + + if "colors" in args and isinstance(args["colors"], dict): + figure = px.scatter( + data, + x=x, + y=y, + color=group, + color_discrete_map=args["colors"], + hover_data=annotation, + size=size, + symbol=symbol, + trendline=trendline, + text=text, + ) + elif "density" in args and args["density"]: + color = get_density(data[x], data[y]) + figure = px.scatter( + data, + x=x, + y=y, + color=color, + hover_data=annotation, + size=size, + symbol=symbol, + trendline=trendline, + text=text, + ) else: - figure = px.scatter(data, x=x, y=y, color=group, hover_data=annotation, size=size, symbol=symbol, trendline=trendline, text=text) - - figure.update_traces(marker=dict(size=14, - opacity=0.7, - line=dict(width=0.5, color='DarkSlateGrey')), - selector=dict(mode='markers')) - figure["layout"] = go.Layout(title = title, - xaxis= {"title": x_title}, - yaxis= {"title": y_title}, - legend=dict(orientation="h", - yanchor="bottom", - y=1.0, - xanchor="right", - x=1), - hovermode='closest', - height=height, - width=width, - annotations = [dict(xref='paper', yref='paper', showarrow=False, text='')], - template='plotly_white' - ) + figure = px.scatter( + data, + x=x, + y=y, + color=group, + hover_data=annotation, + size=size, + symbol=symbol, + trendline=trendline, + text=text, + ) + + figure.update_traces( + marker=dict(size=14, opacity=0.7, line=dict(width=0.5, color="DarkSlateGrey")), + selector=dict(mode="markers"), + ) + figure["layout"] = go.Layout( + title=title, + xaxis={"title": x_title}, + yaxis={"title": y_title}, + legend=dict(orientation="h", yanchor="bottom", y=1.0, xanchor="right", x=1), + hovermode="closest", + height=height, + width=width, + annotations=[dict(xref="paper", yref="paper", showarrow=False, text="")], + template="plotly_white", + ) + + return figure - return dcc.Graph(id= identifier, figure = figure) +def get_density(x: np.ndarray, y: np.ndarray): + """Get kernal density estimate for each (x, y) point.""" + values = np.vstack([x, y]) + kernel = stats.gaussian_kde(values) + density = kernel(values) + + return density + def get_volcanoplot(results, args): """ This function plots volcano plots for each internal dictionary in a nested dictionary. - :param dict[dict] results: nested dictionary with pairwise group comparisons as keys and internal dictionaries containing 'x' (log2FC values), \ - 'y' (-log10 p-values), 'text', 'color', 'pvalue' and 'annotations' (number of hits to be highlighted). + :param dict[dict] results: nested dictionary with pairwise group comparisons as keys and + internal dictionaries containing 'x' (log2FC values), + 'y' (-log10 p-values), 'text', 'color', + 'pvalue' and 'annotations' (number of hits to be highlighted). :param dict args: see below. :Arguments: * **fc** (float) -- fold change threshold. @@ -700,7 +944,8 @@ def get_volcanoplot(results, args): * **range_y** (list) -- list with minimum and maximum values for y axis. * **x_title** (str) -- plot x axis title. * **y_title** (str) -- plot y axis title. - * **colorscale** (str) -- string for predefined plotly colorscales or dict containing one or more of the keys listed in \ + * **colorscale** (str) -- string for predefined plotly colorscales or dict containing \ + one or more of the keys listed in \ https://plot.ly/python/reference/#layout-colorscale. * **showscale** (bool) -- determines whether or not a colorbar is displayed for a trace. * **marker_size** (int) -- sets the marker size (in px). @@ -708,114 +953,156 @@ def get_volcanoplot(results, args): Example:: - result = get_volcanoplot(results, args={'fc':2.0, 'range_x':[0, 1], 'range_y':[-1, 1], 'x_title':'x_axis', 'y_title':'y_title', 'colorscale':'Blues', \ - 'showscale':True, 'marker_size':7}) + result = get_volcanoplot(results, + args={'fc':2.0, + 'range_x':[0, 1], + 'range_y':[-1, 1], + 'x_title':'x_axis', + 'y_title':'y_title', + 'colorscale':'Blues', + 'showscale':True, 'marker_size':7} + ) """ figures = [] - for identifier,title in results: - result = results[(identifier,title)] - figure = {"data":[],"layout":None} + for identifier, title in results: + result = results[(identifier, title)] + figure = {"data": [], "layout": None} if "range_x" not in args: - range_x = [-max(abs(result['x']))-0.1, max(abs(result['x']))+0.1]#if symmetric_x else [] + range_x = [ + -max(abs(result["x"])) - 0.1, + max(abs(result["x"])) + 0.1, + ] # if symmetric_x else [] else: range_x = args["range_x"] if "range_y" not in args: - range_y = [0,max(abs(result['y']))+1.] + range_y = [0, max(abs(result["y"])) + 1.0] else: range_y = args["range_y"] - traces = [go.Scatter(x=result['x'], - y=result['y'], - mode='markers', - text=result['text'], - hoverinfo='text', - marker={'color':result['color'], - 'colorscale': args["colorscale"], - 'showscale': args['showscale'], - 'size': args['marker_size'], - 'line': {'color':result['color'], 'width':2} - } - )] + traces = [ + go.Scatter( + x=result["x"], + y=result["y"], + mode="markers", + text=result["text"], + hoverinfo="text", + marker={ + "color": result["color"], + "colorscale": args["colorscale"], + "showscale": args["showscale"], + "size": args["marker_size"], + "line": {"color": result["color"], "width": 2}, + }, + ) + ] shapes = [] - if ('is_samr' in result and not result['is_samr']) or 'is_samr' not in result: - shapes = [{'type': 'line', - 'x0': np.log2(args['fc']), - 'y0': 0, - 'x1': np.log2(args['fc']), - 'y1': range_y[1], - 'line': { - 'color': 'grey', - 'width': 2, - 'dash':'dashdot' - }, - }, - {'type': 'line', - 'x0': -np.log2(args['fc']), - 'y0': 0, - 'x1': -np.log2(args['fc']), - 'y1': range_y[1], - 'line': { - 'color': 'grey', - 'width': 2, - 'dash': 'dashdot' - }, - }, - {'type': 'line', - 'x0': -max(abs(result['x']))-0.1, - 'y0': result['pvalue'], - 'x1': max(abs(result['x']))+0.1, - 'y1': result['pvalue'], - 'line': { - 'color': 'grey', - 'width': 1, - 'dash': 'dashdot' - }, - }] - #traces.append(go.Scattergl(x=result['upfc'][0], y=result['upfc'][1])) - #traces.append(go.Scattergl(x=result['downfc'][0], y=result['downfc'][1])) + if ("is_samr" in result and not result["is_samr"]) or "is_samr" not in result: + shapes = [ + { + "type": "line", + "x0": np.log2(args["fc"]), + "y0": 0, + "x1": np.log2(args["fc"]), + "y1": range_y[1], + "line": {"color": "grey", "width": 2, "dash": "dashdot"}, + }, + { + "type": "line", + "x0": -np.log2(args["fc"]), + "y0": 0, + "x1": -np.log2(args["fc"]), + "y1": range_y[1], + "line": {"color": "grey", "width": 2, "dash": "dashdot"}, + }, + { + "type": "line", + "x0": -max(abs(result["x"])) - 0.1, + "y0": result["pvalue"], + "x1": max(abs(result["x"])) + 0.1, + "y1": result["pvalue"], + "line": {"color": "grey", "width": 1, "dash": "dashdot"}, + }, + ] + # traces.append(go.Scattergl(x=result['upfc'][0], y=result['upfc'][1])) + # traces.append(go.Scattergl(x=result['downfc'][0], y=result['downfc'][1])) figure["data"] = traces - figure["layout"] = go.Layout(title=title, - xaxis={'title': args['x_title'], 'range': range_x}, - yaxis={'title': args['y_title'], 'range': range_y}, - hovermode='closest', - shapes=shapes, - width=950, - height=1050, - annotations = result['annotations']+[dict(xref='paper', yref='paper', showarrow=False, text='')], - template='plotly_white', - showlegend=False) - - figures.append(dcc.Graph(id= identifier, figure = figure)) + figure["layout"] = go.Layout( + title=title, + xaxis={"title": args["x_title"], "range": range_x}, + yaxis={"title": args["y_title"], "range": range_y}, + hovermode="closest", + shapes=shapes, + width=950, + height=1050, + annotations=result["annotations"] + + [dict(xref="paper", yref="paper", showarrow=False, text="")], + template="plotly_white", + showlegend=False, + ) + + figures.append(dcc.Graph(id=identifier, figure=figure)) return figures -def run_volcano(data, identifier, args={'alpha':0.05, 'fc':2, 'colorscale':'Blues', 'showscale': False, 'marker_size':8, 'x_title':'log2FC', 'y_title':'-log10(pvalue)', 'num_annotations':10, 'annotate_list':[]}): - """ - This function parsers the regulation data from statistical tests and creates volcano plots for all distinct group comparisons. Significant hits with lowest adjusted p-values are highlighed. - :param data: pandas dataframe with format: 'identifier', 'group1', 'group2', 'mean(group1', 'mean(group2)', 'log2FC', 'std_error', 'tail', 't-statistics', 'padj_THSD', \ - 'effsize', 'efftype', 'FC', 'rejected', 'F-statistics', 'pvalue', 'padj', 'correction', '-log10 pvalue' and 'Method'. +def run_volcano( + data, + identifier, + args={ + "alpha": 0.05, + "fc": 2, + "colorscale": "Blues", + "showscale": False, + "marker_size": 8, + "x_title": "log2FC", + "y_title": "-log10(pvalue)", + "num_annotations": 10, + "annotate_list": [], + }, +): + """ + This function parsers the regulation data from statistical tests and + creates volcano plots for all distinct group comparisons. + Significant hits with lowest adjusted p-values are highlighed. + + :param data: pandas dataframe with format: + 'identifier', 'group1', 'group2', 'mean(group1', + 'mean(group2)', 'log2FC', 'std_error', 'tail', + 't-statistics', 'padj_THSD', + 'effsize', 'efftype', 'FC', 'rejected', + 'F-statistics', 'pvalue', 'padj', 'correction', '-log10 pvalue' and 'Method'. :param str identifier: id used to identify the div where the figure will be generated. :param dict args: see below. :Arguments: * **alpha** (float) -- adjusted p-value threshold for significant hits. * **fc** (float) -- fold change threshold. - * **colorscale** (str or dict) -- name of predefined plotly colorscale or dictionary containing one or more of the keys listed in \ - https://plot.ly/python/reference/#layout-colorscale. + * **colorscale** (str or dict) -- name of predefined plotly colorscale or dictionary \ + containing one or more of the keys listed in \ + https://plot.ly/python/reference/#layout-colorscale. * **showscale** (bool) -- determines whether or not a colorbar is displayed for a trace. * **marker_size** (int) -- sets the marker size (in px). * **x_title** (str) -- plot x axis title. * **y_title** (str) -- plot y axis title. - * **num_annotations** (int) -- number of hits to be highlighted (if num_annotations = 10, highlights 10 hits with lowest significant adjusted p-value). + * **num_annotations** (int) -- number of hits to be highlighted (if num_annotations = 10,\ + highlights 10 hits with lowest significant adjusted p-value). :return: list of volcano plot figures within the
. Example:: - result = run_volcano(data, identifier='volvano data', args={'alpha':0.05, 'fc':2.0, 'colorscale':'Blues', 'showscale':False, 'marker_size':6, 'x_title':'log2FC', \ - 'y_title':'-log10(pvalue)', 'num_annotations':10}) + result = run_volcano(data, + identifier='volvano data', + args={'alpha':0.05, + 'fc':2.0, + 'colorscale':'Blues', + 'showscale':False, + 'marker_size':6, + 'x_title':'log2FC', + 'y_title':'-log10(pvalue)', + 'num_annotations':10} + ) """ # Loop through signature volcano_plot_results = {} - grouping = data.groupby(['group1','group2']) + grouping = data.groupby(["group1", "group2"]) for group in grouping.groups: signature = grouping.get_group(group) @@ -824,13 +1111,13 @@ def run_volcano(data, identifier, args={'alpha':0.05, 'fc':2, 'colorscale':'Blue line_colors = [] text = [] annotations = [] - num_annotations = args['num_annotations'] if 'num_annotations' in args else 10 - gidentifier = identifier + "_".join(map(str,group)) - title = 'Comparison: '+str(group[0])+' vs '+str(group[1]) - sig_pval = False + num_annotations = args["num_annotations"] if "num_annotations" in args else 10 + gidentifier = identifier + "_".join(map(str, group)) + title = "Comparison: " + str(group[0]) + " vs " + str(group[1]) + # sig_pval = False # ! not used padj_col = "padj" pval_col = "pvalue" - is_samr = 's0' in signature + is_samr = "s0" in signature if "posthoc padj" in signature: padj_col = "posthoc padj" pval_col = "posthoc pvalue" @@ -838,74 +1125,103 @@ def run_volcano(data, identifier, args={'alpha':0.05, 'fc':2, 'colorscale':'Blue elif "padj" in signature: signature = signature.sort_values(by="padj", ascending=True) - signature = signature.reindex(signature['log2FC'].abs().sort_values(ascending=False).index) + signature = signature.reindex( + signature["log2FC"].abs().sort_values(ascending=False).index + ) pvals = [] for index, row in signature.iterrows(): # Text - text.append(''+str(row['identifier'])+": "+str(index)+'
Comparison: '+str(row['group1'])+' vs '+str(row['group2'])+'
log2FC = '+str(round(row['log2FC'], ndigits=2))+'
p = '+'{:.2e}'.format(row[pval_col])+'
FDR = '+'{:.2e}'.format(row[padj_col])) + text.append( + "" + + str(row["identifier"]) + + ": " + + str(index) + + "
Comparison: " + + str(row["group1"]) + + " vs " + + str(row["group2"]) + + "
log2FC = " + + str(round(row["log2FC"], ndigits=2)) + + "
p = " + + "{:.2e}".format(row[pval_col]) + + "
FDR = " + + "{:.2e}".format(row[padj_col]) + ) # Color - if row[padj_col] < args['alpha']: - pvals.append(row['-log10 pvalue']) - sig_pval = True - if row['log2FC'] <= -np.log2(args['fc']): - annotations.append({'x': row['log2FC'], - 'y': row['-log10 pvalue'], - 'xref':'x', - 'yref': 'y', - 'text': str(row['identifier']), - 'showarrow': False, - 'ax': 0, - 'ay': -10, - 'font': dict(color = "#2c7bb6", size = 13)}) - color.append('rgba(44, 123, 182, 0.7)') - color_dict[row['identifier']] = "#2c7bb6" - line_colors.append('#2c7bb6') - elif row['log2FC'] >= np.log2(args['fc']): - annotations.append({'x': row['log2FC'], - 'y': row['-log10 pvalue'], - 'xref':'x', - 'yref': 'y', - 'text': str(row['identifier']), - 'showarrow': False, - 'ax': 0, - 'ay': -10, - 'font': dict(color = "#d7191c", size = 13)}) - color.append('rgba(215, 25, 28, 0.7)') - color_dict[row['identifier']] = "#d7191c" - line_colors.append('#d7191c') - elif row['log2FC'] < 0.: - color.append('rgba(171, 217, 233, 0.5)') - color_dict[row['identifier']] = '#abd9e9' - line_colors.append('#abd9e9') - elif row['log2FC'] > 0.: - color.append('rgba(253, 174, 97, 0.5)') - color_dict[row['identifier']] = '#fdae61' - line_colors.append('#fdae61') + if row[padj_col] < args["alpha"]: + pvals.append(row["-log10 pvalue"]) + # sig_pval = True # ! not used + if row["log2FC"] <= -np.log2(args["fc"]): + annotations.append( + { + "x": row["log2FC"], + "y": row["-log10 pvalue"], + "xref": "x", + "yref": "y", + "text": str(row["identifier"]), + "showarrow": False, + "ax": 0, + "ay": -10, + "font": dict(color="#2c7bb6", size=13), + } + ) + color.append("rgba(44, 123, 182, 0.7)") + color_dict[row["identifier"]] = "#2c7bb6" + line_colors.append("#2c7bb6") + elif row["log2FC"] >= np.log2(args["fc"]): + annotations.append( + { + "x": row["log2FC"], + "y": row["-log10 pvalue"], + "xref": "x", + "yref": "y", + "text": str(row["identifier"]), + "showarrow": False, + "ax": 0, + "ay": -10, + "font": dict(color="#d7191c", size=13), + } + ) + color.append("rgba(215, 25, 28, 0.7)") + color_dict[row["identifier"]] = "#d7191c" + line_colors.append("#d7191c") + elif row["log2FC"] < 0.0: + color.append("rgba(171, 217, 233, 0.5)") + color_dict[row["identifier"]] = "#abd9e9" + line_colors.append("#abd9e9") + elif row["log2FC"] > 0.0: + color.append("rgba(253, 174, 97, 0.5)") + color_dict[row["identifier"]] = "#fdae61" + line_colors.append("#fdae61") else: - color.append('rgba(153, 153, 153, 0.3)') - color_dict[row['identifier']] = '#999999' - line_colors.append('#999999') + color.append("rgba(153, 153, 153, 0.3)") + color_dict[row["identifier"]] = "#999999" + line_colors.append("#999999") else: - color.append('rgba(153, 153, 153, 0.3)') - line_colors.append('#999999') + color.append("rgba(153, 153, 153, 0.3)") + line_colors.append("#999999") - if 'annotate_list' in args: - if len(args['annotate_list']) > 0: + if "annotate_list" in args: + if len(args["annotate_list"]) > 0: annotations = [] - hits = args['annotate_list'] - selected = signature[signature['identifier'].isin(hits)] + hits = args["annotate_list"] + selected = signature[signature["identifier"].isin(hits)] for index, row in selected.iterrows(): - annotations.append({'x': row['log2FC'], - 'y': row['-log10 pvalue'], - 'xref':'x', - 'yref': 'y', - 'text': str(row['identifier']), - 'showarrow': False, - 'ax': 0, - 'ay': -10, - 'font': dict(color = color_dict[row['identifier']], size = 12)}) + annotations.append( + { + "x": row["log2FC"], + "y": row["-log10 pvalue"], + "xref": "x", + "yref": "y", + "text": str(row["identifier"]), + "showarrow": False, + "ax": 0, + "ay": -10, + "font": dict(color=color_dict[row["identifier"]], size=12), + } + ) if len(annotations) < num_annotations: num_annotations = len(annotations) @@ -916,18 +1232,29 @@ def run_volcano(data, identifier, args={'alpha':0.05, 'fc':2, 'colorscale':'Blue else: min_pval_sign = 0 - volcano_plot_results[(gidentifier, title)] = {'x': signature['log2FC'].values, 'y': signature['-log10 pvalue'].values, 'text':text, 'color': color, 'line_color':line_colors, 'pvalue':min_pval_sign, 'is_samr':is_samr, 'annotations':annotations[0:num_annotations]} + volcano_plot_results[(gidentifier, title)] = { + "x": signature["log2FC"].values, + "y": signature["-log10 pvalue"].values, + "text": text, + "color": color, + "line_color": line_colors, + "pvalue": min_pval_sign, + "is_samr": is_samr, + "annotations": annotations[0:num_annotations], + } figures = get_volcanoplot(volcano_plot_results, args) return figures + def get_heatmapplot(data, identifier, args): """ This function plots a simple Heatmap. - :param data: is a Pandas DataFrame with the shape of the heatmap where index corresponds to rows \ - and column names corresponds to columns, values in the heatmap corresponds to the row values. + :param data: is a Pandas DataFrame with the shape of the heatmap where index corresponds to rows + and column names corresponds to columns, + values in the heatmap corresponds to the row values. :param str identifier: is the id used to identify the div where the figure will be generated. :param dict args: see below. :Arguments: @@ -940,110 +1267,155 @@ def get_heatmapplot(data, identifier, args): Example:: - result = get_heatmapplot(data, identifier='heatmap', args={'format':'edgelist', 'source':'node1', 'target':'node2', 'values':'score', 'title':'Heatmap Plot'}) + result = get_heatmapplot(data, + identifier='heatmap', + args={'format':'edgelist', + 'source':'node1', + 'target':'node2', + 'values':'score', + 'title':'Heatmap Plot'}) """ df = data.copy() - if args['format'] == "edgelist": - df = df.set_index(args['source']) - df = df.pivot_table(values=args['values'], index=df.index, columns=args['target'], aggfunc='first') + if args["format"] == "edgelist": + df = df.set_index(args["source"]) + df = df.pivot_table( + values=args["values"], + index=df.index, + columns=args["target"], + aggfunc="first", + ) df = df.fillna(0) figure = {} figure["data"] = [] - figure["layout"] = {"title":args['title'], - "height": 500, - "width": 700, - "annotations" : [dict(xref='paper', yref='paper', showarrow=False, text='')], - "template":'plotly_white'} - figure['data'].append(go.Heatmap(z=df.values.tolist(), - x = list(df.columns), - y = list(df.index))) + figure["layout"] = { + "title": args["title"], + "height": 500, + "width": 700, + "annotations": [dict(xref="paper", yref="paper", showarrow=False, text="")], + "template": "plotly_white", + } + figure["data"].append( + go.Heatmap(z=df.values.tolist(), x=list(df.columns), y=list(df.index)) + ) - return dcc.Graph(id = identifier, figure = figure) + return dcc.Graph(id=identifier, figure=figure) def get_complex_heatmapplot(data, identifier, args): df = data.copy() - figure = {'data':[], 'layout':{}} - if args['format'] == "edgelist": - df = df.set_index(args['source']) - df = df.pivot_table(values=args['values'], index=df.index, columns=args['target'], aggfunc='first') + figure = {"data": [], "layout": {}} + if args["format"] == "edgelist": + df = df.set_index(args["source"]) + df = df.pivot_table( + values=args["values"], + index=df.index, + columns=args["target"], + aggfunc="first", + ) df = df.fillna(0) - dendro_up = FF.create_dendrogram(df.values, orientation='bottom', labels=df.columns) - for i in range(len(dendro_up['data'])): - dendro_up['data'][i]['yaxis'] = 'y2' + dendro_up = FF.create_dendrogram(df.values, orientation="bottom", labels=df.columns) + for i in range(len(dendro_up["data"])): + dendro_up["data"][i]["yaxis"] = "y2" - dendro_side = FF.create_dendrogram(df.values, orientation='right') - for i in range(len(dendro_side['data'])): - dendro_side['data'][i]['xaxis'] = 'x2' + dendro_side = FF.create_dendrogram(df.values, orientation="right") + for i in range(len(dendro_side["data"])): + dendro_side["data"][i]["xaxis"] = "x2" - figure['data'].extend(dendro_up['data']) - figure['data'].extend(dendro_side['data']) + figure["data"].extend(dendro_up["data"]) + figure["data"].extend(dendro_side["data"]) - if args['dist']: + if args["dist"]: data_dist = pdist(df.values) heat_data = squareform(data_dist) else: heat_data = df.values - dendro_leaves = dendro_side['layout']['yaxis']['ticktext'] + dendro_leaves = dendro_side["layout"]["yaxis"]["ticktext"] dendro_leaves = list(map(int, dendro_leaves)) - heat_data = heat_data[dendro_leaves,:] - heat_data = heat_data[:,dendro_leaves] + heat_data = heat_data[dendro_leaves, :] + heat_data = heat_data[:, dendro_leaves] heatmap = [ go.Heatmap( - x = dendro_leaves, - y = dendro_leaves, - z = heat_data, - colorscale = 'YlOrRd', - reversescale=True + x=dendro_leaves, + y=dendro_leaves, + z=heat_data, + colorscale="YlOrRd", + reversescale=True, ) ] - heatmap[0]['x'] = dendro_up['layout']['xaxis']['tickvals'] - heatmap[0]['y'] = dendro_side['layout']['yaxis']['tickvals'] - figure['data'].extend(heatmap) - - figure['layout'] = dendro_up['layout'] - - figure['layout'].update({'width':800, 'height':800, - 'showlegend':False, 'hovermode': 'closest', - "template":'plotly_white' - }) - figure['layout']['xaxis'].update({'domain': [.15, 1], - 'mirror': False, - 'showgrid': False, - 'showline': False, - 'zeroline': False, - 'ticks':""}) - figure['layout'].update({'xaxis2': {'domain': [0, .15], - 'mirror': False, - 'showgrid': False, - 'showline': False, - 'zeroline': False, - 'showticklabels': False, - 'ticks':""}}) - - figure['layout']['yaxis'].update({'domain': [0, .85], - 'mirror': False, - 'showgrid': False, - 'showline': False, - 'zeroline': False, - 'showticklabels': False, - 'ticks': ""}) - figure['layout'].update({'yaxis2':{'domain':[.825, .975], - 'mirror': False, - 'showgrid': False, - 'showline': False, - 'zeroline': False, - 'showticklabels': False, - 'ticks':""}, - 'annotations': [dict(xref='paper', yref='paper', showarrow=False, text='')] - }) - - - return dcc.Graph(id=identifier, figure=figure,) + heatmap[0]["x"] = dendro_up["layout"]["xaxis"]["tickvals"] + heatmap[0]["y"] = dendro_side["layout"]["yaxis"]["tickvals"] + figure["data"].extend(heatmap) + + figure["layout"] = dendro_up["layout"] + + figure["layout"].update( + { + "width": 800, + "height": 800, + "showlegend": False, + "hovermode": "closest", + "template": "plotly_white", + } + ) + figure["layout"]["xaxis"].update( + { + "domain": [0.15, 1], + "mirror": False, + "showgrid": False, + "showline": False, + "zeroline": False, + "ticks": "", + } + ) + figure["layout"].update( + { + "xaxis2": { + "domain": [0, 0.15], + "mirror": False, + "showgrid": False, + "showline": False, + "zeroline": False, + "showticklabels": False, + "ticks": "", + } + } + ) + + figure["layout"]["yaxis"].update( + { + "domain": [0, 0.85], + "mirror": False, + "showgrid": False, + "showline": False, + "zeroline": False, + "showticklabels": False, + "ticks": "", + } + ) + figure["layout"].update( + { + "yaxis2": { + "domain": [0.825, 0.975], + "mirror": False, + "showgrid": False, + "showline": False, + "zeroline": False, + "showticklabels": False, + "ticks": "", + }, + "annotations": [dict(xref="paper", yref="paper", showarrow=False, text="")], + } + ) + + return dcc.Graph( + id=identifier, + figure=figure, + ) + def get_notebook_network_pyvis(graph, args={}): """ @@ -1060,18 +1432,19 @@ def get_notebook_network_pyvis(graph, args={}): result = get_notebook_network_pyvis(graph, args={'height':100, 'width':100}) """ - if 'width' not in args: - args['width'] = 800 - if 'height' not in args: - args['height'] = 850 - notebook_net = visnet(args['width'], args['height'], notebook=True) + if "width" not in args: + args["width"] = 800 + if "height" not in args: + args["height"] = 850 + notebook_net = visnet(args["width"], args["height"], notebook=True) notebook_net.barnes_hut(overlap=0.8) notebook_net.from_nx(graph) - notebook_net.show_buttons(['nodes', 'edges', 'physics']) + notebook_net.show_buttons(["nodes", "edges", "physics"]) utils.generate_html(notebook_net) return notebook_net + def get_notebook_network_web(graph, args): """ This function converts a networkX graph into a webweb interactive network in a browser. @@ -1084,6 +1457,7 @@ def get_notebook_network_web(graph, args): return notebook_net + def network_to_tables(graph, source, target): """ Creates the graph edge list and node list and returns them as separate Pandas DataFrames. @@ -1092,130 +1466,156 @@ def network_to_tables(graph, source, target): :return: two Pandas DataFrames. """ edges_table = nx.to_pandas_edgelist(graph, source, target) - nodes_table = pd.DataFrame.from_dict(dict(graph.nodes(data=True))).transpose().reset_index() + nodes_table = ( + pd.DataFrame.from_dict(dict(graph.nodes(data=True))).transpose().reset_index() + ) return nodes_table, edges_table + def generate_configuration_tree(report_pipeline, dataset_type): """ - This function retrieves the analysis pipeline from a dataset .yml file and creates a Cytoscape network, organized hierarchically. + This function retrieves the analysis pipeline from a dataset .yml file + and creates a Cytoscape network, organized hierarchically. - :param dict report_pipeline: dictionary with dataset type analysis and visualization pipeline (conversion of .yml files to python dictionary). - :param str dataset_type: type of dataset ('clinical', 'proteomics', 'DNAseq', 'RNAseq', 'multiomics'). + :param dict report_pipeline: dictionary with dataset type analysis and visualization pipeline + (conversion of .yml files to python dictionary). + :param str dataset_type: type of dataset + ('clinical', 'proteomics', 'DNAseq', 'RNAseq', 'multiomics'). :return: new Dash div with title and Cytoscape network, summarizing analysis pipeline. """ nodes = [] edges = [] args = {} conf_plot = None - if len(report_pipeline) >=1: + if len(report_pipeline) >= 1: root = dataset_type.title() + " default analysis pipeline" - nodes.append({'data':{'id':0, 'label':root}, 'classes': 'root'}) + nodes.append({"data": {"id": 0, "label": root}, "classes": "root"}) i = 0 for section in report_pipeline: if section == "args": continue - nodes.append({'data':{'id':i+1, 'label':section.title()}, 'classes': 'section'}) - edges.append({'data':{'source':0, 'target':i+1}}) + nodes.append( + {"data": {"id": i + 1, "label": section.title()}, "classes": "section"} + ) + edges.append({"data": {"source": 0, "target": i + 1}}) i += 1 k = i for subsection in report_pipeline[section]: - nodes.append({'data':{'id':i+1, 'label':subsection.title()}, 'classes': 'subsection'}) - edges.append({'data':{'source':k, 'target':i+1}}) + nodes.append( + { + "data": {"id": i + 1, "label": subsection.title()}, + "classes": "subsection", + } + ) + edges.append({"data": {"source": k, "target": i + 1}}) i += 1 j = i conf = report_pipeline[section][subsection] - data_names = conf['data'] - analysis_types = conf['analyses'] - arguments = conf['args'] + data_names = conf["data"] + analysis_types = conf["analyses"] + arguments = conf["args"] if isinstance(data_names, dict): for d in data_names: - nodes.append({'data':{'id':i+1, 'label':d+':'+data_names[d]}, 'classes': 'data'}) - edges.append({'data':{'source':j, 'target':i+1}}) + nodes.append( + { + "data": {"id": i + 1, "label": d + ":" + data_names[d]}, + "classes": "data", + } + ) + edges.append({"data": {"source": j, "target": i + 1}}) i += 1 else: - nodes.append({'data':{'id':i+1, 'label':data_names}, 'classes': 'data'}) - edges.append({'data':{'source':j, 'target':i+1}}) + nodes.append( + {"data": {"id": i + 1, "label": data_names}, "classes": "data"} + ) + edges.append({"data": {"source": j, "target": i + 1}}) i += 1 for at in analysis_types: - nodes.append({'data':{'id':i+1, 'label':at},'classes': 'analysis'}) - edges.append({'data':{'source':j, 'target':i+1}}) + nodes.append( + {"data": {"id": i + 1, "label": at}, "classes": "analysis"} + ) + edges.append({"data": {"source": j, "target": i + 1}}) i += 1 f = i if len(analysis_types): for a in arguments: - nodes.append({'data':{'id':i+1, 'label':a+':'+str(arguments[a])},'classes': 'argument'}) - edges.append({'data':{'source':f, 'target':i+1}}) + nodes.append( + { + "data": { + "id": i + 1, + "label": a + ":" + str(arguments[a]), + }, + "classes": "argument", + } + ) + edges.append({"data": {"source": f, "target": i + 1}}) i += 1 config_stylesheet = [ - # Group selectors - { - 'selector': 'node', - 'style': { - 'content': 'data(label)' - } - }, - # Class selectors - { - 'selector': '.root', - 'style': { - 'background-color': '#66c2a5', - 'line-color': 'black', - 'font-size': '14' - } - }, - { - 'selector': '.section', - 'style': { - 'background-color': '#a6cee3', - 'line-color': 'black', - 'font-size': '12' - } - }, - { - 'selector': '.subsection', - 'style': { - 'background-color': '#1f78b4', - 'line-color': 'black', - 'font-size': '12' - } - }, - { - 'selector': '.data', - 'style': { - 'background-color': '#b2df8a', - 'line-color': 'black', - 'font-size': '12' - } - }, - { - 'selector': '.analysis', - 'style': { - 'background-color': '#33a02c', - 'line-color': 'black', - 'font-size': '12' - } - }, - { - 'selector': '.argument', - 'style': { - 'background-color': '#fb9a99', - 'line-color': 'black', - 'font-size': '12' - } - }, - ] + # Group selectors + {"selector": "node", "style": {"content": "data(label)"}}, + # Class selectors + { + "selector": ".root", + "style": { + "background-color": "#66c2a5", + "line-color": "black", + "font-size": "14", + }, + }, + { + "selector": ".section", + "style": { + "background-color": "#a6cee3", + "line-color": "black", + "font-size": "12", + }, + }, + { + "selector": ".subsection", + "style": { + "background-color": "#1f78b4", + "line-color": "black", + "font-size": "12", + }, + }, + { + "selector": ".data", + "style": { + "background-color": "#b2df8a", + "line-color": "black", + "font-size": "12", + }, + }, + { + "selector": ".analysis", + "style": { + "background-color": "#33a02c", + "line-color": "black", + "font-size": "12", + }, + }, + { + "selector": ".argument", + "style": { + "background-color": "#fb9a99", + "line-color": "black", + "font-size": "12", + }, + }, + ] net = [] net.extend(nodes) net.extend(edges) - args['stylesheet'] = config_stylesheet - args['title'] = 'Analysis Pipeline' - args['layout'] = {'name': 'breadthfirst', 'roots': '#0'} - #args['mouseover_node'] = {} + args["stylesheet"] = config_stylesheet + args["title"] = "Analysis Pipeline" + args["layout"] = {"name": "breadthfirst", "roots": "#0"} + # args['mouseover_node'] = {} conf_plot = get_cytoscape_network(net, dataset_type, args) return conf_plot + def get_network(data, identifier, args): """ This function filters an input dataframe based on a threshold score and builds a cytoscape network. For more information on \ @@ -1242,82 +1642,93 @@ def get_network(data, identifier, args): 'node_size':'degree', 'title':'Network Figure', 'color_weight': True}) """ net = None - if 'cutoff_abs' not in args: - args['cutoff_abs'] = False - - if 'title' not in args: - args['title'] = identifier + if "cutoff_abs" not in args: + args["cutoff_abs"] = False + + if "title" not in args: + args["title"] = identifier if not data.empty: - if utils.check_columns(data, cols=[args['source'], args['target']]): + if utils.check_columns(data, cols=[args["source"], args["target"]]): if "values" not in args: - args["values"] = 'width' + args["values"] = "width" data[args["values"]] = 1 - if 'cutoff' in args: - if args['cutoff_abs']: - data = data[np.abs(data[args['values']]) >= args['cutoff']] + if "cutoff" in args: + if args["cutoff_abs"]: + data = data[np.abs(data[args["values"]]) >= args["cutoff"]] else: - data = data[data[args['values']] >= args['cutoff']] + data = data[data[args["values"]] >= args["cutoff"]] if not data.empty: - data[args["source"]] = [str(n).replace("'","") for n in data[args["source"]]] - data[args["target"]] = [str(n).replace("'","") for n in data[args["target"]]] - + data[args["source"]] = [ + str(n).replace("'", "") for n in data[args["source"]] + ] + data[args["target"]] = [ + str(n).replace("'", "") for n in data[args["target"]] + ] - data = data.rename(index=str, columns={args['values']: "width"}) - data['width'] = data['width'].fillna(1.0) - data = data.fillna('null') - data.columns = [c.replace('_', '') for c in data.columns] - data['edgewidth'] = data['width'].apply(np.abs) - min_edge_value = data['edgewidth'].min() - max_edge_value = data['edgewidth'].max() + data = data.rename(index=str, columns={args["values"]: "width"}) + data["width"] = data["width"].fillna(1.0) + data = data.fillna("null") + data.columns = [c.replace("_", "") for c in data.columns] + data["edgewidth"] = data["width"].apply(np.abs) + min_edge_value = data["edgewidth"].min() + max_edge_value = data["edgewidth"].max() if min_edge_value == max_edge_value: - min_edge_value = 0. - graph = nx.from_pandas_edgelist(data, args['source'], args['target'], edge_attr=True) + min_edge_value = 0.0 + graph = nx.from_pandas_edgelist( + data, args["source"], args["target"], edge_attr=True + ) degrees = dict(graph.degree()) - nx.set_node_attributes(graph, degrees, 'degree') + nx.set_node_attributes(graph, degrees, "degree") betweenness = None ev_centrality = None if data.shape[0] < 150 and data.shape[0] > 5: try: - betweenness = nx.betweenness_centrality(graph, weight='width') + betweenness = nx.betweenness_centrality(graph, weight="width") ev_centrality = nx.eigenvector_centrality_numpy(graph) - ev_centrality = {k:"%.3f" % round(v, 3) for k,v in ev_centrality.items()} - nx.set_node_attributes(graph, betweenness, 'betweenness') - nx.set_node_attributes(graph, ev_centrality, 'eigenvector') + ev_centrality = { + k: "%.3f" % round(v, 3) for k, v in ev_centrality.items() + } + nx.set_node_attributes(graph, betweenness, "betweenness") + nx.set_node_attributes(graph, ev_centrality, "eigenvector") except Exception as e: - print("There was an exception when calculating centralities: {}".format(e)) + print( + "There was an exception when calculating centralities: {}".format( + e + ) + ) min_node_size = 0 max_node_size = 0 - if 'node_size' not in args: - args['node_size'] = 'degree' + if "node_size" not in args: + args["node_size"] = "degree" - if args['node_size'] == 'betweenness' and betweenness is not None: + if args["node_size"] == "betweenness" and betweenness is not None: min_node_size = min(betweenness.values()) max_node_size = max(betweenness.values()) - nx.set_node_attributes(graph, betweenness, 'radius') - elif args['node_size'] == 'ev_centrality' and ev_centrality is not None: + nx.set_node_attributes(graph, betweenness, "radius") + elif args["node_size"] == "ev_centrality" and ev_centrality is not None: min_node_size = min(ev_centrality.values()) max_node_size = max(ev_centrality.values()) - nx.set_node_attributes(graph, ev_centrality, 'radius') - elif args['node_size'] == 'degree' and len(degrees) > 0: - min_node_size = min(degrees.values()) - max_node_size = max(degrees.values()) - nx.set_node_attributes(graph, degrees, 'radius') + nx.set_node_attributes(graph, ev_centrality, "radius") + elif args["node_size"] == "degree" and len(degrees) > 0: + min_node_size = min(degrees.values()) + max_node_size = max(degrees.values()) + nx.set_node_attributes(graph, degrees, "radius") - clusters = analytics.get_network_communities(graph, args) + clusters = network_analysis.get_network_communities(graph, args) col = utils.get_hex_colors(len(set(clusters.values()))) - colors = {n:col[clusters[n]] for n in clusters} - nx.set_node_attributes(graph, colors, 'color') - nx.set_node_attributes(graph, clusters, 'cluster') + colors = {n: col[clusters[n]] for n in clusters} + nx.set_node_attributes(graph, colors, "color") + nx.set_node_attributes(graph, clusters, "cluster") vis_graph = graph - limit=500 - if 'limit' in args: - limit = args['limit'] + limit = 500 + if "limit" in args: + limit = args["limit"] if limit is not None: if len(vis_graph.edges()) > 500: max_nodes = 150 @@ -1328,7 +1739,10 @@ def get_network(data, identifier, args): cluster_nums[clusters[n]] = 0 cluster_members[clusters[n]].append(n) cluster_nums[clusters[n]] += 1 - valid_clusters = [c for c,n in sorted(cluster_nums.items() , key=lambda x: x[1])] + valid_clusters = [ + c + for c, n in sorted(cluster_nums.items(), key=lambda x: x[1]) + ] valid_nodes = [] for c in valid_clusters: valid_nodes.extend(cluster_members[c]) @@ -1337,74 +1751,138 @@ def get_network(data, identifier, args): break vis_graph = vis_graph.subgraph(valid_nodes) - nodes_table, edges_table = network_to_tables(graph, source=args["source"], target=args["target"]) - nodes_fig_table = get_table(nodes_table, identifier=identifier+"_nodes_table", args={'title':args['title']+" nodes table"}) - edges_fig_table = get_table(edges_table, identifier=identifier+"_edges_table", args={'title':args['title']+" edges table"}) - - stylesheet, layout = get_network_style(colors, args['color_weight']) - stylesheet.append({'selector':'edge','style':{'width':'mapData(edgewidth,'+ str(min_edge_value) +','+ str(max_edge_value) +', .5, 8)'}}) + nodes_table, edges_table = network_to_tables( + graph, source=args["source"], target=args["target"] + ) + nodes_fig_table = get_table( + nodes_table, + identifier=identifier + "_nodes_table", + args={"title": args["title"] + " nodes table"}, + ) + edges_fig_table = get_table( + edges_table, + identifier=identifier + "_edges_table", + args={"title": args["title"] + " edges table"}, + ) + + stylesheet, layout = get_network_style(colors, args["color_weight"]) + stylesheet.append( + { + "selector": "edge", + "style": { + "width": "mapData(edgewidth," + + str(min_edge_value) + + "," + + str(max_edge_value) + + ", .5, 8)" + }, + } + ) if min_node_size > 0 and max_node_size > 0: - mapper = 'mapData(radius,'+ str(min_node_size) +','+ str(max_node_size) +', 15, 50)' - stylesheet.append({'selector':'node','style':{'width':mapper, 'height':mapper}}) - args['stylesheet'] = stylesheet - args['layout'] = layout + mapper = ( + "mapData(radius," + + str(min_node_size) + + "," + + str(max_node_size) + + ", 15, 50)" + ) + stylesheet.append( + { + "selector": "node", + "style": {"width": mapper, "height": mapper}, + } + ) + args["stylesheet"] = stylesheet + args["layout"] = layout cy_elements, mouseover_node = utils.networkx_to_cytoscape(vis_graph) app_net = get_cytoscape_network(cy_elements, identifier, args) - #args['mouseover_node'] = mouseover_node - - net = {"notebook":[cy_elements, stylesheet, layout], "app": app_net, "net_tables": (nodes_table, edges_table), "net_tables_viz":(nodes_fig_table, edges_fig_table), "net_json":json_graph.node_link_data(graph)} + # args['mouseover_node'] = mouseover_node + + net = { + "notebook": [cy_elements, stylesheet, layout], + "app": app_net, + "net_tables": (nodes_table, edges_table), + "net_tables_viz": (nodes_fig_table, edges_fig_table), + "net_json": json_graph.node_link_data(graph), + } return net + def get_network_style(node_colors, color_edges): - ''' + """ This function uses a dictionary of nodes and colors and creates a stylesheet and layout for a network. :param dict node_colors: dictionary with node names as keys and colors as values. :param bool color_edges: if True, add edge coloring to stylesheet (red for positive width, blue for negative). :return: stylesheet (list of dictionaries specifying the style for a group of elements, a class of elements, or a single element) and \ layout (dictionary specifying how the nodes should be positioned on the canvas). - ''' - - color_selector = "{'selector': '[name = \"KEY\"]', 'style': {'background-color': \"VALUE\"}}" - stylesheet=[{'selector': 'node', 'style': {'label': 'data(name)', - 'text-valign': 'center', - 'text-halign': 'center', - 'border-color':'gray', - 'border-width': '1px', - 'font-size': '12', - 'opacity':0.75}}, - {'selector':'edge','style':{'label':'data(label)', - 'curve-style': 'bezier', - 'opacity':0.7, - 'font-size': '4'}}] - - layout = {'name': 'cose', - 'idealEdgeLength': 100, - 'nodeOverlap': 20, - 'refresh': 20, - 'randomize': False, - 'componentSpacing': 100, - 'nodeRepulsion': 400000, - 'edgeElasticity': 100, - 'nestingFactor': 5, - 'gravity': 80, - 'numIter': 1000, - 'initialTemp': 200, - 'coolingFactor': 0.95, - 'minTemp': 1.0} + """ - if color_edges: - stylesheet.extend([{'selector':'[width < 0]', 'style':{'line-color':'#4dc3d6'}},{'selector':'[width > 0]', 'style':{'line-color':'#d6604d'}}]) + color_selector = ( + "{'selector': '[name = \"KEY\"]', 'style': {'background-color': \"VALUE\"}}" + ) + stylesheet = [ + { + "selector": "node", + "style": { + "label": "data(name)", + "text-valign": "center", + "text-halign": "center", + "border-color": "gray", + "border-width": "1px", + "font-size": "12", + "opacity": 0.75, + }, + }, + { + "selector": "edge", + "style": { + "label": "data(label)", + "curve-style": "bezier", + "opacity": 0.7, + "font-size": "4", + }, + }, + ] + layout = { + "name": "cose", + "idealEdgeLength": 100, + "nodeOverlap": 20, + "refresh": 20, + "randomize": False, + "componentSpacing": 100, + "nodeRepulsion": 400000, + "edgeElasticity": 100, + "nestingFactor": 5, + "gravity": 80, + "numIter": 1000, + "initialTemp": 200, + "coolingFactor": 0.95, + "minTemp": 1.0, + } + + if color_edges: + stylesheet.extend( + [ + {"selector": "[width < 0]", "style": {"line-color": "#4dc3d6"}}, + {"selector": "[width > 0]", "style": {"line-color": "#d6604d"}}, + ] + ) - for k,v in node_colors.items(): - stylesheet.append(ast.literal_eval(color_selector.replace("KEY", k.replace("'", "")).replace("VALUE",v))) + for k, v in node_colors.items(): + stylesheet.append( + ast.literal_eval( + color_selector.replace("KEY", k.replace("'", "")).replace("VALUE", v) + ) + ) return stylesheet, layout -def visualize_notebook_network(network, notebook_type='jupyter', layout={}): + +def visualize_notebook_network(network, notebook_type="jupyter", layout={}): """ This function returns a Cytoscape network visualization for Jupyter notebooks @@ -1423,33 +1901,40 @@ def visualize_notebook_network(network, notebook_type='jupyter', layout={}): """ net = None if len(layout) == 0: - layout = {'name': 'cose', - 'idealEdgeLength': 100, - 'nodeOverlap': 20, - 'refresh': 20, - 'randomize': False, - 'componentSpacing': 100, - 'nodeRepulsion': 400000, - 'edgeElasticity': 100, - 'nestingFactor': 5, - 'gravity': 80, - 'numIter': 1000, - 'initialTemp': 200, - 'coolingFactor': 0.95, - 'minTemp': 1.0} - if notebook_type == 'jupyter': - net = Cytoscape(data={'elements':network[0]}, visual_style=network[1], layout=layout) - elif notebook_type == 'jupyterlab': + layout = { + "name": "cose", + "idealEdgeLength": 100, + "nodeOverlap": 20, + "refresh": 20, + "randomize": False, + "componentSpacing": 100, + "nodeRepulsion": 400000, + "edgeElasticity": 100, + "nestingFactor": 5, + "gravity": 80, + "numIter": 1000, + "initialTemp": 200, + "coolingFactor": 0.95, + "minTemp": 1.0, + } + if notebook_type == "jupyter": + net = Cytoscape( + data={"elements": network[0]}, visual_style=network[1], layout=layout + ) + elif notebook_type == "jupyterlab": pass return net -def visualize_notebook_path(path, notebook_type='jupyter'): + +def visualize_notebook_path(path, notebook_type="jupyter"): """ This function returns a Cytoscape network visualization for Jupyter notebooks - :param path object: dash_html_components object with the cytoscape network (returned by get_cytoscape_network()) - :param str notebook_type: the type of notebook where the network will be visualized (currently only jupyter notebook is supported) + :param pathlib.Path object: dash_html_components object with the cytoscape network + (returned by get_cytoscape_network()) + :param str notebook_type: the type of notebook where the network will be visualized + (currently only jupyter notebook is supported) :param dict layout: specific layout properties (see https://dash.plot.ly/cytoscape/layout) :return: cyjupyter.cytoscape.Cytoscape object @@ -1461,13 +1946,14 @@ def visualize_notebook_path(path, notebook_type='jupyter'): visualize_notebook_path(net, notebook_type='jupyter') """ net = None - if notebook_type == 'jupyter': + if notebook_type == "jupyter": net = path.children[1] - elif notebook_type == 'jupyterlab': + elif notebook_type == "jupyterlab": pass return net + def get_pca_plot(data, identifier, args): """ This function creates a pca plot with scores and top "args['loadings']" loadings. @@ -1476,7 +1962,8 @@ def get_pca_plot(data, identifier, args): :param str identifier: id used to identify the div where the figure will be generated. :param dict args: see below :Arguments: - * **loadings** (int) -- number of features with highest loading values to be displayed in the pca plot + * **loadings** (int) -- number of features with highest loading values \ + to be displayed in the pca plot * **title** (str) -- title of the figure * **x_title** (str) -- plot x axis title * **y_title** (str) -- plot y axis title @@ -1486,65 +1973,89 @@ def get_pca_plot(data, identifier, args): Example:: - result = get_pca_plot(data, identifier='pca', args={'loadings':15, 'title':'PCA Plot', 'x_title':'PC1', 'y_title':'PC2', 'height':100, 'width':100}) + result = get_pca_plot(data, + identifier='pca', + args={'loadings':15, + 'title':'PCA Plot', + 'x_title':'PC1', + 'y_title':'PC2', + 'height':100, + 'width':100} + ) """ pca_data, loadings, variance = data figure = {} traces = [] annotations = [] - sct = get_scatterplot(pca_data, identifier, args).figure - traces.extend(sct['data']) - figure['layout'] = sct['layout'] + sct = get_scatterplot(pca_data, identifier, args) + traces.extend(sct["data"]) + figure["layout"] = sct["layout"] factor = 50 num_loadings = 15 - if 'factor' in args: - factor = args['factor'] - if 'loadings' in args: - num_loadings = args['loadings'] + if "factor" in args: + factor = args["factor"] + if "loadings" in args: + num_loadings = args["loadings"] for index in list(loadings.index)[0:num_loadings]: - x = loadings.loc[index,'x'] * factor - y = loadings.loc[index, 'y'] * factor - value = loadings.loc[index, 'value'] - - trace = go.Scattergl(x= [0,x], - y = [0,y], - mode='markers+lines', - text=str(index)+" loading: {0:.2f}".format(value), - name = index, - marker= dict(size=3, - symbol=1, - color='darkgrey', #set color equal to a variable - showscale=False, - opacity=0.9, - ), - showlegend=False, - ) - annotation = dict( x=x * 1.15, - y=y * 1.15, - xref='x', - yref='y', - text=index, - showarrow=False, - font=dict( - size=12, - color='darkgrey' - ), - align='center', - ax=20, - ay=-30, - ) + x = loadings.loc[index, "x"] * factor + y = loadings.loc[index, "y"] * factor + value = loadings.loc[index, "value"] + + trace = go.Scattergl( + x=[0, x], + y=[0, y], + mode="markers+lines", + text=str(index) + " loading: {0:.2f}".format(value), + name=index, + marker=dict( + size=3, + symbol=1, + color="darkgrey", # set color equal to a variable + showscale=False, + opacity=0.9, + ), + showlegend=False, + ) + annotation = dict( + x=x * 1.15, + y=y * 1.15, + xref="x", + yref="y", + text=index, + showarrow=False, + font=dict(size=12, color="darkgrey"), + align="center", + ax=20, + ay=-30, + ) annotations.append(annotation) traces.append(trace) - figure['data'] = traces - figure['layout'].annotations = annotations - figure['layout']['template'] = 'plotly_white' - - return dcc.Graph(id = identifier, figure = figure) - - -def get_sankey_plot(data, identifier, args={'source':'source', 'target':'target', 'weight':'weight','source_colors':'source_colors', 'target_colors':'target_colors', 'orientation': 'h', 'valueformat': '.0f', 'width':800, 'height':800, 'font':12, 'title':'Sankey plot'}): + figure["data"] = traces + figure["layout"].annotations = annotations + figure["layout"]["template"] = "plotly_white" + + return figure + + +def get_sankey_plot( + data, + identifier, + args={ + "source": "source", + "target": "target", + "weight": "weight", + "source_colors": "source_colors", + "target_colors": "target_colors", + "orientation": "h", + "valueformat": ".0f", + "width": 800, + "height": 800, + "font": 12, + "title": "Sankey plot", + }, +): """ This function generates a Sankey plot in Plotly. @@ -1555,8 +2066,10 @@ def get_sankey_plot(data, identifier, args={'source':'source', 'target':'target' * **source** (str) -- name of the column containing the source * **target** (str) -- name of the column containing the target * **weight** (str) -- name of the column containing the weight - * **source_colors** (str) -- name of the column in data that contains the colors of each source item - * **target_colors** (str) -- name of the column in data that contains the colors of each target item + * **source_colors** (str) -- name of the column in data that contains\ + the colors of each source item + * **target_colors** (str) -- name of the column in data that contains\ + the colors of each target item * **title** (str) -- plot title * **orientation** (str) -- whether to plot horizontal ('h') or vertical ('v') * **valueformat** (str) -- how to show the value ('.0f') @@ -1567,67 +2080,97 @@ def get_sankey_plot(data, identifier, args={'source':'source', 'target':'target' Example:: - result = get_sankey_plot(data, identifier='sankeyplot', args={'source':'source', 'target':'target', 'weight':'weight','source_colors':'source_colors', \ - 'target_colors':'target_colors', 'orientation': 'h', 'valueformat': '.0f', 'width':800, 'height':800, 'font':12, 'title':'Sankey plot'}) + result = get_sankey_plot(data, + identifier='sankeyplot', + args={'source':'source', + 'target':'target', + 'weight':'weight', + 'source_colors':'source_colors', + 'target_colors':'target_colors',' + 'orientation': 'h', + 'valueformat': '.0f', + 'width':800, 'height':800, + 'font':12, + 'title':'Sankey plot'} + ) """ figure = {} if data is not None and not data.empty: - nodes = list(set(data[args['source']].tolist() + data[args['target']].tolist())) - if 'source_colors' in args: - node_colors = dict(zip(data[args['source']],data[args['source_colors']])) + nodes = list(set(data[args["source"]].tolist() + data[args["target"]].tolist())) + if "source_colors" in args: + node_colors = dict(zip(data[args["source"]], data[args["source_colors"]])) else: - scolors = ['#045a8d'] * len(data[args['source']].tolist()) - node_colors = dict(zip(data[args['source']],scolors)) - args['source_colors'] = 'source_colors' - data['source_colors'] = scolors + scolors = ["#045a8d"] * len(data[args["source"]].tolist()) + node_colors = dict(zip(data[args["source"]], scolors)) + args["source_colors"] = "source_colors" + data["source_colors"] = scolors hover_data = [] - if 'hover' in args: - hover_data = [str(t).upper() for t in data[args['hover']].tolist()] + if "hover" in args: + hover_data = [str(t).upper() for t in data[args["hover"]].tolist()] - if 'target_colors' in args: - node_colors.update(dict(zip(data[args['target']],data[args['target_colors']]))) + if "target_colors" in args: + node_colors.update( + dict(zip(data[args["target"]], data[args["target_colors"]])) + ) else: - scolors = ['#a6bddb'] * len(data[args['target']].tolist()) - node_colors.update(dict(zip(data[args['target']],scolors))) - args['target_colors'] = 'target_colors' - data['target_colors'] = scolors - - data_trace = dict(type='sankey', - orientation = 'h' if 'orientation' not in args else args['orientation'], - valueformat = ".0f" if 'valueformat' not in args else args['valueformat'], - arrangement = 'snap', - node = dict(pad = 10 if 'pad' not in args else args['pad'], - thickness = 10 if 'thickness' not in args else args['thickness'], - line = dict(color = "black", width = 0.3), - label = nodes, - color = ["rgba"+str(utils.hex2rgb(node_colors[c])) if node_colors[c].startswith('#') else node_colors[c] for c in nodes] - ), - link = dict(source = [list(nodes).index(i) for i in data[args['source']].tolist()], - target = [list(nodes).index(i) for i in data[args['target']].tolist()], - value = data[args['weight']].tolist(), - color = ["rgba"+str(utils.hex2rgb(c)) if c.startswith('#') else c for c in data[args['source_colors']].tolist()], - label = hover_data - )) - layout = dict( - width= 800 if 'width' not in args else args['width'], - height= 800 if 'height' not in args else args['height'], - title = args['title'], - annotations = [dict(xref='paper', yref='paper', showarrow=False, text='')], - font = dict( - size = 12 if 'font' not in args else args['font'], + scolors = ["#a6bddb"] * len(data[args["target"]].tolist()) + node_colors.update(dict(zip(data[args["target"]], scolors))) + args["target_colors"] = "target_colors" + data["target_colors"] = scolors + + data_trace = dict( + type="sankey", + orientation="h" if "orientation" not in args else args["orientation"], + valueformat=".0f" if "valueformat" not in args else args["valueformat"], + arrangement="snap", + node=dict( + pad=10 if "pad" not in args else args["pad"], + thickness=10 if "thickness" not in args else args["thickness"], + line=dict(color="black", width=0.3), + label=nodes, + color=[ + ( + "rgba" + str(hex2rgb(node_colors[c])) + if node_colors[c].startswith("#") + else node_colors[c] + ) + for c in nodes + ], ), - template='plotly_white' - ) + link=dict( + source=[list(nodes).index(i) for i in data[args["source"]].tolist()], + target=[list(nodes).index(i) for i in data[args["target"]].tolist()], + value=data[args["weight"]].tolist(), + color=[ + "rgba" + str(hex2rgb(c)) if c.startswith("#") else c + for c in data[args["source_colors"]].tolist() + ], + label=hover_data, + ), + ) + layout = dict( + width=800 if "width" not in args else args["width"], + height=800 if "height" not in args else args["height"], + title=args["title"], + annotations=[dict(xref="paper", yref="paper", showarrow=False, text="")], + font=dict( + size=12 if "font" not in args else args["font"], + ), + template="plotly_white", + ) figure = dict(data=[data_trace], layout=layout) - return dcc.Graph(id = identifier, figure = figure) + return dcc.Graph(id=identifier, figure=figure) def get_table(data, identifier, args): """ - This function converts a pandas dataframe into an interactive table for viewing, editing and exploring large datasets. For more information visit https://dash.plot.ly/datatable. + This function converts a pandas dataframe into an interactive table for viewing, + editing and exploring large datasets. + + For more information visit https://dash.plot.ly/datatable. :param data: pandas dataframe. :param str identifier: id used to identify the div where the figure will be generated. @@ -1644,96 +2187,132 @@ def get_table(data, identifier, args): if data is not None and isinstance(data, pd.DataFrame) and not data.empty: cols = [] title = "Table" - if 'title' in args: - title = args['title'] - if 'index' in args: - if isinstance(args['index'], list): - cols = args['index'] + if "title" in args: + title = args["title"] + if "index" in args: + if isinstance(args["index"], list): + cols = args["index"] else: - cols.append(args['index']) - if 'cols' in args: - if args['cols'] is not None and len(args['cols']) > 0: - selected_cols = list(set(args['cols']).intersection(data.columns)) + cols.append(args["index"]) + if "cols" in args: + if args["cols"] is not None and len(args["cols"]) > 0: + selected_cols = list(set(args["cols"]).intersection(data.columns)) if len(selected_cols) > 0: data = data[selected_cols + cols] else: data = pd.DataFrame() - table.append(html.Div(children=[dcc.Markdown("### Columns not found: {}".format(','.join(args['cols'])))])) - if 'rows' in args: - if args['rows'] is not None and len(args['rows']) > 0: - selected_rows = list(set(args['rows']).intersection(data.index)) + table.append( + html.Div( + children=[ + dcc.Markdown( + "### Columns not found: {}".format( + ",".join(args["cols"]) + ) + ) + ] + ) + ) + if "rows" in args: + if args["rows"] is not None and len(args["rows"]) > 0: + selected_rows = list(set(args["rows"]).intersection(data.index)) if len(selected_rows) > 0: data = data.loc[selected_rows] else: data = pd.DataFrame() - table.append(html.Div(children=[dcc.Markdown("### Rows not found: {}".format(','.join(args['rows'])))])) - if 'head' in args: - if len(args['head']) > 1: - data = data.iloc[:args['head'][0], :args['head'][1]] + table.append( + html.Div( + children=[ + dcc.Markdown( + "### Rows not found: {}".format( + ",".join(args["rows"]) + ) + ) + ] + ) + ) + if "head" in args: + if len(args["head"]) > 1: + data = data.iloc[: args["head"][0], : args["head"][1]] list_cols = data.applymap(lambda x: isinstance(x, list)).all() list_cols = list_cols.index[list_cols].tolist() for c in list_cols: data[c] = data[c].apply(lambda x: ";".join([str(i) for i in x])) - data_trace = dash_table.DataTable(id='table_'+identifier, - data=data.astype(str).to_dict("rows"), - columns=[{"name": str(i).replace('_', ' ').title(), "id": i} for i in data.columns], - # css=[{ - # 'selector': '.dash-cell div.dash-cell-value', - # 'rule': 'display: inline; white-space: inherit; overflow: inherit; text-overflow: inherit;' - # }], - style_data={'whiteSpace': 'normal', 'height': 'auto'}, - style_cell={ - 'height': 'fit-content', 'whiteSpace': 'normal', - 'minWidth': '130px', 'maxWidth': '200px', - 'textAlign': 'left', 'padding': '1px', 'vertical-align': 'top', - 'overflow': 'hidden', 'textOverflow': 'ellipsis' - }, - style_cell_conditional=[{ - 'if': {'column_id': i}, - 'width': str(20 + round(len(i)*20))+'px'} for i in data.columns], - style_table={ - "height": "fit-content", - # "max-height": "500px", - # "width": "fit-content", - # "max-width": "1500px", - # 'overflowY': 'scroll', - 'overflowX': 'scroll' - }, - style_header={ - 'backgroundColor': '#2b8cbe', - 'fontWeight': 'bold', - 'position': 'sticky' - }, - style_data_conditional=[{ - "if": - {"column_id": "rejected", "filter_query": '{rejected} eq "True"'}, - "backgroundColor": "#3B8861", - 'color': 'white' - }, - ], - fixed_rows={ 'headers': True, 'data': 0}, - filter_action='native', - row_selectable='multi', - page_current= 0, - page_size = 25, - page_action='native', - sort_action='custom', - ) - table.extend([html.H2(title),data_trace]) + data_trace = dash_table.DataTable( + id="table_" + identifier, + data=data.astype(str).to_dict("rows"), + columns=[ + {"name": str(i).replace("_", " ").title(), "id": i} + for i in data.columns + ], + # css=[{ + # 'selector': '.dash-cell div.dash-cell-value', + # 'rule': 'display: inline; white-space: inherit;' + # ' overflow: inherit; text-overflow: inherit;' + # }], + style_data={"whiteSpace": "normal", "height": "auto"}, + style_cell={ + "height": "fit-content", + "whiteSpace": "normal", + "minWidth": "130px", + "maxWidth": "200px", + "textAlign": "left", + "padding": "1px", + "vertical-align": "top", + "overflow": "hidden", + "textOverflow": "ellipsis", + }, + style_cell_conditional=[ + {"if": {"column_id": i}, "width": str(20 + round(len(i) * 20)) + "px"} + for i in data.columns + ], + style_table={ + "height": "fit-content", + # "max-height": "500px", + # "width": "fit-content", + # "max-width": "1500px", + # 'overflowY': 'scroll', + "overflowX": "scroll", + }, + style_header={ + "backgroundColor": "#2b8cbe", + "fontWeight": "bold", + "position": "sticky", + }, + style_data_conditional=[ + { + "if": { + "column_id": "rejected", + "filter_query": '{rejected} eq "True"', + }, + "backgroundColor": "#3B8861", + "color": "white", + }, + ], + fixed_rows={"headers": True, "data": 0}, + filter_action="native", + row_selectable="multi", + page_current=0, + page_size=25, + page_action="native", + sort_action="custom", + ) + table.extend([html.H2(title), data_trace]) return html.Div(id=identifier, children=table) -def get_multi_table(data,identifier, title): +def get_multi_table(data, identifier, title): tables = [html.H2(title)] if data is not None and isinstance(data, dict): for subtitle in data: df = data[subtitle] if len(df.columns) > 10: df = df.transpose() - table = get_table(df, identifier=identifier+"_"+subtitle, args={'title':subtitle}) + table = get_table( + df, identifier=identifier + "_" + subtitle, args={"title": subtitle} + ) if table is not None: tables.append(table) @@ -1754,31 +2333,45 @@ def get_violinplot(data, identifier, args): Example:: - result = get_violinplot(data, identifier='violinplot, args={'drop_cols':['sample', 'subject'], 'group':'group'}) + result = get_violinplot(data, + identifier='violinplot, + args={'drop_cols':['sample', 'subject'], + 'group':'group'} + ) """ df = data.copy() graphs = [] color_map = {} - if 'colors' in args: - color_map = args['colors'] - if 'drop_cols' in args: - if len(list(set(args['drop_cols']).intersection(df.columns))) == len(args['drop_cols']): - df = df.drop(args['drop_cols'], axis=1) + if "colors" in args: + color_map = args["colors"] + if "drop_cols" in args: + if len(list(set(args["drop_cols"]).intersection(df.columns))) == len( + args["drop_cols"] + ): + df = df.drop(args["drop_cols"], axis=1) for c in df.columns.unique(): - if c != args['group']: - figure = create_violinplot(df, x=args['group'], y=c, color=args['group'], color_map=color_map) - figure.update_layout(annotations=[dict(xref='paper', yref='paper', showarrow=False, text='')], template='plotly_white') - graphs.append(dcc.Graph(id=identifier+"_"+c, figure=figure)) + if c != args["group"]: + figure = create_violinplot( + df, x=args["group"], y=c, color=args["group"], color_map=color_map + ) + figure.update_layout( + annotations=[ + dict(xref="paper", yref="paper", showarrow=False, text="") + ], + template="plotly_white", + ) + graphs.append(dcc.Graph(id=identifier + "_" + c, figure=figure)) return graphs + def create_violinplot(df, x, y, color, color_map={}): """ This function creates traces for a simple violin plot. :param df: pandas dataframe with samples as rows and dependent variables as columns. - :pram (str) x: name of the column containing the group. + :param (str) x: name of the column containing the group. :param (str) y: name of the column with the dependent variable. :param (str) color: name of the column used for coloring. :param (dict) color_map: dictionary with custom colors @@ -1788,18 +2381,23 @@ def create_violinplot(df, x, y, color, color_map={}): result = create_violinplot(df, x='group', y='protein a', color='group', color_map={}) """ - traces = [] - violin = px.violin(df, x=x, y=y, color=color, color_discrete_map=color_map, box=True, points="all") + # traces = [] # ! or is this some hack? + violin = px.violin( + df, x=x, y=y, color=color, color_discrete_map=color_map, box=True, points="all" + ) return violin def get_clustergrammer_plot(data, identifier, args): """ - This function takes a pandas dataframe, calculates clustering, and generates the visualization json. + This function takes a pandas dataframe, calculates clustering, + and generates the visualization json. + For more information visit https://github.com/MaayanLab/clustergrammer-py. - :param data: long-format pandas dataframe with columns 'node1' (source), 'node2' (target) and 'weight' + :param data: long-format pandas dataframe with columns 'node1' (source), 'node2' (target) + and 'weight' :param str identifier: id used to identify the div where the figure will be generated :param dict args: see below :Arguments: @@ -1808,48 +2406,60 @@ def get_clustergrammer_plot(data, identifier, args): :return: Dash Div with heatmap plot from Clustergrammer web-based tool """ from clustergrammer2 import net as clustergrammer_net + div = None if not data.empty: - if 'format' in args: - if args['format'] == 'edgelist': - data = data[['node1', 'node2', 'weight']].pivot(index='node1', columns='node2') + if "format" in args: + if args["format"] == "edgelist": + data = data[["node1", "node2", "weight"]].pivot( + index="node1", columns="node2" + ) clustergrammer_net.load_df(data) - link = utils.get_clustergrammer_link(clustergrammer_net, filename=None) + link = get_clustergrammer_link(clustergrammer_net, filename=None) iframe = html.Iframe(src=link, width=1000, height=900) - div = html.Div([html.H2(args['title']),iframe]) + div = html.Div([html.H2(args["title"]), iframe]) return div + def get_parallel_plot(data, identifier, args): """ - This function creates a parallel coordinates plot, with sample groups as the different dimensions. + This function creates a parallel coordinates plot, with sample groups as the different + dimensions. :param data: pandas dataframe with groups as rows and dependent variables as columns. :param str identifier: id used to identify the div where the figure will be generated. :param dict args: see below. :Arguments: * **group** (str) -- name of the column containing the groups. - * **zscore** (bool) -- if True, calculates the z score of each values in the row, relative to the row mean and standard deviation. + * **zscore** (bool) -- if True, calculates the z score of each values in the row, \ + relative to the row mean and standard deviation. * **color** (str) -- line color. * **title** (str) -- plot title. :return: parallel plot figure within
. Example:: - result = get_parallel_plot(data, identifier='parallel plot', args={'group':'group', 'zscore':True, 'color':'blue', 'title':'Parallel Plot'}) + result = get_parallel_plot(data, + identifier='parallel plot', + args={'group':'group', + 'zscore':True, + 'color':'blue', + 'title':'Parallel Plot'} + ) """ fig = None - if 'group' in args: - group = args['group'] - if 'zscore' in args: - if args['zscore']: + if "group" in args: + group = args["group"] + if "zscore" in args: + if args["zscore"]: data = data.set_index(group).apply(zscore) data = data.reset_index() - color = '#de77ae' - if 'color' in args: - color = args['color'] + color = "#de77ae" + if "color" in args: + color = args["color"] group_values = data.groupby(group).mean() min_val = group_values._get_numeric_data().min().min() max_val = group_values._get_numeric_data().max().max() @@ -1860,27 +2470,25 @@ def get_parallel_plot(data, identifier, args): dim = dict(label=i, range=[min_val, max_val], values=values) dims.append(dim) - fig_data = [ - go.Parcoords( - line = dict(color = color), - dimensions = dims) - ] + fig_data = [go.Parcoords(line=dict(color=color), dimensions=dims)] layout = go.Layout( - title=args['title'], - annotations= [dict(xref='paper', yref='paper', showarrow=False, text='')], - template='plotly_white' + title=args["title"], + annotations=[dict(xref="paper", yref="paper", showarrow=False, text="")], + template="plotly_white", ) - fig = dict(data = fig_data, layout = layout) + fig = dict(data=fig_data, layout=layout) return dcc.Graph(id=identifier, figure=fig) + def get_parallel_coord_plot(data, identifier, args): - color = args['color'] - labels = {c:c for c in data.columns if c != color} + color = args["color"] + labels = {c: c for c in data.columns if c != color} fig = px.parallel_coordinates(data, color=color, labels=labels) return fig + def get_WGCNAPlots(data, identifier): """ Takes data from runWGCNA function and builds WGCNA plots. @@ -1893,43 +2501,157 @@ def get_WGCNAPlots(data, identifier): data = tuple(data[k] for k in data) if data is not None: - dissTOM, moduleColors, Features_per_Module, MEs,\ - moduleTraitCor, textMatrix, METDiss, METcor = data + ( + dissTOM, + moduleColors, + Features_per_Module, + MEs, + moduleTraitCor, + textMatrix, + METDiss, + METcor, + ) = data plots = [] - plots.append(wgcnaFigures.plot_complex_dendrogram(dissTOM, moduleColors, title='Co-expression: dendrogram and module colors', dendro_labels=dissTOM.columns, distfun=None, linkagefun='ward', hang=0.1, subplot='module colors', col_annotation=True, width=1000, height=800)) - plots.append(get_table(Features_per_Module, identifier='', args={'title':'Proteins/Genes module color', 'colors': ('#C2D4FF','#F5F8FF'), 'cols': None, 'rows': None})) - plots.append(wgcnaFigures.plot_labeled_heatmap(moduleTraitCor, textMatrix, title='Module-Clinical variable relationships', colorscale=[[0,'#67a9cf'],[0.5,'#f7f7f7'],[1,'#ef8a62']], row_annotation=True, width=1000, height=800)) - dendro_tree = wgcnaAnalysis.get_dendrogram(METDiss, METDiss.index, distfun=None, linkagefun='ward', div_clusters=False) + plots.append( + wgcna.plot_complex_dendrogram( + dissTOM, + moduleColors, + title="Co-expression: dendrogram and module colors", + dendro_labels=dissTOM.columns, + distfun=None, + linkagefun="ward", + hang=0.1, + subplot="module colors", + col_annotation=True, + width=1000, + height=800, + ) + ) + plots.append( + get_table( + Features_per_Module, + identifier="", + args={ + "title": "Proteins/Genes module color", + "colors": ("#C2D4FF", "#F5F8FF"), + "cols": None, + "rows": None, + }, + ) + ) + plots.append( + wgcna.plot_labeled_heatmap( + moduleTraitCor, + textMatrix, + title="Module-Clinical variable relationships", + colorscale=[[0, "#67a9cf"], [0.5, "#f7f7f7"], [1, "#ef8a62"]], + row_annotation=True, + width=1000, + height=800, + ) + ) + dendro_tree = wgcna_analysis.get_dendrogram( + METDiss, METDiss.index, distfun=None, linkagefun="ward", div_clusters=False + ) if dendro_tree is not None: - dendrogram = Dendrogram.plot_dendrogram(dendro_tree, hang=0.9, cutoff_line=False) + dendrogram_ = dendrogram.plot_dendrogram( + dendro_tree, hang=0.9, cutoff_line=False + ) - layout = go.Layout(width=900, height=900, showlegend=False, title='', - xaxis=dict(domain=[0, 1], range=[np.min(dendrogram['layout']['xaxis']['tickvals']) - 6, np.max(dendrogram['layout']['xaxis']['tickvals'])+4], showgrid=False, - zeroline=True, ticks='', automargin=True, anchor='y'), - yaxis=dict(domain=[0.7, 1], autorange=True, showgrid=False, zeroline=False, ticks='outside', title='Height', automargin=True, anchor='x'), - xaxis2=dict(domain=[0, 1], autorange=True, showgrid=True, zeroline=False, ticks='', showticklabels=False, automargin=True, anchor='y2'), - yaxis2=dict(domain=[0, 0.64], autorange=True, showgrid=False, zeroline=False, automargin=True, anchor='x2')) + layout = go.Layout( + width=900, + height=900, + showlegend=False, + title="", + xaxis=dict( + domain=[0, 1], + range=[ + np.min(dendrogram_["layout"]["xaxis"]["tickvals"]) - 6, + np.max(dendrogram_["layout"]["xaxis"]["tickvals"]) + 4, + ], + showgrid=False, + zeroline=True, + ticks="", + automargin=True, + anchor="y", + ), + yaxis=dict( + domain=[0.7, 1], + autorange=True, + showgrid=False, + zeroline=False, + ticks="outside", + title="Height", + automargin=True, + anchor="x", + ), + xaxis2=dict( + domain=[0, 1], + autorange=True, + showgrid=True, + zeroline=False, + ticks="", + showticklabels=False, + automargin=True, + anchor="y2", + ), + yaxis2=dict( + domain=[0, 0.64], + autorange=True, + showgrid=False, + zeroline=False, + automargin=True, + anchor="x2", + ), + ) - if all(list(METcor.columns.map(lambda x: METcor[x].between(-1, 1, inclusive=True).all()))) != True: - df = wgcnaAnalysis.get_percentiles_heatmap(METcor, dendro_tree, bydendro=True, bycols=False).T + if not ( + all( + list( + METcor.columns.map( + lambda x: METcor[x].between(-1, 1, inclusive=True).all() + ) + ) + ) + ): + df = wgcna_analysis.get_percentiles_heatmap( + METcor, dendro_tree, bydendro=True, bycols=False + ).T else: - df = wgcnaAnalysis.df_sort_by_dendrogram(wgcnaAnalysis.df_sort_by_dendrogram(METcor, dendro_tree).T, dendro_tree) - - heatmap = wgcnaFigures.get_heatmap(df, colorscale=[[0,'#67a9cf'],[0.5,'#f7f7f7'],[1,'#ef8a62']], color_missing=False) + df = wgcna_analysis.df_sort_by_dendrogram( + wgcna_analysis.df_sort_by_dendrogram(METcor, dendro_tree).T, + dendro_tree, + ) + + heatmap = wgcna.get_heatmap( + df, + colorscale=[[0, "#67a9cf"], [0.5, "#f7f7f7"], [1, "#ef8a62"]], + color_missing=False, + ) figure = tools.make_subplots(rows=2, cols=1, print_grid=False) - for i in list(dendrogram['data']): + for i in list(dendrogram_["data"]): figure.append_trace(i, 1, 1) - for j in list(heatmap['data']): + for j in list(heatmap["data"]): figure.append_trace(j, 2, 1) - figure['layout'] = layout - figure['layout']['template'] = 'plotly_white' - figure['layout'].update({'xaxis':dict(domain=[0, 1], ticks='', showticklabels=False, anchor='y'), - 'xaxis2':dict(domain=[0, 1], ticks='', showticklabels=True, anchor='y2'), - 'yaxis':dict(domain=[0.635, 1], anchor='x'), - 'yaxis2':dict(domain=[0., 0.635], ticks='', showticklabels=True, anchor='x2')}) + figure["layout"] = layout + figure["layout"]["template"] = "plotly_white" + figure["layout"].update( + { + "xaxis": dict( + domain=[0, 1], ticks="", showticklabels=False, anchor="y" + ), + "xaxis2": dict( + domain=[0, 1], ticks="", showticklabels=True, anchor="y2" + ), + "yaxis": dict(domain=[0.635, 1], anchor="x"), + "yaxis2": dict( + domain=[0.0, 0.635], ticks="", showticklabels=True, anchor="x2" + ), + } + ) plots.append(figure) @@ -1938,7 +2660,7 @@ def get_WGCNAPlots(data, identifier): if isinstance(j, html.Div): graphs.append(j) else: - graphs.append(dcc.Graph(id=identifier+'_'+str(i), figure=j)) + graphs.append(dcc.Graph(id=identifier + "_" + str(i), figure=j)) return graphs @@ -1953,26 +2675,49 @@ def getMapperFigure(data, identifier, title): :param str title: plot title. :return: plotly FigureWidget within
. """ - pl_brewer = [[0.0, '#67001f'], - [0.1, '#b2182b'], - [0.2, '#d6604d'], - [0.3, '#f4a582'], - [0.4, '#fddbc7'], - [0.5, '#000000'], - [0.6, '#d1e5f0'], - [0.7, '#92c5de'], - [0.8, '#4393c3'], - [0.9, '#2166ac'], - [1.0, '#053061']] - figure = plotlyviz.plotlyviz(data, title=title, colorscale=pl_brewer, color_function_name="Group",factor_size=7, edge_linewidth=2.5, - node_linecolor="rgb(200,200,200)", width=1200, height=1200, bgcolor="rgba(240, 240, 240, 0.95)", - left=50, bottom=50, summary_height=300, summary_width=400, summary_left=20, summary_right=20, - hist_left=25, hist_right=25, member_textbox_width=800) - return dcc.Graph(id = identifier, figure=figure) + from kmapper import plotlyviz + + pl_brewer = [ + [0.0, "#67001f"], + [0.1, "#b2182b"], + [0.2, "#d6604d"], + [0.3, "#f4a582"], + [0.4, "#fddbc7"], + [0.5, "#000000"], + [0.6, "#d1e5f0"], + [0.7, "#92c5de"], + [0.8, "#4393c3"], + [0.9, "#2166ac"], + [1.0, "#053061"], + ] + figure = plotlyviz.plotlyviz( + data, + title=title, + colorscale=pl_brewer, + color_function_name="Group", + factor_size=7, + edge_linewidth=2.5, + node_linecolor="rgb(200,200,200)", + width=1200, + height=1200, + bgcolor="rgba(240, 240, 240, 0.95)", + left=50, + bottom=50, + summary_height=300, + summary_width=400, + summary_left=20, + summary_right=20, + hist_left=25, + hist_right=25, + member_textbox_width=800, + ) + return dcc.Graph(id=identifier, figure=figure) + def get_2_venn_diagram(data, identifier, cond1, cond2, args): """ - This function extracts the exlusive features in cond1 and cond2 and their common features, and build a two-circle venn diagram. + This function extracts the exlusive features in cond1 and cond2 and their common features, + and build a two-circle venn diagram. :param data: pandas dataframe with features as rows and group identifiers as columns. :param str identifier: id used to identify the div where the figure will be generated. @@ -1986,17 +2731,31 @@ def get_2_venn_diagram(data, identifier, cond1, cond2, args): Example:: - result = get_2_venn_diagram(data, identifier='venn2', cond1='group1', cond2='group2', args={'color':{'group1':'blue', 'group2':'red'}, \ - 'title':'Two-circle Venn diagram'}) + result = get_2_venn_diagram(data, + identifier='venn2', + cond1='group1', + cond2='group2', + args={'color':{'group1':'blue', 'group2':'red'}, + 'title':'Two-circle Venn diagram'} + ) """ figure = {} figure["data"] = [] - unique1 = len(set(data[cond1].dropna().index).difference(data[cond2].dropna().index))#/total - unique2 = len(set(data[cond2].dropna().index).difference(data[cond1].dropna().index))#/total - intersection = len(set(data[cond1].dropna().index).intersection(data[cond2].dropna().index))#/total + unique1 = len( + set(data[cond1].dropna().index).difference(data[cond2].dropna().index) + ) # /total + unique2 = len( + set(data[cond2].dropna().index).difference(data[cond1].dropna().index) + ) # /total + intersection = len( + set(data[cond1].dropna().index).intersection(data[cond2].dropna().index) + ) # /total + + return plot_2_venn_diagram( + cond1, cond2, unique1, unique2, intersection, identifier, args + ) - return plot_2_venn_diagram(cond1, cond2, unique1, unique2, intersection, identifier, args) def plot_2_venn_diagram(cond1, cond2, unique1, unique2, intersection, identifier, args): """ @@ -2016,92 +2775,110 @@ def plot_2_venn_diagram(cond1, cond2, unique1, unique2, intersection, identifier Example:: - result = plot_2_venn_diagram(cond1='group1', cond2='group2', unique1=10, unique2=15, intersection=8, identifier='vennplot', \ - args={'color':{'group1':'blue', 'group2':'red'}, 'title':'Two-circle Venn diagram'}) + result = plot_2_venn_diagram(cond1='group1', + cond2='group2', + unique1=10, + unique2=15, + intersection=8, + identifier='vennplot', + args={'color':{'group1':'blue', 'group2':'red'}, + 'title':'Two-circle Venn diagram'} + ) """ figure = {} figure["data"] = [] - figure["data"] = [go.Scattergl( - x=[1, 1.75, 2.5], - y=[1, 1, 1], - text=[str(unique1), str(intersection), str(unique2)], - mode='text', - textfont=dict( - color='black', - size=14, - family='Arial', + figure["data"] = [ + go.Scattergl( + x=[1, 1.75, 2.5], + y=[1, 1, 1], + text=[str(unique1), str(intersection), str(unique2)], + mode="text", + textfont=dict( + color="black", + size=14, + family="Arial", + ), ) - )] - - if 'colors' not in args: - args['colors'] = {cond1:'#a6bddb', cond2:'#045a8d'} + ] + if "colors" not in args: + args["colors"] = {cond1: "#a6bddb", cond2: "#045a8d"} figure["layout"] = { - 'xaxis': { - 'showticklabels': False, - 'showgrid': False, - 'zeroline': False, + "xaxis": { + "showticklabels": False, + "showgrid": False, + "zeroline": False, }, - 'yaxis': { - 'showticklabels': False, - 'showgrid': False, - 'zeroline': False, + "yaxis": { + "showticklabels": False, + "showgrid": False, + "zeroline": False, }, - 'shapes': [ + "shapes": [ { - 'opacity': 0.3, - 'xref': 'x', - 'yref': 'y', - 'fillcolor': args['colors'][cond1], - 'x0': 0, - 'y0': 0, - 'x1': 2, - 'y1': 2, - 'type': 'circle', - 'line': { - 'color': args['colors'][cond1], + "opacity": 0.3, + "xref": "x", + "yref": "y", + "fillcolor": args["colors"][cond1], + "x0": 0, + "y0": 0, + "x1": 2, + "y1": 2, + "type": "circle", + "line": { + "color": args["colors"][cond1], }, }, { - 'opacity': 0.3, - 'xref': 'x', - 'yref': 'y', - 'fillcolor': args['colors'][cond2], - 'x0': 1.5, - 'y0': 0, - 'x1': 3.5, - 'y1': 2, - 'type': 'circle', - 'line': { - 'color': args['colors'][cond2], + "opacity": 0.3, + "xref": "x", + "yref": "y", + "fillcolor": args["colors"][cond2], + "x0": 1.5, + "y0": 0, + "x1": 3.5, + "y1": 2, + "type": "circle", + "line": { + "color": args["colors"][cond2], }, - } + }, ], - 'margin': { - 'l': 20, - 'r': 20, - 'b': 100 - }, - 'height': 600, - 'width': 800, - 'title':args['title'], - "template":'plotly_white' + "margin": {"l": 20, "r": 20, "b": 100}, + "height": 600, + "width": 800, + "title": args["title"], + "template": "plotly_white", } - return dcc.Graph(id = identifier, figure=figure) + return dcc.Graph(id=identifier, figure=figure) + -def get_wordcloud(data, identifier, args={'stopwords':[], 'max_words': 400, 'max_font_size': 100, 'width':700, 'height':700, 'margin': 1}): +def get_wordcloud( + data, + identifier, + args={ + "stopwords": [], + "max_words": 400, + "max_font_size": 100, + "width": 700, + "height": 700, + "margin": 1, + }, +): """ This function generates a Wordcloud based on the natural text in a pandas dataframe column. - :param data: pandas dataframe with columns: 'PMID', 'abstract', 'authors', 'date', 'journal', 'keywords', 'title', 'url', 'Proteins', 'Diseases'. + :param data: pandas dataframe with columns: 'PMID', 'abstract', 'authors', 'date', 'journal', + 'keywords', 'title', 'url', 'Proteins', 'Diseases'. :param str identifier: id used to identify the div where the figure will be generated. :param dict args: see below. :Arguments: - * **text_col** (str) -- name of column containing the natural text used to generate the wordcloud. + * **text_col** (str) -- name of column containing the natural text used to \ + generate the wordcloud. * **stopwords** (list) -- list of words that will be eliminated. * **max_words** (int) -- maximum number of words. * **max_font_size** (int) -- maximum font size for the largest word. @@ -2113,36 +2890,165 @@ def get_wordcloud(data, identifier, args={'stopwords':[], 'max_words': 400, 'max Example:: - result = get_wordcloud(data, identifier='wordcloud', args={'stopwords':['BACKGROUND','CONCLUSION','RESULT','METHOD','CONCLUSIONS','RESULTS','METHODS'], \ - 'max_words': 400, 'max_font_size': 100, 'width':700, 'height':700, 'margin': 1}) + result = get_wordcloud(data, + identifier='wordcloud', + args={'stopwords':['BACKGROUND', + 'CONCLUSION', + 'RESULT', + 'METHOD', + 'CONCLUSIONS', + 'RESULTS', + 'METHODS'], + 'max_words': 400, + 'max_font_size': 100, + 'width':700, + 'height':700, + 'margin': 1} + ) """ - figure=None + figure = None if data is not None: - nltk.download('stopwords') - sw = set(stopwords.words('english')).union(set(STOPWORDS)) - sw.update(['patient', 'protein', 'gene', 'proteins', 'genes', 'using', 'review', - 'identified', 'identify', 'expression', 'role', 'functions', 'factors', - 'cell', 'function', 'human', 'may', 'found', 'well', 'analysis', 'recent', - 'data', 'include', 'including', 'specific', 'involve', 'study', 'information', - 'studies', 'demonstrate', 'demonstrated', 'associated', 'show', 'Changes', - 'showed', 'high', 'low', 'level', 'system', 'change', 'different', 'pathways', - 'similar', 'several', 'many', 'factor', 'peptide', 'based', 'lead', 'processes', - 'finding', 'pathway', 'process', 'disease', 'performed', 'p', 'p-value', 'model', - 'value', 'line', 'mechanism', 'patients', 'mechanisms', 'known', 'increased', 'cells', - 'decreased', 'involved', 'complex', 'levels', 'type', 'activity', 'new', 'contribute', - 'approach', 'significant', 'significantly', 'propose', 'suggests', 'biological', 'suggest', - 'lines', 'clinical', 'provide', 'tissue', 'number', 'via', 'important', 'co', 'profile', - 'vitro', 'complexes', 'domain', 'signaling', 'induced', 'regulation', 'signature', 'dependent', - 'diseases', 'vivo', 'promote', 'Thus', 'used', 'related', 'key', 'shown', 'first', 'one', 'two', - 'research', 'pattern', 'major', 'novel', 'effect', 'affect', 'multiple', 'present', 'normal', - 'Finally', 'signalling', 'Although', 'play', 'discuss', 'risk', 'association', 'underlying']) - if 'stopwords' in args: - sw = sw.union(args['stopwords']) + nltk.download("stopwords") + sw = set(stopwords.words("english")).union(set(STOPWORDS)) + sw.update( + [ + "patient", + "protein", + "gene", + "proteins", + "genes", + "using", + "review", + "identified", + "identify", + "expression", + "role", + "functions", + "factors", + "cell", + "function", + "human", + "may", + "found", + "well", + "analysis", + "recent", + "data", + "include", + "including", + "specific", + "involve", + "study", + "information", + "studies", + "demonstrate", + "demonstrated", + "associated", + "show", + "Changes", + "showed", + "high", + "low", + "level", + "system", + "change", + "different", + "pathways", + "similar", + "several", + "many", + "factor", + "peptide", + "based", + "lead", + "processes", + "finding", + "pathway", + "process", + "disease", + "performed", + "p", + "p-value", + "model", + "value", + "line", + "mechanism", + "patients", + "mechanisms", + "known", + "increased", + "cells", + "decreased", + "involved", + "complex", + "levels", + "type", + "activity", + "new", + "contribute", + "approach", + "significant", + "significantly", + "propose", + "suggests", + "biological", + "suggest", + "lines", + "clinical", + "provide", + "tissue", + "number", + "via", + "important", + "co", + "profile", + "vitro", + "complexes", + "domain", + "signaling", + "induced", + "regulation", + "signature", + "dependent", + "diseases", + "vivo", + "promote", + "Thus", + "used", + "related", + "key", + "shown", + "first", + "one", + "two", + "research", + "pattern", + "major", + "novel", + "effect", + "affect", + "multiple", + "present", + "normal", + "Finally", + "signalling", + "Although", + "play", + "discuss", + "risk", + "association", + "underlying", + ] + ) + if "stopwords" in args: + sw = sw.union(args["stopwords"]) if isinstance(data, pd.DataFrame): if not data.empty: if "text_col" in args: - text = ''.join(str(a) for a in data[args["text_col"]].unique().tolist()) + text = "".join( + str(a) for a in data[args["text_col"]].unique().tolist() + ) else: return None else: @@ -2150,19 +3056,21 @@ def get_wordcloud(data, identifier, args={'stopwords':[], 'max_words': 400, 'max else: text = data - wc = WordCloud(stopwords = sw, - max_words = args['max_words'], - max_font_size = args['max_font_size'], - background_color='white', - margin=args['margin']) + wc = WordCloud( + stopwords=sw, + max_words=args["max_words"], + max_font_size=args["max_font_size"], + background_color="white", + margin=args["margin"], + ) wc.generate(text) - word_list=[] - freq_list=[] - fontsize_list=[] - position_list=[] - orientation_list=[] - color_list=[] + word_list = [] + freq_list = [] + fontsize_list = [] + position_list = [] + orientation_list = [] + color_list = [] for (word, freq), fontsize, position, orientation, color in wc.layout_: word_list.append(word) @@ -2173,87 +3081,90 @@ def get_wordcloud(data, identifier, args={'stopwords':[], 'max_words': 400, 'max color_list.append(color) # get the positions - x=[] - y=[] + x = [] + y = [] j = 0 for i in position_list: - x.append(i[1]+fontsize_list[j]+10) - y.append(i[0]+5) + x.append(i[1] + fontsize_list[j] + 10) + y.append(i[0] + 5) j += 1 # get the relative occurence frequencies new_freq_list = [] for i in freq_list: - new_freq_list.append(i*70) + new_freq_list.append(i * 70) new_freq_list - trace = go.Scattergl(x=x, - y=y, - textfont = dict(size=new_freq_list, - color=color_list), - hoverinfo='text', - hovertext=['{0} freq: {1}'.format(w, f) for w, f in zip(word_list, freq_list)], - mode="text", - text=word_list - ) + trace = go.Scattergl( + x=x, + y=y, + textfont=dict(size=new_freq_list, color=color_list), + hoverinfo="text", + hovertext=[ + "{0} freq: {1}".format(w, f) for w, f in zip(word_list, freq_list) + ], + mode="text", + text=word_list, + ) layout = go.Layout( - xaxis=dict(showgrid=False, - showticklabels=False, - zeroline=False, - automargin=True), - yaxis=dict(showgrid=False, - showticklabels=False, - zeroline=False, - automargin=True), - width=args['width'], - height=args['height'], - title=args['title'], - annotations = [dict(xref='paper', yref='paper', showarrow=False, text='')], - template='plotly_white' - ) + xaxis=dict( + showgrid=False, showticklabels=False, zeroline=False, automargin=True + ), + yaxis=dict( + showgrid=False, showticklabels=False, zeroline=False, automargin=True + ), + width=args["width"], + height=args["height"], + title=args["title"], + annotations=[dict(xref="paper", yref="paper", showarrow=False, text="")], + template="plotly_white", + ) figure = dict(data=[trace], layout=layout) - return dcc.Graph(id = identifier, figure=figure) + return dcc.Graph(id=identifier, figure=figure) def get_cytoscape_network(net, identifier, args): """ - This function creates a Cytoscpae network in dash. For more information visit https://dash.plot.ly/cytoscape. + This function creates a Cytoscpae network in dash. - :param dict net: dictionary in which each element (key) is defined by a dictionary with 'id' and 'label' \ - (if it is a node) or 'source', 'target' and 'label' (if it is an edge). + For more information visit https://dash.plot.ly/cytoscape. + + :param dict net: dictionary in which each element (key) is defined by a dictionary with 'id' + and 'label' (if it is a node) or 'source', 'target' and 'label' (if it is an edge). :param str identifier: is the id used to identify the div where the figure will be generated. :param dict args: see below. :Arguments: * **title** (str) -- title of the figure. - * **stylesheet** (list[dict]) -- specifies the style for a group of elements, a class of elements, or a single element \ - (accepts two keys 'selector' and 'style'). + * **stylesheet** (list[dict]) -- specifies the style for a group of elements, \ + a class of elements, or a single element (accepts two keys 'selector' and 'style'). * **layout** (dict) -- specifies how the nodes should be positioned on the screen. :return: network figure within
. """ - height = '700px' - width = '100%' - if 'height' in args: - height = args['height'] - if 'width' in args: - width = args['width'] - cytonet = cyto.Cytoscape(id=identifier, - stylesheet=args['stylesheet'], - elements=net, - layout=args['layout'], - minZoom = 0.2, - maxZoom = 1.5, - #mouseoverNodeData=args['mouseover_node'], - style={'width': width, 'height': height} - ) - net_div = html.Div([html.H2(args['title']), cytonet]) - + height = "700px" + width = "100%" + if "height" in args: + height = args["height"] + if "width" in args: + width = args["width"] + cytonet = cyto.Cytoscape( + id=identifier, + stylesheet=args["stylesheet"], + elements=net, + layout=args["layout"], + minZoom=0.2, + maxZoom=1.5, + # mouseoverNodeData=args['mouseover_node'], + style={"width": width, "height": height}, + ) + net_div = html.Div([html.H2(args["title"]), cytonet]) return net_div -def save_DASH_plot(plot, name, plot_format='svg', directory='.', width=800, height=700): + +def save_DASH_plot(plot, name, plot_format="svg", directory=".", width=800, height=700): """ This function saves a plotly figure to a specified directory, in a determined format. @@ -2265,68 +3176,67 @@ def save_DASH_plot(plot, name, plot_format='svg', directory='.', width=800, heig Example:: - result = save_DASH_plot(plot, name='Plot example', plot_format='svg', directory='/data/plots') + result = save_DASH_plot(plot, name='Plot example', + plot_format='svg', directory='/data/plots') """ try: if not os.path.exists(directory): os.mkdir(directory) - plot_file = os.path.join(directory, str(name)+'.'+str(plot_format)) - if plot_format in ['svg', 'pdf', 'png', 'jpeg', 'jpg']: - if hasattr(plot, 'figure'): + plot_file = os.path.join(directory, str(name) + "." + str(plot_format)) + if plot_format in ["svg", "pdf", "png", "jpeg", "jpg"]: + if hasattr(plot, "figure"): pio.write_image(plot.figure, plot_file, width=width, height=height) else: pio.write_image(plot, plot_file, width=width, height=height) - elif plot_format == 'json': + elif plot_format == "json": figure_json = json.dumps(plot.figure, cls=plotly.utils.PlotlyJSONEncoder) - with open(plot_file, 'w') as f: + with open(plot_file, "w") as f: f.write(figure_json) except ValueError as err: print("Plot could not be saved. Error: {}".format(err)) def mpl_to_plotly(fig, ci=True, legend=True): - ##ToDo Test how it works for multiple groups - ##ToDo Allow visualization of CI + # ToDo Test how it works for multiple groups + # ToDo Allow visualization of CI # Convert mpl fig obj to plotly fig obj, resize to plotly's default py_fig = tls.mpl_to_plotly(fig, resize=True) # Add fill property to lower limit line - if ci == True: - style1 = dict(fill='tonexty') + if ci: + style1 = dict(fill="tonexty") # apply style - py_fig['data'][2].update(style1) + py_fig["data"][2].update(style1) # change the default line type to 'step' - #py_fig.update_traces(dict(line=go.layout.shape.Line({'dash':'solid'}))) - py_fig['data'] = py_fig['data'][0:2] + # py_fig.update_traces(dict(line=go.layout.shape.Line({'dash':'solid'}))) + py_fig["data"] = py_fig["data"][0:2] # Delete misplaced legend annotations - py_fig['layout'].pop('annotations', None) + py_fig["layout"].pop("annotations", None) if legend: # Add legend, place it at the top right corner of the plot - py_fig['layout'].update( + py_fig["layout"].update( font=dict(size=14), showlegend=True, height=400, width=1000, - template='plotly_white', - legend=go.layout.Legend( - x=1.05, - y=1 - ) + template="plotly_white", + legend=go.layout.Legend(x=1.05, y=1), ) # Send updated figure object to Plotly, show result in notebook return py_fig + def get_km_plot_old(data, identifier, args): figure = {} if len(data) == 3: kmfit, summary, plot = data if kmfit is not None: - title = 'Kaplan-meier plot' + summary - if 'title' in args: - title = args['title'] + "--" + summary + title = "Kaplan-meier plot" + summary + if "title" in args: + title = args["title"] + "--" + summary plot.set_title(title) - #p = kmfit.plot(ci_force_lines=True, title=title, show_censors=True) + # p = kmfit.plot(ci_force_lines=True, title=title, show_censors=True) figure = mpl_to_plotly(plot.figure, legend=True) return dcc.Graph(id=identifier, figure=figure) @@ -2335,12 +3245,12 @@ def get_km_plot_old(data, identifier, args): def get_km_plot(data, identifier, args): figure = {} plot = plt.subplot(111) - environment = 'app' + environment = "app" colors = None - if 'environment' in args: - environment = args['environment'] - if 'colors' in args: - colors = args['colors'] + if "environment" in args: + environment = args["environment"] + if "colors" in args: + colors = args["colors"] if len(data) == 2: kmfit, summary = data if kmfit is not None: @@ -2351,35 +3261,36 @@ def get_km_plot(data, identifier, args): c = colors[i] plot = kmf.plot_survival_function(show_censors=True, ax=plot, c=c) i += 1 - title = 'Kaplan-meier plot ' + summary - xlabel = 'Time' - ylabel = 'Survival' - if 'title' in args: - title = args['title'] + "--" + summary - if 'xlabel' in args: - xlabel = args['xlabel'] - if 'ylabel' in args: - ylabel = args['ylabel'] + title = "Kaplan-meier plot " + summary + xlabel = "Time" + ylabel = "Survival" + if "title" in args: + title = args["title"] + "--" + summary + if "xlabel" in args: + xlabel = args["xlabel"] + if "ylabel" in args: + ylabel = args["ylabel"] plot.set_title(title) plot.set_xlabel(xlabel) plot.set_ylabel(ylabel) - if environment == 'app': - figure = utils.mpl_to_html_image(plot.figure, width=800) + if environment == "app": + figure = mpl_to_html_image(plot.figure, width=800) result = html.Div(id=identifier, children=[figure]) else: result = plot.figure - + return result + def get_cumulative_hazard_plot(data, identifier, args): figure = {} plot = plt.subplot(111) - environment = 'app' + environment = "app" colors = None - if 'environment' in args: - environment = args['environment'] - if 'colors' in args: - colors = args['colors'] + if "environment" in args: + environment = args["environment"] + if "colors" in args: + colors = args["colors"] if len(data) == 2: hrfit = data if hrfit is not None: @@ -2390,24 +3301,24 @@ def get_cumulative_hazard_plot(data, identifier, args): c = colors[i] plot = hrdf.plot_cumulative_hazard(ax=plot, c=c) i += 1 - title = 'Cumulative Hazard plot ' - xlabel = 'Time' - ylabel = 'Nelson Aalen - Cumulative Hazard' - if 'title' in args: - title = args['title'] + "--" - if 'xlabel' in args: - xlabel = args['xlabel'] - if 'ylabel' in args: - ylabel = args['ylabel'] + title = "Cumulative Hazard plot " + xlabel = "Time" + ylabel = "Nelson Aalen - Cumulative Hazard" + if "title" in args: + title = args["title"] + "--" + if "xlabel" in args: + xlabel = args["xlabel"] + if "ylabel" in args: + ylabel = args["ylabel"] plot.set_title(title) plot.set_xlabel(xlabel) plot.set_ylabel(ylabel) - if environment == 'app': - figure = utils.mpl_to_html_image(plot.figure, width=800) + if environment == "app": + figure = mpl_to_html_image(plot.figure, width=800) result = html.Div(id=identifier, children=[figure]) else: result = plot.figure - + return result @@ -2415,113 +3326,155 @@ def get_polar_plot(df, identifier, args): """ This function creates a Polar plot with data aggregated for a given group. - :param dataframe df: dataframe with the data to plot + :param pandas.DataFrame df: dataframe with the data to plot :param str identifier: identifier to be used in the app - :param dict args: dictionary containing the arguments needed to plot the figure (value_col (value to aggregate), group_col (group by), color_col (color by)) + :param dict args: dictionary containing the arguments needed to plot the figure ( + value_col (value to aggregate), group_col (group by), color_col (color by)) :return: Dash Graph Example:: - figure = get_polar_plot(df, identifier='polar', args={'value_col':'intensity', 'group_col':'modifier', 'color_col':'group'}) + + figure = get_polar_plot(df, + identifier='polar', + args={'value_col':'intensity', + 'group_col':'modifier', + 'color_col':'group'} + ) """ figure = {} - line_close = True - ptype = 'bar' - title = 'Polar plot' + ptype = "bar" width = 800 height = 700 - value = 'value' + value = "value" group = None colors = None if not df.empty: - if 'value_col' in args: - value = args['value_col'] - if 'theta_col' in args: - group = args['theta_col'] - if 'color_col' in args: - colors = args['color_col'] - if 'line_close' in args: - line_close = args['line_close'] - if 'title' in args: - title = args['title'] - if 'width' in args: - width = args['width'] - if 'height' in args: - height = args['height'] - if 'type' in args: - ptype = args['type'] + if "value_col" in args: + value = args["value_col"] + if "theta_col" in args: + group = args["theta_col"] + if "color_col" in args: + colors = args["color_col"] + if "width" in args: + width = args["width"] + if "height" in args: + height = args["height"] + if "type" in args: + ptype = args["type"] figure = go.Figure() if value is not None and group is not None and colors is not None: if not df.empty: min_value = df[value].min() max_value = df[value].max() - if ptype == 'line': + if ptype == "line": for color in df[colors].unique(): cdf = df[df[colors] == color] - figure.add_trace(go.Scatterpolar(r = cdf[value], - theta = cdf[group], - mode = 'lines', - name = color, - fill='toself')) + figure.add_trace( + go.Scatterpolar( + r=cdf[value], + theta=cdf[group], + mode="lines", + name=color, + fill="toself", + ) + ) else: - print("Type {} not available. Try with 'line' or 'bar' types.".format(ptype)) + print( + "Type {} not available. Try with 'line' or 'bar' types.".format( + ptype + ) + ) - layout = figure.update_layout(width=width, height=height, polar = dict(radialaxis=dict(range=[min_value, max_value]))) + figure.update_layout( + width=width, + height=height, + polar=dict(radialaxis=dict(range=[min_value, max_value])), + ) return dcc.Graph(id=identifier, figure=figure) + def get_enrichment_plots(enrichment_results, identifier, args): """ - This function generates a scatter plot with enriched terms (y-axis) and their adjusted pvalues (x-axis) + This function generates a scatter plot with enriched terms (y-axis) + and their adjusted pvalues (x-axis) - :param dataframe enrichment_results: dataframe with the enrichment data to plot (see enrichment functions for format) + :param pandas.DataFrame enrichment_results: dataframe with the enrichment data to plot + (see enrichment functions for format) :param str identifier: identifier to be used in the app - :param dict args: dictionary containing the arguments needed to plot the figure (width, height, title) - :return list: list of scatter plots one for each enrichment table available (i.e pairwise comparisons) + :param dict args: dictionary containing the arguments needed to plot the figure + (width, height, title) + :return list: list of scatter plots one for each enrichment table available + (i.e pairwise comparisons) Example:: - figure = get_enrichment_plots(df, identifier='enrichment', args={'width':1500, 'height':800, 'title':'Enrichment'}) + + figure = get_enrichment_plots(df, + identifier='enrichment', + args={'width':1500, + 'height':800, + 'title':'Enrichment'} + ) """ figures = [] width = 900 height = 800 - colors = {'upregulated': '#cb181d', 'downregulated': '#3288bd', 'regulated': '#ae017e', 'non-regulated': '#fcc5c0'} + colors = { + "upregulated": "#cb181d", + "downregulated": "#3288bd", + "regulated": "#ae017e", + "non-regulated": "#fcc5c0", + } title = "Enrichment" - if 'width' in args: - width = args['width'] - if 'height' in args: - height = args['height'] - if 'title' in args: - title = args['title'] + if "width" in args: + width = args["width"] + if "height" in args: + height = args["height"] + if "title" in args: + title = args["title"] if not isinstance(enrichment_results, dict): aux = enrichment_results.copy() - enrichment_results = {'regulated~non-regulated': aux} + enrichment_results = {"regulated~non-regulated": aux} for g in enrichment_results: - g1, g2 = g.split('~') - group = 'direction' - nid = identifier+'_{}_{}'.format(g1, g2) + g1, g2 = g.split("~") + group = "direction" + nid = identifier + "_{}_{}".format(g1, g2) if not enrichment_results[g].empty: df = enrichment_results[g][enrichment_results[g].rejected] - if 'direction' not in df: + if "direction" not in df: group = None if not df.empty: - df = df.sort_values(by=[group, 'padj'], ascending=False) - df['x'] = -np.log10(df['padj']) - fig = get_scatterplot(df, identifier=nid, - args={'x': 'x', - 'y': 'terms', - 'group': group, - 'title':'{} {} vs {}'.format(title, g1, g2), - 'symbol':group, - 'colors': colors, - 'x_title':'-log10(padj)', - 'y_title':'Enriched terms', - 'width': width, - 'height':height, - 'hovering_cols':['foreground', 'foreground_pop', 'background', 'background_pop', 'pvalue', 'padj', 'identifiers'], - 'size':'foreground'}) + df = df.sort_values(by=[group, "padj"], ascending=False) + df["x"] = -np.log10(df["padj"]) + fig = get_scatterplot( + df, + identifier=nid, + args={ + "x": "x", + "y": "terms", + "group": group, + "title": "{} {} vs {}".format(title, g1, g2), + "symbol": group, + "colors": colors, + "x_title": "-log10(padj)", + "y_title": "Enriched terms", + "width": width, + "height": height, + "hovering_cols": [ + "foreground", + "foreground_pop", + "background", + "background_pop", + "pvalue", + "padj", + "identifiers", + ], + "size": "foreground", + }, + ) figures.append(fig) return figures diff --git a/src/vuecore/wgcna.py b/src/vuecore/wgcna.py new file mode 100644 index 0000000..92c29ec --- /dev/null +++ b/src/vuecore/wgcna.py @@ -0,0 +1,795 @@ +import numpy as np +import pandas as pd +import plotly.graph_objs as go +import plotly.subplots as tools +import scipy as scp +from acore import wgcna_analysis + +from . import color_list, dendrogram + + +def get_module_color_annotation( + map_list, + col_annotation=False, + row_annotation=False, + bygene=False, + module_colors=[], + dendrogram=[], +): + """ + This function takes a list of values, converts them into colors, and creates a new plotly object to be used as an annotation. + Options module_colors and dendrogram only apply when map_list is a list of experimental features used in module eigenegenes calculation. + + :param list map_list: dendrogram leaf labels. + :param bool col_annotation: if True, adds color annotations as a row. + :param bool row_annotation: if True, adds color annotations as a column. + :param bool bygene: determines wether annotation colors have to be reordered to match dendrogram leaf labels. + :param list module_colors: dendrogram leaf module color. + :param dict dendrogram: dendrogram represented as a plotly object figure. + :return: Plotly object figure. + + .. note:: map_list and module_colors must have the same length. + """ + colors_dict = color_list.make_color_dict() + + n = len(map_list) + val = 1 / (n - 1) + number = 0 + colors = [] + vals = [] + + # Use if color annotation is for experimental features in dendrogram + if bygene: + module_colors = [i.lower().replace(" ", "") for i in module_colors] + gene_colors = dict(zip(map_list, module_colors)) + + for i in map_list: + name = gene_colors[i] + color = colors_dict[name] + n = number + colors.append([round(n, 4), color]) + vals.append((i, round(n, 4))) + number = n + val + + labels = list(dendrogram["layout"]["xaxis"]["ticktext"]) + y = [1] * len(labels) + + df = pd.DataFrame([labels, y], index=["labels", "y"]).T + df["vals"] = df["labels"].map(dict(vals)) + + # Use if map_list is a list of co-expression modules names + else: + for i in map_list: + name = i.split("ME") + if len(name) == 2: + name = name[1] + color = colors_dict[name] + n = number + colors.append([round(n, 4), color]) + vals.append((i, round(n, 4))) + number = n + val + else: + name = name[0] + n = number + colors.append([round(n, 4), "#ffffff"]) + vals.append((i, round(n, 4))) + number = n + val + + y = [1] * len(map_list) + df = pd.DataFrame([map_list, y], index=["labels", "y"]).T + df["vals"] = df["labels"].map(dict(vals)) + + if row_annotation and col_annotation: + r_annot = go.Heatmap( + z=df.vals, + x=df.y, + y=df.labels, + showscale=False, + colorscale=colors, + xaxis="x", + yaxis="y", + ) + c_annot = go.Heatmap( + z=df.vals, + x=df.labels, + y=df.y, + showscale=False, + colorscale=colors, + xaxis="x2", + yaxis="y2", + ) + return r_annot, c_annot + elif row_annotation: + r_annot = go.Heatmap( + z=df.vals, + x=df.y, + y=df.labels, + showscale=False, + colorscale=colors, + xaxis="x2", + yaxis="y2", + ) + return r_annot + elif col_annotation: + c_annot = go.Heatmap( + z=df.vals, + x=df.labels, + y=df.y, + showscale=False, + colorscale=colors, + xaxis="x2", + yaxis="y2", + ) + return c_annot + + return None + + +def get_heatmap(df, colorscale=None, color_missing=True): + """ + This function plots a simple Plotly heatmap. + + :param df: pandas dataframe containing experimental data, with samples/subjects as rows and features as columns. + :param list[list] colorscale: heatmap colorscale (e.g. [[0,'#67a9cf'],[0.5,'#f7f7f7'],[1,'#ef8a62']]). If colorscale is not defined, will take [[0, 'rgb(255,255,255)'], [1, 'rgb(255,51,0)']] as default. + :param bool color_missing: if set to True, plots missing values as grey in the heatmap. + :return: Plotly object figure. + """ + figure = {} + if df is not None: + if colorscale: + colors = colorscale + else: + colors = [[0, "rgb(255,255,255)"], [1, "rgb(255,51,0)"]] + + figure = {"layout": {"template": None}, "data": []} + figure["layout"]["template"] = "plotly_white" + figure["data"].append( + go.Heatmap( + z=df.values.tolist(), + y=list(df.index), + x=list(df.columns), + colorscale=colors, + showscale=True, + colorbar=dict( + x=1, y=0, xanchor="left", yanchor="bottom", len=0.35, thickness=15 + ), + ) + ) + if color_missing: + df_missing = wgcna_analysis.get_miss_values_df(df) + figure["data"].append( + go.Heatmap( + z=df_missing.values.tolist(), + y=list(df.index), + x=list(df.columns), + colorscale=[[0, "rgb(201,201,201)"], [1, "rgb(201,201,201)"]], + showscale=False, + ) + ) + + return figure + + +def plot_labeled_heatmap( + df, + textmatrix, + title, + colorscale=[[0, "rgb(0,255,0)"], [0.5, "rgb(255,255,255)"], [1, "rgb(255,0,0)"]], + width=1200, + height=800, + row_annotation=False, + col_annotation=False, +): + """ + This function plots a simple Plotly heatmap with column and/or row annotations and heatmap annotations. + + :param df: pandas dataframe containing data to be plotted in the heatmap. + :param textmatrix: pandas dataframe with heatmap annotations as values. + :param str title: the title of the figure. + :param list[list] colorscale: heatmap colorscale (e.g. [[0,'rgb(0,255,0)'],[0.5,'rgb(255,255,255)'],[1,'rgb(255,0,0)']]) + :param int width: the width of the figure. + :param int height: the height of the figure. + :param bool row_annotation: if True, adds a color-coded column at the left of the heatmap. + :param bool col_annotation: if True, adds a color-coded row at the bottom of the heatmap. + :return: Plotly object figure. + """ + figure = {} + if df is not None: + figure = get_heatmap(df, colorscale=colorscale, color_missing=False) + figure["data"].append( + get_module_color_annotation( + list(df.index), + row_annotation=row_annotation, + col_annotation=col_annotation, + bygene=False, + ) + ) + + annotations = [] + for n, row in enumerate(textmatrix.values): + for m, val in enumerate(row): + annotations.append( + go.layout.Annotation( + text=str(textmatrix.values[n][m]), + font=dict(size=8), + x=df.columns[m], + y=df.index[n], + xref="x", + yref="y", + showarrow=False, + ) + ) + + layout = go.Layout( + width=width, + height=height, + title=title, + xaxis=dict( + domain=[0.015, 1], + autorange=True, + showgrid=False, + zeroline=False, + showline=False, + ticks="", + showticklabels=True, + automargin=True, + anchor="y", + ), + yaxis=dict( + autorange="reversed", + ticklen=5, + ticks="outside", + tickcolor="white", + showticklabels=False, + automargin=True, + showgrid=False, + anchor="x", + ), + xaxis2=dict( + domain=[0, 0.01], + autorange=True, + showgrid=False, + zeroline=False, + showline=False, + ticks="", + showticklabels=False, + automargin=True, + anchor="y2", + ), + yaxis2=dict( + autorange="reversed", + showgrid=False, + zeroline=False, + showline=False, + ticks="", + showticklabels=True, + automargin=True, + anchor="x2", + ), + ) + + figure["layout"] = layout + figure["layout"]["template"] = "plotly_white" + figure["layout"].update(annotations=annotations) + + return figure + + +def plot_dendrogram_guidelines(Z_tree, dendrogram): + """ + This function takes a dendrogram tree dictionary and its plotly object and creates shapes to be plotted as vertical dashed lines in the dendrogram. + + :param dict Z_tree: dictionary of data structures computed to render the dendrogram. Keys: 'icoords', 'dcoords', 'ivl' and 'leaves'. + :param dendrogram: dendrogram represented as a plotly object figure. + :return: List of dictionaries. + """ + shapes = [] + if dendrogram is not None: + tickvals = list(dendrogram["layout"]["xaxis"]["tickvals"]) + maximum = len(tickvals) + step = int(maximum / 8) + minimum = int(0 + step) + + keys = ["type", "x0", "y0", "x1", "y1", "line"] + line_keys = ["color", "width", "dash"] + line_vals = ["rgb(192,192,192)", 0.1, "dot"] + line = dict(zip(line_keys, line_vals)) + + values = [] + for i in tickvals[minimum::step]: + values.append(("line", i, 0.3, i, np.max(Z_tree["dcoord"]))) + + values = [list(i) + [line] for i in values] + shapes = [] + for i in values: + d = dict(zip(keys, i)) + shapes.append(d) + + return shapes + + +def plot_intramodular_correlation( + MM, FS, feature_module_df, title, width=1000, height=800 +): + """ + This function uses the Feature significance and Module Membership measures, and plots a multi-scatter plot of all modules against all clinical traits. + + :param MM: pandas dataframe with module membership data + :param FS: pandas dataframe with feature significance data + :param feature_module_df: pandas DataFrame of experimental features and module colors (use mode='dataframe' in get_FeaturesPerModule) + :param str title: plot title + :param int width: plot width + :param int height: plot height + :return: Plotly object figure. + + Example:: + + plot = plot_intramodular_correlation(MM, FS, feature_module_df, title='Plot', width=1000, height=800): + + .. note:: There is a limit in the number of subplots one can make in Plotly. This function limits the number of modules shown to 5. + """ + figure = {} + if MM is not None: + MM = MM.iloc[:, -6] + MM["modColor"] = MM.index.map( + feature_module_df.set_index("name")["modColor"].get + ) + + figure = tools.make_subplots( + rows=len(FS.columns), + cols=len(MM.columns) - 1, + shared_xaxes=False, + shared_yaxes=False, + vertical_spacing=0.015, + horizontal_spacing=0.1, + print_grid=True, + ) + + figure.layout.template = "plotly_white" + layout = dict(width=width, height=height, showlegend=False, title=title) + figure.layout.update(layout) + + axis_dict = {} + for i, j in enumerate(MM.columns[MM.columns.str.startswith("MM")]): + n_p = len(FS.columns) * (len(MM.columns) - 1) - len( + MM.columns[MM.columns.str.startswith("MM")] + ) + axis_dict["xaxis{}".format(n_p + i + 1)] = dict( + title=j, titlefont=dict(size=13) + ) + print(axis_dict) + n = 1 + for a, b in enumerate(FS.columns): + name = b.split(" ") + if len(name) > 1: + label = ["
".join(name[i : i + 3]) for i in range(0, len(name), 3)][ + 0 + ] + else: + label = name[0] + axis_dict["yaxis{}".format(a + n)] = dict( + title=label, titlefont=dict(size=13) + ) + n += len(MM.columns[MM.columns.str.startswith("MM")]) - 1 + + annotation = [] + x_axis = 1 + y_axis = 1 + for a, b in enumerate(FS.columns): + for i, j in enumerate(MM.columns[MM.columns.str.startswith("MM")]): + name = MM[MM["modColor"] == j[2:]].index + x = abs(MM[MM["modColor"] == j[2:]][j].values) + y = abs(FS[FS.index.isin(name)][b].values) + + slope, intercept, r_value, p_value, std_err = scp.stats.linregress(x, y) + line = slope * x + intercept + + figure.append_trace( + go.Scattergl( + x=x, + y=y, + text=name, + mode="markers", + opacity=0.7, + marker={ + "size": 7, + "color": "white", + "line": {"width": 1.5, "color": j[2:]}, + }, + ), + a + 1, + i + 1, + ) + + figure.append_trace( + go.Scattergl(x=x, y=line, mode="lines", marker={"color": "black"}), + a + 1, + i + 1, + ) + + annot = dict( + x=0.7, + y=0.7, + xref="x{}".format(x_axis), + yref="y{}".format(y_axis), + text="R={:0.2}, p={:.0e}".format(r_value, p_value), + showarrow=False, + ) + annotation.append(annot) + x_axis += 1 + y_axis += 1 + + figure.layout.update(axis_dict) + figure.layout.update(annotations=annotation) + + return figure + + +def plot_complex_dendrogram( + dendro_df, + subplot_df, + title, + dendro_labels=[], + distfun="euclidean", + linkagefun="average", + hang=0.04, + subplot="module colors", + subplot_colorscale=[], + color_missingvals=True, + row_annotation=False, + col_annotation=False, + width=1000, + height=800, +): + """ + This function plots a dendrogram with a subplot below that can be a heatmap (annotated or not) or module colors. + + :param dendro_df: pandas dataframe containing data used to generate dendrogram, columns will result in dendrogram leaves. + :param subplot_df: pandas dataframe containing data used to generate plot below dendrogram. + :param str title: the title of the figure. + :param list dendro_labels: list of strings for dendrogram leaf nodes labels. + :param str distfun: distance measure to be used (‘euclidean‘, ‘maximum‘, ‘manhattan‘, ‘canberra‘, ‘binary‘, ‘minkowski‘ or ‘jaccard‘). + :param str linkagefun: hierarchical/agglomeration method to be used (‘single‘, ‘complete‘, ‘average‘, ‘weighted‘, ‘centroid‘, ‘median‘ or ‘ward‘). + :param float hang: height at which the dendrogram leaves should be placed. + :param str subplot: type of plot to be shown below the dendrogram (´module colors´ or ´heatmap´). + :param list subplot_colorscale: colorscale to be used in the subplot. + :param bool color_missingvals: if set to `True`, plots missing values as grey in the heatmap. + :param bool row_annotation: if `True`, adds a color-coded column at the left of the heatmap. + :param bool col_annotation: if `True`, adds a color-coded row at the bottom of the heatmap. + :param int width: the width of the figure. + :param int height: the height of the figure. + :return: Plotly object figure. + """ + figure = {} + dendro_tree = wgcna_analysis.get_dendrogram( + dendro_df, + dendro_labels, + distfun=distfun, + linkagefun=linkagefun, + div_clusters=False, + ) + if dendro_tree is not None: + dendrogram_ = dendrogram.plot_dendrogram( + dendro_tree, hang=hang, cutoff_line=False + ) + + layout = go.Layout( + width=width, + height=height, + showlegend=False, + title=title, + xaxis=dict( + domain=[0, 1], + range=[ + np.min(dendrogram_["layout"]["xaxis"]["tickvals"]) - 6, + np.max(dendrogram_["layout"]["xaxis"]["tickvals"]) + 4, + ], + showgrid=False, + zeroline=True, + ticks="", + automargin=True, + anchor="y", + ), + yaxis=dict( + domain=[0.7, 1], + autorange=True, + showgrid=False, + zeroline=False, + ticks="outside", + title="Height", + automargin=True, + anchor="x", + ), + xaxis2=dict( + domain=[0, 1], + autorange=True, + showgrid=True, + zeroline=False, + ticks="", + showticklabels=False, + automargin=True, + anchor="y2", + ), + yaxis2=dict( + domain=[0, 0.64], + autorange=True, + showgrid=False, + zeroline=False, + automargin=True, + anchor="x2", + ), + ) + + if subplot == "module colors": + figure = tools.make_subplots(rows=2, cols=1, print_grid=False) + + for i in list(dendrogram_["data"]): + figure.append_trace(i, 1, 1) + + shapes = plot_dendrogram_guidelines(dendro_tree, dendrogram_) + moduleColors = get_module_color_annotation( + dendro_labels, + col_annotation=col_annotation, + bygene=True, + module_colors=subplot_df, + dendrogram=dendrogram_, + ) + figure.append_trace(moduleColors, 2, 1) + figure["layout"] = layout + figure.layout.template = "plotly_white" + figure["layout"].update( + { + "shapes": shapes, + "xaxis": dict(showticklabels=False), + "yaxis": dict(domain=[0.2, 1]), + "yaxis2": dict( + domain=[0, 0.19], + title="Module colors", + ticks="", + showticklabels=False, + ), + } + ) + + elif subplot == "heatmap": + if not all( + list( + subplot_df.columns.map( + lambda x: subplot_df[x].between(-1, 1, inclusive=True).all() + ) + ) + ): + df = wgcna_analysis.get_percentiles_heatmap( + subplot_df, dendro_tree, bydendro=True, bycols=False + ).T + else: + df = wgcna_analysis.df_sort_by_dendrogram( + wgcna_analysis.df_sort_by_dendrogram(subplot_df, dendro_tree).T, + dendro_tree, + ) + + heatmap = get_heatmap( + df, colorscale=subplot_colorscale, color_missing=color_missingvals + ) + + if row_annotation and col_annotation: + figure = tools.make_subplots( + rows=3, + cols=2, + specs=[[{"colspan": 2}, None], [{}, {}], [{"colspan": 2}, None]], + print_grid=False, + ) + for i in list(dendrogram_["data"]): + figure.append_trace(i, 1, 1) + for j in list(heatmap["data"]): + figure.append_trace(j, 2, 2) + + r_annot, c_annot = get_module_color_annotation( + list(df.index), + row_annotation=row_annotation, + col_annotation=col_annotation, + bygene=False, + ) + figure.append_trace(r_annot, 2, 1) + figure.append_trace(c_annot, 3, 1) + + figure["layout"] = layout + figure.layout.template = "plotly_white" + figure["layout"].update( + { + "xaxis": dict(ticks="", showticklabels=False, anchor="y"), + "xaxis2": dict( + domain=[0, 0.01], + ticks="", + showticklabels=False, + automargin=True, + anchor="y2", + ), + "xaxis3": dict( + domain=[0.015, 1], + ticks="", + showticklabels=False, + automargin=True, + anchor="y3", + ), + "xaxis4": dict( + domain=[0.015, 1], + ticks="", + showticklabels=True, + automargin=True, + anchor="y4", + ), + "yaxis": dict(domain=[0.635, 1], automargin=True, anchor="x"), + "yaxis2": dict( + domain=[0.015, 0.635], + autorange="reversed", + ticks="", + showticklabels=True, + automargin=True, + anchor="x2", + ), + "yaxis3": dict( + domain=[0.01, 0.635], + autorange="reversed", + ticks="", + showticklabels=False, + automargin=True, + anchor="x3", + ), + "yaxis4": dict( + domain=[0, 0.01], + ticks="", + showticklabels=False, + automargin=True, + anchor="x4", + ), + } + ) + + elif not row_annotation and not col_annotation: + figure = tools.make_subplots(rows=2, cols=1, print_grid=False) + + for i in list(dendrogram_["data"]): + figure.append_trace(i, 1, 1) + for j in list(heatmap["data"]): + figure.append_trace(j, 2, 1) + + figure["layout"] = layout + figure.layout.template = "plotly_white" + figure.layout.update( + { + "xaxis": dict( + ticktext=np.array( + dendrogram_["layout"]["xaxis"]["ticktext"] + ), + tickvals=list(dendrogram_["layout"]["xaxis"]["tickvals"]), + ), + "yaxis2": dict(autorange="reversed"), + } + ) + + elif row_annotation: + figure = tools.make_subplots( + rows=2, + cols=2, + specs=[[{"colspan": 2}, None], [{}, {}]], + print_grid=False, + ) + for i in list(dendrogram_["data"]): + figure.append_trace(i, 1, 1) + for j in list(heatmap["data"]): + figure.append_trace(j, 2, 2) + + r_annot = get_module_color_annotation( + list(df.index), + row_annotation=row_annotation, + col_annotation=col_annotation, + bygene=False, + ) + figure.append_trace(r_annot, 2, 1) + + figure["layout"] = layout + figure.layout.template = "plotly_white" + figure["layout"].update( + { + "xaxis": dict( + domain=[0.015, 1], + ticktext=np.array( + dendrogram_["layout"]["xaxis"]["ticktext"] + ), + tickvals=list(dendrogram_["layout"]["xaxis"]["tickvals"]), + automargin=True, + anchor="y", + ), + "xaxis2": dict( + domain=[0, 0.010], + ticks="", + showticklabels=False, + automargin=True, + anchor="y2", + ), + "xaxis3": dict( + domain=[0.015, 1], + ticks="", + showticklabels=False, + automargin=True, + anchor="y3", + ), + "yaxis": dict(automargin=True, anchor="x"), + "yaxis2": dict( + autorange="reversed", + ticks="", + showticklabels=True, + automargin=True, + anchor="x2", + ), + "yaxis3": dict( + domain=[0, 0.64], + ticks="", + showticklabels=False, + automargin=True, + anchor="x3", + ), + } + ) + + elif col_annotation: + figure = tools.make_subplots( + rows=3, cols=1, specs=[[{}], [{}], [{}]], print_grid=False + ) + + for i in list(dendrogram_["data"]): + figure.append_trace(i, 1, 1) + for j in list(heatmap["data"]): + figure.append_trace(j, 3, 1) + + c_annot = get_module_color_annotation( + list(df.index), + row_annotation=row_annotation, + col_annotation=col_annotation, + bygene=False, + ) + figure.append_trace(c_annot, 2, 1) + + figure["layout"] = layout + figure.layout.template = "plotly_white" + figure["layout"].update( + { + "xaxis": dict( + ticktext=np.array( + dendrogram_["layout"]["xaxis"]["ticktext"] + ), + tickvals=list(dendrogram_["layout"]["xaxis"]["tickvals"]), + automargin=True, + anchor="y", + ), + "xaxis2": dict( + ticks="", showticklabels=False, automargin=True, anchor="y2" + ), + "xaxis3": dict( + domain=[0, 1], + ticks="", + showticklabels=False, + automargin=True, + anchor="y3", + ), + "yaxis": dict(domain=[0.70, 1], automargin=True, anchor="x"), + "yaxis2": dict( + domain=[0.615, 0.625], + ticks="", + showticklabels=False, + automargin=True, + anchor="x2", + ), + "yaxis3": dict( + domain=[0, 0.61], + autorange="reversed", + ticks="", + showticklabels=False, + automargin=True, + anchor="x3", + ), + } + ) + + return figure diff --git a/src/vuecore/wgcnaFigures.py b/src/vuecore/wgcnaFigures.py deleted file mode 100644 index b0d3390..0000000 --- a/src/vuecore/wgcnaFigures.py +++ /dev/null @@ -1,412 +0,0 @@ -import pandas as pd -import numpy as np -import scipy as scp -from ckg.analytics_core.viz import color_list -import plotly.graph_objs as go -import plotly.subplots as tools -from ckg.analytics_core.viz import Dendrogram -from ckg.analytics_core.analytics import wgcnaAnalysis - - -def get_module_color_annotation(map_list, col_annotation=False, row_annotation=False, bygene=False, module_colors=[], dendrogram=[]): - """ - This function takes a list of values, converts them into colors, and creates a new plotly object to be used as an annotation. - Options module_colors and dendrogram only apply when map_list is a list of experimental features used in module eigenegenes calculation. - - :param list map_list: dendrogram leaf labels. - :param bool col_annotation: if True, adds color annotations as a row. - :param bool row_annotation: if True, adds color annotations as a column. - :param bool bygene: determines wether annotation colors have to be reordered to match dendrogram leaf labels. - :param list module_colors: dendrogram leaf module color. - :param dict dendrogram: dendrogram represented as a plotly object figure. - :return: Plotly object figure. - - .. note:: map_list and module_colors must have the same length. - """ - colors_dict = color_list.make_color_dict() - - n = len(map_list) - val = 1/(n-1) - number = 0 - colors = [] - vals = [] - - #Use if color annotation is for experimental features in dendrogram - if bygene: - module_colors = [i.lower().replace(' ', '') for i in module_colors] - gene_colors = dict(zip(map_list, module_colors)) - - for i in map_list: - name = gene_colors[i] - color = colors_dict[name] - n = number - colors.append([round(n,4), color]) - vals.append((i, round(n,4))) - number = n+val - - labels = list(dendrogram['layout']['xaxis']['ticktext']) - y = [1]*len(labels) - - df = pd.DataFrame([labels, y], index=['labels', 'y']).T - df['vals'] = df['labels'].map(dict(vals)) - - #Use if map_list is a list of co-expression modules names - else: - for i in map_list: - name = i.split('ME') - if len(name) == 2: - name = name[1] - color = colors_dict[name] - n = number - colors.append([round(n,4), color]) - vals.append((i, round(n,4))) - number = n+val - else: - name = name[0] - n = number - colors.append([round(n,4), '#ffffff']) - vals.append((i, round(n,4))) - number = n+val - - y = [1]*len(map_list) - df = pd.DataFrame([map_list, y], index=['labels', 'y']).T - df['vals'] = df['labels'].map(dict(vals)) - - if row_annotation and col_annotation: - r_annot = go.Heatmap(z=df.vals, x=df.y, y=df.labels, showscale=False, colorscale=colors, xaxis='x', yaxis='y') - c_annot = go.Heatmap(z=df.vals, x=df.labels, y=df.y, showscale=False, colorscale=colors, xaxis='x2', yaxis='y2') - return r_annot, c_annot - elif row_annotation: - r_annot = go.Heatmap(z=df.vals, x=df.y, y=df.labels, showscale=False, colorscale=colors, xaxis='x2', yaxis='y2') - return r_annot - elif col_annotation: - c_annot = go.Heatmap(z=df.vals, x=df.labels, y=df.y, showscale=False, colorscale=colors, xaxis='x2', yaxis='y2') - return c_annot - - return None - - -def get_heatmap(df, colorscale=None, color_missing=True): - """ - This function plots a simple Plotly heatmap. - - :param df: pandas dataframe containing experimental data, with samples/subjects as rows and features as columns. - :param list[list] colorscale: heatmap colorscale (e.g. [[0,'#67a9cf'],[0.5,'#f7f7f7'],[1,'#ef8a62']]). If colorscale is not defined, will take [[0, 'rgb(255,255,255)'], [1, 'rgb(255,51,0)']] as default. - :param bool color_missing: if set to True, plots missing values as grey in the heatmap. - :return: Plotly object figure. - """ - figure = {} - if df is not None: - if colorscale: - colors = colorscale - else: - colors = [[0, 'rgb(255,255,255)'], [1, 'rgb(255,51,0)']] - - figure = {'layout': {'template': None}, 'data': []} - figure['layout']['template'] = 'plotly_white' - figure['data'].append(go.Heatmap(z=df.values.tolist(), y=list(df.index), x=list(df.columns), - colorscale=colors, showscale=True, - colorbar=dict(x=1, y=0, xanchor='left', yanchor='bottom', len=0.35, thickness=15))) - if color_missing: - df_missing = wgcnaAnalysis.get_miss_values_df(df) - figure['data'].append(go.Heatmap(z=df_missing.values.tolist(), - y=list(df.index), - x=list(df.columns), - colorscale=[[0, 'rgb(201,201,201)'], [1, 'rgb(201,201,201)']], - showscale=False)) - - return figure - - -def plot_labeled_heatmap(df, textmatrix, title, colorscale=[[0, 'rgb(0,255,0)'], [0.5, 'rgb(255,255,255)'], [1, 'rgb(255,0,0)']], width=1200, height=800, row_annotation=False, col_annotation=False): - """ - This function plots a simple Plotly heatmap with column and/or row annotations and heatmap annotations. - - :param df: pandas dataframe containing data to be plotted in the heatmap. - :param textmatrix: pandas dataframe with heatmap annotations as values. - :param str title: the title of the figure. - :param list[list] colorscale: heatmap colorscale (e.g. [[0,'rgb(0,255,0)'],[0.5,'rgb(255,255,255)'],[1,'rgb(255,0,0)']]) - :param int width: the width of the figure. - :param int height: the height of the figure. - :param bool row_annotation: if True, adds a color-coded column at the left of the heatmap. - :param bool col_annotation: if True, adds a color-coded row at the bottom of the heatmap. - :return: Plotly object figure. - """ - figure = {} - if df is not None: - figure = get_heatmap(df, colorscale=colorscale, color_missing=False) - figure['data'].append(get_module_color_annotation(list(df.index), row_annotation=row_annotation, col_annotation=col_annotation, bygene=False)) - - annotations = [] - for n, row in enumerate(textmatrix.values): - for m, val in enumerate(row): - annotations.append(go.layout.Annotation(text=str(textmatrix.values[n][m]), font=dict(size=8), - x=df.columns[m], y=df.index[n], xref='x', yref='y', showarrow=False)) - - layout = go.Layout(width=width, height=height, title=title, - xaxis=dict(domain=[0.015, 1], autorange=True, showgrid=False, zeroline=False, showline=False, ticks='', showticklabels=True, automargin=True, anchor='y'), - yaxis=dict(autorange='reversed', ticklen=5, ticks='outside', tickcolor='white', showticklabels=False, automargin=True, showgrid=False, anchor='x'), - xaxis2=dict(domain=[0, 0.01], autorange=True, showgrid=False, zeroline=False, showline=False, ticks='', showticklabels=False, automargin=True, anchor='y2'), - yaxis2=dict(autorange='reversed', showgrid=False, zeroline=False, showline=False, ticks='', showticklabels=True, automargin=True, anchor='x2')) - - figure['layout'] = layout - figure['layout']['template'] = 'plotly_white' - figure['layout'].update(annotations=annotations) - - - return figure - - -def plot_dendrogram_guidelines(Z_tree, dendrogram): - """ - This function takes a dendrogram tree dictionary and its plotly object and creates shapes to be plotted as vertical dashed lines in the dendrogram. - - :param dict Z_tree: dictionary of data structures computed to render the dendrogram. Keys: 'icoords', 'dcoords', 'ivl' and 'leaves'. - :param dendrogram: dendrogram represented as a plotly object figure. - :return: List of dictionaries. - """ - shapes = [] - if dendrogram is not None: - tickvals = list(dendrogram['layout']['xaxis']['tickvals']) - maximum = len(tickvals) - step = int(maximum/8) - minimum = int(0+step) - - keys = ['type', 'x0', 'y0', 'x1', 'y1', 'line'] - line_keys = ['color', 'width', 'dash'] - line_vals = ['rgb(192,192,192)', 0.1, 'dot'] - line = dict(zip(line_keys,line_vals)) - - values = [] - for i in tickvals[minimum::step]: - values.append(('line', i, 0.3, i, np.max(Z_tree['dcoord']))) - - values = [list(i)+[line] for i in values] - shapes = [] - for i in values: - d = dict(zip(keys, i)) - shapes.append(d) - - return shapes - - -def plot_intramodular_correlation(MM, FS, feature_module_df, title, width=1000, height=800): - """ - This function uses the Feature significance and Module Membership measures, and plots a multi-scatter plot of all modules against all clinical traits. - - :param MM: pandas dataframe with module membership data - :param FS: pandas dataframe with feature significance data - :param feature_module_df: pandas DataFrame of experimental features and module colors (use mode='dataframe' in get_FeaturesPerModule) - :param str title: plot title - :param int width: plot width - :param int height: plot height - :return: Plotly object figure. - - Example:: - - plot = plot_intramodular_correlation(MM, FS, feature_module_df, title='Plot', width=1000, height=800): - - .. note:: There is a limit in the number of subplots one can make in Plotly. This function limits the number of modules shown to 5. - """ - figure = {} - if MM is not None: - MM = MM.iloc[:, -6] - MM['modColor'] = MM.index.map(feature_module_df.set_index('name')['modColor'].get) - - figure = tools.make_subplots(rows=len(FS.columns), cols=len(MM.columns) - 1, shared_xaxes=False, shared_yaxes=False, vertical_spacing=0.015, horizontal_spacing=0.1, print_grid=True) - - figure.layout.template = 'plotly_white' - layout = dict(width=width, height=height, showlegend=False, title=title) - figure.layout.update(layout) - - axis_dict = {} - for i, j in enumerate(MM.columns[MM.columns.str.startswith('MM')]): - n_p = len(FS.columns) * (len(MM.columns)-1)-len(MM.columns[MM.columns.str.startswith('MM')]) - axis_dict['xaxis{}'.format(n_p+i+1)] = dict(title=j, titlefont=dict(size=13)) - print(axis_dict) - n = 1 - for a, b in enumerate(FS.columns): - name = b.split(' ') - if len(name) > 1: - label = ['
'.join(name[i:i+3]) for i in range(0, len(name), 3)][0] - else: - label = name[0] - axis_dict['yaxis{}'.format(a+n)] = dict(title=label, titlefont=dict(size=13)) - n += len(MM.columns[MM.columns.str.startswith('MM')])-1 - - annotation = [] - x_axis = 1 - y_axis = 1 - for a, b in enumerate(FS.columns): - for i, j in enumerate(MM.columns[MM.columns.str.startswith('MM')]): - name = MM[MM['modColor'] == j[2:]].index - x = abs(MM[MM['modColor'] == j[2:]][j].values) - y = abs(FS[FS.index.isin(name)][b].values) - - slope, intercept, r_value, p_value, std_err = scp.stats.linregress(x, y) - line = slope*x+intercept - - figure.append_trace(go.Scattergl(x = x, - y = y, - text = name, - mode = 'markers', - opacity=0.7, - marker={'size': 7, - 'color': 'white', - 'line': {'width': 1.5, 'color': j[2:]}}), a+1, i+1) - - figure.append_trace(go.Scattergl(x = x, y = line, mode = 'lines', marker={'color': 'black'}), a+1, i+1) - - annot = dict(x = 0.7, y = 0.7, - xref = 'x{}'.format(x_axis), yref = 'y{}'.format(y_axis), - text = 'R={:0.2}, p={:.0e}'.format(r_value, p_value), - showarrow = False) - annotation.append(annot) - x_axis += 1 - y_axis += 1 - - - figure.layout.update(axis_dict) - figure.layout.update(annotations = annotation) - - return figure - -def plot_complex_dendrogram(dendro_df, subplot_df, title, dendro_labels=[], distfun='euclidean', linkagefun='average', hang=0.04, subplot='module colors', subplot_colorscale=[], color_missingvals=True, row_annotation=False, col_annotation=False, width=1000, height=800): - """ - This function plots a dendrogram with a subplot below that can be a heatmap (annotated or not) or module colors. - - :param dendro_df: pandas dataframe containing data used to generate dendrogram, columns will result in dendrogram leaves. - :param subplot_df: pandas dataframe containing data used to generate plot below dendrogram. - :param str title: the title of the figure. - :param list dendro_labels: list of strings for dendrogram leaf nodes labels. - :param str distfun: distance measure to be used (‘euclidean‘, ‘maximum‘, ‘manhattan‘, ‘canberra‘, ‘binary‘, ‘minkowski‘ or ‘jaccard‘). - :param str linkagefun: hierarchical/agglomeration method to be used (‘single‘, ‘complete‘, ‘average‘, ‘weighted‘, ‘centroid‘, ‘median‘ or ‘ward‘). - :param float hang: height at which the dendrogram leaves should be placed. - :param str subplot: type of plot to be shown below the dendrogram (´module colors´ or ´heatmap´). - :param list subplot_colorscale: colorscale to be used in the subplot. - :param bool color_missingvals: if set to `True`, plots missing values as grey in the heatmap. - :param bool row_annotation: if `True`, adds a color-coded column at the left of the heatmap. - :param bool col_annotation: if `True`, adds a color-coded row at the bottom of the heatmap. - :param int width: the width of the figure. - :param int height: the height of the figure. - :return: Plotly object figure. - """ - figure = {} - dendro_tree = wgcnaAnalysis.get_dendrogram(dendro_df, dendro_labels, distfun=distfun, linkagefun=linkagefun, div_clusters=False) - if dendro_tree is not None: - dendrogram = Dendrogram.plot_dendrogram(dendro_tree, hang=hang, cutoff_line=False) - - layout = go.Layout(width=width, height=height, showlegend=False, title=title, - xaxis=dict(domain=[0, 1], range=[np.min(dendrogram['layout']['xaxis']['tickvals'])-6,np.max(dendrogram['layout']['xaxis']['tickvals'])+4], showgrid=False, - zeroline=True, ticks='', automargin=True, anchor='y'), - yaxis=dict(domain=[0.7, 1], autorange=True, showgrid=False, zeroline=False, ticks='outside', title='Height', automargin=True, anchor='x'), - xaxis2=dict(domain=[0, 1], autorange=True, showgrid=True, zeroline=False, ticks='', showticklabels=False, automargin=True, anchor='y2'), - yaxis2=dict(domain=[0, 0.64], autorange=True, showgrid=False, zeroline=False, automargin=True, anchor='x2')) - - - if subplot == 'module colors': - figure = tools.make_subplots(rows=2, cols=1, print_grid=False) - - for i in list(dendrogram['data']): - figure.append_trace(i, 1, 1) - - shapes = plot_dendrogram_guidelines(dendro_tree, dendrogram) - moduleColors = get_module_color_annotation(dendro_labels, col_annotation=col_annotation, bygene=True, module_colors=subplot_df, dendrogram=dendrogram) - figure.append_trace(moduleColors, 2, 1) - figure['layout'] = layout - figure.layout.template = 'plotly_white' - figure['layout'].update({'shapes':shapes, - 'xaxis':dict(showticklabels=False), - 'yaxis':dict(domain=[0.2, 1]), - 'yaxis2':dict(domain=[0, 0.19], title='Module colors', ticks='', showticklabels=False)}) - - - elif subplot == 'heatmap': - if all(list(subplot_df.columns.map(lambda x: subplot_df[x].between(-1,1, inclusive=True).all()))) != True: - df = wgcnaAnalysis.get_percentiles_heatmap(subplot_df, dendro_tree, bydendro=True, bycols=False).T - else: - df = wgcnaAnalysis.df_sort_by_dendrogram(wgcnaAnalysis.df_sort_by_dendrogram(subplot_df, dendro_tree).T, dendro_tree) - - heatmap = get_heatmap(df, colorscale=subplot_colorscale, color_missing=color_missingvals) - - - if row_annotation == True and col_annotation == True: - figure = tools.make_subplots(rows=3, cols=2, specs=[[{'colspan':2}, None], - [{}, {}], - [{'colspan':2}, None]], print_grid=False) - for i in list(dendrogram['data']): - figure.append_trace(i, 1, 1) - for j in list(heatmap['data']): - figure.append_trace(j, 2, 2) - - r_annot, c_annot = get_module_color_annotation(list(df.index), row_annotation=row_annotation, col_annotation=col_annotation, bygene=False) - figure.append_trace(r_annot, 2, 1) - figure.append_trace(c_annot, 3, 1) - - figure['layout'] = layout - figure.layout.template = 'plotly_white' - figure['layout'].update({'xaxis':dict(ticks='', showticklabels=False, anchor='y'), - 'xaxis2':dict(domain=[0, 0.01], ticks='', showticklabels=False, automargin=True, anchor='y2'), - 'xaxis3':dict(domain=[0.015, 1], ticks='', showticklabels=False, automargin=True, anchor='y3'), - 'xaxis4':dict(domain=[0.015, 1], ticks='', showticklabels=True, automargin=True, anchor='y4'), - 'yaxis':dict(domain=[0.635, 1], automargin=True, anchor='x'), - 'yaxis2':dict(domain=[0.015, 0.635], autorange='reversed', ticks='', showticklabels=True, automargin=True, anchor='x2'), - 'yaxis3':dict(domain=[0.01, 0.635], autorange='reversed', ticks='', showticklabels=False, automargin=True, anchor='x3'), - 'yaxis4':dict(domain=[0,0.01], ticks='', showticklabels=False, automargin=True, anchor='x4')}) - - - - elif row_annotation == False and col_annotation == False: - figure = tools.make_subplots(rows=2, cols=1, print_grid=False) - - for i in list(dendrogram['data']): - figure.append_trace(i, 1, 1) - for j in list(heatmap['data']): - figure.append_trace(j, 2, 1) - - figure['layout'] = layout - figure.layout.template = 'plotly_white' - figure.layout.update({'xaxis':dict(ticktext=np.array(dendrogram['layout']['xaxis']['ticktext']), tickvals=list(dendrogram['layout']['xaxis']['tickvals'])), - 'yaxis2':dict(autorange='reversed')}) - - elif row_annotation == True: - figure = tools.make_subplots(rows=2, cols=2, specs=[[{'colspan':2}, None], - [{}, {}]], print_grid=False) - for i in list(dendrogram['data']): - figure.append_trace(i, 1, 1) - for j in list(heatmap['data']): - figure.append_trace(j, 2, 2) - - r_annot = get_module_color_annotation(list(df.index), row_annotation=row_annotation, col_annotation=col_annotation, bygene=False) - figure.append_trace(r_annot, 2, 1) - - figure['layout'] = layout - figure.layout.template = 'plotly_white' - figure['layout'].update({'xaxis':dict(domain=[0.015, 1], ticktext=np.array(dendrogram['layout']['xaxis']['ticktext']), tickvals=list(dendrogram['layout']['xaxis']['tickvals']), automargin=True, anchor='y'), - 'xaxis2':dict(domain=[0, 0.010], ticks='', showticklabels=False, automargin=True, anchor='y2'), - 'xaxis3':dict(domain=[0.015, 1], ticks='', showticklabels=False, automargin=True, anchor='y3'), - 'yaxis':dict(automargin=True, anchor='x'), - 'yaxis2':dict(autorange='reversed', ticks='', showticklabels=True, automargin=True, anchor='x2'), - 'yaxis3':dict(domain=[0, 0.64], ticks='', showticklabels=False, automargin=True, anchor='x3')}) - - elif col_annotation == True: - figure = tools.make_subplots(rows=3, cols=1, specs=[[{}], [{}], [{}]], print_grid=False) - - for i in list(dendrogram['data']): - figure.append_trace(i, 1, 1) - for j in list(heatmap['data']): - figure.append_trace(j, 3, 1) - - c_annot = get_module_color_annotation(list(df.index), row_annotation=row_annotation, col_annotation=col_annotation, bygene=False) - figure.append_trace(c_annot, 2, 1) - - figure['layout'] = layout - figure.layout.template = 'plotly_white' - figure['layout'].update({'xaxis':dict(ticktext=np.array(dendrogram['layout']['xaxis']['ticktext']), tickvals=list(dendrogram['layout']['xaxis']['tickvals']), automargin=True, anchor='y'), - 'xaxis2':dict(ticks='', showticklabels=False, automargin=True, anchor='y2'), - 'xaxis3':dict(domain=[0, 1], ticks='', showticklabels=False, automargin=True, anchor='y3'), - 'yaxis':dict(domain=[0.70, 1], automargin=True, anchor='x'), - 'yaxis2':dict(domain=[0.615, 0.625], ticks='', showticklabels=False, automargin=True, anchor='x2'), - 'yaxis3':dict(domain=[0, 0.61], autorange='reversed', ticks='', showticklabels=False, automargin=True, anchor='x3')}) - - return figure