In [20]:
import pandas as pd

origin_csv = pd.read_csv('/home/eva/0_point_mutation/playground/Chain_LH_Corrected_Mutation_Data.csv')
origin_csv

# Define the canonical amino acid order
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
aa_rank = {aa: i for i, aa in enumerate(AMINO_ACIDS)}

# Apply ranking to the mt_res column
origin_csv['mt'] = origin_csv['mt_res'].map(aa_rank)

# Sort the DataFrame by mt_res_rank (you can also sort by other columns if needed)
tidy_df_sorted = origin_csv.sort_values(by=['chain','pos','mt']).drop(columns=['mt'])

tidy_df_sorted.iloc[:, 1:].to_csv('/home/eva/0_point_mutation/playground/Chain_LH_Corrected_Mutation_Data_sorted.csv', index=False)


In [None]:
import pandas as pd
ablang_csv = pd.read_csv('/home/eva/0_point_mutation/playground_mAb_DMS/results/mab_ablang.csv')
antifold_csv = pd.read_csv('/home/eva/0_point_mutation/playground_mAb_DMS/results/mab_antifold.csv')
esm1f_csv = pd.read_csv('/home/eva/0_point_mutation/playground_mAb_DMS/results/mab_esm1f.csv')
esm1v_csv = pd.read_csv('/home/eva/0_point_mutation/playground_mAb_DMS/results/mab_esm1v.csv')
pyrosetta_csv = pd.read_csv('/home/eva/0_point_mutation/playground_mAb_DMS/results/mab_pyrosetta.csv')
experiment_csv = pd.read_csv('/home/eva/0_point_mutation/playground_mAb_DMS/Chain_LH_Corrected_Mutation_Data_sorted.csv')

ablang_keep = ablang_csv[["sample","chain", "pos", "wt", "mt","delta_log_likelihood_ablang"]]
antifold_keep = antifold_csv[["sample","chain", "pos", "wt", "mt","delta_log_likelihood_antifold"]]
esm1f_keep = esm1f_csv[["sample","chain", "pos", "wt", "mt","delta_log_likelihood_esm1f"]]
esm1v_keep = esm1v_csv[["sample","chain", "pos", "wt", "mt","delta_log_likelihood_esm1v"]]
pyrosetta_keep = pyrosetta_csv[["sample","chain", "pos", "wt", "mt","delta_log_likelihood_pyrosetta"]]
experiment_keep = experiment_csv[["chain", "pos", "wt", "mt","score_experiment"]]

# Merge all dataframes on sample, chain, pos, wt, mt
merged_df = ablang_keep.merge(antifold_keep, on=["sample","chain","pos","wt","mt"], how="outer")
merged_df = merged_df.merge(esm1f_keep, on=["sample","chain","pos","wt","mt"], how="outer")
merged_df = merged_df.merge(esm1v_keep, on=["sample","chain","pos","wt","mt"], how="outer")
merged_df = merged_df.merge(pyrosetta_keep, on=["sample","chain","pos","wt","mt"], how="outer")
merged_df = merged_df.merge(experiment_keep, on=["chain","pos","wt","mt"], how="outer")

# Fill all NA values with 0
merged_df = merged_df.fillna("NA")

# (Optional) sort values
merged_df = merged_df.sort_values(by=["sample", "chain", "pos", "wt", "mt"]).reset_index(drop=True)


KeyError: "['sample', 'wt', 'mt', 'score_experiment'] not in index"

In [None]:
# Step 2: Filter numeric scoring columns
score_columns = [col for col in merged_df.columns if col.startswith("delta") or col.startswith("score")]

# Step 2.5: Drop rows where any score is missing
merged_df = merged_df.dropna(subset=score_columns)

# Step 3: Group by sample and compute correlations
samples = merged_df["sample"].unique()

for sample in samples:
    sample_df = merged_df[merged_df["sample"] == sample][score_columns]
    
    # Skip samples with less than 2 non-NA columns
    if sample_df.shape[0] < 2:
        continue

    corr = sample_df.corr()

    # Step 4: Plot the heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", cbar=True)
    plt.title(f"Correlation Heatmap - Sample: {sample}")
    plt.tight_layout()
    plt.show()


In [None]:
from scipy.stats import zscore

# Step 1: Identify delta columns
delta_cols = [col for col in merged_df.columns if col.startswith("delta")]

# Step 2: Calculate z-score per delta column (across all samples)
# If you want to do this per-sample instead, let me know!
zscore_df = merged_df[delta_cols].apply(zscore)

# Step 3: Rename columns to indicate they're z-scores
zscore_df.columns = [col + "_zscore" for col in zscore_df.columns]

# Step 4: Restore to the original dataframe
merged_with_zscore = pd.concat([merged_df, zscore_df], axis=1)

# Step 5 (Optional): Preview
print(merged_with_zscore.head())


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Step 1: Identify zscore columns
zscore_cols = [col for col in merged_with_zscore.columns if col.endswith("_zscore")]
experiment_col = "score_experiment"

# Step 2: Multiply zscore with experimental score
for col in zscore_cols:
    new_col = f"{col}_multiplied"
    merged_with_zscore[new_col] = merged_with_zscore[col] * merged_with_zscore[experiment_col]


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Step 1: Make sure 'wt' is available (should be from merged_with_zscore)
merged_with_zscore = merged_with_zscore.copy()

# Step 2: Build the product × sample × model data
long_data = []

for col in zscore_cols:
    mult_col = f"{col}_multiplied"
    model_name = col.replace("delta_log_likelihood_", "").replace("_zscore", "")
    
    temp_df = merged_with_zscore.copy()
    temp_df["model"] = model_name
    temp_df["product"] = temp_df[mult_col]
    long_data.append(temp_df[["sample", "pos", "mt", "wt", "model", "product"]])

plot_df = pd.concat(long_data, ignore_index=True)

# Step 3: Plot per sample per model
for sample in plot_df["sample"].unique():
    sample_df = plot_df[plot_df["sample"] == sample]
    
    for model in sample_df["model"].unique():
        model_df = sample_df[sample_df["model"] == model]

        # Build pivot table
        pivot = model_df.pivot(index="mt", columns="pos", values="product")
        pivot_wt = model_df.pivot(index="mt", columns="pos", values="wt")

        if pivot.isna().all().all():
            continue  # skip if no valid data

        # Create mask: True where wt == mt
        mask = pivot_wt == pivot.index[:, np.newaxis]

        # Plot heatmap with gray cells for wt==mt
        plt.figure(figsize=(10, 4))
        ax = sns.heatmap(
            pivot,
            annot=False,
            cmap="bwr",
            center=0,
            linewidths=0.5,
            cbar_kws={'label': 'z-score × Experimental Score'},
            mask=mask,
        )

        # Draw grey boxes manually on wt==mt positions
        for y, mt in enumerate(pivot.index):
            for x, pos in enumerate(pivot.columns):
                if pivot_wt.at[mt, pos] == mt:
                    ax.add_patch(plt.Rectangle((x, y), 1, 1, color='lightgray'))

        plt.title(f"{sample} – {model}")
        plt.xlabel("Position")
        plt.ylabel("Mutant Residue")
        plt.tight_layout()
        plt.show()
