In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import gzip
from Bio import SeqIO
from logomaker import transform_matrix
from pssm_analysis import plot_logomaker 

pd.set_option('display.max_columns', 100)

AA_ALPHABETS = "ACDEFGHIKLMNPQRSTVWY-"
id2aa = {0: 'A', 1: 'C', 2: 'D', 3: 'E', 4: 'F', 5: 'G', 
         6: 'H', 7: 'I', 8: 'K', 9: 'L', 10: 'M', 
         11: 'N', 12: 'P', 13: 'Q', 14: 'R', 15: 'S', 
         16: 'T', 17: 'V', 18: 'W', 19: 'Y', 
         20: 'X', 21: 'Z', 22: '-', 23: 'B'}

In [None]:
### Specify the sequence and the ESM2 model used 
name = 'CeVSRA-1'
gene_id = 'Q9XVF1'
model = 'esm2_t36_3B_UR50D'

# Conservation

In [None]:
esm2_conservation_path = f'/home/moon/projects/AgoAnalysis/esm2/{name}/{name}_conservation_{model}.csv.gz'

with gzip.open(esm2_conservation_path, 'rt') as f:
    esm2_df = pd.read_csv(f, sep=',', index_col=0)

esm2_pivot_df = esm2_df.pivot(index='Position', columns='Amino Acid', values='Probability')
esm2_pivot_df.reset_index(drop=True, inplace=True)
esm2_pivot_df.index = esm2_pivot_df.index + 1  # Convert from 0-index to 1-index

# background_dict = {aa: 1/20 for aa in AA_ALPHABETS}
background_vals = np.array([1/20]*20)
esm2_ic_df = transform_matrix(esm2_pivot_df, from_type='probability', to_type='information', background=background_vals)

In [None]:
title = f'Probabilities in ESM2 Conservation of {name}'
plot_logomaker(esm2_pivot_df.iloc[:, 0:20], title=title, ylim=1, color_name='charge')
plt.savefig(f'/home/moon/projects/AgoAnalysis/esm2/{name}/{name}.{model}.pssm_logo.png')
plt.show()

In [None]:
title = f'Information Content in ESM2 Conservation of {name}'
plot_logomaker(esm2_ic_df.iloc[:, 0:20], title=title, color_name='charge')
# plt.savefig(f'/home/moon/projects/AgoAnalysis/esm2/{name}/{name}.{model}.info_logo.pdf')
plt.show()

In [None]:
### Loop through all the proteins
model = 'esm2_t36_3B_UR50D'
names = ['CeVSRA-1', 'CeHRDE-1', 'CePRG-1', 'HsAgo2', 'HsPIWIL2', 
         'CeCSR-1a', 'MIWI', 'BmSIWI', 'AtAgo', 'CeALG-2', 
         'TtAgo', 'CeALG-1', 'DmPIWI', 'CeSAGO-1', 'PfAgo', 'HsAgo1']

background_vals = np.array([1/20]*20)

for name in names:
    outfile = f'/home/moon/projects/AgoAnalysis/esm2/{name}/{name}.{model}.info_logo.pdf'
    if os.path.exists(outfile):
        continue 

    esm2_conservation_path = f'/home/moon/projects/AgoAnalysis/esm2/{name}/{name}_conservation_{model}.csv.gz'
    with gzip.open(esm2_conservation_path, 'rt') as f:
        esm2_df = pd.read_csv(f, sep=',', index_col=0)
    
    esm2_pivot_df = esm2_df.pivot(index='Position', columns='Amino Acid', values='Probability')
    esm2_pivot_df.reset_index(drop=True, inplace=True)
    esm2_pivot_df.index = esm2_pivot_df.index + 1  # Convert from 0-index to 1-index    
    esm2_ic_df = transform_matrix(esm2_pivot_df, from_type='probability', to_type='information', background=background_vals)

    title = f'Information Content in ESM2 Conservation of {name}'
    plot_logomaker(esm2_ic_df.iloc[:, 0:20], title=title, color_name='charge')
    plt.savefig(outfile)

## Conservation - MSA

In [33]:
ago_id2name = {'sp|O61931|ERGO1_CAEEL': 'CeERGO-1', 'sp|Q8CJG0|AGO2_MOUSE': 'MmAgo2', 'sp|Q9QZ81|AGO2_RAT': 'RnAgo2', 'tr|G5EEH0|G5EEH0_CAEEL': 'CeRDE-1', 'tr|A0A8V0Y222|A0A8V0Y222_CHICK': 'GgAgo2', 'tr|G5EES3|G5EES3_CAEEL': 'CeALG-1', 'sp|Q9H9G7|AGO3_HUMAN': 'HsAgo3', 'sp|Q9UKV8|AGO2_HUMAN': 'HsAgo2', 'tr|O16720|O16720_CAEEL': 'CeALG-2', 'sp|Q9UL18|AGO1_HUMAN': 'HsAgo1', 'tr|Q32KD4|Q32KD4_DROME': 'DmAgo1', 'tr|G5EC94|G5EC94_CAEEL': 'CeALG-3', 'sp|Q9HCK5|AGO4_HUMAN': 'HsAgo4', 'sp|P34681|TAG76_CAEEL': 'CeALG-4', 'tr|Q9XVI3|Q9XVI3_CAEEL': 'CeALG-5', 'sp|Q746M7|AGO_THET2': 'TtAgo', 'sp|Q9SHF3|AGO2_ARATH': 'AtAgo', 'tr|A0A8U0S055|A0A8U0S055_MUSPF': 'MputfAgo2'}
wago_id2name = {'tr|A0A0U1RML5|A0A0U1RML5_CAEEL': 'CeSAGO-2', 'tr|A0A0T7CIX3|A0A0T7CIX3_CAEEL': 'CeSAGO-1', 'tr|A8XRG0|A8XRG0_CAEBR': 'CbrCSR', 'sp|Q09249|YQ53_CAEEL': 'CeHRDE-1', 'tr|Q9XVF1|Q9XVF1_CAEEL': 'CeVSRA-1', 'tr|Q9TXN7|Q9TXN7_CAEEL': 'CeWAGO-10', 'tr|E3M6J3|E3M6J3_CAERE': 'CreCSR', 'tr|H2KZD5|H2KZD5_CAEEL': 'CeCSR-1a', 'sp|Q21691|NRDE3_CAEEL': 'CeNRDE-3', 'tr|A0A2G5U890|A0A2G5U890_9PELO': 'CniCSR', 'tr|A8WQA0|A8WQA0_CAEBR': 'CbrHRDE-1', 'tr|Q86NJ8|Q86NJ8_CAEEL': 'CePPW-1', 'tr|Q9N585|Q9N585_CAEEL': 'CePPW-2', 'sp|Q21770|WAGO1_CAEEL': 'CeWAGO-1', 'sp|O62275|WAGO4_CAEEL': 'CeWAGO-4'}
piwi_id2name = {'sp|Q7Z3Z4|PIWL4_HUMAN': 'HsPIWIL4', 'tr|P90786|P90786_CAEEL': 'CePRG-1', 'sp|Q96J94|PIWL1_HUMAN': 'HsPIWIL1', 'sp|Q8TC59|PIWL2_HUMAN': 'HsPIWIL2', 'sp|Q7Z3Z3|PIWL3_HUMAN': 'HsPIWIL3', 'sp|Q9VKM1|PIWI_DROME': 'DmPIWI', 'sp|A8D8P8|SIWI_BOMMO': 'BmSIWI'}

id2name = {}
id2name.update(ago_id2name)
id2name.update(wago_id2name)
id2name.update(piwi_id2name)

In [34]:
ago_names = ['AtAgo', 'CeALG-1', 'CeALG-2', 'CeALG-3', 'CeALG-4', 
             'CeALG-5', 'CeERGO-1', 'CeRDE-1', 'DmAgo1', 'GgAgo2', 
             'HsAgo1', 'HsAgo2', 'HsAgo3', 'HsAgo4', 'MmAgo2', 
             'MputfAgo2', 'RnAgo2', 'TtAgo']
piwi_names = ['BmSIWI', 'CePRG-1', 'DmPIWI', 'HsPIWIL1', 'HsPIWIL2', 
              'HsPIWIL3', 'HsPIWIL4']
wago_names = ['CbrCSR', 'CbrHRDE-1', 'CeCSR-1a', 'CeHRDE-1', 'CeNRDE-3', 
              'CePPW-1', 'CePPW-2', 'CeSAGO-1', 'CeSAGO-2', 'CeVSRA-1', 
              'CeWAGO-1', 'CeWAGO-10', 'CeWAGO-4', 'CniCSR', 'CreCSR']

In [49]:
### Load esm2 conservation data

from Bio import AlignIO
from pssm_analysis import numeric_encode

target = 'CeVSRA-1'
model = 'esm2_t36_3B_UR50D'

esm2_conservation_path = f'/home/moon/projects/AgoAnalysis/esm2/{target}/{target}_conservation_{model}.csv.gz'

with gzip.open(esm2_conservation_path, 'rt') as f:
    esm2_df = pd.read_csv(f, sep=',', index_col=0)

esm2_pivot_df = esm2_df.pivot(index='Position', columns='Amino Acid', values='Probability')
esm2_pivot_df.reset_index(drop=True, inplace=True)
esm2_array = esm2_pivot_df.to_numpy()

In [50]:
### Load msa.fasta into a matrix

msa_path = f'/home/moon/projects/AgoAnalysis/msa/Argonaute_all.msa.fasta'
with open(msa_path, 'r') as f:
    msa = AlignIO.read(f, 'fasta')

msa_matrix = np.zeros((len(msa), msa.get_alignment_length()), dtype=int)
# record_ids = []
names = []
for i, record in enumerate(msa):
    tmp = id2name[record.id]
    names.append(tmp)
    msa_matrix[i, :] = numeric_encode(str(record.seq))

In [51]:
### Get the corresponding index + original -> msa position mapping

index = names.index(target)
msa_positions = np.where(msa_matrix[index, :] < 20)[0]
# msa_pos = msa_positions[895]

In [52]:
original_positions = np.full(msa.get_alignment_length(), np.nan)
for i, pos in enumerate(msa_positions):
    original_positions[pos] = i + 1

esm2_AA_highest = np.full(msa.get_alignment_length(), np.nan)
for i, pos in enumerate(msa_positions):
    esm2_AA_highest[pos] = np.argmax(esm2_array, axis=1)[i]

conservation_score = np.full(msa.get_alignment_length(), np.nan)
for i, pos in enumerate(msa_positions):
    conservation_score[pos] = esm2_array.max(axis=1)[i]

In [53]:
df = pd.DataFrame(msa_matrix.T, columns=names)
df['msa_pos'] = list(range(1, len(msa_matrix.T)+1))
df['pos'] = original_positions
df['AA'] = df[target].map(lambda x: id2aa[int(x)])
df['esm2_AA'] = [id2aa[int(aa_i)] if not np.isnan(aa_i) else np.nan for aa_i in esm2_AA_highest]
df['conservation_score'] = conservation_score
df['conservation_score'] = df['conservation_score'].map(lambda x: '{:.3f}'.format(x) if not np.isnan(x) else x)
df['count'] = df.apply(lambda row: np.sum(row[names] == row[target]) if row[target] < 20 else np.nan, axis=1)
name_order = ago_names + piwi_names + wago_names
for name in name_order:
    df[name] = df[name].map(lambda x: id2aa[int(x)])
df = df[['msa_pos', 'pos', 'AA', 'esm2_AA', 'conservation_score', 'count']  + name_order]
outpath = f'/home/moon/projects/AgoAnalysis/esm2/{target}/esm2-conservation_msa_comparison_{target}.csv'
df.to_csv(outpath, index=False)

In [56]:
### Loop through all the Argonaute proteins

msa_path = f'/home/moon/projects/AgoAnalysis/msa/Argonaute_all.msa.fasta'
with open(msa_path, 'r') as f:
    msa = AlignIO.read(f, 'fasta')
msa_matrix = np.zeros((len(msa), msa.get_alignment_length()), dtype=int)
names = []
for i, record in enumerate(msa):
    name = id2name[record.id]
    names.append(name)
    msa_matrix[i, :] = numeric_encode(str(record.seq))

model = 'esm2_t36_3B_UR50D'
for target in ago_names + piwi_names + wago_names:
    esm2_conservation_path = f'/home/moon/projects/AgoAnalysis/esm2/{target}/{target}_conservation_{model}.csv.gz'
    if not os.path.exists(esm2_conservation_path):
        print(f'{target} does not have ESM2 conservation data. Skipping.')
        continue
    else:
        print(f'{target} is being processed.')
    with gzip.open(esm2_conservation_path, 'rt') as f:
        esm2_df = pd.read_csv(f, sep=',', index_col=0)
    esm2_pivot_df = esm2_df.pivot(index='Position', columns='Amino Acid', values='Probability')
    esm2_pivot_df.reset_index(drop=True, inplace=True)
    esm2_array = esm2_pivot_df.to_numpy()

    index = names.index(target)
    msa_positions = np.where(msa_matrix[index, :] < 20)[0]
    
    original_positions = np.full(msa.get_alignment_length(), np.nan)
    for i, pos in enumerate(msa_positions):
        original_positions[pos] = i + 1
    esm2_AA_highest = np.full(msa.get_alignment_length(), np.nan)
    for i, pos in enumerate(msa_positions):
        esm2_AA_highest[pos] = np.argmax(esm2_array, axis=1)[i]
    conservation_score = np.full(msa.get_alignment_length(), np.nan)
    for i, pos in enumerate(msa_positions):
        conservation_score[pos] = esm2_array.max(axis=1)[i]

    df = pd.DataFrame(msa_matrix.T, columns=names)
    df['msa_pos'] = list(range(1, len(msa_matrix.T)+1))
    df['pos'] = original_positions
    df['AA'] = df[target].map(lambda x: id2aa[int(x)])
    df['esm2_AA'] = [id2aa[int(aa_i)] if not np.isnan(aa_i) else np.nan for aa_i in esm2_AA_highest]
    df['conservation_score'] = conservation_score
    df['conservation_score'] = df['conservation_score'].map(lambda x: '{:.3f}'.format(x) if not np.isnan(x) else x)
    df['count'] = df.apply(lambda row: np.sum(row[names] == row[target]) if row[target] < 20 else np.nan, axis=1)
    name_order = ago_names + piwi_names + wago_names
    for name in name_order:
        df[name] = df[name].map(lambda x: id2aa[int(x)])
    df = df[['msa_pos', 'pos', 'AA', 'esm2_AA', 'conservation_score', 'count']  + name_order]
    outpath = f'/home/moon/projects/AgoAnalysis/esm2/{target}/esm2-conservation_msa_comparison_{target}.csv'
    df.to_csv(outpath, index=False)

AtAgo is being processed.
CeALG-1 is being processed.
CeALG-2 is being processed.
CeALG-3 does not have ESM2 conservation data. Skipping.
CeALG-4 does not have ESM2 conservation data. Skipping.
CeALG-5 does not have ESM2 conservation data. Skipping.
CeERGO-1 does not have ESM2 conservation data. Skipping.
CeRDE-1 does not have ESM2 conservation data. Skipping.
DmAgo1 does not have ESM2 conservation data. Skipping.
GgAgo2 does not have ESM2 conservation data. Skipping.
HsAgo1 is being processed.
HsAgo2 is being processed.
HsAgo3 does not have ESM2 conservation data. Skipping.
HsAgo4 does not have ESM2 conservation data. Skipping.
MmAgo2 does not have ESM2 conservation data. Skipping.
MputfAgo2 does not have ESM2 conservation data. Skipping.
RnAgo2 does not have ESM2 conservation data. Skipping.
TtAgo is being processed.
BmSIWI is being processed.
CePRG-1 is being processed.
DmPIWI is being processed.
HsPIWIL1 does not have ESM2 conservation data. Skipping.
HsPIWIL2 is being processed.
H

# Coevolution

In [None]:
import bokeh.plotting
from bokeh.models import BasicTicker, PrintfTickFormatter
from bokeh.palettes import viridis, RdBu
from bokeh.transform import linear_cmap
from bokeh.plotting import figure, show, output_file, save

from matplotlib.colors import to_hex
cmap = plt.colormaps["bwr_r"]
bwr_r = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]
cmap = plt.colormaps["gray_r"]
gray = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]
palette = viridis(256)

In [None]:
esm2_coevolution_path = f'/home/moon/projects/AgoAnalysis/esm2/{name}/{name}_coevolution_{model}.csv.gz'

with gzip.open(esm2_coevolution_path, 'rt') as f:
    df = pd.read_csv(f, sep=',', index_col=0)


def get_fasta(infasta):
    for record in SeqIO.parse(infasta, "fasta"):
        return str(record.seq)

infasta = f'/home/moon/projects/AgoAnalysis/esm2/{name}/{name}.txt'
seq = get_fasta(infasta)

In [None]:
df

In [None]:
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="COEVOLUTION",
          x_range=[str(x) for x in range(1,len(seq)+1)],
          y_range=[str(x) for x in range(1,len(seq)+1)][::-1],
          width=800, height=800,
          tools=TOOLS, toolbar_location='below',
          tooltips=[('i', '@i'), ('j', '@j'), ('value', '@value')])

r = p.rect(x="i", y="j", width=1, height=1, source=df,
          fill_color=linear_cmap('value', palette, low=df.value.min(), high=df.value.max()),
          line_color=None)
p.xaxis.visible = False  # Hide the x-axis
p.yaxis.visible = False  # Hide the x-axis
# show(p)

In [None]:
output_file(f"/home/moon/projects/AgoAnalysis/esm2/{name}/{name}_coevolution_{model}_color.html")
save(p)