In [None]:
"""Calculate correlations between learned and known features."""

In [None]:
# define desired model, model layer, and window size for analysis

model_acc = "6_2" # ["3_6_1", "6_2", "7_1_2", "7_2_2", "7_3_2"]
layer_name = "conv1d2"
known_window = 20 # [10,20,30,40,50]

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import src.models.utils
import src.utils
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist


class MaskedConv1D(tf.keras.layers.Conv1D):
    def __init__(self, filters, kernel_size, **kwargs):
        super().__init__(filters, kernel_size, **kwargs)
        self.supports_masking = True

In [None]:
alphabet = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
weights = {'0': 1, '1': 2, '-': 3}
batch_size = 32

batch_data = src.models.utils.load_data('../../mobidb-pdb_validation/split_data/out/all_seqs.fasta', 
                                        '../../mobidb-pdb_validation/split_data/out/all_labels.fasta')

batches = src.models.utils.BatchGenerator(batch_data, batch_size, alphabet, weights,
                                          shuffle=False, all_records=True)

seq_fasta = src.utils.read_fasta('../../mobidb-pdb_validation/split_data/out/all_seqs.fasta')

known_feature_dir = f"../generate_maps/out/window_size{known_window}"

# load in names of known features from first file in known_feature_dir
known_feature_names = np.loadtxt(f"{known_feature_dir}/{os.listdir(known_feature_dir)[0]}", dtype = str, max_rows = 1)

In [None]:
# Extract accessions
protein_acc = []
for accession, _, in seq_fasta:
    protein_acc.append(accession.split("|")[0][1:]) # Keep only acc

In [None]:
# Calculate learned features
model_name = f"mobidb-pdb_cnn_{model_acc}"
model_path = f"../../models/{model_name}/out_model/{model_name}.h5"
model = tf.keras.models.load_model(model_path, custom_objects={"MaskedConv1D":MaskedConv1D})

layer = model.get_layer(layer_name)
feature_extractor = tf.keras.Model(inputs=model.inputs, outputs=layer.output)
    
learned_features = []
for input, _, training_weights, in batches:  # Predict method was acting strange, so extract individual batches
    features = feature_extractor(input).numpy()
    features = features[training_weights != 0]  # Drop padding
    learned_features.append(features)

In [None]:
# Construct matrix of learned feature values
learned_feature_array = [array for array in learned_features]
learned_feature_array = np.concatenate(learned_feature_array, axis = 0)
learned_feature_array = learned_feature_array.transpose()

In [None]:
# Load in known features
known_features = []
for acc in protein_acc:
    acc_path = f"{known_feature_dir}/{acc}_feature_map{known_window}.tsv"
    protein_known_features = np.loadtxt(acc_path, skiprows = 1)
    known_features.append(protein_known_features)

In [None]:
# Construct matrix of known feature values
known_feature_array = [array for array in known_features]
known_feature_array = np.concatenate(known_feature_array, axis = 0)
known_feature_array = known_feature_array.transpose()

In [None]:
# Calculate correlation between known and learned features
corr_matrix = np.corrcoef(learned_feature_array, known_feature_array)
corr_matrix = corr_matrix[:128, -37:]

In [None]:
# Generate heatmap
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['font.family'] = 'Arial'
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
matplotlib.rcParams.update({'figure.autolayout': True})

plt.subplots(figsize=(8,12))
plt.title(f"{model_name} Feature Correlations", fontsize=16)
plt.xlabel('Known Features', fontsize = 12)
plt.xlabel('Learned Features', fontsize = 12)

heatmap = sns.heatmap(corr_matrix,
                      vmin=-0.5,
                      vmax=0.5,
                      xticklabels=known_feature_names,
                      yticklabels=2,
                      cmap="RdYlBu_r",
                      cbar_kws={"ticks":np.arange(-0.5,0.6,0.1)})

fig = heatmap.get_figure()

In [None]:
class TreeNode:
    def __init__(self, name=None, children=None):
        self.name = name
        self.children = children if children is not None else []

        
def make_tree(lm):
    num_tips = len(lm) + 1
    nodes = {node_id: TreeNode(name=node_id, children=[]) for node_id in range(num_tips)}
    heights = {node_id: 0 for node_id in range(2*num_tips-1)}
    for idx in range(len(lm)):
        node_id = idx + num_tips
        child_id1, child_id2, distance, _ = lm[idx]
        child1, child2 = nodes[child_id1], nodes[child_id2]
        height1, height2 = heights[child_id1], heights[child_id2]
        child1.length = distance - height1
        child2.length = distance - height2
        parent = TreeNode(name=node_id, children=[child1, child2])
        nodes[node_id] = parent
        heights[node_id] = distance
    tree = nodes[2*(num_tips-1)]
    return tree


def get_tip_order(tree):
    tip_order = []
    stack = [tree]
    while stack:
        node = stack.pop()
        if node.children:
            stack.extend(node.children)
        else:
            tip_order.append(node.name)
    return tip_order


# Identify constant learned features with NaN correlations
nan_rows = []
nonnan_rows = []
for i, isnan in enumerate(np.any(np.isnan(corr_matrix.transpose()), axis=0)):
    if isnan:
        nan_rows.append(i)
    else:
        nonnan_rows.append(i)

# Sort learned features by correlation hierarchy
nan_array = corr_matrix[nan_columns]
nonnan_array = corr_matrix[nonnan_columns]
cdm = pdist(nonnan_array, metric='correlation')
lm = linkage(cdm, method='average')
tree = make_tree(lm)
tip_order = get_tip_order(tree)
array = np.concatenate([nonnan_array[tip_order], nan_array], axis=0)  # Combine sorted non-NaN and NaN features
vext = np.nanmax(np.abs(array))

# Plot results
fig, ax = plt.subplots(figsize=(12.8, 6.4), layout='constrained')
im = ax.imshow(array.transpose(), vmin=-vext, vmax=vext, cmap='RdBu_r',
               aspect='auto', interpolation='none')
ax.set_xlabel('Learned Features')
ax.set_ylabel('Known Features')
ax.set_yticks(range(len(known_feature_names)), known_feature_names, fontsize=10)
ax.set_title(f"{model_name} Feature Correlations")
fig.colorbar(im)
#fig.savefig(f"out/{model_name}_win{known_window}_corr_matrix.png", dpi=300);

In [None]:
# Save heat map
if not os.path.exists('out/'):
    os.mkdir('out/')
    
fig.savefig(f"out/{model_name}_win{known_window}_corr_matrix.svg") 