In [None]:
import torch
from scripts import load_model, MambaLMHeadModelwithPosids, AA_TO_ID
import numpy as np
import pandas as pd

from config import DATA_DIR

In [None]:
result_file = "ProteinGym_reference_file_substitutions.a3m"
labels = []
with open(result_file, "r") as f:
    for line in f:
        if line.startswith(">"):
            labels.append(line[1:].strip())

In [None]:
len(labels)

In [None]:
labels.index("158")

In [None]:
# for all a3m file in this folder, rename with the name of the first sequence in the a3m file
import os

files = os.listdir("/data2/malbrank/protein_gym/mmseqs_colabfold_protocol")
for file in files:
    if not file.endswith(".a3m"):
        continue
    with open(f"/data2/malbrank/protein_gym/mmseqs_colabfold_protocol/{file}", "r") as f:
        first_line = f.readline()
        prot_name = first_line.split()[0][1:]
    os.rename(f"/data2/malbrank/protein_gym/mmseqs_colabfold_protocol/{file}", f"/data2/malbrank/protein_gym/mmseqs_colabfold_protocol/{prot_name}.a3m")

In [None]:
def build_landscape_from_csv(csv_file, msa_start=1):
    df = pd.read_csv(csv_file)
    msa_len = len(df["mutated_sequence"].loc[0])
    gt_mut_landscape = torch.ones((msa_len, 20)) * np.inf
    for i, row in df.iterrows():
        mut_pos = int(row["mutant"][1:-1]) - msa_start
        mut_aa = row["mutant"][-1]
        eff = float(row["DMS_score"])
        mut_aa_id = AA_TO_ID[mut_aa] - 4
        gt_mut_landscape[mut_pos, mut_aa_id] = eff
    keep_idx = torch.where(gt_mut_landscape != np.inf)
    return gt_mut_landscape, keep_idx

In [None]:
n_tokens = [8000, 16000, 32000, 64000, ]
csv_folder = f"{DATA_DIR}/protein_gym/substitutions/DMS_ProteinGym_substitutions/"
out_folder = f"{DATA_DIR}/protein_gym/mut_effects/"
database_df = pd.read_csv(f"{DATA_DIR}/protein_gym/substitutions/DMS_substitutions.csv")
database_df = database_df[database_df["DMS_number_multiple_mutants"] == 0]
results_df = pd.read_csv(f"{DATA_DIR}/protein_gym/substitutions/DMS_substitutions_Spearman.csv", index_col=0)

In [None]:
results_df["TranceptEVE L"]

In [None]:
select_models = ["TranceptEVE L", "GEMME", "ESM-IF1", "MSA Transformer (single)", "ESM2 (650M)", "EVmutation", "Site-Independent"]

In [None]:
model = load_model("/nvme1/common/mamba_100M_FIM_finetuned_32k_checkpoint-16500",
                   model_class=MambaLMHeadModelwithPosids,
                   device="cuda",
                   dtype=torch.bfloat16).eval()

In [None]:
from scipy.stats import spearmanr
from matplotlib import pyplot as plt


In [None]:
database_df

In [None]:

from tqdm import tqdm_notebook

tests = ["mamba_8000", "mamba_16000", "mamba_32000", "mamba_ft_8000", "mamba_ft_16000", "mamba_ft_32000"]
colors = ["red", "blue", "green", "red", "blue", "green"]
markers = ["o", "o", "o", "x", "x", "x"]
spearmanrs_tests = {}

for j, test in enumerate(tests):
    spearmanrs_tests[test] = []
    for i, row in tqdm_notebook(enumerate(database_df.iterrows())):
        csv_file = row[1]["DMS_filename"]
        msa_start = int(row[1]["MSA_start"])
        gt_mut_landscape, keep_idx = build_landscape_from_csv(csv_folder+csv_file, msa_start)
        dms_id = row[1]["DMS_id"]
        pred_mut_landscape = torch.load(f"{out_folder}/{test}/{dms_id}_landscape.pt")
        spearman = spearmanr(gt_mut_landscape[keep_idx], pred_mut_landscape[keep_idx])[0]
        spearmanrs_tests[test].append(spearman)

In [None]:
spearmanrs_baselines = {}
for model in select_models:
    for i, row in enumerate(database_df.iterrows()):
        dms_id = row[1]["DMS_id"]
        spearman = results_df[model].loc[dms_id]
        if model not in spearmanrs_baselines:
            spearmanrs_baselines[model] = []
        spearmanrs_baselines[model].append(spearman)

In [None]:
# get the mean of the spearman correlations
for test in tests:
    print(test, np.mean(spearmanrs_tests[test]))
for model in select_models:
    print(model, np.mean(spearmanrs_baselines[model]))

In [None]:
for k, v in zip(database_df["DMS_id"], spearmanrs_tests["mamba_ft_8000"]):
    print(k, v)

In [None]:
plt.figure(figsize=(10, 35))

for i, test in enumerate(tests):
    plt.scatter(spearmanrs_tests[test], range(len(database_df)), color=colors[i], marker=markers[i], label=test)
for i, model in enumerate(select_models):
    plt.scatter(spearmanrs_baselines[model], range(len(database_df)), color="black", marker="x", label=model)
plt.legend()
# add grid
plt.grid(axis='x')
plt.yticks(range(len(database_df)), database_df["DMS_id"])
# add ylabels
plt.ylabel("DMS_id")
plt.yticks(range(len(database_df)), database_df["DMS_id"])
plt.xlabel("Spearman correlation")


In [None]:
database_df

In [None]:
model_performances_spearmanr = pd.read_csv("/data2/malbrank/protein_gym/ProteinGym/Detailed_performance_files/Substitutions/Spearman/all_models_substitutions_Spearman_DMS_level.csv", index_col = 0)

In [None]:
# list all files in /data2/malbrank/protein_gym/DMS_msa_files by order of size
import os
msa_files = os.listdir("/data2/malbrank/protein_gym/DMS_msa_files")
msa_files = sorted(msa_files, key=lambda x: os.path.getsize(f"/data2/malbrank/protein_gym/DMS_msa_files/{x}"))

prot_to_msa = {}
for filename in msa_files:
    prot = "_".join(filename.split("_")[:2])
    prot_to_msa[prot] = filename

In [None]:
row