# Deciphering the Black Box: Mastering the MSA Transformer for Phylogenetic Tree Construction

## 1.Install and import package ##

In [5]:
!pip install fair-esm --quiet
!pip install transformers --quiet
!pip install pysam --quiet
!pip install Bio
!pip install ete3
!pip install dendropy



In [6]:
import string
import torch
import torch.nn.functional as F
import numpy as np
import esm
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
from pysam import FastaFile,FastxFile
from torch.utils.data import TensorDataset,Dataset
import pandas as pd
from Bio import SeqIO

In [None]:
class Extracting:

  def __init__(self, msa_fasta_file, protein_family, file_location):
    self.device = "cuda" if torch.cuda.is_available() else "cpu"
    self.model_name = "esm_msa1b_t12_100M_UR50S"
    self.protein_family = protein_family
    self.msa_fasta_file = msa_fasta_file
    self.file_location = file_location
    self.encoding_dim, self.encoding_layer, self.max_seq_length = 768, 12, 1022

  def process_fasta_file(self):
    msa = []
    labels = [record.id for record in SeqIO.parse(self.msa_fasta_file, "fasta")]
    for record in SeqIO.parse(self.msa_fasta_file, "fasta"):
      sequence = str(record.seq)
      self.sequence_length = len(sequence)
      if self.sequence_length > self.max_seq_length: # if the sequence length over the maximum value
        sequence = sequence[:self.max_seq_length]
        self.sequence_length = self.max_seq_length
      sequence_id = record.id
      msa.append((sequence_id, sequence))
    return msa,labels


  # process fasta file into pytorch dataset and dataloader
  def get_plm_msa_embedding(self):
    model, alphabet = pretrained.load_model_and_alphabet(self.model_name)
    batch_converter = alphabet.get_batch_converter()
    tokens_to_index = alphabet.tok_to_idx.copy()
    output_file_name = self.file_location + self.protein_family + '_all_layers_' + self.model_name + '.pt'
    print("Total Tokens",len(tokens_to_index))
    print("Model Name",model)
    print("Encoding Dim size",self.encoding_dim)
    print("Representation Layer",self.encoding_layer)

    model.eval()
    # process fasta file into pytorch dataset
    MSA,labels = self.process_fasta_file()
    # convert all batches into model embeddings
    plm_embedding = {}
    msa_labels, msa_strs, msa_tokens = batch_converter(MSA)
    with torch.no_grad():
      for layer in range(self.encoding_layer):
        out = model(msa_tokens, repr_layers=[layer], return_contacts = False)
        token_representations = out["representations"][layer].view(-1, self.sequence_length+1, self.encoding_dim)
        token_representations = token_representations[:,1:,:]
        print(f"Finish extracting embeddings from layer {layer}.")
        plm_embedding[layer] = token_representations

    torch.save(plm_embedding, output_file_name)
    print("Embeddings saved in output file:", output_file_name)

  def get_col_attention(self):
    col_file_name = self.file_location + self.protein_family + '_attention_' + self.model_name + '.pt'
    col_attention_file = self.file_location + '/' + self.protein_family + '/col_attention.csv'
    model, alphabet = pretrained.load_model_and_alphabet(self.model_name)
    batch_converter = alphabet.get_batch_converter()
    tokens_to_index = alphabet.tok_to_idx.copy()
    model.eval()
    # process fasta file into pytorch dataset
    MSA,labels = self.process_fasta_file()
    # convert all batches into model embeddings
    plm_embedding = {}
    msa_labels, msa_strs, msa_tokens = batch_converter(MSA)
    with torch.no_grad():
      for layer in range(self.encoding_layer):
        out = model(msa_tokens, repr_layers=[12],need_head_weights=True)
    torch.save(out, col_file_name)
    print("Embeddings saved in output file:", col_file_name)