In [None]:
# podaj nazwe pliku z dopasowaniem w formacie .fas np. moja_nazwa.fas
my_aln_name = "RbcS_aln.fas"
# podaj nazwe pliku z drzewem w formacie .nwk np. moja_nazwa.nwk
my_tree_name = "RbcS_tree4.nwk"
# podaj nazwe pliku z .csv np. moja_nazwa.csv
my_csv_name = "RbcS.csv"
# podaj nazwe pliku z taksonomami w formacie .xlsx np. moja_nazwa.xlsx
my_taxo_name = "RbcS_uniq.xlsx"
# ustaw threshold
threshold = 1.08


# ustaw treshold dla specyficznej grupy organizmów
threshold_spec = 1.03
# podaj nazwę dla specyficznej grupy organizmów
group = "Dwuliścienne"

# ustaw treshold dla specyficznej grupy organizmów
threshold_spec2 = 1.04
# podaj nazwę dla specyficznej grupy organizmów
group2 = "Jednoliścienne"

# ustaw stopień pokrycia sekwencji konserwatywnych
all_conserv = 0.95


In [None]:
taxo_list = ["Bacteria", "brak", "Sar", "Haptista", "Cryptophyceae", "Glaucocystophyceae", "Metazoa", "Rhodophyta", "Chlorophyta", "Marchantiophyta", "Bryophyta", "Polypodiopsida", "Gymnospermae", "Acrogymnospermae", "eudicotyledons", "Magnoliopsida", ""]
taxo_examples = {"Nagonasienne":"Brown", "Jednoliścienne":"DarkGreen", "Dwuliścienne":"SpringGreen", "Sar":"Goldenrod", "Rodofity":"Red", "Bakterie":"Blue", "Mchy":"Olive", "Amborella trichopoda":"DarkBlue", "brak":"black", "Haptista":"Orange", "Krypromonady": "DarkOragne", "Archaeplastida": "Pink", "Zielenice": "SeaGreen", "Paprocie": "Lime", "Zwierzęta": "Brown"}
color_dict = {"-": "black", "A": "yellow", "G": "grey", "R": "red", "K": "red", "T": "grey", "S": "grey", "E": "blue", "D": "blue", "L": "grey", "I": "grey", "V": "grey", "N": "pink", "Q": "pink", "F": "yellow", "Y": "grey", "M": "grey", "C": "grey", "P": "grey", "H": "grey", "W": "yellow"}

In [None]:
import os
try:
    !pip install --upgrade pip
    !pip install biopython
    os.system("pip install ete3 pyqt5")
    !pip install ete3
    os.environ['QT_QPA_PLATFORM']='offscreen'
except ImportError:
    pass

from Bio import Entrez
Entrez.email = 'A.N.Other@gmail.com'


import pandas as pd
pd.set_option("max_colwidth", None)
pd.set_option('display.max_rows', None)
from __future__ import print_function
import numpy as np
import time
import re
import matplotlib.pyplot as plt
import ete3 as ete
from ete3 import PhyloTree, Tree,  NCBITaxa, TreeNode, NodeStyle, TreeStyle, faces, AttrFace, TextFace, CircleFace
import seaborn as sns
import matplotlib.patches as patches
from scipy import stats
from scipy.stats import mode
import zipfile
import io
from random import randint
import random
from PIL import Image
import PIL
# from google.colab import drive
# drive.mount('/content/drive')


In [None]:
class Tree():
  """Making a tree visualization."""
  def __init__(self, my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv):
    """Initialization of the attributes needed to create a tree, alignment visualization and heatmap."""
    self.my_aln_name = my_aln_name
    self.my_csv_name = my_csv_name
    self.my_tree_name = my_tree_name
    self.my_taxo_name = my_taxo_name
    self.threshold = threshold
    self.threshold_spec = threshold_spec
    self.group = group
    self.all_conserv = all_conserv
  
  def preparing_taxo_dict(self):
    """Preparation of a dictionary with taxonomy.
    Some taxonomies are not in the Taxonomy database and need to be added manually"""
    
    RbcS_uniq = pd.read_excel(self.my_taxo_name)
    RbcS_dict = {}
    for i in range(len(RbcS_uniq)):
        name = RbcS_uniq["Organism"][i]
        taxo = RbcS_uniq["Taxo"][i]
        RbcS_dict[name] = taxo
    RbcS_dict["Thermosynechococcus elongatus"] = "Bakterie"
    RbcS_dict["Chlamydomonas reinhardtii"] = "Zielenice"
    RbcS_dict["Pyropia suborbiculata"] = "Rodofity"
    RbcS_dict["Lactuca sativa"] = "Dwuliścienne"
    RbcS_dict["Amborella trichopoda"] = "Amborella trichopoda"
    # RbcS_dict["Arabidopsis thaliana"] = "Arabidopsis thaliana"
    return RbcS_dict

  def permut_ID(self, list_ID):
    df = pd.read_csv(self.my_csv_name, index_col=0)
    """Permutation of elements of list"""
    list_parr = []
    for ID_1 in range(0, len(list_ID)):
      for ID_2 in range(ID_1, len(list_ID)):
        if ID_1 != ID_2:
          list_parr.append(df[list_ID[ID_1]][list_ID[ID_2]])
    return list_parr

  def name_internal_nodes2(self, T):
    """Add style to leaves and name a internal nodes"""
    RbcS_dict = self.preparing_taxo_dict()
    for node in T.traverse():
      if node.is_leaf()==False:
        leaf_names=[leaf.name for leaf in node.iter_leaves()]
        node.name = ".".join(leaf_names)
      else:
        name = node.name
        prot_name = name.replace("_"," ")
        prot_name = (re.sub("\d+", "", prot_name))[0:-1]
        face = TextFace(name, fsize=40, fgcolor = taxo_examples[RbcS_dict[prot_name]])
        node.add_face(face, column=1)
    
    return None

  def remove_subsets(self, lista):
    """Removing a subsets of clads"""
    lista_list = [set(sub_list) for sub_list in lista]
    set_to_remove = []
    for set1 in lista_list:
      for set2 in lista_list:
        if set1 != set2 and set1.issubset(set2):
          set_to_remove.append(set1)
    
    for list_rm in set_to_remove:
      if list_rm in lista_list:
        lista_list.remove(list_rm)
    return lista_list

  def list_of_names(self, names, list_name):
    new_names = []
    name_count = {}
    for name in names:
      list_dupl = []
      numb = list_name.count(name)
      for n in range(1, numb+1):
        list_dupl.append(name + str(n))
      name_count[name] = list_dupl

    for name in list_name:

      if len(name_count[name]) > 1:
        new_name = name_count[name][0]
        new_names.append(new_name)
        name_count[name].remove(new_name)
      else:
        new_names.append(name)
    return new_names

  def get_example_tree(self, tree):
    """Add style to clads"""
    RbcS_dict = self.preparing_taxo_dict()
    node_list = []
    for node in tree.traverse():
      node_name = node.name.split(".")
      first_name = (re.sub(r"\d+", "", node_name[int(len(node.name.split("."))/2)])).replace("_"," ")[:-1]
      if node.is_leaf() == False:
        # print(RbcS_dict[first_name])
        if RbcS_dict[first_name] == self.group:
          if np.nanstd(np.array(self.permut_ID(node_name))) < self.threshold_spec:
            node_list.append(set(node_name))
        elif RbcS_dict[first_name] == group2:
          if np.nanstd(np.array(self.permut_ID(node_name))) < threshold_spec2:
            node_list.append(set(node_name))
        else:
          if np.nanstd(np.array(self.permut_ID(node_name))) < self.threshold:
            node_list.append(set(node_name))
        
    # usówanie podzbiorów
    list_clad = self.remove_subsets(node_list)
    # dodawanie lisci które sie nie załapały do żadnego kladu jako osobnych kladów 1-elementowych
    organ_in_list = []
    for lista in list_clad:
      for organ in lista:
        organ_in_list.append(organ)
    for node in tree.traverse():
      if node.is_leaf() == True and node.name not in organ_in_list:
        list_clad.append([node.name])


    print(f"Liczba kladów: {len(list_clad)}")
    print(list_clad)
    RbcS_dict = self.preparing_taxo_dict()
    list_of_names = []
    for node in tree.traverse():
      if set(node.name.split(".")) in list_clad or [node.name] in list_clad:
        # r = lambda: random.randint(200,255)
        # color = '#%02X%02X%02X' % (r(),r(),r())
        r = random.randint(50,200)
        color = "#"
        for i in range(3):
          color += '%02X' % r
        style1 = NodeStyle()
        style1["shape"] = "sphere"
        style1["size"] = 5
        style1["fgcolor"] = "darkred"
        style1["bgcolor"] = color
        node.set_style(style1)

        name = RbcS_dict[(re.sub(r"\d+", "", node.name.split(".")[int(len(node.name.split("."))/2)] )).replace("_"," ")[:-1]]

        if name in list_of_names:
          new_list = [re.sub("\d+","", x) for x in list_of_names]
          count = new_list.count(name)
          name += str(count)
        face = TextFace(name, fsize=40, fgcolor = "red")
        node.add_face(face, column=0)
        list_of_names.append(name)
        
    return list_clad
            
  def make_tree(self):
    """Creating a final tree"""
    tree_file = open(self.my_tree_name, "r")
    tree_file = tree_file.read()
    tree = ete.Tree(tree_file)
    self.name_internal_nodes2(tree)
    node_list = self.get_example_tree(tree)

    new_node_list = [list(set1) for set1 in node_list]

    style = ete.TreeStyle()
    style.allow_face_overlap = True
    style.branch_vertical_margin = 20

    style.show_leaf_name = False
    style.show_branch_length = True
    style.show_scale = True
    style.tree_width = 200
    style.mode = "c"
    style.draw_guiding_lines = True
    return new_node_list, tree, style

class Heatmap(Tree):
  """Making a heatmap visualization."""
  def __init__(self, my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv, clad_list):
    """Initialization of the attributes needed to create heatmaps."""
    super().__init__(my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv)
    self.clad_list = clad_list
    # self.my_file_name = my_file_name

  def make_aln_list(self):
    """Create a dictionary with alignment matching the names of organisms and a list with names"""
    with open(self.my_tree_name, 'r+') as tree_file:
      tree_file = tree_file.read()
      regex = r"\w+_\w+_\d+"
      matches = re.findall(regex, tree_file)
      matches = [match.replace(" ", "_") for match in matches]
    dict_aln = {}
    with open(self.my_aln_name, 'r+') as f:
      for line in f:
        if line[0] == ">":
          try:
            organ = re.findall(r"\w+_\w+_\d+", line)[0]
            organ = organ.replace(" ", "_")
          except:
            organ = ""
        else:
          dict_aln[organ] = str(line[0:-1])
    return dict_aln, matches

  def make_df(self):
      """Prepare csv file for heatmap creation. values 0 are replaced with mean values. Log matrix"""
      dict_aln, matches = self.make_aln_list()
      df = pd.read_csv(self.my_csv_name, index_col=0)
      list_ID = list(dict_aln)
      df.reindex(matches)
      df = df.replace({np.NAN : (df.mean()).mean()}, regex=True)
      numpy_frame = np.log(np.add((np.nan_to_num(df.T.to_numpy())), np.nan_to_num(df.to_numpy())))
      df = pd.DataFrame(data=numpy_frame, index=list(df.index.values), columns=list(df.index.values))
      df = df.reindex(matches)
      df = df.reindex(columns=matches)
      return df

  def prepare_heatmap(self):
    """Visualization of the heatmap and plotting a square on it according to the clades obtained from the tree analysis"""
    dict_aln, matches = self.make_aln_list()
    fig, ax = plt.subplots(figsize=(40,40))         # Sample figsize in inches
    # list_clad = self.remove_subsets(self.clad_list)
    for lista_organ in self.clad_list:
      list_index = []
      for element in lista_organ:
        index = matches.index(element)
        list_index.append(index)
      
      min_seq = min(list_index)
      max_seq = max(list_index)
      zakres = max_seq-min_seq

      # Create a Rectangle patch
      rect = patches.Rectangle((min_seq, min_seq), zakres, zakres, linewidth=3, edgecolor='green', facecolor='none')
      # Add the patch to the Axes
      ax.add_patch(rect)
    df = self.make_df()
    sns.heatmap(df, annot=False)
  
  def make_heatmap(self):
    """Making a heatmap visualization."""
    df = self.make_df()
    dict_aln, matches = self.make_aln_list()
    list_clad = self.remove_subsets(self.clad_list)
    self.prepare_heatmap()


class Alignment(Tree):
  """Making a alignment visualization."""
  def __init__(self, my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv, clad_list, color_dict):
    """Initialization of the attributes needed to create alignment."""
    super().__init__(my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv)
    self.clad_list = clad_list
    # self.my_file_name = my_file_name
    self.color_dict = color_dict

  def make_lines(self):
    file = open(self.my_aln_name,mode='r+')
    file_name = file.read()
    file.close()

    organ_to_cut_list = []

    new_file = ""
    for lista in self.clad_list:
      for organ in lista:
        organ_to_cut_list.append(organ)
    regex = r">.*\n.*\n"
    lines = re.findall(regex, file_name)
    lines = [seq.replace(" ", "_") for seq in lines]
    return lines, organ_to_cut_list

  def make_seq_dict(self):
    # seq_dict = {organ: [[-,-,A],[A,-,-],[-,D,-]], organ2: [[-,D,D],[A,K,V],[-,-,-]]}
    RbcS_dict = self.preparing_taxo_dict()
    new_seq = ""
    seq_dict = {}
    # list_group_name = []
    # names = []
    lists_lista = []
    list_name = []
    count = 0
    lines, organ_to_cut_list = self.make_lines()
    for line in lines:
      organ_regex = r"\w+_\w+_\d+"
      organ = re.search(organ_regex, line).group(0)
      if organ not in organ_to_cut_list:
        new_seq += line
    for lista in self.clad_list:
      list_of_lists = []
      for organ in lista:
        for seq in lines:
          if organ in seq:
            seq2 = re.sub(r">.*\n", "", seq)
            list_of_lists.append(list(seq2[0:-1]))
            name = re.sub(r"\n.*", "", seq)
      list_name.append(RbcS_dict[(re.sub(r"\d+", "", name)).replace("_"," ")[1:-1]])
      list_name = list(map(lambda x: x.replace("Arabidopsis thaliana", "Dwuliścienne"), list_name))
      lists_lista.append(list_of_lists)

      names2 = list(set(list_name))

      new_names = self.list_of_names(names2, list_name)

    for n in range(len(new_names)):
      seq_dict[new_names[n]] = np.array(lists_lista[n], dtype=object,)      
    keys = seq_dict.keys()

    return seq_dict, keys

  def make_seq_aa_dict(self):
    # seq_aa_dict = {organ: {"-": [0, 0.33, 0.66, 0, 0], "A": [1, 0.66, 0.33, 1, 1]}, organ2: {"-": [0, 0.33, 0.66, 0, 0], "A": [1, 0.66, 0.33, 1, 1]}}
    seq_dict, keys = self.make_seq_dict()
    aa_list = list(self.color_dict.keys())

    # słownik --> kolumna: aa:(aa).count
    seq_aa_dict = {}
    for key in keys:
        # iterujemy przez kolejne aa:
        aa_dict = {}
        for aa in aa_list:
            list_aa = []
            # iterujemu przez kolumny:
            for n in range(len(seq_dict[key][0])):
                count = np.count_nonzero(seq_dict[key][:,n] == aa)
                count = count/len(seq_dict[key]) * 100
                list_aa.append(count)
                aa_dict[aa] = list_aa
        seq_aa_dict[key] = aa_dict
        
    # keys = list(keys).sort()

    keys = list(keys)
    # print(seq_aa_dict[keys[0]])
    keys.sort()
    return seq_aa_dict, keys

  def make_plot(self):
    seq_aa_dict, keys = self.make_seq_aa_dict()
    os.makedirs("lista_dopasowań", exist_ok = True)
    name = self.my_aln_name[0:-4]
    list_img = []
    for n in range(480, len(seq_aa_dict[keys[0]]["-"]), 80):
    # for n in range(0, 79, 80):
      df = pd.DataFrame(columns = list(seq_aa_dict[keys[0]].keys()))
      for key2 in keys:
        common_aa = []
        aa_dict = seq_aa_dict[key2]
        for key in list(aa_dict.keys()): # "-", "A", "G" ...
          df[key] = aa_dict[key]
        df2 = df[n:n+80]
        seq_len = len(df2)

        fig, ax = plt.subplots(figsize=(seq_len*2,3))

        # group_name = key2[0:5] + re.search("\d+", key2).group(0)
        group_name = key2[0:5] + re.sub("[^0-9]", "", key2)
        plt.ylabel(group_name, fontsize=50)
        for border in ["top", "right", "left", "bottom"]:
          ax.spines[border].set_visible(False)
        ax.get_yaxis().set_ticks([])

        colors = list(color_dict.values())
        bottom = np.zeros(len(df2))
        new_list = []
        for col in df.columns:
            bars = ax.bar(df2.index, df2[col], bottom=bottom, color="w")
            bottom += np.array(df2[col])
            for bar in ax.patches:
                if bar not in new_list and bar.get_height() != 0:
                    new_list.append(bar)
                    box = plt.text(
                            bar.get_x() + bar.get_width()/2,
                            bar.get_y()/2,
                            col,
                            color=color_dict[col],
                            size=bar.get_height(),
                            fontfamily = 'sans-serif',
                            ha='center')
                    if bar.get_height() > 95 and col != "-":
                      box.set_bbox(dict(facecolor='green', alpha=0.2))
                      # common_aa.append(bar.get_x() + bar.get_width()/2)

        plt.close(fig)
        fig_name = str(n) + group_name + ".jpg"    
        fig.savefig("lista_dopasowań/" + fig_name)
        list_img.append("lista_dopasowań/"+fig_name)

        imgs = [PIL.Image.open(i) for i in list_img]
        min_shape = sorted([(np.sum(i.size), i.size) for i in imgs])[0][1]
        imgs_comb = np.vstack([np.asarray(i.resize(min_shape)) for i in imgs])
        imgs_comb = PIL.Image.fromarray(imgs_comb)
        aln_name = str(int(n/80)) + "_" + name + '_dopasowanie.jpg'
        imgs_comb.save(aln_name)
      print(list_img)
      list_img = []


class PyMOL_query(Tree):
  """Making a PyMOL query."""
  def __init__(self, my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv, clad_list, seq_dict, keys):
    """Initialization of the attributes needed to create query."""
    super().__init__(my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv)
    self.clad_list = clad_list
    # self.my_file_name = my_file_name
    self.seq_dict = seq_dict
    self.keys = keys

  def make_organ_common_dict(self):
    organ_common_dict = {}
    for key in list(self.keys):
      seq = ""
      for n in range(len(self.seq_dict[key][0,:])):
        seq += stats.mode(self.seq_dict[key][:,n])[0][0]
      organ_common_dict[key] = seq
    return organ_common_dict

  def make_common_dict(self):
    """Unified sequences in clades"""
    keys = list(self.seq_dict.keys())
    common_dict = {}

    for key in keys:
      seq = ""
      for letter in stats.mode(self.seq_dict[key])[0][0]:
        seq += letter
        # seq = seq.replace("-", "")
      common_dict[key] = seq
    return common_dict

  def clad_organ(self):
    """Dictionary of lists with organisms in each clads"""
    klad_dict = {}
    count = 0
    for klad in new_node_list:
      organ_list = []
      for organ in klad:
        organ = re.sub("\d+", "", organ)
        organ = organ.replace("_", " ")
        organ = organ[0:-1]
        organ_list.append(organ)
      klad_dict[list(keys)[count]] = list(set(organ_list))
      count += 1
    return klad_dict

  def make_query(self, numb):
    dict_trans = {}
    list_groups = list(self.clad_organ().keys())
    organ_common_dict = self.make_organ_common_dict()
    dict_groups = self.make_common_dict()
    for group in list_groups:
      trans = organ_common_dict[group][0:numb].replace("-","")
      seq = dict_groups[group]
      dict_trans[group] = (len(trans), len(seq))

    seq = ""
    seq2 = ""
    for group in list_groups:
      if re.search("\d+", group):
        name = "\"" + re.sub("\d+", "", group) + "_" + re.sub("\D+", "", group) + "_1" + "\", "
      else:
        name = "\"" + re.sub("\d+", "", group) + "_0" + re.sub("\D+", "", group) + "_1" + "\","
      name = name.replace("ś", "s")
      seq += name
      seq2 += name
      seq += str(dict_trans[group])+","
    return seq[0:-1], seq2[0:-1]

  def make_aln_file(self, name):
    list_groups = list(self.clad_organ().keys())
    organ_common_dict = self.make_organ_common_dict()
    with open(name, "w+") as f:
      for group in list_groups:
        seq_aln = organ_common_dict[group]
        group = group.replace("ś", "s")
        group = re.sub("\d+", "", group) + "_" + re.sub("\D+", "", group) + "_1"
        f.write(f">{group}\n{seq_aln}")

In [None]:
df_class = Tree(my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv)
new_node_list, tree, style = df_class.make_tree()
# tree.render('%%inline', tree_style=style, w=2000, units='px')

In [None]:
# clad_list = new_node_list
# df_class = Heatmap(my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv, clad_list)
# df_class.make_heatmap()

In [None]:
clad_list = new_node_list
if len(clad_list) < 44:
  df_class = Alignment(my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv, clad_list, color_dict)
  seq_aa_dict, keys = df_class.make_seq_aa_dict()
  seq_dict, keys = df_class.make_seq_dict()
  # df_class.make_plot()

In [None]:
clad_list = new_node_list
df_class = PyMOL_query(my_aln_name, my_csv_name, my_tree_name, my_taxo_name, threshold, threshold_spec, group, all_conserv, clad_list, seq_dict, keys)
common_dict = df_class.make_common_dict()
# print(df_class.make_query(93))