In [None]:
import math
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from datasets import load_dataset
from sklearn.decomposition import PCA
from transformers import BertModel, BertTokenizer

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = load_dataset("Trelis/protein_stability_single_mutation")
df = pd.DataFrame(dataset["test"])

In [5]:
mutation_dict = {}
for WT, group in df.groupby("WT_name"):
    mutation_dict[WT] = {"WTseq":group["base_aa_seq"].iloc[0], "MTseq":group["aa_seq_full"].tolist()}

In [None]:
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model = BertModel.from_pretrained("Rostlab/prot_bert", torch_dtype="auto")
model.to(device)
model.eval()

In [7]:
def ExtractAttention(WTsequences, MTsequences, device, batch = 4):
    WTattn = []
    WTsequence = [' '.join(list(seq)) for seq in WTsequences]
    WTseq_spaced = [' '.join(list(seq)) for seq in WTsequences]
    tokens = tokenizer(WTseq_spaced, return_tensors='pt', padding=True, truncation=True)
    tokens = {k: v.to(device) for k, v in tokens.items()}

    with torch.no_grad():
        outputs = model(**tokens, output_attentions=True)
    attn = torch.stack(outputs.attentions).mean(dim=(0,2))  # [batch, L, L]
    WTattn = [t.cpu() for t in attn]  # move to CPU
    del tokens, outputs
    torch.cuda.empty_cache()

    MTattn = []
    
    for i in range(0, len(MTsequences), batch):
        MTseq = MTsequences[i:i+batch]
        MTsequence = [' '.join(list(seq)) for seq in MTseq]
        tokens = tokenizer(MTsequence, return_tensors='pt', padding=True, truncation=True)
        tokens = {k: v.to(device) for k, v in tokens.items()}

        with torch.no_grad():
            outputs = model(**tokens, output_attentions=True)
        attn = torch.stack(outputs.attentions).mean(dim=(0,2))
        MTattn.extend([t.cpu() for t in attn])
        del tokens, outputs
        torch.cuda.empty_cache()  

    return WTattn, MTattn

In [8]:
def PadAttention(attention):
    max_len = max([t.shape[0] for t in attention])

    padded_attention = []
    for t in attention:
        L = t.shape[0]
        if L == max_len:
            padded_attention.append(t.numpy())
        else:
            pad = np.zeros((max_len, max_len), dtype=np.float32)
            pad[:L, :L] = t.numpy()
            padded_attention.append(pad)

    return np.stack(padded_attention)

In [None]:
cols = 4
num_wt = len(mutation_dict)
rows = math.ceil(num_wt / cols)

fig, axes = plt.subplots(rows, cols, figsize=(cols*5, rows*5))
axes = axes.flatten()

for i, (wt_name, values) in enumerate(mutation_dict.items()):
    WTseq_list = values["WTseq"]   # list of WT sequences
    MTseq_list = values["MTseq"]   # list of mutation sequences

    # Extract attention masks
    WTattn, MTattn = ExtractAttention(WTseq_list, MTseq_list, device, batch = 4)

    all_attn = WTattn + MTattn
    all_attn_padded = PadAttention(all_attn)

    # --- Flatten for PCA ---
    all_attn_flat = [t.flatten() for t in all_attn_padded]
    all_attn_flat = np.stack(all_attn_flat)

    # PCA to 2D
    X_pca = PCA(n_components=2).fit_transform(all_attn_flat)

    # Plot attention shifts in this subplot
    ax = axes[i]
    names = ["original"] + [f"mut_{j}" for j in range(len(MTseq_list))]
    orig_idx = 0
    for j, name in enumerate(names):
        if j == orig_idx:
            ax.scatter(X_pca[j,0], X_pca[j,1], marker="*", s=100, color="red")
        else:
            ax.scatter(X_pca[j,0], X_pca[j,1], marker="o", s=50, color="blue")
            ax.arrow(X_pca[orig_idx,0], X_pca[orig_idx,1],
                     X_pca[j,0]-X_pca[orig_idx,0],
                     X_pca[j,1]-X_pca[orig_idx,1],
                     alpha=0.3, color="gray", linestyle="--")
    ax.set_title(wt_name)
    ax.axis('off')

plt.tight_layout()
plt.show()