In [None]:
############################################################################################
### plot the overview plots with all KEGG hits and all genomes 
############################################################################################

def plotClusterOverviewTot(KEGGhitType):
    
    if KEGGhitType == 'all':
        dataframe = pd.read_pickle('Knumbers/brian/brian_genome_module_map_all')
        dataframe_col = pd.read_pickle('Knumbers/brian/brian_module_map_all_col')
    elif KEGGhitType == 'conservative':
        dataframe = pd.read_pickle('Knumbers/brian/brian_genome_module_map_conservative')
        dataframe_col = pd.read_pickle('Knumbers/brian/brian_module_map_cons_col')
    elif KEGGhitType == 'complete':
        dataframe = pd.read_pickle('Knumbers/brian/brian_genome_module_map_complete')
        dataframe_col = pd.read_pickle('Knumbers/brian/brian_module_map_comp_col')
    else:
        import sys
        sys.exit("aborted, enter correct df type")
        
            
    """
    0 = Energy metabolism
    1 = Carb and lipid metabolism
    2 = Nucleic acid and aa metabolism
    3 = Secondary metabolism
    4 = Genetic info proc.
    5 = Env. info proc.
    6 = Metabolism
    7 = Cellular processes
    8 = Gene set
    """
    # some stored dataframes
    metadf = pd.read_pickle('BrianDataSets/metadata_df')
    KEGGmoddf = pd.read_pickle('KEGG_modules_df') #(from KEGG database -check for updates)
    mod_df = pd.read_pickle('KEGG_module_list') #(from KEGG database -check for updates)
    completeness = KEGGhitType

    
    ############################################################################################
    # cluster correlation for KEGG B-level = environmental info processing
    ############################################################################################

    dataC = dataframe.fillna(0).corr().dropna(how='all').T.dropna(how='all')
    clustindexlist = dataC.index

    D = pdist(dataC, 'euclidean')
    Z = linkage(D, 'ward')
    row_link = _optimal_order(dataC, metric='correlation')

    optimal_Z = optimal_leaf_ordering(Z, D)

    cgClust = sns.clustermap(dataC, row_linkage=row_link, col_linkage=row_link,figsize=(20,40))


    ############################################################################################
    # KEGG-clust correlation for KEGG B-level = environmental info processing
    ############################################################################################

    #dataK = alt6.ix[idxblevdf['column'].values,:].T.corr() #np.random.choice(10000, (n, 1), replace=False)
    dataK = dataframe.fillna(0).T.corr()

    dataK = dataK.dropna(how='all').T.dropna(how='all')

    D = pdist(dataK, 'euclidean')
    Z = linkage(D, 'ward')
    row_link = _optimal_order(dataK, metric='correlation')

    optimal_Z = optimal_leaf_ordering(Z, D)

    cgKEGG_blevel = sns.clustermap(dataK, row_linkage=row_link, col_linkage=row_link,figsize=(10,10))
    roundd=0
    # sort by clustermap, 
    for blev in range(0,9):

        lut = {'Archaea':'r','Bacteria':'b','Unassigned':'k','Eukaryota':'c','Viruses':'m'}

        idx = cgKEGG_blevel.dendrogram_col.reordered_ind
        idxcl = cgClust.dendrogram_col.reordered_ind
        cmap=ListedColormap(["#e74c3c", "#3498db", "#2ecc71", "#95a5a6", "#34495e","#9b59b6"])
        cmap=ListedColormap(["#e74c3c", "#9b59b6", "#2ecc71", "#95a5a6", "#34495e","#3498db"])#[red,lightblue,green,GREY,darkblue,purple]
        #cmap = LinearSegmentedColormap.from_list('Custom', myColors, len(myColors))

        f = plt.figure()
        gs = gridspec.GridSpec(45,37)


        dataK2 = dataK.copy()
        dataK2.index = mod_df.set_index('Module').loc[dataK.index]['Module_combined']

        ax1 = f.add_subplot(gs[4:36,0:11])
        sns.heatmap(dataK2.iloc[idx,idx],cbar=False,ax=ax1,xticklabels=True,yticklabels=True,cmap="RdBu_r",vmin=-1,vmax=1)
        plt.xticks([]);plt.yticks(fontsize=15)
        plt.xlabel('KEGG modules',fontsize=30,labelpad=40)

        ###############################################
        ######### the sorter #####################
        ###############################################
        #idx2 = alt.sort_values('assembly').index
        ###############################################
        ###############################################
        # save dataframe in ordered way for future examination
        clusteredframe = dataframe.loc[dataK.iloc[idx,:].index,dataC.iloc[idxcl,:].index]
        if roundd == 0:
            clusteredframe.to_pickle('ClusteredDF_brian_'+completeness)
            roundd=1 #save only once for a loop through B-levels

        ax2 = f.add_subplot(gs[4:36,13:32])
        sns.heatmap(dataframe_col.fillna(0).loc[dataK.iloc[idx,:].index,dataC.iloc[idxcl,:].index]
                    ,ax=ax2,cbar=False,linewidth=0.5,cmap=cmap,vmin=-1,vmax=1,xticklabels=True)
        plt.yticks([])
        plt.xlabel('Genome number',fontsize=30)
        #plt.xticks(rotation=0)

        ax13 = f.add_subplot(gs[37:,13:32])
        sns.heatmap(dataC.iloc[idxcl,idxcl],cbar=False,ax=ax13,cmap="RdBu_r",vmin=-1,vmax=1)
        plt.yticks([])
        plt.xticks([])

        ##############################################################################
        #################### plot hbars on right #####################################
        ##############################################################################
        

        arc = dataframe.loc[dataK.iloc[idx,:].index,metadf[metadf['Domain']=='Archaea'].index].fillna(0).T.sum()
        bac = dataframe.loc[dataK.iloc[idx,:].index,metadf[metadf['Domain']=='Bacteria'].index].fillna(0).T.sum()
        #vir = dataframe.loc[dataK.iloc[idx,:].index,metadf[metadf['Domain']=='Viruses'].index].T.sum()
        una = dataframe.loc[dataK.iloc[idx,:].index,metadf[metadf['Domain']=='Unassigned'].index].fillna(0).T.sum()
            
        ax3 = f.add_subplot(gs[4:36,32]) #Archaea
        arc = arc/arc.sum()
        arc.plot.barh(ax=ax3,sharey=True, color="#e74c3c")
        ax3.grid(False)
        plt.title('Archaea',fontsize=25,rotation=45,y=1.02)
        plt.xticks([])
        #ax3.set_xlabel('Module-cluster\npresence (%)')
        plt.gca().invert_yaxis()

        ax4 = f.add_subplot(gs[4:36,33]) #Bacteria
        bac = bac/bac.sum()
        bac.plot.barh(ax=ax4,sharey=True, color='#34495e')
        ax4.grid(False)
        plt.title('Bacteria',fontsize=25,rotation=45,y=1.02)
        plt.xticks([])
        #ax3.set_xlabel('Module-cluster\npresence (%)')
        plt.gca().invert_yaxis()

        """ax5 = f.add_subplot(gs[4:36,34]) #Viruses
        vir = vir/vir.sum()
        vir.plot.barh(ax=ax5,sharey=True, color="#9b59b6")
        ax5.grid(False)
        plt.title('Viruses',fontsize=25)
        plt.xticks(rotation=90)
        #ax3.set_xlabel('Module-cluster\npresence (%)')
        plt.gca().invert_yaxis()"""

        ax6 = f.add_subplot(gs[4:36,34]) #Unassigned
        una = una/una.sum()
        una.plot.barh(ax=ax6,sharey=True, color="#3498db")
        ax6.grid(False)
        plt.title('Unassigned',fontsize=25,rotation=45,y=1.025)
        plt.xticks([])
        #ax3.set_xlabel('Module-cluster\npresence (%)')
        plt.gca().invert_yaxis()

        ax14 = f.add_subplot(gs[4:36,35:]) #Archaea-Bacteria    
        arc = arc/arc.max()
        bac = bac/bac.max()
        diff = arc-bac

        diff.plot.barh(ax=ax14,sharey=True, color="#e79f3c")
        ax14.grid(False)
        plt.title('Arch-Bact\ndifference',fontsize=25)
        #plt.xticks(rotation=90)
        plt.axvline(x=0);plt.xlim(-1,1)
        plt.gca().invert_yaxis()

        ##############################################################################
        ##########################   plot top bars   #################################
        ##############################################################################

        ax7 = f.add_subplot(gs[1,13:32]) #phylum
        
        col = sns.color_palette("cubehelix", len(list(metadf.loc[dataC.index]['Phylum'].unique())))
        keys = list(metadf.loc[dataC.index]['Phylum'].unique());values = [i for i in col]
        lut = dict(zip(keys,values))
        lut.update({'Unassigned':"#3498db"})
        color = metadf['Phylum'].map(lut)

        for x,y in lut.items():
            plt.bar(0,0,color=y,label=x,width=1)
        plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
                   ncol=6, mode="expand", borderaxespad=0.,fontsize=20)

        plt.bar(list(range(0,len(dataC))),np.ones(len(dataC)),color=color[dataC.iloc[idxcl,:].index],width=1)
        plt.xticks([]);plt.yticks([])
        plt.xlim(-.5,len(dataC)-.5);plt.ylim(0,1)
        ax7.set_ylabel('Phylum', rotation=0, fontsize=30, labelpad=90)



        #3498db
        ax8 = f.add_subplot(gs[2,13:32]) #domain
        lut = {'Archaea':"#e74c3c",'Bacteria':'#34495e','Unassigned':"#3498db"}#'Eukaryota':'c','Viruses':"#9b59b6"}
        color = metadf['Domain'].map(lut)

        for x,y in lut.items():
            plt.bar(0,0,color=y,label=x,width=1)
        plt.legend(bbox_to_anchor=(1.01, 0.95, 1., .102), loc=2,
                   ncol=2, borderaxespad=0.,fontsize=20)

        plt.bar(list(range(0,len(dataC))),np.ones(len(dataC)),color=color[dataC.iloc[idxcl,:].index],width=1)
        plt.xticks([]);plt.yticks([])
        plt.xlim(-.5,len(dataC)-.5);plt.ylim(0,1)
        ax8.set_ylabel('Domain', rotation=0, fontsize=30, labelpad=90)



        ax9 = f.add_subplot(gs[3,13:32]) #mean GC???
        
        #col = sns.color_palette("BrBG", len(alt['assembly'].unique()))
        #keys = np.sort(alt['assembly'].unique())
        #values = [i for i in col]
        #lut = dict(zip(keys,values))
        #color = alt['assembly'].map(lut)

        """for x,y in lut.items():
            plt.bar(0,0,color=y,label=x,width=1)
        plt.legend(bbox_to_anchor=(1.01, 0.95, 1., .102), loc=2,
                   ncol=3, borderaxespad=0.,fontsize=20)
        """
        #print(idxcl)
        listGC = list(metadf.T.loc['meanGC',dataC.iloc[idxcl,:].index])
        #print(listGC)
        sns.heatmap([listGC],ax=ax9,cbar=False,cmap="RdBu_r",linewidth=0.5,vmin=0,vmax=1,xticklabels=False)
        ax9.set_ylabel('GC-content', rotation=0, fontsize=30, labelpad=90)
        """plt.bar(list(range(0,len(dataC))),np.ones(len(dataC)),color=color[dataC.iloc[idxcl,:].index],width=1)
        plt.xlim(-.5,len(dataC)-.5);plt.ylim(0,1)
        plt.xticks([]);plt.yticks([])
        ax9.set_ylabel('Assembly', rotation=0, fontsize=30, labelpad=90)"""




        """ax10 = f.add_subplot(gs[3,13:32]) #experiment
        col = sns.color_palette("Set2", 5)
        if group == 'YES':
            lut = {'Obs2':col[0],'Obs3':col[1],'Obs4':col[2],'Obs5':col[3],'Obs6':col[4]}
            color2 = alt5['sample'].map(lut)
        elif group == 'SC':
            col = sns.color_palette("Set2", 6)
            lut = {'Obs2':col[0],'Obs3':col[1],'Obs4':col[2],'Obs5':col[3],'Obs6':col[4],'mix':col[5]}
            color2 = altSC['sample'].map(lut)
        else:
            lut = {'Obsidian2':col[0],'Obsidian3':col[1],'Obsidian4':col[2],'Obsidian5':col[3],'Obsidian6':col[4]}
            color2 = alt['sample'].map(lut)

        for x,y in lut.items():
            plt.bar(0,0,color=y,label=x,width=1)
        plt.legend(bbox_to_anchor=(1.01, 0.95, 1., .102), loc=2,
                   ncol=3, borderaxespad=0.,fontsize=20)

        plt.bar(list(range(0,len(dataC))),np.ones(len(dataC)),color=color2[dataC.iloc[idxcl,:].index],width=1)
        plt.xticks([]);plt.yticks([])
        plt.xlim(-.5,len(dataC)-.5);plt.ylim(0,1)
        ax10.set_ylabel('Experiment', rotation=0, fontsize=30, labelpad=100)"""


        ##############################################################################
        ##########################   plot left side bars  ############################
        ##############################################################################
        Blevel = blev

        """
        0 = Energy metabolism
        1 = Carb and lipid metabolism
        2 = Nucleic acid and aa metabolism
        3 = Secondary metabolism
        4 = Genetic info proc.
        5 = Env. info proc.
        6 = Metabolism
        7 = Cellular processes
        8 = Gene set
        """
        # B-level assignment
        ax11 = f.add_subplot(gs[4:36,11])
        col = sns.color_palette("gist_earth", len(KEGGmoddf.B.unique())-1)
        keys = list(KEGGmoddf.B.unique());keys = keys[0:-1];values = [i for i in col]
        keysC = KEGGmoddf[KEGGmoddf['B']==keys[Blevel]].C.unique();

        lut = dict(zip(keys,values))
        idxlist = dataK.iloc[idx,:].index
        #idxlist2 = [i[0:6] for i in idxlist]
        D_group = KEGGmoddf.groupby('D-module').sum()

        for i in D_group.index:
            a = D_group.loc[i,'B']
            for j in keys:
                if a.find(j)!=-1:
                    D_group.loc[i,'B'] = j
            c = D_group.loc[i,'C']
            for j in KEGGmoddf.C.unique():
                if c.find(j)!=-1:
                    D_group.loc[i,'C'] = j

        B_list  = D_group.loc[idxlist]['B']
        #B_list.replace(B_list.index,idxlist)
        color = B_list.map(lut)

        for x,y in lut.items():
            plt.bar(0,0,color=y,label=x,width=1)
        plt.legend(bbox_to_anchor=(-20, 1.02, 1., .102), loc=3,
                   ncol=1, mode="expand", borderaxespad=0.,fontsize=30)

        plt.barh(list(range(0,len(dataK))),list(np.ones(len(dataK))),color=color,height=1)
        plt.xticks([]);plt.yticks([])
        plt.ylim(-.5,len(dataK)-.5)
        plt.xlim(0,1)
        ax11.set_xlabel('B', rotation=0, fontsize=30, labelpad=40)
        plt.gca().invert_yaxis()


        #C-level assignment

        ax12 = f.add_subplot(gs[4:36,12])
        col = sns.color_palette("Set3", len(keysC));values = [i for i in col]
        lutC = dict(zip(keysC,values))

        C_list  = D_group.loc[idxlist]['C']
        C_list.replace(C_list.index,idxlist)
        color = C_list.map(lutC)
        color = color.fillna('None')

        plt.barh(list(range(0,len(dataK))),list(np.ones(len(dataK))),color=color,height=1)
        plt.xticks([]);plt.yticks([])
        plt.ylim(-.5,len(dataK)-.5)
        plt.xlim(0,1)
        ax12.set_xlabel('C', rotation=0, fontsize=30, labelpad=40)
        plt.gca().invert_yaxis()

        for x,y in lutC.items():
            plt.bar(0,0,color=y,label=x,width=1)
        plt.legend(bbox_to_anchor=(-10, 1.02, 1., .102), loc=3,
                   ncol=1, mode="expand", borderaxespad=0.,fontsize=30)

        idx_blevel5 = C_list.loc[color[color!='None'].index].index
        idxblevdf = pd.DataFrame(idx_blevel5,columns=['column'])
        idxblevdf.to_pickle('ObsCom_indexblevel_'+completeness+'_'+str(Blevel))

        """ax5 = f.add_subplot(gs[0,:])
        ax5.text(0.5,0.5,'bla',fontsize=28,horizontalalignment='center',verticalalignment='center')
        ax5.grid(False)
        plt.yticks([])
        plt.xticks([])"""

        gs.update(wspace=.05,hspace=.05)
        plt.gcf().subplots_adjust(left=0.3)
        f.set_figheight(110)
        #plt.show()
        f.set_figwidth(60)
        f.savefig('BrianDataSets/figures/heatmaps/Overview_heatmap_KEGG_'+completeness+'_'+keys[Blevel])
        f.savefig('BrianDataSets/figures/heatmaps/pdfs/Overview_heatmap_KEGG_'+completeness+'_'+keys[Blevel]+'.pdf')
        f.clf()