In [2]:
import pandas as pd
import numpy as np
import pickle
import joblib
from pathlib import Path
from Bio import SeqIO
from Bio.Seq import Seq

In [14]:
def overlap_proteingym_with_spurs():
    meta = pd.read_csv("/work/commons/proteingym/DMS_substitutions.csv")
    # keep only single-mutation dataset
    meta = meta[meta['includes_multiple_mutants'] == False].reset_index(drop=True)
    root = Path("/work/yunan/PsiFit/data/proteingym")
    for i, row in meta.iterrows():
        dms_id = row['DMS_id']

        # only process those with SPURS prediction
        spurs_path = Path(f"/work/ziang/spurs_test/SPURS/data/fitness/data_for_prof_luo/{dms_id}_ddg_result.pkl")
        if not spurs_path.exists():
            print(f"SPURS prediction for {dms_id} not found.")
            continue
        spurs_data = joblib.load(spurs_path)
        
        # trim sequence to match PDB range (SPURS prediction is based on PDB)
        pdb_range = row['pdb_range']
        st, ed = map(int, pdb_range.split('-'))
        assert ed - st + 1 == spurs_data.shape[0], f"Length mismatch for {dms_id}"
        seq = row['target_seq'][st-1:ed]  # trim to pdb range

        fitness = pd.read_csv(f"/work/commons/proteingym/DMS_ProteinGym_substitutions/{dms_id}.csv")
        # Update the 'mutant' column to match the trimmed sequence indexing
        # finess['mutant'] format: A23C, need to offset the index by st-1;
        # for example, if st=23, then A23C should be A1C
        # need to handle multiple mutations (e.g., D28E:T32K)
        if st > 1:
            fitness['mutant'] = fitness['mutant'].apply(
                lambda x: ':'.join([f"{mut[0]}{int(mut[1:-1]) - (st - 1)}{mut[-1]}" for mut in x.split(':')])
            )
        # drop `mutated_sequence` and `DMS_score_bin` columns
        fitness = fitness.drop(columns=['mutated_sequence', 'DMS_score_bin'])

        save_dir = root / dms_id
        save_dir.mkdir(parents=True, exist_ok=True)
        fitness.to_csv(save_dir / f"proteingym_dms.tsv", index=False, sep="\t", float_format="%.6f")

        # use BioPython to save wildtype sequence
        wt_record = SeqIO.SeqRecord(Seq(seq), id=dms_id, description="wildtype sequence")
        SeqIO.write(wt_record, save_dir / "wildtype.fasta", "fasta")
        

overlap_proteingym_with_spurs()

SPURS prediction for BRCA2_HUMAN_Erwood_2022_HEK293T not found.
SPURS prediction for CAS9_STRP1_Spencer_2017_positive not found.
SPURS prediction for P53_HUMAN_Giacomelli_2018_Null_Etoposide not found.
SPURS prediction for P53_HUMAN_Giacomelli_2018_Null_Nutlin not found.
SPURS prediction for P53_HUMAN_Giacomelli_2018_WT_Nutlin not found.
SPURS prediction for POLG_HCVJF_Qi_2014 not found.


In [11]:
def reformat_spurs_prediction():
    data_dir = Path("/work/ziang/spurs_test/SPURS/data/fitness/data_for_prof_luo/")
    save_dir = Path("/work/yunan/PsiFit/data/proteingym")
    alphabet = list("ACDEFGHIKLMNPQRSTVWY")
    for dms_dir in sorted(save_dir.iterdir()):
        dms_id = dms_dir.name
        data = joblib.load(data_dir / f"{dms_id}_ddg_result.pkl")
        # create a df using data; row index is position from 1 to data.shape[0], columns are amino acids
        df = pd.DataFrame(data, columns=alphabet, index=np.arange(1, data.shape[0]+1))
        save_path = save_dir / dms_id / "spurs_prediction.tsv"
        df.to_csv(save_path, sep="\t", float_format="%.6f")
reformat_spurs_prediction()

In [12]:
def corr_spurs_proteingym():
    root = Path("/work/yunan/PsiFit/data/proteingym")
    meta = pd.read_csv("/work/commons/proteingym/DMS_substitutions.csv")
    meta = meta[meta['includes_multiple_mutants'] == False].reset_index(drop=True)
    results = []
    for i, row in meta.iterrows():
        dms_id = row['DMS_id']
        dms_dir = root / dms_id
        if not dms_dir.exists():
            continue
        fitness = pd.read_csv(dms_dir / "proteingym_dms.tsv", sep="\t")
        spurs_pred = pd.read_csv(dms_dir / "spurs_prediction.tsv", sep="\t", index_col=0)

        preds = []
        trues = []
        for _, r in fitness.iterrows():
            mut = r['mutant']
            pos = int(mut[1:-1])
            wt_aa = mut[0]
            mut_aa = mut[-1]
            pred = spurs_pred.loc[pos, mut_aa]
            true = r['DMS_score']
            preds.append(pred)
            trues.append(true)
        
        corr = np.corrcoef(preds, trues)[0, 1]
        results.append({'DMS_id': dms_id, 'correlation': corr})
    
    results_df = pd.DataFrame(results)
    print(results_df['correlation'].describe())
corr_spurs_proteingym()

count    142.000000
mean      -0.436181
std        0.224997
min       -0.943388
25%       -0.534814
50%       -0.441555
75%       -0.297483
max        0.361102
Name: correlation, dtype: float64
