# Protein language models trained on multiple sequence alignments learn phylogenetic relationships

In [None]:
import os
import pathlib
import itertools
import string
from typing import List, Tuple
import warnings

import tqdm

import numpy as np
from numpy.random import default_rng
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib import cm

from patsy import dmatrices
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=FutureWarning)
    import statsmodels.api as sm

import esm
import torch

from Bio import SeqIO

In [None]:
# Plotting settings
SMALL_SIZE = 50
MEDIUM_SIZE = 60
BIGGER_SIZE = 70
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["times"],
    "font.size": MEDIUM_SIZE,
    "axes.titlesize": BIGGER_SIZE,
    "axes.labelsize": BIGGER_SIZE,
    "figure.titlesize": BIGGER_SIZE,
    "xtick.labelsize": MEDIUM_SIZE,
    "ytick.labelsize": MEDIUM_SIZE,
    "legend.fontsize": MEDIUM_SIZE,
})

In [None]:
SEED = 42
rng = np.random.default_rng(seed=SEED)

In [None]:
## Utilities from https://github.com/facebookresearch/esm
# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)


def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    return sequence.translate(translation)


def read_msa(filename: str, nseq: int) -> List[Tuple[str, str]]:
    """ Reads the first nseq sequences from an MSA file, automatically removes insertions."""    
    return [(record.description, remove_insertions(str(record.seq)))
            for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)]

In [None]:
msas_folder = pathlib.Path("data/Pfam_seed/msa")

pfam_families = [
    "PF00004",
    "PF00005",
    "PF00041",
    "PF00072",
    "PF00076",
    "PF00096",
    "PF00153",
    "PF00271",
    "PF00397",
    "PF00512",
    "PF00595",
    "PF01535",
    "PF02518",
    "PF07679",
    "PF13354"
]

MAX_DEPTH = 500

n_layers = n_heads = 12

## Create dataset of Hamming distance matrices and averages of MSA Transformer column attentions averages

In [None]:
# Create a folder to host Hamming distance matrices and averaged column attention matrices as .npy files
DISTS_FOLDER = pathlib.Path("data/hamming")
ATTNS_FOLDER = pathlib.Path("data/col_attentions")

for folder in [DISTS_FOLDER, ATTNS_FOLDER]:
    if not folder.exists():
        os.mkdir(folder)

In [None]:
msa_transformer, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
msa_transformer = msa_transformer.eval()
msa_batch_converter = msa_alphabet.get_batch_converter()

for pfam_family in tqdm.tqdm(pfam_families):
    dists_path = DISTS_FOLDER / f"{pfam_family}_seed.npy"
    attns_path = ATTNS_FOLDER / f"{pfam_family}_seed_mean-on-cols_symm.npy"

    msa_data = [read_msa(msas_folder / f"{pfam_family}_seed_hmmalign_no_inserts.fasta", MAX_DEPTH)]
    msa_batch_labels, msa_batch_strs, msa_batch_tokens = msa_batch_converter(msa_data)

    depth = msa_batch_tokens.shape[1]
    with torch.no_grad():
        results = msa_transformer(msa_batch_tokens, repr_layers=[12], need_head_weights=True)

    # Compute and save averaged and symmetrized column attention matrices
    attns_mean_on_cols_symm = results["col_attentions"].cpu().numpy()[0, ...].mean(axis=2)
    attns_mean_on_cols_symm += attns_mean_on_cols_symm.transpose(0, 1, 3, 2)
    np.save(attns_path, attns_mean_on_cols_symm)

    # Compute and save Hamming distance matrices
    np.save(dists_path, squareform(pdist(msa_batch_tokens.cpu().numpy()[0, :, 1:], "hamming")))

## Fit one logistic model per MSA

In [None]:
def create_train_test_sets(attns,
                           dists,
                           normalize_dists=False,
                           train_size=0.7,
                           ensure_same_size=False,
                           zero_attention_diagonal=False):
    """Attentions assumed averaged across column dimensions, i.e. 4D tensors"""
    if zero_attention_diagonal:
        attns[:, :, np.arange(attns.shape[2]), np.arange(attns.shape[2])] = 0
    assert attns.shape[2] == attns.shape[3]
    if normalize_dists:
        dists = dists.astype(np.float64)
        dists /= np.max(dists)
    if ensure_same_size:
        dists = dists[:attns.shape[2], :attns.shape[2]]
    assert len(dists) == attns.shape[2]
    depth = len(dists)
    n_layers, n_heads, depth, _ = attns.shape

    # Train-test split
    n_train = int(train_size * depth)
    train_idxs = rng.choice(depth, size=n_train, replace=False)
    split_mask = np.zeros(depth, dtype=bool)
    split_mask[train_idxs] = True

    attns_train, attns_test = attns[:, :, split_mask, :][:, :, :, split_mask], attns[:, :, ~split_mask, :][:, :, :, ~split_mask]
    dists_train, dists_test = dists[split_mask, :][:, split_mask], dists[~split_mask, :][:, ~split_mask]
    
    n_rows_train, n_rows_test = attns_train.shape[-1], attns_test.shape[-1]
    triu_indices_train = np.triu_indices(n_rows_train)
    triu_indices_test = np.triu_indices(n_rows_test)
    
    attns_train = attns_train[..., triu_indices_train[0], triu_indices_train[1]]
    attns_test = attns_test[..., triu_indices_test[0], triu_indices_test[1]]
    dists_train = dists_train[triu_indices_train]
    dists_test = dists_test[triu_indices_test]
    
    attns_train = attns_train.transpose(2, 0, 1).reshape(-1, n_layers * n_heads)
    attns_test = attns_test.transpose(2, 0, 1).reshape(-1, n_layers * n_heads)

    return (attns_train, dists_train), (attns_test, dists_test), (n_rows_train, n_rows_test)

In [None]:
def perform_regressions_msawise(normalize_dists=False,
                                ensure_same_size=False,
                                zero_attention_diagonal=False):
    regr_results = {}
    for pfam_family in tqdm.tqdm(pfam_families):
        dists = np.load(DISTS_FOLDER / f"{pfam_family}_seed.npy")
        attns = np.load(ATTNS_FOLDER / f"{pfam_family}_seed_mean-on-cols_symm.npy")

        ((attns_train, dists_train),
         (attns_test, dists_test),
         (n_rows_train, n_rows_test)) = create_train_test_sets(attns,
                                                               dists,
                                                               normalize_dists=normalize_dists,
                                                               ensure_same_size=ensure_same_size,
                                                               zero_attention_diagonal=zero_attention_diagonal)

        df_train = pd.DataFrame(attns_train,
                                columns=[f"lyr{i}_hd{j}" for i in range(n_layers) for j in range(n_heads)])
        df_train["dist"] = dists_train
        df_test = pd.DataFrame(attns_test,
                               columns=[f"lyr{i}_hd{j}" for i in range(n_layers) for j in range(n_heads)])
        df_test["dist"] = dists_test

        # Carve out the training matrices from the training and testing data frame using the regression formula
        formula = "dist ~ " + " + ".join([f'lyr{i}_hd{j}' for i in range(n_layers) for j in range(n_heads)])
        y_train, X_train = dmatrices(formula, df_train, return_type="dataframe")
        y_test, X_test = dmatrices(formula, df_test, return_type="dataframe")

        binom_model = sm.GLM(y_train, X_train, family=sm.families.Binomial(), cov_type="H0")
        binom_model_results = binom_model.fit(maxiter=200, tol=1e-9)

        y_train = y_train["dist"].to_numpy()
        y_test = y_test["dist"].to_numpy()
        y_pred_train = binom_model_results.predict(X_train).to_numpy()
        y_pred_test = binom_model_results.predict(X_test).to_numpy()

        regr_results[pfam_family] = {
            "bias": binom_model_results.params[0],
            "coeffs": binom_model_results.params.to_numpy()[-n_layers * n_heads:].reshape(n_layers, n_heads),
            "y_train": y_train,
            "y_pred_train": y_pred_train,
            "y_test": y_test,
            "y_pred_test": y_pred_test,
            "depth": dists.shape[0],
            "n_rows_train": n_rows_train,
            "n_rows_test": n_rows_test
            }
    
    return regr_results

In [None]:
regr_results_hamming_msawise = perform_regressions_msawise()

### Plot and analyse the results

In [None]:
def create_dist_comparison_mat(y, y_pred, n_rows):
    assert len(y) == len(y_pred)

    comparison_mat = np.zeros((n_rows, n_rows), dtype=np.float32)
    ct = 0
    for i in range(n_rows):
        for j in range(i, n_rows):
            # Order is important as we want the diagonal to be a prediction
            comparison_mat[i, j] = y[ct]
            comparison_mat[j, i] = y_pred[ct]
            ct += 1
    assert ct == len(y)
    
    return comparison_mat

In [None]:
# Select only Pfam families with depth >= 200 and length >=50
pfam_families_selec = ["PF00004", "PF00271", "PF00512", "PF02518"]
n_selec = len(pfam_families_selec)

In [None]:
cmap = cm.bwr
vpad = 30
x_vals_coeffs = np.arange(0, n_heads, 2)
y_vals_coeffs = np.arange(0, n_layers, 2)

fig, axs = plt.subplots(
    figsize=(40, 10 * n_selec),
    nrows=n_selec,
    ncols=4,
    gridspec_kw={"width_ratios": [10, 3, 10, 10]},
    constrained_layout=True
)

for i, pfam_family in enumerate(pfam_families_selec):
    res = regr_results_hamming_msawise[pfam_family]
    for key in res:
        exec(f"{key} = res['{key}']")

    im = axs[i, 0].imshow(coeffs, norm=colors.CenteredNorm(), cmap=cmap)
    cbar = fig.colorbar(im, ax=axs[i, 0], fraction=0.05, pad=0.03)
    axs[i, 0].set_ylabel(fr"\bf {pfam_family}" + "\nLayer")
    axs[i, 0].set_xticks(x_vals_coeffs)
    axs[i, 0].set_yticks(y_vals_coeffs)
    axs[i, 0].set_xticklabels(list(map(str, x_vals_coeffs + 1)))
    axs[i, 0].set_yticklabels(list(map(str, y_vals_coeffs + 1)))

    axs[i, 1].plot(np.mean(np.abs(coeffs), axis=1),
                   np.arange(n_layers),
                   "-o",
                   markersize=12,
                   lw=5)
    axs[i, 1].invert_yaxis()
    axs[i, 1].set_yticks(y_vals_coeffs)
    axs[i, 1].set_yticklabels(list(map(str, y_vals_coeffs + 1)))
    axs[i, 1].set_xlim([0, 55])
    axs[i, 1].set_ylabel("Layer")

    for j, y, y_pred, n_rows in [(2, y_train, y_pred_train, n_rows_train),
                                 (3, y_test, y_pred_test, n_rows_test)]:
        # 2 is train, 3 is test
        hamming_comparison = create_dist_comparison_mat(y, y_pred, n_rows)
        axs[i, j].imshow(np.triu(hamming_comparison, k=1) + np.tril(np.full_like(hamming_comparison, fill_value=np.nan)),
                         cmap="Blues",
                         vmin=0,
                         vmax=1)
        pos = axs[i, j].imshow(np.tril(hamming_comparison) + np.triu(np.full_like(hamming_comparison, fill_value=np.nan), k=1),
                               cmap="Greens",
                               vmin=0,
                               vmax=1)
    
    axs[i, 2].set_ylabel("Sequence")

axs[0, 0].set_title("Regression coefficients", pad=vpad)
axs[0, 1].set_title("Avg.\ abs.\ coeff.", pad=vpad)
axs[0, 2].set_title("Training", pad=vpad)
axs[0, 3].set_title("Test", pad=vpad)

axs[-1, 0].set_xlabel("Head")
axs[-1, 2].set_xlabel("Sequence")
axs[-1, 3].set_xlabel("Sequence")

plt.show()

In [None]:
df_regr_results_hamming_msawise = pd.DataFrame()
for pfam_family in pfam_families:
    res = regr_results_hamming_msawise[pfam_family]
    for key in res:
        exec(f"{key} = res['{key}']")
    n_samples_train = len(y_train)
    n_samples_test = len(y_test)
    df_regr_results_hamming_msawise.loc[pfam_family, "Depth"] = depth
    
    df_regr_results_hamming_msawise.loc[pfam_family, "mean (training)"] = np.mean(y_train)
    df_regr_results_hamming_msawise.loc[pfam_family, "mean (test)"] = np.mean(y_test)
    df_regr_results_hamming_msawise.loc[pfam_family, "std (training)"] = np.std(y_train)
    df_regr_results_hamming_msawise.loc[pfam_family, "std (test)"] = np.std(y_test)
    
    df_regr_results_hamming_msawise.loc[pfam_family, "RMSE (training)"] = np.linalg.norm(y_train - y_pred_train) / np.sqrt(n_samples_train)
    df_regr_results_hamming_msawise.loc[pfam_family, "RMSE (test)"] = np.linalg.norm(y_test - y_pred_test) / np.sqrt(n_samples_test)
    df_regr_results_hamming_msawise.loc[pfam_family, "MAE (training)"] = np.sum(np.abs(y_train - y_pred_train)) / n_samples_train
    df_regr_results_hamming_msawise.loc[pfam_family, "MAE (test)"] = np.sum(np.abs(y_test - y_pred_test)) / n_samples_test
    
    df_regr_results_hamming_msawise.loc[pfam_family, "R^2 (test)"] = 1 - np.sum((y_test - y_pred_test)**2) / np.sum((y_test - np.mean(y_test))**2)
    pearson = pearsonr(y_test, y_pred_test)[0]
    df_regr_results_hamming_msawise.loc[pfam_family, "Pearson (test)"] = pearson
    df_regr_results_hamming_msawise.loc[pfam_family, "Slope (test)"] = pearson * np.std(y_pred_test) / np.std(y_test)

df_regr_results_hamming_msawise

### Regression coefficients from different MSAs are highly correlated

In [None]:
pearsons_coeffs = []
for i, pfam_family_x in enumerate(pfam_families):
    for pfam_family_y in pfam_families[i + 1:]:
        x = regr_results_hamming_msawise[pfam_family_x]["coeffs"].flatten()
        y = regr_results_hamming_msawise[pfam_family_y]["coeffs"].flatten()
        pearsons_coeffs.append(pearsonr(x, y)[0])
pearsons_coeffs = squareform(np.array(pearsons_coeffs))

In [None]:
# Select only Pfam families with depth >= 100 and length >=30
pfam_families_selec = ["PF00004", "PF00153", "PF00271", "PF00397", "PF00512", "PF01535", "PF02518"]
mask = np.isin(np.array(pfam_families), [pfam_families_selec])
pearsons_coeffs_selec = pearsons_coeffs[mask, :][:, mask]

In [None]:
labels = [pfam_family for i, pfam_family in enumerate(pfam_families) if mask[i]]

fig, ax = plt.subplots(figsize=(15, 12),
                       constrained_layout=True)
im = ax.imshow(np.tril(pearsons_coeffs_selec),
               cmap="Blues",
               aspect="equal",
               vmin=0, vmax=1)
ax.set_yticks(np.arange(len(pearsons_coeffs_selec)),
              labels=labels)
ax.set_xticks(np.arange(len(pearsons_coeffs_selec)),
              labels=labels,
              rotation=45,
              ha="right")
fig.colorbar(im, ax=ax, fraction=0.05, pad=0.04, label="Pearson correlation")

plt.show()

## Fit one common logistic model across MSAs

In [None]:
pfam_families_train = pfam_families[:12]

In [None]:
def perform_regressions_msawise(pfam_families_train):
    df = pd.DataFrame(columns=[f"lyr{i}_hd{j}" for i in range(n_layers) for j in range(n_heads)] + ["dist"], dtype=np.float64)

    for pfam_family in pfam_families_train:
        dists = np.load(DISTS_FOLDER / f"{pfam_family}_seed.npy")
        attns = np.load(ATTNS_FOLDER / f"{pfam_family}_seed_mean-on-cols_symm.npy")

        triu_indices = np.triu_indices(attns.shape[-1])
        attns = attns[..., triu_indices[0], triu_indices[1]]
        dists = dists[triu_indices]
        df2 = pd.DataFrame(attns.transpose(2, 0, 1).reshape(-1, n_layers * n_heads),
                           columns=[f"lyr{i}_hd{j}" for i in range(n_layers) for j in range(n_layers)])
        df2["dist"] = dists
        df = pd.concat([df, df2], ignore_index=True)

    # Carve out the training matrices from the training and testing data frame using the regression formula
    formula = "dist ~ " + " + ".join([f"lyr{i}_hd{j}" for i in range(n_layers) for j in range(n_heads)])
    y_train, X_train = dmatrices(formula, df, return_type="dataframe")

    # Fit the model
    binom_model = sm.GLM(y_train, X_train, family=sm.families.Binomial())
    binom_model_results = binom_model.fit(maxiter=200, tol=1e-9)

    regr_results_hamming_common = {}
    for pfam_family in pfam_families:
        dists = np.load(DISTS_FOLDER / f"{pfam_family}_seed.npy")
        attns = np.load(ATTNS_FOLDER / f"{pfam_family}_seed_mean-on-cols_symm.npy")
        depth = len(dists)

        triu_indices = np.triu_indices(depth)
        attns = attns[..., triu_indices[0], triu_indices[1]]
        dists = dists[triu_indices]
        attns = attns.transpose(2, 0, 1).reshape(-1, n_layers * n_heads)

        df = pd.DataFrame(attns,
                          columns=[f"lyr{i}_hd{j}" for i in range(n_layers) for j in range(n_heads)])
        df["dist"] = dists
        _, X = dmatrices(formula, df, return_type="dataframe")

        y_pred = binom_model_results.predict(X).to_numpy()

        regr_results_hamming_common[pfam_family] = {
            "bias": binom_model_results.params[0],
            "coeffs": binom_model_results.params.to_numpy()[-n_layers * n_heads:].reshape(n_layers, n_heads),
            "y": dists,
            "y_pred": y_pred,
            "depth": depth,
        }

    return regr_results_hamming_common

In [None]:
regr_results_hamming_common = perform_regressions_msawise(pfam_families_train)

### Plot and analyse the results

In [None]:
fig, axs = plt.subplots(figsize=(43, 10),
                        nrows=1,
                        ncols=5,
                        gridspec_kw={"width_ratios": [10, 3, 10, 10, 10]},
                        constrained_layout=True)

coeffs = regr_results_hamming_common[pfam_families[0]]["coeffs"]
im = axs[0].imshow(coeffs, norm=colors.CenteredNorm(), cmap=cmap)
cbar = fig.colorbar(im, ax=axs[0], fraction=0.05, pad=0.03)
axs[0].set_xticks(x_vals_coeffs)
axs[0].set_yticks(y_vals_coeffs)
axs[0].set_xticklabels(list(map(str, x_vals_coeffs + 1)))
axs[0].set_yticklabels(list(map(str, y_vals_coeffs + 1)))

axs[1].plot(np.mean(np.abs(coeffs), axis=1),
            np.arange(n_layers),
            "-o",
            markersize=12,
            lw=5)
axs[1].invert_yaxis()
axs[1].set_yticks(y_vals_coeffs)
axs[1].set_yticklabels(list(map(str, y_vals_coeffs + 1)))
axs[1].set_xticks([0, 10, 20])

axs[0].set_title("Regression coefficients", pad=vpad)
axs[1].set_title("Avg.\ abs.\ coeff.", pad=vpad)

for i, pfam_family in enumerate(pfam_families[-3:]):
    y = regr_results_hamming_common[pfam_family]["y"]
    y_pred = regr_results_hamming_common[pfam_family]["y_pred"]
    n_rows = regr_results_hamming_common[pfam_family]["depth"]

    hamming_comparison = create_dist_comparison_mat(y, y_pred, n_rows)
    axs[2 + i].imshow(np.triu(hamming_comparison, k=1) + np.tril(np.full_like(hamming_comparison, fill_value=np.nan)),
                      cmap="Blues",
                      vmin=0,
                      vmax=1)
    pos = axs[2 + i].imshow(np.tril(hamming_comparison) + np.triu(np.full_like(hamming_comparison, fill_value=np.nan), k=1),
                            cmap="Greens",
                            vmin=0,
                            vmax=1)

    axs[2 + i].set_xlabel("Sequence")

    axs[2 + i].set_title(pfam_family, pad=vpad)

axs[0].set_ylabel("Layer")
axs[0].set_xlabel("Head")
axs[1].set_ylabel("Layer")
axs[2].set_ylabel("Sequence")

plt.show()

In [None]:
df_regr_results_hamming_common = pd.DataFrame()

fig, axs = plt.subplots(figsize=(25, 15),
                        nrows=3,
                        ncols=5,
                        sharex=True,
                        sharey=True,
                        constrained_layout=True)

for i, pfam_family in enumerate(pfam_families):
    df_regr_results_hamming_common.loc[pfam_family, "Depth"] = regr_results_hamming_common[pfam_family]["depth"]
    y = regr_results_hamming_common[pfam_family]["y"]
    y_pred = regr_results_hamming_common[pfam_family]["y_pred"]
    n_samples = len(y)
    df_regr_results_hamming_common.loc[pfam_family, "RMSE"] = np.linalg.norm(y - y_pred) / np.sqrt(n_samples)
    y_std = np.std(y)
    df_regr_results_hamming_common.loc[pfam_family, "Std"] = y_std
    pearson = pearsonr(y, y_pred)[0]
    df_regr_results_hamming_common.loc[pfam_family, "Pearson"] = pearson
    slope = pearson * y_std / np.std(y_pred)
    df_regr_results_hamming_common.loc[pfam_family, "Slope"] = slope
    df_regr_results_hamming_common.loc[pfam_family, "R^2"] = 1 - np.sum((y - y_pred)**2) / np.sum((y - np.mean(y))**2)
    intercept = np.mean(y) - slope * np.mean(y_pred)
    
    axs.flat[i].set_title(pfam_family, fontsize=30)
    axs.flat[i].scatter(y_pred, y, s=1)
    axs.flat[i].axline((0, intercept), slope=slope, linewidth=2, color='r')

    plt.setp(axs.flat[i].get_yticklabels(), fontsize=20)
    plt.setp(axs.flat[i].get_xticklabels(), fontsize=20)

fig.supxlabel("Predicted", fontsize=40)
fig.supylabel("Actual", fontsize=40)

plt.show()

In [None]:
df_regr_results_hamming_common.sort_values("Depth")

#### Training + test MSAs

In [None]:
fig, axs = plt.subplots(figsize=(25, 15),
                        nrows=3,
                        ncols=5,
                        constrained_layout=True)

for i, pfam_family in enumerate(pfam_families):
    y = regr_results_hamming_common[pfam_family]["y"]
    y_pred = regr_results_hamming_common[pfam_family]["y_pred"]
    n_rows = regr_results_hamming_common[pfam_family]["depth"]
    
    hamming_comparison = create_dist_comparison_mat(y, y_pred, n_rows)
    axs.flat[i].imshow(np.triu(hamming_comparison, k=1) + np.tril(np.full_like(hamming_comparison, fill_value=np.nan)),
                       cmap="Blues",
                       vmin=0,
                       vmax=1)
    axs.flat[i].imshow(np.tril(hamming_comparison) + np.triu(np.full_like(hamming_comparison, fill_value=np.nan), k=1),
                       cmap="Greens",
                       vmin=0,
                       vmax=1)
    
    axs.flat[i].set_title(pfam_family, fontsize=30)

    plt.setp(axs.flat[i].get_yticklabels(), fontsize=20)
    plt.setp(axs.flat[i].get_xticklabels(), fontsize=20)

plt.show()