In [None]:
import os
from collections import defaultdict
import pandas as pd
import numpy as np
from pyrosetta import *
from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover

# === SETTINGS === #
input_folder = "/groups/sbinlab/zqt390/Thesis_Project/af2_files/output_1AO7/" #folder where the pdbs of mutations is
wt_pdb_path = "/groups/sbinlab/beatriz/pipeline/thesis/pipeline_results/6vjj_crystal/relax/output/output.pdb" #wild type pdb
output_csv = "/groups/sbinlab/beatriz/pipeline/thesis/test_interface/interface_ddg_6vjj.csv" #output data frame
monomer_chain = "A" # Chain where mutations are located

# === Score Function === #
init(extra_options="-ex1 -ex2aro -corrections:beta_nov16_cart true")

from pyrosetta import create_score_function

scorefxn = create_score_function('beta_nov16_cart')

def get_interface_dG(pdb_path, chain_of_interest):
  pose = pose_from_pdb(pdb_path)
  chains = list(set([pose.pdb_info().chain(i+1) for i in range(pose.size())]))
  partner_chains = ''.join([c for c in chains if c != chain_of_interest])
  interface_str = f"{chain_of_interest}_{partner_chains}"
  iam = InterfaceAnalyzerMover(interface_str)
  iam.set_scorefunction(scorefxn)
  iam.set_pack_separated(True)
  iam.set_pack_rounds(5)
  iam.apply(pose)
  return iam.get_interface_dG()


# === Get WT dG === #
wt_dG = get_interface_dG(wt_pdb_path, monomer_chain)
print(f"Wildtype interface dG: {wt_dG:.2f}")

# === Process Mutant Files === #
ddg_data = defaultdict(list)

for file in os.listdir(input_folder):
  if file.startswith("MUT_") and file.endswith(".pdb"):
    file_path = os.path.join(input_folder, file)
    name_core = file.replace("MUT_", "").replace(".pdb", "")
    mutation_id = name_core.split("_bj")[0] # group replicates by mutation

    try:
      dG = get_interface_dG(file_path, monomer_chain)
      ddg_data[mutation_id].append(dG)
    except Exception as e:
      print(f"Failed on {file}: {e}")

# === Calculate ddG === #
final_rows = []

for mutation, dG_list in ddg_data.items():
  if len(dG_list) == 0:
    continue
  mean_dG = np.mean(dG_list)
  ddG = round(mean_dG - wt_dG, 3)
  final_rows.append({"mutation": mutation, "mean_dG": mean_dG, "ddG": ddG})

# === Save === #
df = pd.DataFrame(final_rows)
df.sort_values("ddG", inplace=True)
df.to_csv(output_csv, index=False)
print(f"Results saved to: {output_csv}")