## Abstract

Despite the header being `## Abstract`, this section will render as a highlighted section titled *Summary*. Ensure this section is a **maximum** of 280 characters.

----

:::{.callout-note title="AI usage disclosure" collapse="true"}
This is a placeholder for the AI usage disclosure. Once all authors sign the AI code form on Airtable, SlackBot will message you an AI disclosure that you should place here.
:::

## Purpose

Once edited by you, this file will become your publication. Alternatively, if you already have a notebook written that you're trying to transform into a pub, replace this file with your own, but be sure to add the YAML front matter (the first cell) to your notebook.

Your pub should begin with a section titled **Purpose** where you, as briefly as possible, explain why you did the work described in the pub, the key takeaway, your primary audience, and how you think it could be useful to them/why you're sharing it.

## Introduction

The MSA Pairformer achieves state-of-the-art performance in protein structure and function prediction, in large part due to its innovative query-biased attention mechanism [@Akiyama2025]. The authors put forward a compelling hypothesis for this success: by learning to weight sequences based on their evolutionary relevance to a query, the model mitigates phylogenetic averaging and amplifies faint, subfamily-specific signals.

This interpretation is both powerful and likely correct. The model's impressive performance, especially in challenging cases with diverse subfamilies, strongly suggests it captures more than just superficial sequence identity. The natural next step, therefore, is to move from hypothesis to characterization. How precisely does the model's attention mechanism map onto the underlying phylogenetic structure of an MSA? In what regimes does it excel at recapitulating evolutionary history, and where are its limits?

In this work, we explore these questions directly. We begin by performing a deep anatomical study of the paper's original case study—the response regulator family—and then expand our analysis to thousands of MSAs across the tree of life. Our goal is not simply to validate the authors' claim, but to scope out the space—to quantify the relationship between learned weights and genuine tree-based distances, providing a clearer picture of the model's remarkable capabilities.

## Revisiting the response regulator study

The response regulator (RR) family was the original paper's showcase for demonstrating the power of query-biased attention. By constructing a mixed MSA of GerE, LytTR, and OmpR subfamilies, the authors showed their model could successfully identify key structural contacts unique to each lineage, a task where standard methods failed.

They illustrated the mechanism behind this success by plotting learned sequence weights against Hamming distance, revealing that members of the query's subfamily were consistently upweighted.

![Figure 4B from @Akiyama2025. Original caption: "*Median sequence weight across the layers of the model versus Hamming distance to the query sequence. Top panels show distribution of sequence attention weights for subfamily members (red) and non-subfamily sequences (grey). Grey dotted line indicates weights used for uniform sequence attention and red dotted line indicates weight assigned to the query sequence.*"](assets/figure4b.jpg){fig-align="center" width=100% fig-alt="Figure 4B from @Akiyama2025 showing the relationship between median sequence weight and Hamming distance to the query."}

This provides strong evidence for the model's ability to group similar sequences. To build on this finding, we can refine the analysis by replacing the proxy of Hamming distance with a formal phylogenetic tree. This allows us to ask a more nuanced question: Do the learned weights reflect the continuous distances and specific branching patterns of evolutionary history, or do they simply create a binary distinction between "in-group" and "out-group"?

To ground our phylogenetic analysis in the paper's original findings, we must first replicate their specific MSA. We begin by downloading the full PFAM alignments for the GerE, LytTR, and OmpR subfamilies, combining them, and then sampling a final set of 4096 sequences to match the dataset used in the study.

In [None]:
# | echo: false
# | output: false
%env USE_MODAL=1
%load_ext autoreload
%autoreload 2

:::{.callout-note title="Reproducing the response regulator MSAs"}

@Akiyama2025 qualitatively describe how to reproduce the response regulator MSAs, however these details are insufficient for exact replication. The code below is our attempted reproduction, and we find these MSAs yield similar, yet not identical, sequence weight statistics.

In [None]:
# | code-fold: true
# | code-summary: Response regulator MSA code

from collections import Counter
from pathlib import Path

import pandas as pd
from analysis.pfam import download_and_process_response_regulator_msa

from MSA_Pairformer.dataset import MSA

response_regulator_dir = Path("./data/response_regulators")
response_regulator_dir.mkdir(parents=True, exist_ok=True)

rr_msas: dict[str, MSA] = {}

rr_queries = {"1NXS": "OmpR", "4CBV": "LytTR", "4E7P": "GerE"}
for query in rr_queries:
    msa_path = response_regulator_dir / f"PF00072.final_{query}.a3m"

    if not msa_path.exists():
        download_and_process_response_regulator_msa(
            output_dir=response_regulator_dir,
            subset_size=4096,
        )

    rr_msas[query] = MSA(msa_file_path=msa_path, diverse_select_method="none")

example_msa = rr_msas[query]
membership_path = response_regulator_dir / "membership.txt"
target_to_subfamily = (
    pd.read_csv(membership_path, sep="\t").set_index("record_id")["subfamily"].to_dict()
)

print(f"MSA has {len(example_msa.ids_l)} sequences:")

subfamily_member_count = Counter()
for sequence in example_msa.ids_l:
    subfamily_member_count[target_to_subfamily[sequence]] += 1

for subfamily, count in subfamily_member_count.items():
    print(f"  - {subfamily} sequences: {count}")


:::

With the response regulator MSA reconstructed, we can now move beyond simple sequence comparisons and infer the phylogenetic relationships among its members. We use FastTree, a rapid method for approximating maximum-likelihood phylogenies suitable for trees of this size, to build a tree from our alignment [@Price2009].

In [None]:
# | code-fold: true
# | code-summary: Tree inference code

from analysis.tree import read_newick, run_fasttree

fasttree_path = response_regulator_dir / "PF00072.final.fasttree.newick"
msa_for_tree = response_regulator_dir / "PF00072.final.fasta"
if not fasttree_path.exists():
    run_fasttree(msa_path.with_suffix(".fasta"), fasttree_path)

rr_tree = read_newick(fasttree_path)


Visualizing this tree provides our first direct look at the evolutionary structure of the data. We expect to see distinct, well-supported clades corresponding to the three subfamilies. This visual confirmation is a critical first step before we begin to quantify the relationship between the tree's structure and the model's learned attention weights.

In [None]:
# | code-fold: true
# | label: fig-rr-tree
# | fig-cap: A randomly selected subset of the response regulator family illustrating the tree structure. The tree was calculated with FastTree. The RSCB structure ID is labelled for the query sequence of each subfamily. Leaves are colored according to subfamily  ([⬤]{style="color:#5088C5"} [GerE](https://www.ebi.ac.uk/interpro/entry/pfam/PF00196/), [⬤]{style="color:#F28360"} [OmpR](https://www.ebi.ac.uk/interpro/entry/pfam/PF00486/), [⬤]{style="color:#3B9886"} [LytTR](https://www.ebi.ac.uk/interpro/entry/pfam/PF04397/)).

import arcadia_pycolor as apc
from analysis.plotting import tree_style_with_categorical_annotation
from analysis.tree import read_newick, subset_tree

query_colors = {
    "4E7P": apc.aegean,
    "1NXS": apc.amber,
    "4CBV": apc.seaweed,
}
subfamily_colors = {rr_queries[k]: v for k, v in query_colors.items()}

tree_style = tree_style_with_categorical_annotation(
    categories=target_to_subfamily,
    highlight=list(rr_queries),
    color_map=subfamily_colors,
)
visualized_tree = subset_tree(tree=rr_tree, n=100, force_include=list(rr_queries), seed=42)
visualized_tree.render("%%inline", tree_style=tree_style, dpi=300)

Our phylogenetic tree provides the evolutionary "ground truth." Now, we need to generate the model's data to compare against it. We are specifically interested in the sequence attention weights, which, according to the paper's central hypothesis, are the mechanism by which the model captures this evolutionary information.

To probe the model's query-biased behavior, let's run an MSA Pairformer inference on our MSA three separate times. In each run, let's set the representative from each subfamily (GerE, LytTR, and OmpR) as the query. This process will yield three distinct sets of attention weights, each representing the model's perspective from a different starting point in the evolutionary landscape.

:::{.callout-note title="Running MSA Pairformer..."}
In order to reproduce this step of the workflow, you'll need a GPU with at least 40Gb of GPU VRAM. Since we don't assume the hardware that may be available for you, we've stored pre-computed inference results (`data/response_regulators/inference_results.pt`), and by default, the code below will load these inference results rather then attempting to re-compute this expensive operation. If you have the available hardware and want to re-compute the inference, delete this file and the cell below will regenerate them using your GPU.

In [None]:
# | code-fold: true
# | code-summary: MSA inference code

from typing import Any

import torch
from analysis.pairformer import run_inference

inference_results_path = response_regulator_dir / "inference_results.pt"
if inference_results_path.exists():
    rr_inference_results = torch.load(inference_results_path, weights_only=True)
else:
    rr_inference_results: dict[str, dict[str, Any]] = {}
    for query in rr_queries:
        rr_inference_results[query] = run_inference(
            rr_msas[query], return_seq_weights=True, query_only=True
        )

    torch.save(rr_inference_results, inference_results_path)

:::

We now have our two key components: a **phylogenetic tree** for the RR family (our evolutionary "ground truth") and the model's **attention weights** relative to each of the three subfamily queries. Before diving into a formal statistical analysis, let's build an intuition for how the model's attention relates to tree structure by visualizing weights directly onto the tree.

For each of the three queries, let's center our view on a small subset of the full MSA (for ease of visualization) and color at each leaf the median attention weight it received from the model. If the model is indeed capturing evolutionary relevance, we should expect a gradient of attention that follows the tree's branches away from the query.

In [None]:
# | code-fold: true
# | label: fig-tree-weight-paint
# | fig-cap: Trees for each query, where each leaf is colored according to the median sequence weight it received from the model. Darker nodes signify sequences receiving high levels of attention (upweighting), while lighter nodes signify sequence receiving low levels of attention (downweighting). Each tree is subset to 100 sequences sampled from the full tree (includes all subfamilies). To better visualize a gradient, sequences were sampled with a probability inversely proportional to their phylogenetic rank distance from the query, raised to a power of 0.8, which overall gives slight preference for selecting sequences phylogenetically similar to the query.
# | fig-subcap:
# |   - "Median attention weights with respect to 1NXS (OmpR)."
# |   - "Median attention weights with respect to 4CBV (LytTR)."
# |   - "Median attention weights with respect to 4E7P (GerE)."
# | layout-ncol: 1

from analysis.data import get_sequence_weight_data
from analysis.plotting import tree_style_with_scalar_annotation
from analysis.tree import (
    sort_tree_by_reference,
    subset_tree_around_reference,
)
from IPython.display import display

rr_data_dict = dict(
    query=[],
    target=[],
    median_weight=[],
)

for query in rr_queries:
    msa = rr_msas[query]
    targets = msa.ids_l

    weights = get_sequence_weight_data(rr_inference_results[query])

    # For each layer, sequence weights sum to 1. Scaling by number of
    # sequences yields a scale where 1 implies uniform weighting.
    weights *= weights.size(0)

    median_weights = torch.median(weights, dim=1).values

    rr_data_dict["query"].extend([query] * len(targets))
    rr_data_dict["target"].extend(targets)
    rr_data_dict["median_weight"].extend(median_weights.tolist())

response_regulator_df = pd.DataFrame(rr_data_dict)

tree_images = []
queries_list = response_regulator_df["query"].unique()
for query in queries_list:
    color = query_colors[query]
    specific_layer = "median_weight"
    specific_layer_weights = (
        response_regulator_df.loc[
            (response_regulator_df["query"] == query) & (response_regulator_df["query"] != response_regulator_df["target"]), [specific_layer, "target"]
        ]
        .set_index("target")[specific_layer]
        .to_dict()
    )

    gradient = apc.Gradient.from_dict(
        "gradient",
        {"1": "#EEEEEE", "2": "#EEEEEE", "3": color, "4": color},
        values=[0.0, 0.0, 0.75, 1.0],
    )
    tree_style = tree_style_with_scalar_annotation(
        specific_layer_weights, gradient, highlight=[query]
    )
    visualized_tree = sort_tree_by_reference(
        subset_tree_around_reference(tree=rr_tree, n=100, reference=query, bias_power=0.8, seed=42),
        query,
    )
    tree_images.append(visualized_tree.render("%%inline", tree_style=tree_style, dpi=300))

display(*tree_images)

Encouragingly, @fig-tree-weight-paint provides an intuitive picture: weights with respect *1NXS* and *4CBV* visually correlate with distance from the query, suggesting the model's attention mechanism prioritizes evolutionary relatedness. However, the picture is less clear for *4E7P*, which invites a more rigorous, quantitative test.

To formalize this observation, let's extract the *patristic tree distance*[^1] between each of the three queries and every other sequence in the MSA, then compare this to the median sequence weight the model assigned to each sequence. As in @Akiyama2025, we'll normalize weights by the number of sequences, so a value of 1 represents the uniform weighting baseline and a value greater than 1 indicates upweighting. And on the suspicion that this median value might smooth over layer-specific complexity, let's also store the individual weights from each layer to leave room for a more granular, layer-by-layer analysis.

[^1]: Patristic tree distance is the distance between two members of a phylogenetic tree, calculated as the sum of branch lengths connecting them.

In [None]:
# | code-fold: true

from analysis.data import get_sequence_weight_data
from analysis.tree import get_patristic_distance
from scipy.stats import linregress

rr_data_dict = dict(
    query=[],
    target_subfamily=[],
    target=[],
    patristic_distance=[],
    median_weight=[],
)

num_layers = 22
for layer_idx in range(num_layers):
    rr_data_dict[f"layer_{layer_idx}_weight"] = []

for query in rr_queries:
    msa = rr_msas[query]
    targets = msa.ids_l

    patristic_distances = get_patristic_distance(rr_tree, query)
    patristic_distances = patristic_distances[targets]

    weights = get_sequence_weight_data(rr_inference_results[query])

    # For each layer, sequence weights sum to 1. Scaling by number of
    # sequences yields a scale where 1 implies uniform weighting.
    weights *= weights.size(0)

    median_weights = torch.median(weights, dim=1).values

    for layer_idx in range(num_layers):
        rr_data_dict[f"layer_{layer_idx}_weight"].extend(weights[:, layer_idx].tolist())

    rr_data_dict["query"].extend([query] * len(targets))
    rr_data_dict["target_subfamily"].extend(
        [target_to_subfamily.get(target, "Unknown") for target in targets]
    )
    rr_data_dict["target"].extend(targets)
    rr_data_dict["median_weight"].extend(median_weights.tolist())
    rr_data_dict["patristic_distance"].extend(patristic_distances.tolist())

response_regulator_df = pd.DataFrame(rr_data_dict)
response_regulator_df = response_regulator_df.query("query != target").reset_index(drop=True)
response_regulator_df.head()

We can analyze the relationship between sequence weights and patristic distance with a simple linear regression. We'll frame the problem to directly assess the explanatory power of the model's sequence weights: how well can they explain the true evolutionary distance to the query?

For each query $q$ and each target sequence $i$ in the MSA, let's define our model as:

$$
d_{i} = \beta_1^{(l)} w_{i}^{(l)} + \beta_0^{(l)}
$$ {#eq-scalar-regression}

where:

- $d_{i}$ is the patristic distance from the query $q$ to the target sequence $i$.
- $w_{i}^{(l)}$ is the normalized sequence weight assigned to sequence $i$ by a specific layer $l$.
- $\beta_1^{(l)}$ and $\beta_0^{(l)}$ are the slope and intercept for the regression at layer $l$.

Let's perform this regression independently for each of the three queries. For each query, we'll calculate the fit using the median weight across all layers and also for each of the 22 layers individually.

We'll use the coefficient of determination ($R^2$) as the key statistic to measure the proportion of the variance in patristic distance that is explainable from the sequence weights. The following code calculates these regression statistics and generates an interactive plot to explore the relationships.

In [None]:
# | code-fold: true
# | label: fig-rr-interactive
# | fig-cap: An interactive display illustrating sequene weight versus patristic distance for each MSA member to the query. Each subplot represents the sequence weights relative to a different query. The dropdown controls which layer the sequence weights are from. By default, the median sequence weights across all layers are visualized. Black lines indicate the lines of best fit.

import arcadia_pycolor as apc
from analysis.plotting import interactive_layer_weight_plot

regression_data = dict(
    query=[],
    layer=[],
    r_squared=[],
    p_value=[],
    slope=[],
    intercept=[],
)

for query in queries_list:
    query_data = response_regulator_df[response_regulator_df["query"] == query]
    y = query_data["patristic_distance"].values
    x = query_data["median_weight"].values
    result = linregress(x, y)
    regression_data["query"].append(query)
    regression_data["layer"].append("median")
    regression_data["r_squared"].append(result.rvalue**2)
    regression_data["p_value"].append(result.pvalue)
    regression_data["slope"].append(result.slope)
    regression_data["intercept"].append(result.intercept)

    for layer_idx in range(num_layers):
        weight_col = f"layer_{layer_idx}_weight"
        x = query_data[weight_col].values
        result = linregress(x, y)
        regression_data["query"].append(query)
        regression_data["layer"].append(layer_idx)
        regression_data["r_squared"].append(result.rvalue**2)
        regression_data["p_value"].append(result.pvalue)
        regression_data["slope"].append(result.slope)
        regression_data["intercept"].append(result.intercept)

rr_regression_df = pd.DataFrame(regression_data)
rr_regression_df

apc.plotly.setup()
interactive_layer_weight_plot(response_regulator_df, rr_regression_df, rr_queries, subfamily_colors)

First, when viewing the median sequence weights (the default view), we observe a strong negative correlation with patristic distance across all three subfamilies. This provides direct, quantitative support for the original paper's central claim: on average, the model effectively learns to upweight evolutionarily closer sequences and downweight more distant ones.

However, the layer-by-layer analysis uncovers a more nuanced and specialized division of labor within the model. The strength, and even the direction, of this correlation varies dramatically with network depth:

* Strong Phylogenetic Filters: Some layers, such as layer 11, act as powerful phylogenetic filters. They exhibit a strong negative correlation ($R^2 > 0.6$), sharply penalizing sequences as their evolutionary distance from the query increases.
* Alternative or Inverted Signals: In stark contrast, other layers show weak or even positive correlations. Layer 12, for instance, behaves inconsistently. When GerE is the query, it slightly upweights more distant sequences, suggesting it has learned a feature representation that is either independent of or runs counter to simple phylogenetic distance.

This indicates that while the model as a whole successfully captures evolutionary relevance, this complex task is not distributed uniformly. Instead, specific layers appear to specialize in learning the phylogenetic structure of the MSA, while others focus on capturing different kinds of sequence information.

Overall, @fig-rr-interactive shows a clear negative correlation between patristic distance and median sequence weight.


## A survey across the tree of life

Our analysis shows that median sequence weights correlate moderately with phylogenetic distance. More intriguingly, this layer-by-layer view has given us a peek behind the curtain, revealing a complex division of labor in how the model's attention mechanism captures evolutionary relatedness through the query-biased outer product.

To understand how broadly these patterns hold, and to further characterize MSA Pairformer's understanding of evolutionary relationships, we need to see the extent this behavior generalizes by expanding our analysis to thousands of diverse protein families.

To do this, we turn to the OpenProteinSet [@Ahdritz2023], a massive public database of protein alignments. This resource, derived from UniClust30 and [hosted on AWS](https://registry.opendata.aws/openfold/), provides the scale we need to move beyond our single case study.

Inferring phylogenetic trees for all ~270,000 UniClust30 MSAs in the collection would require roughly 10 times the amount of patience most people possess. Furthermore, some of these MSAs would be unsuitable for our analysis for one reason or another. So to whittle this down to a more digestable size, we'll create the following procedure.

:::{.callout-note title="MSA pre-processing workflow"}

First, randomly select 20,000 MSAs from the UniClust30 collection. Then, for each of the 20,000 MSAs, apply the following procedure:

- Diversity selection: Select a diverse subset of up to 1024 sequences from the MSA
    - Do this with the [MSA Pairformer API](https://github.com/yoakiyama/MSA_Pairformer/blob/33083d027788ee4a4295b554e782559b87e58fe5/MSA_Pairformer/dataset.py#L251-L276), which implements the sampling procedure introduced in MSA Transformer [@Rao2021]
- Filter undesirables: Apply several checks that if the MSA does not pass, is discarded:
    - Too shallow (fewer than 200 sequences) for downstream modelling (more on that later).
    - Too long (over 1024 residues), posing computational constraints.
    - Contains duplicate sequence identifiers.

In [None]:
# | code-fold: true
# | output: false
import random

from analysis.open_protein_set import fetch_all_ids, fetch_msas
from analysis.sequence import write_processed_msa
from analysis.utils import progress

uniclust30_dir = Path("data") / "uniclust30"
uniclust30_dir.mkdir(parents=True, exist_ok=True)

msa_ids_path = uniclust30_dir / "ids"
msa_ids = fetch_all_ids(cache_file=msa_ids_path)

random.seed(42)
msa_ids_subset = random.sample(msa_ids, k=20000)

uniclust30_raw_msa_dir = uniclust30_dir / "raw_msas"
uniclust30_raw_msa_dir.mkdir(exist_ok=True)

raw_msa_paths = fetch_msas(msa_ids_subset, db_dir=uniclust30_raw_msa_dir)

max_seq_length = 1024
min_sequences = 200

uniclust30_msa_dir = uniclust30_dir / "msas"
uniclust30_msa_dir.mkdir(exist_ok=True)

skipped_file = uniclust30_dir / "skipped_ids"
if skipped_file.exists():
    skipped_set = set(skipped_file.read_text().strip().split("\n"))
else:
    skipped_set = set()

for id, raw_msa_path in progress(raw_msa_paths.items(), desc="Processing MSAs"):
    msa_path = uniclust30_msa_dir / f"{id}.a3m"

    if msa_path.exists():
        skipped_set.add(id)
        continue

    if id in skipped_set:
        continue

    msa = MSA(
        raw_msa_path,
        max_seqs=1024,
        max_length=max_seq_length + 1,
        diverse_select_method="hhfilter",
        secondary_filter_method="greedy",
    )

    # Skip MSAs containing duplicate deflines. This likely occurs when multi-domain proteins
    # generate multiple alignment hits. Duplicate names would cause tree construction to fail.
    deflines = [msa.ids_l[idx] for idx in msa.select_diverse_indices]
    if len(set(deflines)) != len(deflines):
        skipped_set.add(id)
        continue

    # We simplify verbose deflines from format tr|A0A1V5V6X5|LONG_SUFFIX to just A0A1V5V6X5.
    # In rare cases (~0.5% of MSAs), simplification creates duplicates when both a consensus
    # sequence (tr|ID|ID_consensus) and its non-consensus counterpart (tr|ID|ID_SPECIES)
    # are present in the alignment. Rather than handle this edge case, we skip these MSAs.
    simplified_deflines = [defline.split("|")[1] for defline in deflines]
    if len(set(simplified_deflines)) != len(simplified_deflines):
        skipped_set.add(id)
        continue

    # Skip MSAs exceeding maximum sequence length due to memory constraints
    if msa.select_diverse_msa.shape[1] > max_seq_length:
        skipped_set.add(id)
        continue

    # Skip MSAs with too few sequences to avoid overfitting when modeling
    # patristic distance with all 22 sequence weights.
    if msa.select_diverse_msa.shape[0] < min_sequences:
        skipped_set.add(id)
        continue

    # Write processed MSA to A3M format
    write_processed_msa(msa, msa_path, format="a3m", simplify_ids=True)

_ = skipped_file.write_text("\n".join(skipped_set) + "\n")

msas = {}
for msa_path in progress(sorted(uniclust30_msa_dir.glob("*.a3m")), desc="Loading MSAs"):
    msas[msa_path.stem] = MSA(msa_path, diverse_select_method="none")

print(f"Final MSA count: {len(msas)}")

:::

Just like we did for the response regulator case, let's calculate a tree and sequence weights for each MSA.

In [None]:
# | code-fold: true
# | code-summary: Calculating phylogenies
# | output: false
import asyncio
import os

from analysis.tree import run_fasttree_async

uniclust30_tree_dir = uniclust30_dir / "trees"
uniclust30_tree_dir.mkdir(exist_ok=True)

jobs = []
semaphore = asyncio.Semaphore(os.cpu_count() - 1)
for a3m_path in uniclust30_msa_dir.glob("*.a3m"):
    fasttree_path = uniclust30_tree_dir / f"{a3m_path.stem}.fasttree.newick"
    log_path = uniclust30_tree_dir / f"{a3m_path.stem}.fasttree.log"
    if fasttree_path.exists():
        continue

    jobs.append(run_fasttree_async(a3m_path, fasttree_path, log_path, semaphore))

_ = await asyncio.gather(*jobs)


trees = {}
for tree_path in progress(
    sorted(uniclust30_tree_dir.glob("*.fasttree.newick")),
    desc="Loading trees",
):
    id = tree_path.name.split(".")[0]
    trees[id] = read_newick(tree_path)

In [None]:
# | code-fold: true
# | code-summary: Calculating sequence weights
from analysis.pairformer import calculate_sequence_weights

seq_weights_path = uniclust30_dir / "seq_weights.pt"

# When running on Modal, calculate_sequence_weights serializes MSAs to send to remote GPU workers.
# Serializing all 11k+ MSAs at once exceeds Modal's serialization limits, so we batch into groups
# of 1000. This constraint is specific to Modal's RPC layer. This batching choice is unnecessary
# yet harmless for local execution.
seq_weights = {}
batch_size = 1000

_msa_items = list(msas.items())

if seq_weights_path.exists():
    seq_weights = torch.load(seq_weights_path, weights_only=True)
else:
    for batch_start in progress(
        range(0, len(_msa_items), batch_size), desc="Calculating sequence weights"
    ):
        _batch_msas = dict(_msa_items[batch_start : batch_start + batch_size])
        _batch_weights = calculate_sequence_weights(_batch_msas)
        seq_weights.update(_batch_weights)

    torch.save(seq_weights, seq_weights_path)

:::{.callout-note title="For those following at home" collapse="true"}
After running the above computations, you'll have the primary data in the following directories:

* MSAs: `data/uniclust30/msas`
* trees: `data/uniclust30/trees`
* sequence weights: `data/unitclust30/seq_weights.pt`

These data could be a prime launch point for followup studies.
:::

### Explanatory model across all layers

In our response regulator case study, we performed simple linear regressions for each layer separately. This was a powerful visualization tool, giving us a peek at the "division of labor" by showing how the explanatory power of each layer varies in isolation.

Now, with 11,000 MSAs processed, we can graduate to a more comprehensive question. Instead of asking how well *a single layer* predicts evolutionary distance, we can now ask:

**How well do all 22 layers, when used *jointly*, predict phylogenetic distance?**

This moves us from a set of simple regressions to a single multiple linear regression (per MSA). We must be clear about our goal: this is for **explanation, not prediction**. We are not building a generalizable model for held-out data. We're using this regression as a diagnostic tool to create a single, holistic score that quantifies the *in-sample explanatory power* of the model's complete set of weights.

To formalize this, we define a **weight vector** $\mathbf{w}_i$ for each sequence $i$, which stacks the normalized weights from all $L$ layers. Our model then finds the single coefficient vector $\boldsymbol{\beta}$ that best maps these weights to the patristic distance $d_i$:

$$
\mathbf{w}_i = \begin{bmatrix} w_i^{(1)} \\ w_i^{(2)} \\ \vdots \\ w_i^{(L)} \end{bmatrix} \quad \text{and} \quad d_i = \boldsymbol{\beta}^T \mathbf{w}_i + \beta_0
$$

Because we are using 22 predictors ($k=22$), and some of our MSAs have a low number of sequences ($N$), a standard $R^2$ value could be misleadingly high. We will therefore use the **Adjusted $R^2$ ($R^2_{adj}$)**. This metric penalizes the score for each added predictor, providing a more conservative and honest measure of explanatory power.

This $R^2_{adj}$ score is a **comparative metric**. It's not an absolute measure of performance, but a tool that allows us to finally explore how the model's phylogenetic awareness correlates with MSA properties, like depth, or how much it varies *at* a given depth.

In [None]:
# | code-fold: true
# | code-summary: Summarizing the data into a main table
# | output: false
import numpy as np
import pandas as pd
from analysis.regression import regress_and_analyze_features
from joblib import Parallel, delayed
from scipy.stats import linregress


def process_msa(query: str, dist_to_query: np.ndarray, weights: torch.Tensor) -> dict[str, Any]:
    data = {}

    # Regress the sequence weights against patristic distance.
    # Perform an ANOVA (type III) to establish explanatory importance
    # of each layer's sequence weights.
    model, anova_table = regress_and_analyze_features(weights, dist_to_query)

    data["Query"] = query
    data["Size"] = len(dist_to_query)
    data["R2"] = model.rsquared
    data["Adjusted R2"] = model.rsquared_adj
    data.update(anova_table["percent_sum_sq"].to_dict())

    return data


jobs = []
for query in seq_weights.keys():
    msa = msas[query]
    tree = trees[query]
    weights = seq_weights[query]

    size = len(tree.get_leaf_names())
    if size < 200:
        continue

    dist_to_query = get_patristic_distance(tree, query)[msa.ids_l].values
    jobs.append(delayed(process_msa)(query, dist_to_query, weights))

results_df = pd.DataFrame(Parallel(-1)(jobs))
results_df

In [None]:
from analysis.plotting import ridgeline_r2_plot

ridgeline_r2_plot(results_df, gradient=apc.gradients.verde.reverse())

In [None]:
from analysis.plotting import stacked_feature_importance_plot

stacked_feature_importance_plot(results_df, gradient=apc.gradients.verde.reverse())

(describe, describe, etc)

* MSA depth (aka tree size) does not explain much of the variance in performance.
    * This is likely due to MSA Pairformer training dataset representing a broad range MSA depths
* Tree size is a zeroth order parameter to describe tree topology
* For trees a given size, the topology can be complex
* Papers like [@Janzen2024] quantify tree topology characterization using dozens of metric
    * They come to important conclusion that tree statistics are not easily normalizable by tree size
        * This justifies characterizing MSA Pairformer with a tree-size normalized analysis
        * We'll downsample to 200.
    * They propose 3 statistics for joint use: Colless, cherry count, and phylogenetic diversity
    * These metrics are global tree measures, but how the query is situated within the topology is equally important
        * We'll add another set of measures related the Patristic distance distribution moments relative to query.

In [None]:
from analysis.sequence import filter_msa_by_tree, write_fasta_like
from analysis.tree import write_newick

uniclust30_msa_depth_200_dir = uniclust30_dir / "msas_depth_200"
uniclust30_msa_depth_200_dir.mkdir(parents=True, exist_ok=True)

uniclust30_trees_depth_200_dir = uniclust30_dir / "trees_depth_200"
uniclust30_trees_depth_200_dir.mkdir(parents=True, exist_ok=True)

msas_depth_200 = {}
trees_depth_200 = {}

for query, msa in progress(msas.items()):
    tree = trees[query]
    if len(tree.get_leaf_names()) < 200:
        continue

    msa_depth_200_path = uniclust30_msa_depth_200_dir / f"{query}.a3m"
    tree_depth_200_path = uniclust30_trees_depth_200_dir / f"{query}.fasttree.newick"

    if not msa_depth_200_path.exists() or not tree_depth_200_path.exists():
        tree_subset = subset_tree(tree, n=200, force_include=[query], seed=42)
        write_fasta_like(*filter_msa_by_tree(msa, tree_subset), msa_depth_200_path)
        write_newick(tree_subset, tree_depth_200_path)

    msas_depth_200[query] = MSA(msa_depth_200_path, diverse_select_method="none")
    trees_depth_200[query] = read_newick(tree_depth_200_path)

In [None]:
# | code-fold: true
# | code-summary: Calculating sequence weights
seq_weights_depth_200_path = uniclust30_dir / "seq_weights_depth_200.pt"

# When running on Modal, calculate_sequence_weights serializes MSAs to send to remote GPU workers.
# Serializing all 11k+ MSAs at once exceeds Modal's serialization limits, so we batch into groups
# of 1000. This constraint is specific to Modal's RPC layer. This batching choice is unnecessary
# yet harmless for local execution.
seq_weights_depth_200 = {}
batch_size = 1000

_msa_items = list(msas_depth_200.items())

if seq_weights_depth_200_path.exists():
    seq_weights_depth_200 = torch.load(seq_weights_depth_200_path, weights_only=True)
else:
    for batch_start in progress(
        range(0, len(_msa_items), batch_size), desc="Calculating sequence weights"
    ):
        _batch_msas = dict(_msa_items[batch_start : batch_start + batch_size])
        _batch_weights = calculate_sequence_weights(_batch_msas)
        seq_weights_depth_200.update(_batch_weights)

    torch.save(seq_weights_depth_200, seq_weights_depth_200_path)

In [None]:
# | code-fold: true
# | code-summary: Summarizing the data into a main table
# | output: false

import pandas as pd
from analysis.regression import regress_and_analyze_features
from analysis.tree import (
    cherry_count_statistic,
    colless_statistic,
    phylogenetic_diversity_statistic,
)
from ete3 import Tree
from joblib import Parallel, delayed


def process_msa(query: str, dist_to_query: np.ndarray, tree: Tree, weights: torch.Tensor) -> dict[str, Any]:
    data = {}

    num_seqs = len(dist_to_query)

    # Regress the sequence weights against patristic distance.
    # Perform an ANOVA (type III) to establish explanatory importance
    # of each layer's sequence weights.
    model, anova_table = regress_and_analyze_features(weights, dist_to_query)

    data["Query"] = query
    data["R2"] = model.rsquared
    data["Adjusted R2"] = model.rsquared_adj
    data["Mean patristic"] = dist_to_query.mean()
    data["Closest quartile patristic"] = np.sort(dist_to_query)[:int(num_seqs // 4)].mean()
    data["Phylogenetic diversity"] = phylogenetic_diversity_statistic(tree)
    data["Colless"] = colless_statistic(tree)
    data["Cherry count"] = cherry_count_statistic(tree)
    data.update(anova_table["percent_sum_sq"].to_dict())

    return data


jobs = []
for query in seq_weights_depth_200.keys():
    msa = msas_depth_200[query]
    tree = trees_depth_200[query]
    weights = seq_weights_depth_200[query]

    dist_to_query = get_patristic_distance(tree, query)[msa.ids_l].values
    jobs.append(delayed(process_msa)(query, dist_to_query, tree, weights))

results_depth_200_df = pd.DataFrame(Parallel(-1)(jobs))
results_depth_200_df

In [None]:
results_depth_200_df

In [None]:
results_depth_200_df.to_csv("results.tsv", sep="\t", index=False)

## Experimental stuff below

In [None]:
import numpy as np
from analysis.data import get_sequence_weight_data
from analysis.plotting import tree_style_with_scalar_annotation
from analysis.tree import (
    get_patristic_distance,
    sort_tree_by_reference,
    subset_tree_around_reference,
)
from scipy.stats import linregress


def plot_tree_with_weights(id: str, n: int = 400):
    tree = trees[id]
    msa = msas[id]
    weights_tensor = seq_weights[id]

    targets = msa.ids_l
    weights = weights_tensor * weights_tensor.size(0)
    median_weights = torch.median(weights, dim=0).values

    weight_dict = {
        target: weight.item() for target, weight in zip(targets, median_weights, strict=False)
    }

    gradient = apc.Gradient.from_dict(
        "gradient",
        {"1": "#EEEEEE", "2": "#EEEEEE", "3": apc.seaweed, "4": apc.seaweed},
        values=[0.0, 0.0, 0.75, 1.0],
    )
    tree_style = tree_style_with_scalar_annotation(weight_dict, gradient, highlight=[id])
    visualized_tree = sort_tree_by_reference(
        subset_tree_around_reference(tree=tree, n=n, reference=id, bias_power=0.8, seed=42),
        id,
    )

    return visualized_tree.render("%%inline", tree_style=tree_style, dpi=300)


for query in list(seq_weights.keys())[:3]:
    print(query)
    plot_regression(query).show()
    # plot_tree_with_weights(query)

In [None]:
query = "A0A022Y0Q6"
plot_tree_with_weights(query)

In [None]:
query = "A0A009GC83"
plot_tree_with_weights(query)

In [None]:
from analysis.tree import (
    cherry_count_statistic,
    colless_statistic,
    phylogenetic_diversity_statistic,
)

for _, tree in progress(trees.items()):
    phylogenetic_diversity_statistic(tree)
    colless_statistic(tree)
    cherry_count_statistic(tree)

In [None]:
from analysis.regression import regress_sequence_weights_on_patristic_distance

id = "A0A009GC83"

msa = msas[id]
tree = trees[id]
weights_tensor = seq_weights[id]

targets = msa.ids_l
patristic_distances = get_patristic_distance(tree, id)
patristic_distances = patristic_distances[targets]

model, r2 = regress_sequence_weights_on_patristic_distance(weights_tensor.cpu(), patristic_distances.values)
r2

In [None]:
len(tree.get_leaves())

In [None]:
np.abs(model.coef_)

In [None]:
plot_regression("A0A2I1VP33")

In [None]:
fix_these = []
for query in seq_weights.keys():
    msa = msas[query]
    tree = trees[query]
    try:
        dist = get_patristic_distance(tree, query)[msa.ids_l].values
    except:
        fix_these.append(query)

In [None]:
len(fix_these)

In [None]:
import asyncio
import os

from analysis.tree import run_fasttree_async

uniclust30_tree_dir = uniclust30_dir / "trees"
uniclust30_tree_dir.mkdir(exist_ok=True)

inputs_for_tree_fixing = [Path(f"./data/uniclust30/msas/{query}.a3m") for query in fix_these]

jobs = []
semaphore = asyncio.Semaphore(os.cpu_count() - 1)
for a3m_path in inputs_for_tree_fixing:
    fasttree_path = uniclust30_tree_dir / f"{a3m_path.stem}.fasttree.newick"
    log_path = uniclust30_tree_dir / f"{a3m_path.stem}.fasttree.log"

    if fasttree_path.exists():
        fasttree_path.unlink()
    if log_path.exists():
        log_path.unlink()

    jobs.append(run_fasttree_async(a3m_path, fasttree_path, log_path, semaphore))

_ = await asyncio.gather(*jobs)

In [None]:
for name, tree in trees.items():
    if len(tree.get_leaf_names()) == 1024:
        print(name)
        break

In [None]:
from analysis.plotting import tree_style_with_highlights
import matplotlib.pyplot as plt

query = "A0A015LDC8"
tree = trees[query]

images = []
for power in np.linspace(-3, 3, 10):
    subset = sort_tree_by_reference(subset_tree_around_reference(tree, 100, query, bias_power=power), query)
    distances = get_patristic_distance(subset, query).values
    print(distances.mean())
    plt.hist(distances, bins=20)
    plt.show()
    #images.append(subset.render("%%inline", tree_style=tree_style_with_highlights([query])))
#display(*images)

In [None]:
kktree.get_leaf_names()

In [None]:
tree = trees[query]
subset_tree