<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Motif-Search" data-toc-modified-id="Motif-Search-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Motif Search</a></span><ul class="toc-item"><li><span><a href="#Pre-run-MEME" data-toc-modified-id="Pre-run-MEME-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Pre-run MEME</a></span></li><li><span><a href="#Compare-motifs-to-RegulonDB-motifs" data-toc-modified-id="Compare-motifs-to-RegulonDB-motifs-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Compare motifs to RegulonDB motifs</a></span></li></ul></li><li><span><a href="#Helper-functions" data-toc-modified-id="Helper-functions-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Helper functions</a></span><ul class="toc-item"><li><span><a href="#Fonts" data-toc-modified-id="Fonts-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Fonts</a></span></li><li><span><a href="#Colors" data-toc-modified-id="Colors-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Colors</a></span></li><li><span><a href="#Fitting" data-toc-modified-id="Fitting-2.3"><span class="toc-item-num">2.3&nbsp;&nbsp;</span>Fitting</a></span></li><li><span><a href="#Misc" data-toc-modified-id="Misc-2.4"><span class="toc-item-num">2.4&nbsp;&nbsp;</span>Misc</a></span></li></ul></li><li><span><a href="#Plot-functions" data-toc-modified-id="Plot-functions-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Plot functions</a></span></li><li><span><a href="#Final-PDF-Maker" data-toc-modified-id="Final-PDF-Maker-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Final PDF Maker</a></span></li></ul></div>

In [1]:
import sys
from tqdm import tqdm_notebook as tqdm
import matplotlib.gridspec as gridspec
from PIL import Image

In [2]:
sys.path.append('../')
from icaviz.plotting import *

In [8]:
DATA_DIR = '../data/precise_data/'
GENE_DIR = '../data/annotation/'
enrich = pd.read_csv(DATA_DIR+'curated_enrichments.csv')
names = enrich['name'].tolist()

In [10]:
ica_data = load_data(X=DATA_DIR+'log_tpm.csv',
                     S=DATA_DIR+'S.csv',
                     A=DATA_DIR+'A.csv',
                     metadata=DATA_DIR+'metadata.csv',
                     annotation=GENE_DIR+'gene_info.csv',
                     trn=GENE_DIR+'TRN.csv',
                     fasta=GENE_DIR+'NC_000913.3.fasta',
                     cutoff = 550,
                     names=names)

In [12]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

# Set various parameters
plt.ioff()

sns.set_style('ticks')

params = {'mathtext.default': 'regular','font.family':'sans-serif',
          'font.sans-serif':['Arial'],
          'pdf.fonttype':42,'ps.fonttype':42,
          'xtick.labelsize':7,'xtick.major.size':3,'xtick.direction':'in',
          'ytick.labelsize':7,'ytick.major.size':3,'ytick.direction':'in'
          }
plt.rcParams.update(params)

# Motif Search

## Pre-run MEME

In case you need to re-run the directory, this pre-calculates all the information for MEME and TOMTOM so you only need to run this block once

In [None]:
# Force -> True if you are running this for the first time
force = False

In [None]:
sites = []
motifs = []

comp_names = enrich[enrich.Category!='Genomic Alterations'].name
# Run once using 300 nts upstream
for k in tqdm(comp_names):
    DF_motif,DF_sites = find_motifs(ica_data,k,nmotifs=1,verbose=False,evt=1e-3,
                                    force=force,upstream=300,downstream=0,maxw=30)
    
    # Save all found binding sites
    if len(DF_sites) > 0:
        tmp_sites = DF_sites[DF_sites.site_seq.notnull()]
        tmp_sites.index = pd.MultiIndex.from_tuples([(k, x[1]) for x in tmp_sites.index])
        sites.append(tmp_sites)

found_motifs1 = [k for k,v in ica_data.motif_info.items() if len(v[0])>0]

# Run again for remaining sequences with 600 nt upstream and 100 downstream
for k in tqdm(set(comp_names) - set(found_motifs1)):
    DF_motif,DF_sites = find_motifs(ica_data,k,nmotifs=1,verbose=False,evt=1e-3,
                                    force=force,upstream=600,downstream=100,maxw=30)
    
    # Save all found binding sites
    if len(DF_sites) > 0:
        tmp_sites = DF_sites[DF_sites.site_seq.notnull()]
        tmp_sites.index = pd.MultiIndex.from_tuples([(k, x[1]) for x in tmp_sites.index])
        sites.append(tmp_sites)
    
found_motifs2 = [k for k,v in ica_data.motif_info.items() if len(v[0])>0]

## Compare motifs to RegulonDB motifs

In [None]:
force = True

In [None]:
dfs = []
motif_db = '../data/annotation/regulonDB_MEME.txt'

for k in tqdm(found_motifs2):
    DF_motif,_ = find_motifs(ica_data,k,verbose=False,force=False)
    DF_motif.index = [k]
    DF_motif['motif_img'] = ['motifs/'+re.sub('/','_',k)+'/logo1.eps']
    _,compare_str,tomtom_img = compare_motifs(k,motif_db,evt=0.01,force=force)
    DF_motif['tomtom_str'] = [compare_str]
    DF_motif['tomtom_img'] = [tomtom_img]
    dfs.append(DF_motif)
        
DF_motif_final = pd.concat(dfs)

In [None]:
DF_motif_final.sort_values('e_value')

In [None]:
pd.concat(sites).to_csv('raw_motif_search.csv')
DF_motif_final.to_csv('raw_motif_comparison.csv')

# Helper functions

## Fonts

In [13]:
from fontTools.ttLib import TTFont
font = TTFont('/usr/share/fonts/truetype/msttcorefonts/arial.ttf')

In [None]:
font_cmap = font['cmap']
font_t = font_cmap.getcmap(3,1).cmap
font_s = font.getGlyphSet()
font_units_per_em = font['head'].unitsPerEm

In [None]:
from fontTools.ttLib import TTFont
from fontTools.ttLib.tables._c_m_a_p import CmapSubtable

def getTextWidth(text,pointSize):
    total = 0
    for c in text:
        if ord(c) in font_t and font_t[ord(c)] in font_s:
            total += font_s[font_t[ord(c)]].width
        else:
            total += font_s['.notdef'].width
    total = total*float(pointSize)/font_units_per_em;
    return (total/72,1)

In [None]:
def word_wrap(text, width, fontsize):
    '''
    Word wrap function / algorithm for wrapping text using proportional (versus 
    fixed-width) fonts.
    
    `text`: a string of text to wrap
    `width`: the width in pixels to wrap to
    `extent_func`: a function that returns a (w, h) tuple given any string, to
                   specify the size (text extent) of the string when rendered. 
                   the algorithm only uses the width.
    
    Returns a list of strings, one for each line after wrapping.
    '''
    extent_func = lambda x: getTextWidth(x,fontsize)
    lines = []
    pattern = re.compile(r'(\s+)')
    lookup = dict((c, extent_func(c)[0]) for c in set(text))
    for line in text.splitlines():
        tokens = pattern.split(line)
        tokens.append('')
        widths = [sum(lookup[c] for c in token) for token in tokens]
        start, total = 0, 0
        for index in range(0, len(tokens), 2):
            if total + widths[index] > width:
                end = index + 2 if index == start else index
                lines.append(''.join(tokens[start:end]))
                start, total = end, 0
                if end == index + 2:
                    continue
            total += widths[index] + widths[index + 1]
        if start < len(tokens):
            lines.append(''.join(tokens[start:]))
    lines = [line.strip() for line in lines]
    return lines or ['']

## Colors

In [None]:
def clamp(val, minimum=0, maximum=255):
    if val < minimum:
        return minimum
    if val > maximum:
        return maximum
    return val

def colorscale(hexstr, scalefactor):
    """
    Scales a hex string by ``scalefactor``. Returns scaled hex string.

    To darken the color, use a float value between 0 and 1.
    To brighten the color, use a float value greater than 1.

    >>> colorscale("#DF3C3C", .5)
    #6F1E1E
    >>> colorscale("#52D24F", 1.6)
    #83FF7E
    >>> colorscale("#4F75D2", 1)
    #4F75D2
    """

    hexstr = hexstr.strip('#')

    if scalefactor < 0 or len(hexstr) != 6:
        return hexstr

    r, g, b = int(hexstr[:2], 16), int(hexstr[2:4], 16), int(hexstr[4:], 16)

    r = clamp(r * scalefactor)
    g = clamp(g * scalefactor)
    b = clamp(b * scalefactor)

    return "#{:02x}{:02x}{:02x}".format(int(r), int(g), int(b))

In [None]:
rgb2hex = lambda x:'#{:02x}{:02x}{:02x}'.format(int(x[0]*255),int(x[1]*255),int(x[2]*255))

## Fitting

In [None]:
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score

def broken_line(x, A, B, C): # this is your 'straight line' y=f(x)
    y = np.zeros(len(x),dtype=np.float)
    y += (A*x+B) * (x >= C)
    y += (A*C + B) * (x < C)
    return y

def solid_line(x,A,B):
    y = (A*x+B)
    return y

def get_fit(x,y):
    
    def adj_r2(f,x,y,params):
        n = len(x)
        k = len(params)-1
        r2 = r2_score(y,f(x,*params))
        return 1 - np.true_divide((1-r2)*(n-1),(n-k-1))

    all_params = []
    for c in [min(x),np.mean(x),max(x)]:
        try:
            all_params.append(curve_fit(broken_line, x, y,p0=[1,1,c])[0])
        except:
            pass
        
    all_params.append(curve_fit(solid_line,x,y)[0])

    best_r2 = -np.inf
    for params in all_params:
        if len(params) == 2:
            r2 = adj_r2(solid_line,x,y,params)
        else:
            r2 = adj_r2(broken_line,x,y,params)
            
        if r2 > best_r2:
            best_r2 = r2
            best_params = params
            
    if best_r2 < 0:
        return [0,np.mean(y)],0

    return best_params,best_r2

## Misc

In [None]:
flatten = lambda l: [item for sublist in l for item in sublist]

In [None]:
def reformat_label(text):
    name,genes = text.split(':')
    new_genes = re.sub(r'( |\n)(.*?)(?:,|$)',r',\g<1>$\mathit{{\g<2>}}$',genes,flags=re.M)[2:]
    return r'{}: {}'.format(name,new_genes)

# Plot Functions

In [None]:
def plot_genes(ica_data,k,ax):
    cutoff = ica_data.thresholds[k]
    colors = [ica_data.gene_colors[gene] for gene in ica_data.S.index]
    
    # Draw scatterplot
    baseline = ica_data.X[['control__wt_glc__1','control__wt_glc__2']].mean(axis=1)
    scatter = ax.scatter(baseline.values,
                         ica_data.S[k].values,
                         c=colors,s=10,
                         alpha=0.7,linewidth=0.0)
    
    ax.set_xlabel('Baseline Expression (log-TPM)',fontsize=10)
    ax.set_ylabel('I-Modulon Gene Coefficients',fontsize=10)
 
    # Get axes bounds
    xmin,xmax = ax.get_xlim()
    ymin,ymax = ax.get_ylim()
            
    # Add labels on datapoints        
    component_genes = ica_data.show_enriched(k).index
    texts = []
    expand_args = {'expand_objects':(1.2,1.4),
                   'expand_points':(1.3,1.3)}
    
    ## Put gene name if component contains under 20 genes
    if len(component_genes) <= 25:
        for gene in component_genes:
            texts.append(ax.text(baseline[gene],
                                 ica_data.S.loc[gene,k],
                                 r'${}$'.format(ica_data.num2name[gene]),
                                 fontsize=7,fontstyle='italic'))
        expand_args['expand_text'] =(1.4,1.4)
                                 
    ## Repel texts from other text and points
    rect = ax.add_patch(Rectangle((xmin,-cutoff),xmax-xmin,2*cutoff,fill=False,
                     linewidth=0))
        
    adjust_text(texts,ax=ax,add_objects = [rect],
                arrowprops=dict(arrowstyle="-",color='k',lw=0.5),
                only_move={'objects':'y'},**expand_args)
    
    # Draw horizontal dashed lines
    ax.hlines([cutoff,-cutoff],xmin,xmax,colors='gray',linestyles='dashed',
              linewidth=0.5)
                                 

    # Add legend
    legend_info = []
    num_lines = []
    leg_fontsize = 8
    for name,group in ica_data.gene_info.groupby('cog'):
        # Get number of genes in COG
        cog_genes = group.index
        comp_cog_genes = set(ica_data.show_enriched(k).index) & set(cog_genes)

        if len(comp_cog_genes) > 0:
            gene_names = sorted([ica_data.num2name[gene] for gene in comp_cog_genes])
            
            # Create legend entry and wrap text
            entry = '{} ({:d}): {}'.format(name,len(gene_names),', '.join(gene_names))
            wrapped = word_wrap(entry,3,leg_fontsize)
            text = '\n'.join(wrapped)
            num_lines.append(len(wrapped))
            
            # Create legend patch and add to list
            legend_info.append((ica_data.cog_colors[name],text,len(gene_names),gene_names))

    ## Sort legend by number of genes, keeping unknown function and no COG annotation at bottom
    legend_info = sorted(legend_info,key=lambda x: (x[1][:10] not in ['Function u','No COG Ann'],x[2]),reverse=True)

    max_lines = 10
    ## If no legend entries, return
    if len(legend_info) == 0:
        return ax
    ## If over X legend lines, only print first X-1 plus "Other"
    elif sum(num_lines) > max_lines:
        max_entries = max(np.where(np.cumsum(num_lines) < max_lines)[0])
        other_genes = flatten([row[3] for row in legend_info[max_entries:]])
        num_genes = len(other_genes)
        if num_genes > 20:
            other_entry = word_wrap('Other ({:d}): {} +{:d}'.format(num_genes,', '.join(other_genes[:20]),num_genes-20),3.5,leg_fontsize)
        else:
            other_entry = word_wrap('Other ({:d}): {}'.format(num_genes,', '.join(other_genes)),3.5,leg_fontsize)
        other_text = '\n'.join(other_entry)
        legend_info = legend_info[:max_entries] + [('white',other_text,num_genes,other_genes)]

    ## Add legend entries to plot as empty    
    for info in legend_info:   
        ax.plot([],[],"o",color=info[0],markersize=5,label=reformat_label(info[1]))

    leg = ax.legend(loc='upper left', bbox_to_anchor=(1, 1.01),
                    fontsize=leg_fontsize,title='COG Categories')

    leg._legend_box.align = "left"
    leg.get_title().set_fontweight('bold')
    leg.get_title().set_fontsize(10)
    
    return ax

In [None]:
def plot_samples_bar(ica_data,k,ax):
    
    # Define sort function for conditions
    def keyfxn(x):
        name = x[0]
        match = re.match('^.+?ale(\d+)$',name)
        if match:
            return '{:02d}'.format(int(match.group(1)))
        elif name.startswith('wt') or name.startswith('glu4'):
            return '00'+name[3:]
        else:
            return name

    # Set order for display
    proj_order = ica_data.metadata.project_id.drop_duplicates()
    rename = {'control':'','fur':'Iron','acid':'Acid\n','oxidative':'Oxid.',
              'nac_ntrc':'Nac/\nNtrC','ompr':'OmpR\n','misc':'Misc','omics':'Omics',
              'minspan':'MinSpan','cra_crp':'Cra/\nCrp','rpoB':'RpoB\nKnock-in',
              'crp':'Crp\nARs','glu':'Glucose\nEvol.',
              '42c':'42C\nEvol.','ssw':'Substrate-\nswitching Evol.',
              'pgi':'PGI KO\nEvol.','ica':'Current\nStudy','ytf':'Unchar.\nTFs',
              'efeU':'Pseudogene\nRepair','pal':'Enz.\nPromisc.','fps':'False\nPos.'}
    
    
    list2struct = []
    for proj in proj_order:
        group1 = ica_data.metadata[ica_data.metadata.project_id == proj]
        for cond,group2 in sorted(group1.groupby('condition_id'),key=keyfxn):
            for name in group2.index:
                list2struct.append([proj,cond,ica_data.A.loc[k,name],len(group2)])
    DF_comp = pd.DataFrame(list2struct,columns = ['project_id','condition_id','value','length'])
    DF_comp['width'] = 1/DF_comp.length
    DF_comp.index = [0]+np.cumsum(DF_comp.width).tolist()[:-1]

    # Get xlabels and tick marks
    xticks = []
    xticklabels = []
    vlines = []
    proj_labels = []
    proj_locs = []
    for proj,group1 in DF_comp.groupby('project_id'):
        vlines.append(min(group1.index))
        proj_labels.append(rename[proj])
        proj_locs.append(min(group1.index)+group1.width.sum()/2)
        for cond,group2 in group1.groupby('condition_id'):
            xticks.append(np.mean(group2.index)+np.mean(group2.width)/2)
            xticklabels.append(cond)
    
    # Add alternating gray rectangles in back
    for i in np.arange(0,len(DF_comp),2):
        ax.add_patch(Rectangle((i,-100),1,200,color='#888888',alpha=0.1))
    
    
    # Draw bar chart
    ax.bar(DF_comp.index,DF_comp.value,width = DF_comp.width,align='edge',linewidth=0)
    
    # Set x axis parameters
    ax.set_xlim([0,len(DF_comp)])
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels,fontsize=4)
    ax.tick_params(axis='x',which='major',direction='in',length=4,labelrotation=90)
    ax.tick_params(axis='y',labelsize=7)
    ax.grid(False,which='major')
    
    # Add vertical bars to designate project
    ax.vlines(vlines,-100,100,colors='#888888',linewidth=1)
    ax.set_ylim(min(DF_comp.value)*1.2,max(DF_comp.value)*1.2)
    ax.set_xlim(DF_comp.index.min(),DF_comp.index.max()+DF_comp.width.iloc[-1])
    ax.set_ylabel('I-Modulon Activity',fontsize=6)


    # Add project labels
    ax2 = ax.twiny()
    ax2.grid(False)
    ax2.set_xlim(ax.get_xlim())
    ax2.set_xticks(proj_locs)
    ax2.set_xticklabels(proj_labels,fontsize=7)
    ax2.tick_params(axis='x',length=0)
    ax2.xaxis.tick_top()
    
    # Add top and bottom labels
    ax.set_xlabel('Condition\nID',x=-0.05,va='bottom',fontsize=9,labelpad=-5)
    ax2.set_xlabel('Study',x=-0.05,va='top',fontsize=9,labelpad=3)
    
    return ax

In [None]:
from functools import reduce

In [None]:
def plot_histogram(ica_data,k,tfs,tf_info,ax):
    pal = [rgb2hex(x) for x in sns.color_palette('Set1')]
    pal = [pal[1],pal[2],pal[0]]+pal[3:]
        
    DF_gene = ica_data.component_DF(k,tfs=tfs)
    bins = 20
    
    # Compute optimal range for histogram
    xmin = min(min(DF_gene.comp),-ica_data.thresholds[k])
    xmax = max(max(DF_gene.comp),ica_data.thresholds[k])
    width = 2*ica_data.thresholds[k]/(np.floor(2*ica_data.thresholds[k]*bins/(xmax-xmin)-1))
    xmin = -ica_data.thresholds[k]-width*np.ceil((-ica_data.thresholds[k] - xmin)/width)
    xmax = xmin + width*bins
        
    ax.set_yscale('log', nonposy='clip')
    ax.xaxis.grid(False)
    ax.set_xlabel('I-Modulon Gene Coefficients',fontsize=10)
    ax.set_ylabel('Count (log scale)',fontsize=10)
    
    # Plot histograms
    if len(tfs) == 0:
        label = 'No enriched\nregulator'
        color = '#888888'
        
        DF_gene['color'] = color
        
        ax.hist(DF_gene.comp,bins=bins,alpha=0.5,
                range=(xmin,xmax),label=label,color=color,linewidth=0)
        
    elif len(tfs) == 1:
        tf = tfs[0]
        
        # Plot unregulated genes
        unreg_genes = DF_gene[~DF_gene[tf]].index
        unreg_label = 'Unregulated\nby {}'.format(tf[0].upper()+tf[1:])
        unreg_color = '#aaaaaa'
        DF_gene.loc[unreg_genes,'color'] = unreg_color
        
        ax.hist(DF_gene.loc[unreg_genes].comp,alpha=0.5,bins=bins,
            range=(xmin,xmax),label=unreg_label,color=unreg_color,linewidth=0)
        
        # Plot regulated genes
        reg_genes = DF_gene[DF_gene[tf]].index
        reg_label = 'Regulated\nby {}'.format(tf[0].upper()+tf[1:])
        reg_color = pal[0]
        DF_gene.loc[reg_genes,'color'] = reg_color
        
        ax.hist(DF_gene.loc[reg_genes].comp,alpha=0.7,bins=bins,
            range=(xmin,xmax), label=reg_label,color=reg_color,linewidth=0)
        
    else:
        # Plot unregulated genes
        unreg_genes = DF_gene[~reduce(lambda x,y: (x | y),[DF_gene[tf] for tf in tfs])].index
        unreg_label = 'Unregulated\nby any'
        unreg_color = '#aaaaaa'
        DF_gene.loc[unreg_genes,'color'] = unreg_color
        
        ax.hist(DF_gene.loc[unreg_genes].comp,alpha=0.5,bins=bins,
            range=(xmin,xmax),label=unreg_label,color=unreg_color,linewidth=0)        
        
        # Plot multiple regulated genes
        multireg_genes = set()
        for tf1,tf2 in combinations(tfs,2):
            multireg_genes = multireg_genes.union(set(DF_gene[DF_gene[tf1] & DF_gene[tf2]].index))
        multireg_label = 'Regulated by both' if len(tfs) ==2 else 'Regulated\nby multiple'
        multireg_color = '#8B4513'
        DF_gene.loc[multireg_genes,'color'] = multireg_color
        
        # Get individually regulated genes
        reg_genes = []
        for i,tf in enumerate(tfs):
            reg_genes = set(DF_gene[DF_gene[tf]].index) - set(multireg_genes)
            reg_label = 'Regulated by\nonly {}'.format(tf[0].upper()+tf[1:])
            reg_color = pal[i]
            DF_gene.loc[reg_genes,'color'] = reg_color
            
            if len(reg_genes) > 0:
                ax.hist(DF_gene.loc[reg_genes].comp,alpha=0.7,bins=bins,
                    range=(xmin,xmax), label=reg_label,color=reg_color,linewidth=0) 
        
        # Plot multireg last
        ax.hist(DF_gene.loc[multireg_genes].comp,alpha=0.5,bins=bins,
            range=(xmin,xmax), label=multireg_label,color=multireg_color,linewidth=0)

    
    # Add vertical lines       
    ymin,ymax=ax.get_ylim()
    if len(tfs) > 0:
        labelstr = 'P-value = {:.0e}\nPrecision = {:.0f}%\nRecall = {:.0f}%'.format(tf_info.pvalue, \
                              tf_info.precision*100,tf_info.recall*100)
    else:
        labelstr = None
        
    ax.vlines([ica_data.thresholds[k],-ica_data.thresholds[k]],0,1000,linestyles='dashed',
               label=labelstr,linewidth=0.5)
        
    # Add gene names to each bin
    eps = 1e-5
    for x in np.arange(xmin,xmax,width):
        bin_genes = DF_gene[(x < DF_gene.comp) & 
                    (DF_gene.comp <= x+width)]
        # Only add names to bins outside the threshold
        if (x <= -ica_data.thresholds[k]-eps or x >= ica_data.thresholds[k]-eps) and len(bin_genes) > 0:
            y = max(bin_genes.color.value_counts())
            y *= 1.1
            
            # If more than X genes, keep only largest X-1 and include "+Y"
            max_genes1 = 9
            max_genes2 = 16
            if len(bin_genes) > max_genes2:
                leftover = bin_genes.reindex(abs(bin_genes.comp).sort_values()[:-7].index)
                bin_genes = bin_genes.reindex(abs(bin_genes.comp).sort_values()[-7:].index)
                bin_genes.loc['Other','gene_name'] = '+{:d}'.format(len(leftover))
                bin_genes.loc['Other','color'] = '#000000'
            elif len(bin_genes) > max_genes1:
                leftover = bin_genes.reindex(abs(bin_genes.comp).sort_values()[:-max_genes1].index)
                bin_genes = bin_genes.reindex(abs(bin_genes.comp).sort_values()[-max_genes1:].index)
                bin_genes.loc['Other','gene_name'] = '+{:d}'.format(len(leftover))
                bin_genes.loc['Other','color'] = '#000000'
            
            # Draw text
            for i,row in bin_genes.iterrows():
                ax.text(x+width/2,y,'${}$'.format(row['gene_name']),color=colorscale(row.color,.75),
                        ha='center',va='bottom',fontsize=7,fontweight='bold',fontstyle='italic')
                y *= 1.65
    
    # Draw legend
    legend = ax.legend(fontsize=7,frameon=True,ncol=(int(len(tfs)/3)+1))
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_linewidth(1)
    legend.get_frame().set_alpha(1)
    
    # Set y axis tick size
    ax.tick_params(axis='y',which='minor',length=3)
    ax.tick_params(axis='y',which='major',length=5)
    return ax

In [None]:
def fit_line(x,y,ax):
    params,r2 = get_fit(x,y)
    label = '$\mathit{{R^2_{{adj}}}}$ = {:.2f}'.format(r2)
    
    plot_params = {'c':'gray','linewidth':1,'label':label}
    
    if len(params) == 2:
        xlim = np.array([min(x),max(x)])
        ax.plot(xlim,solid_line(xlim,*params),
            **plot_params)
    else:
        xmin = min(x)
        xmax = max(x)
        mid = params[2]
        xvals = np.array([xmin,mid,xmax])
        ax.plot(xvals,broken_line(xvals,*params),
            **plot_params)
    return r2

In [None]:
def regulon_scatter(ica_data,k,tf,ax):
    # Separate KO experiments
    ko_cols = [exp for exp in ica_data.A.columns if 'del'+tf in exp \
               or 'del_'+tf in exp \
               or 'del'+tf.lower() in exp \
               or 'del_'+tf.lower() in exp]
    
    other_cols = set(ica_data.A.columns) - set(ko_cols) 

    
    # Make scatter plot only if TF expression is in our dataset
    if tf in ica_data.name2num.keys() and ica_data.name2num[tf] in ica_data.X.index:
        ax.scatter(ica_data.X.loc[ica_data.name2num[tf],other_cols],
                   ica_data.A.loc[k,other_cols],s=5,
                   label='Conditions')

        # Add colors to TF KO experiments
        if len(ko_cols) > 0:
            ax.scatter(ica_data.X.loc[ica_data.name2num[tf],ko_cols],
                       ica_data.A.loc[k,ko_cols],s=5,
                       label='{} KO'.format(tf))
        
        # Draw best fit line
        r2 = fit_line(ica_data.X.loc[ica_data.name2num[tf],other_cols],
                      ica_data.A.loc[k,other_cols],
                      ax)

        ax.set_ylabel('I-Modulon Activity',fontsize=10)
        ax.set_xlabel(r'$\mathit{{{}}}$ Expression Level'.format(tf),fontsize=10)
        
        # Draw and format legend
        legend = ax.legend(frameon=True,fontsize=6)
        legend.get_frame().set_facecolor('white')
        legend.get_frame().set_linewidth(1)
    else:
        ax.axis('off')
        
    return ax,r2

In [None]:
def regulon_venn(ica_data,k,tf,ax=None):
    if ica_data.regulon_mapping is None:
        raise ValueError('No TRN information found')
    
    if ax == None:
        fig,ax = plt.subplots(figsize=(5,5))
    
    # Take care of and/or enrichments
    if '+' in tf:
        reg_list = []
        for tfx in tf.split('+'):
            reg_list.append(set(ica_data.trn[ica_data.trn.TF==tfx].gene_id.unique()))
        reg_genes = set.intersection(*reg_list)
    elif '/' in tf:
        reg_genes = set(ica_data.trn[ica_data.trn.TF.isin(tf.split('/'))].gene_id.unique())
    else:
        reg_genes = set(ica_data.trn[ica_data.trn.TF==tf].gene_id.unique())
    
    # Get component genes and operons
    comp_genes = set(ica_data.show_enriched(k).index)
    
    reg_operons = len(ica_data.genes2operons(reg_genes-comp_genes))
    comp_operons = len(ica_data.genes2operons(comp_genes-reg_genes))
    both_operons = len(ica_data.genes2operons(reg_genes & comp_genes))
    
    # Draw venn diagram and resize texts
    venn = venn2((reg_genes,comp_genes),
                 set_labels=('Regulon\nGenes','I-Modulon\nGenes'),
                 ax=ax)
    for text in venn.set_labels:
        text.set_fontsize(10)
        if text.get_text() == u'Regulon\nGenes':
            text.set_color('darkred')
        else:
            text.set_color('darkgreen')
    [reg_venn,comp_venn,both_venn] = venn.subset_labels

    # Add operon numbers to labels
    comp_venn.set_fontsize(10)
    comp_venn.set_text(comp_venn.get_text()+'\n({:d})'.format(comp_operons))
    both_venn.set_fontsize(10)
    both_venn.set_text(both_venn.get_text()+'\n({:d})'.format(both_operons))
    reg_venn.set_fontsize(10)
    reg_venn.set_text(reg_venn.get_text()+'\n({:d})'.format(reg_operons))

    return ax

# Final PDF Maker

In [None]:
def set_axes(tfs,motif):
    
    # Set height ratios
    height = [50,15,
              20,15,
              40,10]
    
    if len(tfs) > 0:
        height += [40,10]
    if motif:
        height += [20,3]
        
    # Set TF scatterplots
    scatter_tfs = []
    for tf in tfs:
        if (tf in ica_data.trn.TF.unique() and \
            tf in ica_data.name2num.keys() and \
            ica_data.name2num[tf] in ica_data.X.index):
            scatter_tfs.append(tf)
    
    ## WARNING NOTE: HARDCODED EDGE CASE
    if tfs == ['csqR']:
        scatter_tfs = ['yihW']
        
    # Set width ratio
    width = [5,2,80,10,40,40]
    
    # Set up gridspec
    fig = plt.figure(figsize=(8.5,11))
    gs = gridspec.GridSpec(len(height),len(width),
                           height_ratios=height,width_ratios=width,
                           left=0.08,right=0.92,top=0.92,bottom=0.05,
                           wspace=0,hspace=0)
    gs.update()
    
    # Populate axis dict
    axes = {}
    axes['s'] = plt.subplot(gs[0, 2])
    axes['a'] = plt.subplot(gs[2, 1:])
    axes['hist'] = plt.subplot(gs[4,1:])
    
    
    if len(tfs) > 0:
        if len(scatter_tfs) == 1:
            axes['scatter'] = [(scatter_tfs[0],plt.subplot(gs[6,4:]))]
        elif len(scatter_tfs) > 1:
            ax2 = plt.subplot(gs[6,5])
            ax2.set_ylabel('')
            ax2.tick_params(axis='y',left=False,right=True,
                            labelleft=False,labelright=True)
            
            axes['scatter'] = [(scatter_tfs[0],plt.subplot(gs[6,4])),
                               (scatter_tfs[1],ax2)]
        
        axes['venn'] = plt.subplot(gs[6,1:3]) 
        
        
        if motif: 
            axes['motif'] = plt.subplot(gs[8,1:3])
            axes['tomtom'] = plt.subplot(gs[8:,4:])

    elif motif:
            axes['motif'] = plt.subplot(gs[6,1:3])
            axes['tomtom'] = plt.subplot(gs[6:,4:]) 
    
    return fig,axes

In [None]:
import warnings
from scipy.optimize import OptimizeWarning
warnings.simplefilter(action='ignore', category=OptimizeWarning)

In [None]:
unchar_locs = [i for i,row in enrich.sort_values('name').iterrows() if 'uncharacterized' in row['name']]
enrich_sort = enrich.loc[set(enrich.index) - set(unchar_locs)].sort_values('name')
enrich_sort = pd.concat([enrich_sort,enrich.loc[unchar_locs]])

In [None]:
import matplotlib
matplotlib.rcParams.update({'figure.max_open_warning': 0})

In [None]:
from matplotlib.backends.backend_pdf import PdfPages

r2_dict = {}

with PdfPages(DATA_DIR+'modulons.pdf') as pdf:
    for i,row in tqdm(enrich_sort.iterrows(),total=len(enrich)):
        k = row['name']
        # Get I-modulon Title
        
        if 'uncharacterized' in k:
            title = 'Uncharacterized I-Modulon #{}'.format(k[-1])
        else:
            title = r'${}$ I-Modulon'.format(k)
            
        ########################
        ## Get TF enrichments ##
        ########################
        
        if pd.isnull(row.TF):
            enrich_type = None
            tfs = []
        elif '+' in row.TF:
            enrich_type = lambda x,y: x and y
            tfs = row.TF.split('+')
        elif '/' in row.TF:
            enrich_type = lambda x,y: x or y
            tfs = row.TF.split('/')
        else:
            enrich_type = None
            tfs = [row.TF]
        
        # Check if there is an issue
        if any([tf not in ica_data.trn.TF.unique() for tf in tfs]):
            print('WARNING:',tf)
            tfs = []
        
        ######################
        ## Initialize Plots ##
        ######################
        
        fig,axes = set_axes(tfs,k in DF_motif_final.index)

        # Plot genes
        plot_genes(ica_data,k,ax=axes['s'])
 
        # Plot conditions
        plot_samples_bar(ica_data,k,axes['a'])
        
        # Plot histogram
        plot_histogram(ica_data,k,tfs,row,axes['hist'])
        
        # Plot venn diagram
        if 'venn' in axes.keys():
            regulon_venn(ica_data,k,row.TF,axes['venn'])

        # Plot scatter
        if 'scatter' in axes.keys():
            ylims = []
            for tf,ax in axes['scatter']:
                _,r2 = regulon_scatter(ica_data,k,tf,ax)
                r2_dict[(k,tf)] = r2
                ylims.append(ax.get_ylim())
            
            # Ensure dual y-axes have same ylims
            mins,maxes = zip(*ylims)
            for _,ax in axes['scatter']:
                ax.set_ylim(min(mins),max(maxes))
        
        # Plot motif
        if 'motif' in axes.keys():
            motif_ax = axes['motif']
            tomtom_ax = axes['tomtom']
            motif_ax.set_frame_on(False)
            motif_ax.set_yticklabels([])
            motif_ax.set_xticklabels([])
            motif_ax.tick_params(bottom=False,left=False)
            tomtom_ax.set_frame_on(False)
            tomtom_ax.set_yticklabels([])
            tomtom_ax.set_xticklabels([])
            tomtom_ax.tick_params(bottom=False,left=False)

            # Get motif
            motif_row = DF_motif_final.loc[k]
            im = plt.imread(motif_row.motif_img)
            motif_ax.imshow(im)

            xlabel_str = 'Motif E-value: {:.2e}\nOperons with Upstream Motif: {:.0f}%'
            xlabel = xlabel_str.format(motif_row.e_value,100*motif_row.motif_frac)
            motif_ax.set_xlabel(xlabel,rotation='horizontal',fontsize=10,
                          horizontalalignment='center',verticalalignment='top')            

            # Plot tomtom output if relevant
            if motif_row.tomtom_str != '':
                im = plt.imread(motif_row.tomtom_img)
                tomtom_ax.imshow(im)
                tomtom_ax.set_xlabel(motif_row.tomtom_str,fontsize=10)
            else:
                tomtom_ax.axis('off')
        
        fig.suptitle(title,fontsize=26,fontweight='bold',va='baseline',y=0.96)
        if pd.isnull(row.Regulator):
            sub_str = ''
        else:
            sub_str = 'Regulated by: {}\n'.format(row.Regulator)
        if pd.notnull(row.Function):
            sub_str += 'Biological Function: {}'.format(row.Function)
        
        fig.text(0.5,0.952,sub_str,ha='center',va='top',fontsize=12)
        pdf.savefig(fig,dpi=1200)