In [None]:
#interpret variant prediciton from esm

import pandas as pd


In [None]:
# read in labeled data

def read_labeled(in_data):
    
    in_df = pd.read_csv(in_data, index_col = 0)
    print(in_df)
    return in_df
    
read_acrIIa4_v1 = read_labeled("") # provide esm output matrix
    

In [None]:
# compute entropy of model predictions

from matplotlib import pyplot as plt
import numpy as np
from scipy.stats import entropy

def compute_entropy(in_df, cutoff):
    
    for index, row in in_df.iterrows():

        preds = row[["esm1v_t33_650M_UR90S_1", "esm1v_t33_650M_UR90S_2", "esm1v_t33_650M_UR90S_3", "esm1v_t33_650M_UR90S_4", 
                    "esm1v_t33_650M_UR90S_5"]].values.astype(np.float64) # predictions from ensemble
        entrpy = entropy(preds)
        if entrpy < cutoff:
            print(index)
        in_df.loc[index, "inv_entropy"] = 1 / entrpy
            
    plt.hist(in_df["inv_entropy"], bins = 100)
    plt.show()

    filtered_df = in_df[in_df["inv_entropy"] >= cutoff]

    return filtered_df

acrIIa4_compute = compute_entropy(read_acrIIa4_v1, 0)
    

In [None]:
# plot heatmap of model scores where column is position and index is amino acid substition change

import re
import seaborn as sns

def plt_hmaps(in_df, ref_seq, cutoff):

    grep_str = r"(\D)([0-9]+)(\D)"
    
    aas = ["G", "P", "A", "V", "L", "I", "M", "C", "F", "Y", "W", "H",
          "K", "R", "Q", "N", "E", "D", "S", "T"]
    
    sum_score_mut_df = pd.DataFrame(columns = range(cutoff, len(ref_seq)), index = aas)
    entrpy_score_mut_df = pd.DataFrame(columns = range(cutoff, len(ref_seq)), index = aas)
        
    for index, row in in_df.iterrows():

        find = re.findall(grep_str, index)[0]
        wt_seq = find[0]
        pos = int(find[1])
        mut_seq = find[2]
        sum_scores = row["esm1v_t33_650M_UR90S_1"] + row["esm1v_t33_650M_UR90S_2"] 
        + row["esm1v_t33_650M_UR90S_3"] + row["esm1v_t33_650M_UR90S_4"] #+ row["esm1v_t33_650M_UR90S_5"]
        if pos >= cutoff:
            sum_score_mut_df.loc[mut_seq, pos] = sum_scores
            entrpy_score_mut_df.loc[mut_seq, pos] = row["inv_entropy"]
            
    mask_1 = sum_score_mut_df.isnull()
    fig, ax = plt.subplots(figsize=(20, 5))
    plot_sum_df = sum_score_mut_df.fillna(0)
    sns.heatmap(plot_sum_df, mask=mask_1, cmap="crest")
    
    plt.show()
    mask_2 = entrpy_score_mut_df.isnull()
    plt_sum_df_2 = entrpy_score_mut_df.fillna(0)
    fig, ax = plt.subplots(figsize=(20, 5))
    sns.heatmap(plt_sum_df_2, cmap="viridis", mask=mask_2)
    
    return entrpy_score_mut_df
    
plt_acrIIa4 = plt_hmaps(acrIIa4_compute, "MNINDLIREIKNKDYTVKLSGTDSNSITQLIIRVNNDGNEYVISESENESIVEKFISAFKNGWNQEYEDEEEFYNDMQTITLKSELN", 2)


In [None]:
## transpose entropy df and plot sum of inverse entropies at each position to show hotspots (can plot above heatmap)

def plt_scatter(in_df):

    t_df = in_df.T.fillna(0)
    row_sum = t_df.sum(axis = 1)
    row_max = t_df.max(axis = 1)
    t_df["sum"] = row_sum
    t_df["max"] = row_max
    plt.hist(t_df["sum"])
    plt.show()
    plt.hist(t_df["max"])
    plt.show()
    t_df.to_csv("") # write matrix to file for downstream analysis
    fig, ax = plt.subplots(figsize=(20, 5))
    plt.scatter(t_df.index, t_df["sum"])
    plt.plot(t_df.index, t_df["sum"])
    
    plt.show()
    
    fig, ax = plt.subplots(figsize=(20, 5))
    plt.scatter(t_df.index, t_df["max"])
    plt.plot(t_df.index, t_df["max"])

acrIIa4_scatter = plt_scatter(plt_acrIIa4)
