# Saliency map for MF974418_1_Crotalus_atrox

- Gradients: Partial derivatives of the output with respect to input features
- Saliencies: Derived importance measures often based on gradients

In [1]:
import numpy as np

In [2]:
saliency_map = np.load('../output_c_atrox/MF974418_1_Crotalus_atrox_saliency.npy')
print(saliency_map.shape)

(122, 20, 2560)


In [3]:
saliency_map[0]

array([[0.13022904, 0.01327325, 0.02823179, ..., 0.0680813 , 0.05770273,
        0.09232409],
       [0.09533709, 0.0301348 , 0.05030434, ..., 0.06476049, 0.03788453,
        0.10734852],
       [0.11689086, 0.01884253, 0.0443464 , ..., 0.05808348, 0.05051094,
        0.10167618],
       ...,
       [0.13128738, 0.00361724, 0.11111882, ..., 0.06084758, 0.0995035 ,
        0.06273454],
       [0.10972849, 0.01102298, 0.05026603, ..., 0.06959357, 0.05160882,
        0.08606771],
       [0.11513422, 0.01747751, 0.06242976, ..., 0.05949955, 0.08072774,
        0.09000763]], shape=(20, 2560))

In [4]:
saliency_map[0].shape

(20, 2560)

In [5]:
# load the sequence from the fasta file 
fasta_path = '../data/MF974418_1_Crotalus_atrox.fasta'
with open(fasta_path) as f: 
    fasta = f.readlines()
    sequence = ''.join(fasta[1:]).replace('\n', '')
print(sequence)

SLVQFETLIMKIAGRSGLLWYSAYGCYCGWGGHGLPQDATDRCCFVHDCCYGKATDCNPKTVSYTYSEENGEIICGGDDPCGTQICECDKAAAICFRDNIPSYDNKYWLFPPKNCREEPEPC


In [6]:
# Check where all elements in the third dimension (=gradients) are zero
all_zero = np.all(saliency_map == 0, axis=2)
zero_indices = np.argwhere(all_zero)

In [7]:
alphabet = "ACDEFGHIKLMNPQRSTVWY"
# check whether for these pairs (x,y), the sequence[x] is equal to alphabet[y]
not_equal = 0
equal = 0   
for i, j in zero_indices:
    if sequence[i] != alphabet[j]:
        not_equal += 1
    else:
        equal += 1
print(not_equal, equal)

0 122


Saliency map seems to be fine. 

Next step: Analyse it

In [22]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

In [None]:
# take the average of the saliency map in the embedding dimension and in the mutation dimension
average_saliency = np.mean(saliency_map, axis=(1, 2))
average_saliency.shape

In [None]:
# annotate the sequence with the average saliency
plt.figure(figsize=(20, 1))
ax = sns.heatmap(
    average_saliency.reshape(1, -1),
    cmap='viridis',
    cbar=True,
    xticklabels=list(sequence),
    yticklabels=False
)

# Adjust the color bar label
colorbar = ax.collections[0].colorbar
colorbar.set_label('Average Saliency', rotation=270, labelpad=20)

# Format the x-axis
plt.xticks(rotation=90)
plt.title("Annotated Sequence with Average Saliency")
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
average_saliency.shape

In [None]:
df = pd.DataFrame({
    "residue": list(sequence),
    "saliency": average_saliency
})
df

In [None]:
# take the average of the saliency map in the embedding dimension and in the mutation dimension
average_saliency = np.mean(saliency_map, axis=(1, 2))

df = pd.DataFrame({
    "residue": list(sequence),
    "saliency": average_saliency
})

# boxplot grouped by residue type
plt.figure(figsize=(12, 6))
sns.boxplot(data=df, x="residue", y="saliency", order=sorted(set(sequence)), boxprops=dict(alpha=.5))
sns.violinplot(data=df, x="residue", y="saliency", order=sorted(set(sequence)), inner=None, alpha=0.7)
plt.title("Distribution of Average Saliency Values per Residue", fontsize=14)
plt.xlabel("Amino Acid", fontsize=12)
plt.ylabel("Average Saliency", fontsize=12)
plt.xticks(rotation=45, fontsize=10)
plt.tight_layout()
plt.show()

In [None]:
# take the average of the saliency map in the embedding dimension and in the mutation dimension
average_saliency = np.mean(saliency_map, axis=(1, 2))

df = pd.DataFrame({
    "residue": list(sequence),
    "saliency": average_saliency
})

# boxplot grouped by residue type
plt.figure(figsize=(12, 6))
sns.boxplot(data=df, x="residue", y="saliency", order=sorted(set(sequence)), boxprops=dict(alpha=.5))
sns.violinplot(data=df, x="residue", y="saliency", order=sorted(set(sequence)), inner=None, alpha=0.7)
plt.title("Distribution of Average Saliency Values per Residue", fontsize=14)
plt.xlabel("Amino Acid", fontsize=12)
plt.ylabel("Average Saliency", fontsize=12)
plt.xticks(rotation=45, fontsize=10)
plt.tight_layout()
plt.show()