#  ESM2 MSA Comparison 

In [1]:
import os
import gzip
import numpy as np
import pandas as pd
import torch
from Bio import AlignIO

pd.set_option('display.max_columns', 100)

id2aa = {0: 'A', 1: 'C', 2: 'D', 3: 'E', 4: 'F', 
         5: 'G', 6: 'H', 7: 'I', 8: 'K', 9: 'L', 
         10: 'M', 11: 'N', 12: 'P', 13: 'Q', 14: 'R', 
         15: 'S', 16: 'T', 17: 'V', 18: 'W', 19: 'Y', 
         20: 'X', 21: 'Z', 22: '-', 23: 'B'}
aa2id = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 
         'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 
         'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 
         'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, 
         'X':20, 'Z': 21, '-': 22, 'B': 23}

In [2]:
def read_esm2_conservation(esm2_conservation_path):
    with gzip.open(esm2_conservation_path, 'rt') as f:
        esm2_df = pd.read_csv(f, sep=',', index_col=0)
    esm2_pivot_df = esm2_df.pivot(index='Position', columns='Amino Acid', values='Probability')
    esm2_pivot_df.reset_index(drop=True, inplace=True)
    esm2_array = esm2_pivot_df.to_numpy()
    return esm2_array


def read_msa(msa_path):
    with open(msa_path, 'r') as f:
        msa = AlignIO.read(f, 'fasta')
    msa_matrix = np.zeros((len(msa), msa.get_alignment_length()), dtype=int)
    record_ids = []
    for i, record in enumerate(msa):
        record_ids.append(record.id)
        msa_matrix[i, :] = numeric_encode(str(record.seq))
    return msa_matrix, record_ids


def numeric_encode(seq):
    seq_numeric = np.array(
        [aa2id[aa] for aa in str(seq).upper().replace(".", "-")]
    )
    return torch.from_numpy(seq_numeric)


def make_esm2_msa_df(msa_matrix, esm2_array, index):
    original_positions = np.full(msa_matrix.shape[1], np.nan)
    msa_positions = np.where(msa_matrix[index, :] < 20)[0]
    for i, pos in enumerate(msa_positions):
        original_positions[pos] = i + 1
    esm2_AA_highest = np.full(msa_matrix.shape[1], np.nan)
    for i, pos in enumerate(msa_positions):
        esm2_AA_highest[pos] = np.argmax(esm2_array, axis=1)[i]    
    conservation_score = np.full(msa_matrix.shape[1], np.nan)
    for i, pos in enumerate(msa_positions):
        conservation_score[pos] = esm2_array.max(axis=1)[i]

    df = pd.DataFrame(msa_matrix.T, columns=names)
    df['msa_pos'] = list(range(1, len(msa_matrix.T)+1))
    df['pos'] = original_positions
    df['AA'] = df[target].map(lambda x: id2aa[int(x)])
    df['esm2_AA'] = [id2aa[int(aa_i)] if not np.isnan(aa_i) else np.nan for aa_i in esm2_AA_highest]
    df['conservation_score'] = conservation_score
    df['conservation_score'] = df['conservation_score'].map(lambda x: '{:.3f}'.format(x) if not np.isnan(x) else x)
    df['count'] = df.apply(lambda row: np.sum(row[names] == row[target]) if row[target] < 20 else np.nan, axis=1)
    name_order = ago_names + piwi_names + wago_names
    for name in name_order:
        df[name] = df[name].map(lambda x: id2aa[int(x)])
    df = df[['msa_pos', 'pos', 'AA', 'esm2_AA', 'conservation_score', 'count']  + name_order]
    return df

In [3]:
# Pre-define fasta id to common names for readability (optional)

id2name = {
    'sp|O61931|ERGO1_CAEEL': 'CeERGO-1', 'sp|Q8CJG0|AGO2_MOUSE': 'MmAgo2', 'sp|Q9QZ81|AGO2_RAT': 'RnAgo2', 
    'tr|G5EEH0|G5EEH0_CAEEL': 'CeRDE-1', 'tr|A0A8V0Y222|A0A8V0Y222_CHICK': 'GgAgo2', 'tr|G5EES3|G5EES3_CAEEL': 'CeALG-1', 
    'sp|Q9H9G7|AGO3_HUMAN': 'HsAgo3', 'sp|Q9UKV8|AGO2_HUMAN': 'HsAgo2', 'tr|O16720|O16720_CAEEL': 'CeALG-2', 
    'sp|Q9UL18|AGO1_HUMAN': 'HsAgo1', 'tr|Q32KD4|Q32KD4_DROME': 'DmAgo1', 'tr|G5EC94|G5EC94_CAEEL': 'CeALG-3', 
    'sp|Q9HCK5|AGO4_HUMAN': 'HsAgo4', 'sp|P34681|TAG76_CAEEL': 'CeALG-4', 'tr|Q9XVI3|Q9XVI3_CAEEL': 'CeALG-5', 
    'sp|Q746M7|AGO_THET2': 'TtAgo', 'sp|Q9SHF3|AGO2_ARATH': 'AtAgo', 'tr|A0A8U0S055|A0A8U0S055_MUSPF': 'MputfAgo2',
    'sp|Q7Z3Z4|PIWL4_HUMAN': 'HsPIWIL4', 'tr|P90786|P90786_CAEEL': 'CePRG-1', 'sp|Q96J94|PIWL1_HUMAN': 'HsPIWIL1', 
    'sp|Q8TC59|PIWL2_HUMAN': 'HsPIWIL2', 'sp|Q7Z3Z3|PIWL3_HUMAN': 'HsPIWIL3', 'sp|Q9VKM1|PIWI_DROME': 'DmPIWI', 
    'sp|A8D8P8|SIWI_BOMMO': 'BmSIWI',
    'tr|A0A0U1RML5|A0A0U1RML5_CAEEL': 'CeSAGO-2', 'tr|A0A0T7CIX3|A0A0T7CIX3_CAEEL': 'CeSAGO-1', 'tr|A8XRG0|A8XRG0_CAEBR': 'CbrCSR',
    'sp|Q09249|YQ53_CAEEL': 'CeHRDE-1', 'tr|Q9XVF1|Q9XVF1_CAEEL': 'CeVSRA-1', 'tr|Q9TXN7|Q9TXN7_CAEEL': 'CeWAGO-10', 
    'tr|E3M6J3|E3M6J3_CAERE': 'CreCSR', 'tr|H2KZD5|H2KZD5_CAEEL': 'CeCSR-1a', 'sp|Q21691|NRDE3_CAEEL': 'CeNRDE-3', 
    'tr|A0A2G5U890|A0A2G5U890_9PELO': 'CniCSR', 'tr|A8WQA0|A8WQA0_CAEBR': 'CbrHRDE-1', 'tr|Q86NJ8|Q86NJ8_CAEEL': 'CePPW-1', 
    'tr|Q9N585|Q9N585_CAEEL': 'CePPW-2', 'sp|Q21770|WAGO1_CAEEL': 'CeWAGO-1', 'sp|O62275|WAGO4_CAEEL': 'CeWAGO-4'}

ago_names = ['AtAgo', 'CeALG-1', 'CeALG-2', 'CeALG-3', 'CeALG-4', 
             'CeALG-5', 'CeERGO-1', 'CeRDE-1', 'DmAgo1', 'GgAgo2', 
             'HsAgo1', 'HsAgo2', 'HsAgo3', 'HsAgo4', 'MmAgo2', 
             'MputfAgo2', 'RnAgo2', 'TtAgo']
piwi_names = ['BmSIWI', 'CePRG-1', 'DmPIWI', 'HsPIWIL1', 'HsPIWIL2', 
              'HsPIWIL3', 'HsPIWIL4']
wago_names = ['CbrCSR', 'CbrHRDE-1', 'CeCSR-1a', 'CeHRDE-1', 'CeNRDE-3', 
              'CePPW-1', 'CePPW-2', 'CeSAGO-1', 'CeSAGO-2', 'CeVSRA-1', 
              'CeWAGO-1', 'CeWAGO-10', 'CeWAGO-4', 'CniCSR', 'CreCSR']

# Execution (single run)

In [4]:
# Specifiy your input

target = 'CeVSRA-1'
esm2_conservation_path = f'/home/moon/projects/AgoAnalysis/esm2/{target}/{target}_conservation_esm2_t36_3B_UR50D.csv.gz'

msa_path = f'/home/moon/projects/AgoAnalysis/msa/Argonaute_all.msa.fasta'

In [5]:
# Load msa into a matrix
msa_matrix, record_ids = read_msa(msa_path)
# Convert fasta id to a common name (optional)
names = [id2name[record_id] for record_id in record_ids]
# Load esm2 conservation data into an 2D array
esm2_array = read_esm2_conservation(esm2_conservation_path)

# Get an index for your target protein
index = names.index(target)

# Generate a dataframe that integrates esm2 & msa info
esm2_msa_df = make_esm2_msa_df(msa_matrix, esm2_array, index)

# Save the dataframe into a csv
outpath = f'/home/moon/projects/AgoAnalysis/esm2/{target}/esm2-conservation_msa_comparison_{target}.csv'
esm2_msa_df.to_csv(outpath, index=False)

# Execution (multiple runs)

In [6]:
# Specify one msa input and multiple targets
msa_path = f'/home/moon/projects/AgoAnalysis/msa/Argonaute_all.msa.fasta'
targets = ago_names + piwi_names + wago_names

In [9]:
msa_matrix, record_ids = read_msa(msa_path)
names = [id2name[record_id] for record_id in record_ids]

for target in targets:
    esm2_conservation_path = f'/home/moon/projects/AgoAnalysis/esm2/{target}/{target}_conservation_esm2_t36_3B_UR50D.csv.gz'
    if not os.path.exists(esm2_conservation_path):
        print(f'{target} does not have ESM2 conservation data. Skipping.')
        continue
    else:
        print(f'{target} is being processed.')

    esm2_array = read_esm2_conservation(esm2_conservation_path)
    index = names.index(target)
    esm2_msa_df = make_esm2_msa_df(msa_matrix, esm2_array, index)
    outpath = f'/home/moon/projects/AgoAnalysis/esm2/{target}/esm2-conservation_msa_comparison_{target}.csv'
    esm2_msa_df.to_csv(outpath, index=False)

AtAgo is being processed.
CeALG-1 is being processed.
CeALG-2 is being processed.
CeALG-3 does not have ESM2 conservation data. Skipping.
CeALG-4 does not have ESM2 conservation data. Skipping.
CeALG-5 does not have ESM2 conservation data. Skipping.
CeERGO-1 does not have ESM2 conservation data. Skipping.
CeRDE-1 does not have ESM2 conservation data. Skipping.
DmAgo1 does not have ESM2 conservation data. Skipping.
GgAgo2 does not have ESM2 conservation data. Skipping.
HsAgo1 is being processed.
HsAgo2 is being processed.
HsAgo3 does not have ESM2 conservation data. Skipping.
HsAgo4 does not have ESM2 conservation data. Skipping.
MmAgo2 does not have ESM2 conservation data. Skipping.
MputfAgo2 does not have ESM2 conservation data. Skipping.
RnAgo2 does not have ESM2 conservation data. Skipping.
TtAgo is being processed.
BmSIWI is being processed.
CePRG-1 is being processed.
DmPIWI is being processed.
HsPIWIL1 does not have ESM2 conservation data. Skipping.
HsPIWIL2 is being processed.
H