In [None]:
import sys
import re
import io
import mygene
import base64
import warnings
import logging
import requests
import urllib.request as url
import numpy as np
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from hublib.ui import Download
from bokeh.io import show, output_notebook
from bokeh.models import ColorBar, ColumnDataSource, CategoricalColorMapper
from bokeh.plotting import figure
from bokeh.transform import transform
import bokeh.palettes

output_notebook(hide_banner=True)
sys.path.append('..')
warnings.simplefilter("ignore", UserWarning)
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

api_url="https://www.alliancegenome.org/api/gene" #to this, add "gene_Id/phenotypes"

header ="""\
<div class="app-sidebar">
Submit a list of genes MGI or gene symbol. The program will create a table showing for each gene the list of related phenotypes.
</div>
"""

In [None]:
class App:
    
    def __init__(self):
        self._plot_container = widgets.Output()
        self._table_container = widgets.Output()
        self._load_button, load_box = self._create_load_button()
        self._reset_button = self._create_reset_button()
        
        self.container = widgets.VBox([load_box,widgets.HBox([self._load_button,self._reset_button])],layout=widgets.Layout(align_items='center'))
        self.down_container = widgets.HBox([], layout = widgets.Layout(align_items='flex-end', flex='3 0 auto'))
        self.slide_container = widgets.HBox([], layout = widgets.Layout(align_items='flex-end', flex='3 0 auto'))
        self.final_container = widgets.VBox([
            widgets.HTML(('<h1><em><b>Phenotype retrieving</b></em></h1>'), 
                         layout=widgets.Layout(margin='0 0 5em 0')),
            widgets.HTML(header,layout=widgets.Layout(align_items='center')),
            self.container,self._table_container,self.down_container,self.slide_container,self._plot_container], layout = widgets.Layout(align_items='center', flex='3 0 auto'))
        
        
    def _create_load_button(self):
        load_label = widgets.Label('List of genes:')
        self.list_area = widgets.Textarea(placeholder="E.g: Pax6,MGI:00001,...")
        load_b = widgets.Button(description="Submit", tooltip="submit your gene list", disabled=True)
        load_b.on_click(self._on_change_l)
        self.list_area.observe(self._on_change_text, names =["value"])
        sub_box = widgets.HBox([load_label,self.list_area])
        return (load_b, sub_box)
    
    def _create_reset_button(self):
        reset_b = widgets.Button(description="Reset", tooltip="reset the environment to generate a new table", disabled = True)
        reset_b.on_click(self._on_click_r)
        return (reset_b)
    
    #OLD FUNCTIONS, HERE ONLY FOR REFERENCE (TO BE ROMOVED BEFORE RELEASE)
    #def _create_down_button(self):
        #down_b = Download("pheno_table.csv", label='Download', icon='fa-cloud-download', tooltip='Download table in csv format').w
        #return(down_b)
        
    #def _create_down_button(self):
        #data = open("pheno_table.csv", "rb").read()
        #b64 = base64.b64encode(data)
        #payload = b64.decode()
        #down_b = Download(f"data:text/csv;base64,{payload}", label='Download', icon='fa-cloud-download', tooltip='Download table in csv format').w
        #return(down_b)    
        
    
    def _on_click_r(self, _):    
        self.data = []
        del(self.gene_dataframe,self.csv_file,self.gen_matr)
        remove = self.slide_container.children[-1]
        self.slide_container.children = self.slide_container.children[:-1]
        remove.close()
        remove = self.down_container.children[-1]
        self.down_container.children = self.down_container.children[:-1]
        remove.close()
        self.list_area.value=""
        self._reset_button.disabled=True
        self._load_button.disabled=True
        self._table_container.clear_output(wait=False)
        self._plot_container.clear_output(wait=False)
        
    def _on_change_text(self,_):
        if (self.list_area.value == "") | (self.list_area.value.isspace()):
            self._load_button.disabled=True
        else:
            self._load_button.disabled=False
            
    def _on_change_l(self,_):
        with self._table_container:
            loading = widgets.HTML(value='<i class=\"fa fa-spinner fa-spin fa-5x fa-fw\" style="color: #f09f14;"></i>')
            display(loading)
            self._load_button.disabled=True
            self.data = self.map_gene()
            self.gene_dataframe = pd.DataFrame(columns=["gene","phenotypes"])
            genes = []
            pheno_list = []
            for i in self.data:
                pheno = []
                genes.append(i)
                url = f"{api_url}/{i}/phenotypes"
                json = requests.get(url).json()
                for j in json["results"]:
                    pheno.append(j["phenotype"])
                pheno_list.append(pheno)
            self.gene_dataframe["gene"] = genes
            self.gene_dataframe["phenotypes"] = pheno_list
            del(genes,pheno_list,pheno)
            
            self.gen_matr = self.create_matrix()
            self._table_container.clear_output(wait=False)
            display(HTML(self.gene_dataframe.to_html(justify="left", index=False)))
            self.csv_file = self.gene_dataframe.to_csv("pheno_table.csv",index=False)
            self._reset_button.disabled=False
            htmlWidget = widgets.HTML(value="")
            #self._down_button = self._create_down_button()
            self.create_download_link("pheno_table.csv", htmlWidget)
            self.down_container.children = tuple(list(self.down_container.children) + [htmlWidget])
            self._slider, self.slider_box = self._create_slider()
            self.slide_container.children = tuple(list(self.slide_container.children) + [self.slider_box])
            self._update_app()
            #self.down_container.children = tuple(list(self.down_container.children) + [self._down_button])
    
    def create_matrix(self):
        df = pd.DataFrame(columns=self.gene_dataframe.gene.values,index=self.gene_dataframe.gene.values)
        
        for i in df.index:
            x = self.gene_dataframe.where(self.gene_dataframe.gene == i).dropna().phenotypes.values[0]
            for j in df.columns:
                y = self.gene_dataframe.where(self.gene_dataframe.gene == j).dropna().phenotypes.values[0]
                df.loc[i,j] = len(np.intersect1d(x,y))
            del(x)
            df.columns.name="gene_col"
            df.index.name="gene_ind"
        return df
    
    def _create_slider(self):
        slider_label = widgets.Label('Threshold: ')
        slider = widgets.IntSlider(value=0, min=0, max = 0, step=1, orintation='horizontal', readout=True, readout_format="d")
        slider.observe(self._on_change, names=['value'])
        slider_box = widgets.HBox([slider_label,slider])
        return (slider, slider_box) 
    
    def _on_change(self, _):
        self._update_app()
    
    def _update_app(self):
        
        self._slider.max = max(self.gen_matr.max())
        threshold = self._slider.value
        try:
            with self._plot_container:
                p = self._create_plot(threshold)
                self._reset_button.disabled=False
                self._plot_container.clear_output(wait=True)
                show(p, notebook_handle=True)
        except (NameError,AttributeError) as e:
            pass
    
    def _create_plot(self, threshold):
        tem_mat= self.gen_matr.copy()
        rem=[]
        if threshold != 0:
            for i in tem_mat.index:
                if tem_mat.loc[i].max() < threshold:
                    rem.append(i)
            tem_mat.drop(rem,inplace=True,axis=0)
            tem_mat.drop(rem,inplace=True,axis=1)

        #Create a custom palette and add a specific mapper to map color with values, we are converting them to strings to create a categorical color mapper to include only the
        #values that we have in the matrix and retrieve a better representation
        
        tmp = tem_mat.stack(dropna=False).rename("value").reset_index()
        fact= tmp.value.unique()
        fact.sort()
        fact = fact.astype(str)
        tmp.value = tmp.value.astype(str)

        mapper = CategoricalColorMapper(palette=bokeh.palettes.inferno(len(tmp.value.unique())), factors= fact, nan_color = 'gray')

        #Define a figure
        p = figure(
            width=1280,
            height=800,
            x_range=list(tmp.gene_col.drop_duplicates()),
            y_range=list(tmp.gene_ind.drop_duplicates()[::-1]),
            tooltips=[('common phenotypes: ','@value')],
            x_axis_location="above",
            output_backend="webgl",
            toolbar_location="right",
            tools="pan,wheel_zoom,box_zoom,reset,save")

        #Create rectangles for heatmap
        p.rect(
            x="gene_col",
            y="gene_ind",
            width=1,
            height=1,
            source=ColumnDataSource(tmp),
            fill_color=transform('value', mapper))
        p.xaxis.major_label_orientation = 45

        #Add legend
        color_bar = ColorBar(
        color_mapper=mapper,
        label_standoff=6,
        border_line_color=None)
        p.add_layout(color_bar, 'right')
        return(p)
    
    def map_gene(self):
        tmp = [item.strip() for item in re.split(r',|,\s|;|;\s|\n|\t|\s',self.list_area.value)]
        final_list = []
        sym_list = []
        for i in tmp:
            if 'MGI:' in i:
                final_list.append(i)
            else:
                sym_list.append(i)
        del (i)

        if len(sym_list) != 0:
            # symbol for symbols, mgi for MGI : https://docs.mygene.info/en/latest/doc/query_service.html#available-fields
            mg = mygene.MyGeneInfo()
            ginfo = mg.querymany(sym_list, scopes='symbol', fields="symbol,MGI", species='mouse')
            empty = True
            discarded = []
            for i in ginfo:
                try:
                    final_list.append(i['MGI'])
                    empty = False
                except KeyError:
                    discarded.append(i['query'])
            if empty and len(final_list) == 0:
                stop_err("Error: it was not possible to map the input.")
            elif empty:
                print("Warning: it was not possible to map any of the symbol ids. Only MGI ids will be used.")
            elif len(discarded) != 0:
                print("Warning: it was not possible to map these elements: " + ",".join(discarded) + "\n")

        return(final_list)
    
    def create_download_link(self, filename, htmlWidget):  
        title="Click here to download the table in csv format"
        
        data = open(filename, "rb").read()
        b64 = base64.b64encode(data)
        payload = b64.decode()
        
        html = '<a download="{filename}" href="data:text/csv;base64,{payload}" target="_blank"><button class="button-style">Download table as csv file</button></a>'
        htmlWidget.value = html.format(payload=payload,title=title,filename=filename)
        
        
    
        
    
app = App()
app.final_container    