# Deciphering the Black Box: Mastering the MSA Transformer for Phylogenetic Tree Construction

## Install and import package ##

In [2]:
!pip install fair-esm --quiet
!pip install transformers --quiet
!pip install pysam --quiet
!pip install Bio
!pip install ete3
!pip install dendropy

Collecting dendropy
  Using cached DendroPy-4.6.1-py3-none-any.whl (458 kB)
Installing collected packages: dendropy
Successfully installed dendropy-4.6.1


In [3]:
import csv
import os
import random
import string
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.utils.data as data
from Bio import SeqIO
from ete3 import Tree
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, MSATransformer, pretrained
from pysam import FastaFile, FastxFile
from scipy import stats
from torch.utils.data import TensorDataset, Dataset

## 1. Generate 4 kinds of MSA ##

In [11]:
EMB_PATH = './Embeddings/'
ATTN_PATH = './Attentions/'
MSA_PATH = './Msa/'
TREE_PATH = './Trees/'

MSA_TYPE_MAP = {
        "no": "_seed_hmmalign_no_inserts.fasta",
        "sc": "_shuffle_column.fasta",
        "sa": "_shuffle_all.fasta",
        "default": "_mix_column.fasta"
    }

EMB_TYPE_MAP = {
        "no": "_emb_no_shuffle_",
        "sc": "_emb_shuffle_column_",
        "sa": "_emb_shuffle_all_",
        "default": "_emb_mix_column_"
    }
    
ATTN_TYPE_MAP = {
        "no": "_attn_no_shuffle_",
        "sc": "_attn_shuffle_column_",
        "sa": "_attn_shuffle_all_",
        "default": "_attn_mix_column_"
    }

In [8]:
class ChangeAA:

    def __init__(self, protein_family):
        self.protein_family = protein_family
        self.msa_file = f'{MSA_PATH}{self.protein_family}_seed_hmmalign_no_inserts.fasta'

    def _shuffle_list(self, data_list):
        if len(set(data_list)) == 1:
            return data_list
        random.shuffle(data_list)
        return data_list

    def mix_fasta_column(self):
        output_file = f'{MSA_PATH}{self.protein_family}_mix_column.fasta'
        records = list(SeqIO.parse(self.msa_file, "fasta"))
        seq_length = len(records[0].seq)
        shuffled_records = []

        for i in range(seq_length):
            column = ''.join(record.seq[i] for record in records)
            shuffled_column = self._shuffle_list(list(column))
            for j, record in enumerate(records):
                if j >= len(shuffled_records):
                    shuffled_records.append(record)
                shuffled_records[j].seq = shuffled_records[j].seq[:i] + shuffled_column[j] + shuffled_records[j].seq[i+1:]

        SeqIO.write(shuffled_records, output_file, "fasta")
        print('Generated Mix columns data!')

    def _read_fasta(self):
        with open(self.msa_file, 'r') as file:
            lines = file.readlines()

        sequences = []
        current_sequence = ''
        for line in lines:
            line = line.rstrip()
            if line.startswith('>'):
                if current_sequence:
                    sequences.append(current_sequence)
                sequences.append(line)
                current_sequence = ''
            else:
                current_sequence += line
        sequences.append(current_sequence)

        return sequences

    def shuffle_fasta_all(self):
        output_seq_file = f'{MSA_PATH}{self.protein_family}_shuffle_all.fasta'
        output_order_file = f'{MSA_PATH}{self.protein_family}_shuffle_all_order.txt'
        sequences = self._read_fasta()
        sequence_length = len(sequences[1])

        shuffled_sequences = []
        shuffled_order = []

        for sequence in sequences:
            if not sequence.startswith('>'):
                shuffled_indices = random.sample(range(sequence_length), sequence_length)
                shuffled_order.append(shuffled_indices)
                shuffled_sequence = ''.join(sequence[i] for i in shuffled_indices)
                shuffled_sequences.append(shuffled_sequence)
            else:
                shuffled_sequences.append(sequence)

        with open(output_seq_file, 'w') as file:
            file.write('\n'.join(shuffled_sequences))

        with open(output_order_file, 'w') as file:
            writer = csv.writer(file)
            writer.writerows(shuffled_order)

        print('Generated Shuffle all data!')

    def shuffle_fasta_column(self):
        output_seq_file = f'{MSA_PATH}{self.protein_family}_shuffle_column.fasta'
        output_order_file = f'{MSA_PATH}{self.protein_family}_shuffle_column_order.txt'
        sequences = self._read_fasta()
        sequence_length = len(sequences[1])

        shuffled_order = [random.sample(range(sequence_length), sequence_length)]
        shuffled_sequences = []

        for sequence in sequences:
            if not sequence.startswith('>'):
                shuffled_sequence = ''.join(sequence[i] for i in shuffled_order[0])
                shuffled_sequences.append(shuffled_sequence)
            else:
                shuffled_sequences.append(sequence)

        with open(output_seq_file, 'w') as file:
            file.write('\n'.join(shuffled_sequences))

        with open(output_order_file, 'w') as file:
            writer = csv.writer(file)
            writer.writerows(shuffled_order)

        print('Generated Shuffle columns data!')

## 2. Generate four kinds of embeddings and attentions

In [9]:
def remove_insertions(sequence):
  deletekeys = dict.fromkeys(string.ascii_lowercase)
  deletekeys["."] = None
  deletekeys["*"] = None
  translation = str.maketrans(deletekeys)
  """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
  return sequence.translate(translation)

class Extractor:

    def __init__(self, protein_family, msa_type):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_name = "esm_msa1b_t12_100M_UR50S"
        self.encoding_dim, self.encoding_layer, self.max_seq_length = 768, 12, 1022
        self.protein_family = protein_family
        self.msa_type = msa_type if msa_type in MSA_TYPE_MAP else "default"
        self.msa_fasta_file = f'{MSA_PATH}{protein_family}{MSA_TYPE_MAP[self.msa_type]}'

    def read_msa(self):
        return [(record.description, remove_insertions(str(record.seq)))
                for record in SeqIO.parse(self.msa_fasta_file, "fasta")]

    def get_embedding(self):
        model, alphabet = pretrained.load_model_and_alphabet(self.model_name)
        batch_converter = alphabet.get_batch_converter()

        emb = f'{EMB_PATH}{self.protein_family}{EMB_TYPE_MAP[self.msa_type]}{self.model_name}.pt'
        plm_embedding = {}

        model.eval()
        msa_data = [self.read_msa()]
        msa_labels, msa_strs, msa_tokens = batch_converter(msa_data)

        with torch.no_grad():
            for layer in range(self.encoding_layer):
                out = model(msa_tokens, repr_layers=[layer], return_contacts = False)
                token_representations = out["representations"][layer].view(-1, self.sequence_length+1, self.encoding_dim)
                # remove the start token
                token_representations = token_representations[:,1:,:]
                print(f"Finish extracting embeddings from layer {layer}.")
                plm_embedding[layer] = token_representations

        torch.save(plm_embedding, emb)
        print("Embeddings saved in output file:",emb)

    
    def get_col_attention(self):
        model, alphabet = pretrained.load_model_and_alphabet(self.model_name)
        batch_converter = alphabet.get_batch_converter()

        attn = f'{ATTN_PATH}{self.protein_family}{ATTN_TYPE_MAP[self.msa_type]}{self.model_name}.pt'

        model.eval()
        msa_data = [self.read_msa()]
        msa_labels, msa_strs, msa_tokens = batch_converter(msa_data)

        with torch.no_grad():
            results = model(msa_tokens, repr_layers=[12], need_head_weights=True)

        torch.save(results, attn)
        print("Column attention saved in output file:", attn)

## 3.Calculate evolutionary distances for each protein domain ##

In [10]:
LAYER = 12
HEAD = 12

class EvDist:
    
    """Class for evolutionary distance processing"""

    def __init__(self, protein_family, msa_type):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_name = "esm_msa1b_t12_100M_UR50S"
        self.protein_family = protein_family
        self.msa_type = msa_type if msa_type in MSA_TYPE_MAP else "default"
        self.msa_fasta_file = f'{MSA_PATH}{protein_family}{MSA_TYPE_MAP[self.msa_type]}'
        self.emb = f'{EMB_PATH}{self.protein_family}{EMB_TYPE_MAP[self.msa_type]}{self.model_name}.pt'
        self.attn = f'{ATTN_PATH}{self.protein_family}{ATTN_TYPE_MAP[self.msa_type]}{self.model_name}.pt'
        self.tree = os.path.join('/content/drive/MyDrive/PhD/tree', f"{self.protein_family}.tree")

    @staticmethod
    def euc_distance(a, b):
        """Calculate Euclidean distance between two points."""
        return np.sqrt(np.sum((a - b)**2))

    def evolutionary_distance(self, phylo_tree, seq_labels):
        """Calculate the evolutionary distance between sequences based on the phylogenetic tree."""
        phylo_tree = Tree(self.tree)
        ev_distances = []

        for ref_seq_name in seq_labels:
            ref_seq_node = phylo_tree & ref_seq_name

            current_seq_distances = []
            for ex_seq_name in seq_labels:
                ex_seq_node = phylo_tree & ex_seq_name
                distance = ref_seq_node.get_distance(ex_seq_node)
                current_seq_distances.append(distance)

            ev_distances.append(current_seq_distances)

        return np.array(ev_distances)

    def pairwise_euclidean_distance(self, emb):
        emb = np.array(emb)
        m, n, p = emb.shape

        distances = np.zeros((m, m))

        for i in range(m):
            for j in range(i+1, m):
                distance = self.euc_distance(emb[i].flatten(), emb[j].flatten())
                distances[i, j] = distance
                distances[j, i] = distance

        return distances

    def compute_embedding_correlation(self):
        """
        Calculate the correlation between evolutionary distances from the trees and pairwise Euclidean distances of embeddings
        """
        output_file = os.path.join('./Results', f"{self.protein_family}_ev_and_euclidean_analysis.csv")
        
        # Load embeddings
        embeddings = torch.load(self.emb)
        
        # Load sequence names and extract shorter names
        sequences = [record.id for record in SeqIO.parse(self.msa_fasta_file, "fasta")]
        sequence_list = [seq.split(' ')[0] for seq in sequences]

        # Compute evolutionary distances
        ev_distances = self.evolutionary_distance(Tree(self.tree), sequence_list)
        correlations = []

        # Compute Euclidean distances and their correlation with evolutionary distances
        for layer in range(LAYER):
            euc_distances = self.pairwise_euclidean_distance(embeddings[layer].mean(1))
            spear_corr = stats.spearmanr(ev_distances.flatten(), euc_distances.flatten())
            correlations.append([self.protein_family, layer, spear_corr.correlation, spear_corr.pvalue])

        # Save to CSV file
        with open(output_file, 'w') as file:
            writer = csv.writer(file)
            writer.writerow(['Protein family', 'Layer', 'Correlation', 'P value'])
            writer.writerows(correlations)
            
    def compute_attention_correlation(self):
        """
        Calculate the correlation between evolutionary distances from the trees and column attention
        """
        output_file = os.path.join('./Results', f"{self.protein_family}_ev_and_euclidean_analysis.csv")
        
        # Load sequence names and extract shorter names
        sequences = [record.id for record in SeqIO.parse(self.msa_fasta_file, "fasta")]
        sequence_list = [seq.split(' ')[0] for seq in sequences]
        # Compute evolutionary distances
        ev_distances = self.evolutionary_distance(Tree(self.tree), sequence_list)
        
        # Load column attention
        attn = torch.load(self.attn)
        # remove start token
        attn_mean_on_cols_symm = attn["col_attentions"].cpu().numpy()[0,:,:,1:,:,:].mean(axis=2)
        attn_mean_on_cols_symm += attn_mean_on_cols_symm.transpose(0, 1, 3, 2)
        # Generate the row and column indices of the upper triangular part of the attention matrix
        tri_indices = np.triu_indices(attn_mean_on_cols_symm.shape[-1])
        # Select the upper triangle of attention and distance matrix
        attn = attn_mean_on_cols_symm[..., tri_indices[0], tri_indices[1]]  # (12,12, M * (M+1) / 2)
        ev = ev_distances[tri_indices]
        # Reshape the attention matrix
        attn = attn.transpose(2, 0, 1).reshape(-1, 12 * 12)
        df_attn = pd.DataFrame(attn,columns=[f"lyr{i}_hd{j}" for i in range(12) for j in range(12)])

        for layer in range(12):
            for head in range(12):
                attns = df_attn[f"lyr{layer}_hd{head}"].values
                sp_corr = stats.spearmanr(ev, attns)
                spear_evdist_corr.append([self.pfam, layer, head, sp_corr.correlation, sp_corr.pvalue])
        # field names
        fields = ['Pfam id', 'Layer', 'Head', 'Correlation','Pvalue']
        # save csv file
        with open(self.output_corr_data_file, 'w') as f:
        # using csv.writer method from CSV package
            write = csv.writer(f)
            write.writerow(fields)
            write.writerows(spear_evdist_corr)
        