# Deciphering the Black Box: Mastering the MSA Transformer for Phylogenetic Tree Construction

## 1.Install and import package ##

In [5]:
!pip install fair-esm --quiet
!pip install transformers --quiet
!pip install pysam --quiet
!pip install Bio
!pip install ete3
!pip install dendropy



In [1]:
import torch
import string
import torch.nn.functional as F
import numpy as np
import esm
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
from pysam import FastaFile,FastxFile
from torch.utils.data import TensorDataset,Dataset
import pandas as pd
from Bio import SeqIO
from ete3 import Tree
from scipy import stats
import csv
import sys
import os
import torch.utils.data as data
import random

In [None]:
class ChangeAA:

  def __init__(self, protein_family):
    self.protein_family = protein_family
    self.msa_file = '/content/drive/MyDrive/PhD/Pfam/'+ f"{self.pfam}_seed_hmmalign_no_inserts.fasta"

  def shuffle_col(self,column):

    if len(set(column)) == 1:
        return column

    aa_list = list(column)
    random.shuffle(aa_list)

    return ''.join(aa_list)

  def mix_fasta_column(self):
    mix_column = '/content/drive/MyDrive/PhD/Pfam/' + f"{self.pfam}_mix_column.fasta"

    records = list(SeqIO.parse(self.msa_file, "fasta"))
    seq_length = len(records[0].seq)
    shuffled_records = []

    for i in range(seq_length):
        column = ''.join(record.seq[i] for record in records)
        shuffled_column = self.shuffle_col(column)
        for j, record in enumerate(records):
            if j >= len(shuffled_records):
                shuffled_records.append(record)
            shuffled_records[j].seq = shuffled_records[j].seq[:i] + shuffled_column[j] + shuffled_records[j].seq[i+1:]

    shuffled_msa = [record for record in shuffled_records]
    SeqIO.write(shuffled_msa, mix_column, "fasta")

    print('Generate Mix columns data!')

  def shuffle_fasta_all(self):
    shuffle_all = '/content/drive/MyDrive/PhD/Pfam/' + f"{self.pfam}_shuffle_all.fasta"
    shuffle_order = '/content/drive/MyDrive/PhD/Pfam/' + f"{self.pfam}_shuffle_all_order.txt"

    with open(self.msa_file, 'r') as file:
        lines = file.readlines()

    sequences = []
    current_sequence = ''
    for line in lines:
        line = line.rstrip()
        if line.startswith('>'):
            if current_sequence:
                sequences.append(current_sequence)
            sequences.append(line)
            current_sequence = ''
        else:
            current_sequence += line
    sequences.append(current_sequence)
    sequence_length = len(sequences[1])

    shuffled_sequences = []
    shuffled_order = []

    for sequence in sequences:
        if not sequence.startswith('>'):
            # for each sequence, regenerate new sequences in a random order using the same amino acid composition
            shuffled_indices = random.sample(range(len(sequence)), len(sequence))
            shuffled_order.append(shuffled_indices)
            shuffled_sequence = ''.join(''.join([sequence[i] for i in shuffled_indices]))
            shuffled_sequences.append(shuffled_sequence)
        else:
            shuffled_sequences.append(sequence)

    with open(shuffle_all, 'w') as file:
        file.write('\n'.join(shuffled_sequences))

    with open(shuffle_order,'w') as file:
      write = csv.writer(file)
      write.writerows(shuffled_order)

    print('Generate Shuffle all data!')


  def shuffle_fasta_column(self):
    shuffle_column = '/content/drive/MyDrive/PhD/Pfam/' + f"{self.pfam}_shuffle_column.fasta"
    shuffle_order = '/content/drive/MyDrive/PhD/Pfam/' + f"{self.pfam}_shuffle_column_order.txt"

    with open(self.msa_file, 'r') as file:
      lines = file.readlines()

    sequences = []
    current_sequence = ''
    for line in lines:
      line = line.rstrip()
      if line.startswith('>'):
        if current_sequence:
          sequences.append(current_sequence)
        sequences.append(line)
        current_sequence = ''
      else:
        current_sequence += line
    sequences.append(current_sequence)
    sequence_length = len(sequences[1])

    shuffled_sequences = []
    shuffled_order = []
    # regernate new sequences using the same order
    shuffled_indices = random.sample(range(sequence_length), sequence_length)
    shuffled_order.append(shuffled_indices)

    for sequence in sequences:
      if not sequence.startswith('>'):
        shuffled_sequence = ''.join(''.join([sequence[i] for i in shuffled_indices]))
        shuffled_sequences.append(shuffled_sequence)
      else:
        shuffled_sequences.append(sequence)

    with open(shuffle_column, 'w') as file:
      file.write('\n'.join(shuffled_sequences))

    with open(shuffle_order,'w') as file:
      write = csv.writer(file)
      write.writerows(shuffled_order)

    print('Generate Shuffle columns data!')

In [None]:
def remove_insertions(sequence):
  deletekeys = dict.fromkeys(string.ascii_lowercase)
  deletekeys["."] = None
  deletekeys["*"] = None
  translation = str.maketrans(deletekeys)
  """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
  return sequence.translate(translation)


class Extractor:

  def __init__(self, protein_family, msa_type):
    self.device = "cuda" if torch.cuda.is_available() else "cpu"
    self.model_name = "esm_msa1b_t12_100M_UR50S"
    self.encoding_dim, self.encoding_layer, self.max_seq_length = 768, 12, 1022
    self.protein_family = protein_family
    self.msa_type = msa_type
    self.emb_path = './embeddings/'
    self.attn_path = '/attentions/'
   
    if self.msa_type == "no":
      self.msa_fasta_file = '/MSA/' + f"{self.protein_family}_seed_hmmalign_no_inserts.fasta"
    # Shuffle column data
    elif self.msa_type == "sc":
      self.msa_fasta_file = '/MSA/' + f"{self.protein_family}_shuffle_column.fasta"
    # Shuffle all data
    elif self.msa_type == "sa":
      self.msa_fasta_file = '/MSA/' + f"{self.protein_family}_shuffle_all.fasta"
    # Mix column data
    else:
      self.msa_fasta_file = '/MSA/' + f"{self.protein_family}_mix_column.fasta"

  def read_msa(self):
    return [(record.description, remove_insertions(str(record.seq)))
            for record in SeqIO.parse(self.msa_fasta_file, "fasta")]

  def get_col_attention(self):
    model, alphabet = pretrained.load_model_and_alphabet(self.model_name)
    batch_converter = alphabet.get_batch_converter()
    tokens_to_index = alphabet.tok_to_idx.copy()

    if self.msa_type == "no":
      col_attn = self.attn_path + self.protein_family + '_attn_no_shuffle_' + self.model_name + '.pt'
    # Shuffle column data
    elif self.msa_type == "sc":
      col_attn = self.attn_path + self.protein_family + '_attn_shuffle_column_' + self.model_name + '.pt'
    # Shuffle all data
    elif self.msa_type == "sa":
      col_attn = self.attn_path + self.protein_family + '_attn_shuffle_all_' + self.model_name + '.pt'
    # Mix column data
    else:
      col_attn = self.attn_path + self.protein_family + '_attn_mix_column_' + self.model_name + '.pt'

    model.eval()
    # process fasta file into pytorch dataset
    msa_data = [self.read_msa()]
    msa_labels, msa_strs, msa_tokens = batch_converter(msa_data)

    with torch.no_grad():
        results = model(msa_tokens, repr_layers=[12], need_head_weights=False)

    torch.save(results, col_attn)
    print("Embeddings saved in output file:", col_attn)