# Setup

In [None]:
# !pip install py3Dmol
!pip install biopython

In [None]:
import json
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
import requests # for pulling info from urls
import Bio.PDB

In [None]:
# mount drive (if using colab)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# set data dir
DATA_DIR = "data"
assert os.path.exists(DATA_DIR)

# Stability Data Wrangling
NOTE: These cells do not need to be run again if you already have the cleaned/combined data stored

### Data cleanup
Load data for ThermomutDB and FireProtDB and remove any entries that have missing values for critical columns

In [None]:
FIREPROT_PTH = os.path.join(DATA_DIR, "fireprotdb_results.csv")
fireprot_df = pd.read_csv(FIREPROT_PTH)

THERMOMUT_PTH = os.path.join(DATA_DIR, "thermomutdb.json")
thermomut_file = open(THERMOMUT_PTH)
thermomut_df = json.load(thermomut_file)
thermomut_df = pd.DataFrame(thermomut_df)

In [None]:
# remove entries that are not single-point mutation (fireprot already only stores only single-point)
thermomut_df = thermomut_df[thermomut_df['mutation_type'] == 'Single']

In [None]:
def drop_nas(ess_cols, df: pd.DataFrame):
  for ec in ess_cols:
    df = df[df[ec].notna()]
  return df

In [None]:
essential_fire_cols = ["chain", "ddG", "pdb_id", "position", "interpro_families", "mutation", "uniprot_id", "wild_type"]
fireprot_df = drop_nas(essential_fire_cols, fireprot_df)

essential_therm_cols = ["mutated_chain", "ddg", "mutation_code", "PDB_wild", "uniprot"]
thermomut_df = drop_nas(essential_therm_cols, thermomut_df)

In [None]:
# these fireprot entries were found to have sequence entries inconsistent with UniProtDB
fireprot_df = fireprot_df[fireprot_df['uniprot_id'] != "P60174"]

In [None]:
# for some reason thermomut has 8 entries with "-" as the uniprot id
# get rid of these
uni_lens = thermomut_df['uniprot'].map(lambda s: len(s))
thermomut_df = thermomut_df[uni_lens == 6]

In [None]:
# make interpro_families entries lists
fireprot_df['interpro_families'] = fireprot_df['interpro_families'].map(lambda s: s.split("|"))

In [None]:
print(len(fireprot_df))
print(len(thermomut_df))

### Extract Mutation Info from ThermoMutDB
Thermomut stores mutation info as a 3-digit code xNy where x is the original amino acid, N is the position in the sequence where the mutation occurs, and y is the amino acid in the mutated sequence.

In [None]:
# remove erroneous mutation codes
# (shouldn't have > 2 letters or < 1 digit)

letter_cnt = thermomut_df['mutation_code'].map(lambda s: sum(c.isalpha() for c in s))
digit_cnt = thermomut_df['mutation_code'].map(lambda s: sum(c.isdigit() for c in s))

thermomut_df = thermomut_df[letter_cnt == 2]
thermomut_df = thermomut_df[digit_cnt > 0]

In [None]:
thermomut_df['wild_type'] = thermomut_df['mutation_code'].map(lambda s: s[0])
thermomut_df['position'] = thermomut_df['mutation_code'].map(lambda s: int(s[1:-1]))
thermomut_df['mutant_type'] = thermomut_df['mutation_code'].map(lambda s:s[-1])

In [None]:
i = 500
print(thermomut_df['mutation_code'].iloc[i])
print(thermomut_df['wild_type'].iloc[i], thermomut_df['position'].iloc[i], thermomut_df['mutant_type'].iloc[i])

### Filling in Sequence Info for ThermoMutDB

In [None]:
def get_seq(target_url: str):
  response = requests.get(target_url)
  data = response.text
  # sequence starts on line 2
  seq = data[data.index('\n')+1:]
  seq = "".join(seq.split())
  return seq

def get_thermomut_seqs(therm_df: pd.DataFrame, save_path: str):
  thermomut_ids = set(therm_df['uniprot'])
  # map uniprot id to sequence info
  thermomut_seqs = {}
  bad_entries = []
  for id in tqdm(thermomut_ids):
    url = "https://rest.uniprot.org/uniprotkb/%s.fasta" % id
    try:
      thermomut_seqs[id] = get_seq(url) 
    except ValueError:
      # obsolete or deleted entries, will toss these out
      bad_entries.append(id)

  # obsolete: M5A5Y8, A0A410ZNC6
  print("\nSequences not found for:", bad_entries)

  # save sequence info obtained from uniprot
  np.save(save_path, thermomut_seqs)
  print("Saved thermomutdb sequence info to:", save_path)
  return thermomut_seqs

In [None]:
def load_thermomut_seqs(save_file: str):
  if not(os.path.exists(save_file)):
    # have to load information from uniprot database, takes ~2-3 mins
    print("Given file not found, loading seqs from uniprot...")
    return get_thermomut_seqs(thermomut_df, save_file)
  thermomut_seqs = np.load(save_file, allow_pickle=True).item()
  print("Loaded sequence info from '%s' successfully" % save_file)
  return thermomut_seqs

In [None]:
# load sequence info
save_file = os.path.join(DATA_DIR, "thermomut_seqs.npy")
thermomut_seqs = load_thermomut_seqs(save_file)

In [None]:
# update thermomut_df base on loaded info
thermomut_df['sequence'] = thermomut_df['uniprot'].map(thermomut_seqs)

In [None]:
# validate that sequences were loaded properly by checking w/ overlap fireprot data
thermomut_ids = set(thermomut_df['uniprot'])
fireprot_ids = set(fireprot_df['uniprot_id'])
overlap = thermomut_ids & fireprot_ids
for id in overlap:
  b = thermomut_df[thermomut_df['uniprot'] == id]['sequence'].iloc[0] == fireprot_df[fireprot_df['uniprot_id'] == id]['sequence'].iloc[0]
  if not(b):
    print(id)

### Removing nonsense mutations in databases
Several entries have mutations where wild-type doesn't match position

In [None]:
# get rid of entries where wild-type doesn't match given info
# see where mut-position is > sequence
def filter(df: pd.DataFrame):
  # remove any entries where position > sequence length
  seq_lens = df['sequence'].map(lambda x: len(x))
  positions = df['position']
  filt_df = df[positions <= seq_lens]

  # remove entries where position doesn't match wild-type
  # this for some reason is mainly a thermomut problem
  w_types = filt_df['wild_type']
  positions = filt_df['position']

  seqs = filt_df['sequence']
  s_types = [x[p-1] for x, p in zip(seqs, positions)]
  wrong_df = filt_df[w_types != s_types]
  filt_df = filt_df[w_types == s_types]
  return filt_df, wrong_df

In [None]:
thermomut_df, excluded_therm_df = filter(thermomut_df)
fireprot_df, excluded_fire_df = filter(fireprot_df)

In [None]:
print(len(excluded_therm_df))
print(len(thermomut_df))
print(len(fireprot_df))
print(len(excluded_fire_df))

In [None]:
entry = excluded_fire_df.iloc[1]
print(entry['uniprot_id'])
print(entry['sequence'][entry['position'] - 1])
print(entry['wild_type'], entry['position'], entry['mutation'])

In [None]:
entry = excluded_therm_df.iloc[4]
print(entry['mutation_code'])
print(entry['uniprot'])
print(entry['sequence'][entry['position'] - 1])
print(entry['sequence'])

other_uni_entries = excluded_therm_df[entry['uniprot'] == excluded_therm_df['uniprot']]
print(len(set(other_uni_entries['sequence'])))

### Fill in Interpro Family Info for ThermoMutDB 
This code based on the InterPro program-friendly API:

https://www.ebi.ac.uk/interpro/result/download/#/entry/InterPro/protein/UniProt/|json 

In [None]:
# standard library modules
import sys, errno, re, json, ssl
from urllib import request
from urllib.error import HTTPError
from time import sleep

def get_interpro_fams(start_url:str):
  #disable SSL verification to avoid config issues
  context = ssl._create_unverified_context()

  next = start_url
  last_page = False

  
  #json header
  # sys.stdout.write("{ \"results\": [\n")
  
  attempts = 0
  result = []
  while next:
    try:
      req = request.Request(next, headers={"Accept": "application/json"})
      res = request.urlopen(req, context=context)
      # If the API times out due a long running query
      if res.status == 408:
        # wait just over a minute
        sleep(61)
        # then continue this loop with the same URL
        continue
      elif res.status == 204:
        #no data so leave loop
        break
      payload = json.loads(res.read().decode())
      next = payload["next"]
      attempts = 0
      if not next:
        last_page = True
    except HTTPError as e:
      if e.code == 408:
        sleep(61)
        continue
      else:
        # If there is a different HTTP error, it wil re-try 3 times before failing
        if attempts < 3:
          attempts += 1
          sleep(61)
          continue
        else:
          sys.stderr.write("LAST URL: " + next)
          raise e

    for i, item in enumerate(payload["results"]):
      # sys.stdout.write(json.dumps(item))
      result.append(item['metadata']['accession'])
      # for indented output replace the above line with the following
      # sys.stdout.write(json.dumps(item, indent=4))
      # for 1 record per line uncomment the following line
      # sys.stdout.write("\n")

      """
      if last_page and i+1 == len(payload["results"]):
        sys.stdout.write("")
      else:
        sys.stdout.write(",\n")
      """
      
    # Don't overload the server, give it time before asking for more
    if next:
      sleep(1)

  #json footer
  # sys.stdout.write("\n] }\n")
  return result

In [None]:
def load_interpro_fams(uni_ids, save_file:str):
  """
  Load Protein family information from Interpro database
  :param uni_ids: set of strings, each string is a uniprot id
  :return: dict, maps id to list of families
  """
  if os.path.exists(save_file):
    family_map = np.load(save_file, allow_pickle=True).item()
    print("Successfully loaded file from:", save_file)
    return family_map
  # query info from interpro and save to .npy file
  family_map = {}
  print("Loading data from Interpro...")
  for id in tqdm(uni_ids):
    id_url = "https://www.ebi.ac.uk:443/interpro/api/entry/InterPro/protein/UniProt/%s/?page_size=200" % id
    family_map[id] = get_interpro_fams(id_url)
  np.save(save_file, family_map)
  print("\nSaved data to:", save_file)
  return family_map

In [None]:
therm_ids = set(thermomut_df['uniprot'])
save_file = os.path.join(DATA_DIR, "thermomut_interpro_fams.npy")
therm_fams = load_interpro_fams(therm_ids, save_file)

In [None]:
thermomut_df['interpro_families'] = thermomut_df['uniprot'].map(therm_fams)

### Fix ddG values for FireProt
FireProt considers destabilizing mutations to have ddG > 0, but other tools in this work consider destabilizing mutations to be ddG < 0

In [None]:
fireprot_df['ddG'] = -1.0 * fireprot_df['ddG'] 

### Combine Data

In [None]:
useful_csv = "data\col_mapping.csv"
col_map_df = pd.read_csv(csv_export_url)

In [None]:
# cols to keep for each dataset
useful_cols_df = col_map_df[col_map_df['useful'] == 1]
fire_keep_cols = useful_cols_df['fireprot']
therm_keep_cols = useful_cols_df['thermomut']

fire_filt_df = fireprot_df[fire_keep_cols]
therm_filt_df = thermomut_df[therm_keep_cols]

In [None]:
# change thermomut col names to match fireprot
col_map = pd.Series(useful_cols_df.fireprot.values,index=useful_cols_df.thermomut).to_dict()
therm_filt_df = therm_filt_df.rename(columns=col_map)

In [None]:
fire_filt_df = fire_filt_df.assign(db_origin='fireprot')
therm_filt_df = therm_filt_df.assign(db_origin='thermomut')

In [None]:
# if the attributes below all match another entry then it's a duplicate
match_cols = ['uniprot_id', 'position', 'wild_type', 'mutation']

In [None]:
# drop duplicates from each df
# for duplicates keep only the first entry in the database
therm_filt_df = therm_filt_df.drop_duplicates(subset=match_cols, keep='first')
fire_filt_df = fire_filt_df.drop_duplicates(subset=match_cols, keep='first')

In [None]:
# get unique entries in thermomut
# code taken from: 
# https://stackoverflow.com/questions/44706485/how-to-remove-rows-in-a-pandas-dataframe-if-the-same-row-exists-in-another-dataf 
a = therm_filt_df
b = fire_filt_df

a_index = a.set_index(match_cols).index
b_index = b.set_index(match_cols).index
mask = ~a_index.isin(b_index)
therm_filt_unique = a.loc[mask]

In [None]:
# combine fireprot entries with entries unique to thermomut
comb_df = pd.concat([fire_filt_df, therm_filt_unique])

In [None]:
# validate that all sequences are the same for a given id
uniprot_ids = set(comb_df['uniprot_id'])
for ui in uniprot_ids:
  unique_seqs = set(comb_df[comb_df['uniprot_id'] == ui]['sequence'])
  if len(unique_seqs) != 1:
    print("failed at: %s" % ui)

In [None]:
# save to csv file
comb_save_pth = os.path.join(DATA_DIR, "combined_cleaned.csv")
comb_df.to_csv(comb_save_pth, index=False)

### Add Small Prot Dataset
Due to computational limitations will only use smaller proteins (<= 400 residues)

In [None]:
MAX_RESIDUES = 400

In [None]:
comb_save_pth = os.path.join(DATA_DIR, "combined_cleaned.csv")
prot_df = pd.read_csv(comb_save_pth)

In [None]:
# sequence lengths
seq_lens = prot_df['sequence'].map(lambda x : len(x))
# ignore large proteins (~500 are larger than 1000 residues)
small_prots = prot_df.loc[seq_lens <= MAX_RESIDUES, :]
small_prots['sequence_length'] = seq_lens.loc[seq_lens <= MAX_RESIDUES]

In [None]:
print('min length:', min(small_prots['sequence_length']))
print('max length:', max(small_prots['sequence_length']))
print('n entries:', len(small_prots))

In [None]:
# save to csv file
small_comb_save_pth = os.path.join(DATA_DIR, "small_combined_cleaned.csv")
small_prots.to_csv(small_comb_save_pth, index=False)

# ESMFold Setup
Heavy-lifting done on CARC, here just prepping FASTA files for input

### Create FASTA Files
Generate FASTA files for usage with ESMFold

In [None]:
def gen_fasta_file(fpath: str, seqs: list, labels: list):
  """
  Generate a FASTA file with the given params. Each sequence labeled
  by uniprot id.
  :param str fpath: path to save fasta file to
  :param list seqs: sequences to write
  :param labels: corresponding labels for each sequence
  """
  with open(fpath, "w") as ofile:
    for seq, id in zip(seqs, labels):
      ofile.write(">" + id + "\n" + seq + "\n")
  print("Sucessfully saved given seqs/ids to:", fpath)

In [None]:
comb_save_pth = os.path.join(DATA_DIR, "small_combined_cleaned.csv")
prot_df = pd.read_csv(comb_save_pth)

### Sample shorter Seqs for time benchmarking

In [None]:
def get_sample_seqs(length: int, n_samples: int=1, eps: int=20):
  """
  Get a protein sequence from prot_df of the desired length (+-eps)
  """
  seq_lens = prot_df['sequence'].map(lambda x : len(x))
  sample_df = prot_df[seq_lens >= (length - eps)]
  sample_df = sample_df[seq_lens <= (length + eps)]
  if len(sample_df) == 0:
    # no entries of desired length found
    return None
  sampled_entry = sample_df.sample(n=n_samples, random_state=42).iloc[0]
  return sampled_entry['sequence'], sampled_entry['uniprot_id']

In [None]:
desired_lens = np.arange(50, 450, 50)
found_uni_ids = []
found_seqs = []
for length in desired_lens:
  sample_seqs, ids = get_sample_seqs(length, n_samples=2)
  if sample_seqs is not None:
    found_seqs.append(sample_seqs)
    found_uni_ids.append(ids)

In [None]:
print(len(found_seqs))
print(found_uni_ids)

For ESMFold

In [None]:
pth = os.path.join(DATA_DIR, "benchmarking_seqs.fasta")
gen_fasta_file(pth, found_seqs, found_uni_ids)

### Full Sequences

In [None]:
# create dir for fasta files
FASTA_DIR = os.path.join(DATA_DIR, "esm_fastas")
WILDTYPE_DIR = os.path.join(FASTA_DIR, "wildtypes")
MUT_DUR = os.path.join(FASTA_DIR, "mutants")
dirs = [FASTA_DIR, WILDTYPE_DIR, MUT_DUR]
for d in dirs:
  os.makedirs(d, exist_ok=True)

In [None]:
def get_length_n_seqs(df: pd.DataFrame, seq_lens: pd.Series, lower_b: int, upper_b: int):
  """
  Get all protein sequences from df of desired length (lower_b <= length < upper_b)
  """
  ret_df = df.loc[seq_lens.between(lower_b, upper_b, inclusive='left')]
  return ret_df

In [None]:
def save_wildtypes(df: pd.DataFrame, fpath: str):
  uni_ids = list(set(df['uniprot_id']))
  wild_seqs = []
  # this is needed bc a small minority of uni-ids share the same sequence
  for ui in uni_ids:
    seq = df[df['uniprot_id'] == ui]['sequence'].iloc[0]
    wild_seqs.append(seq)
  assert len(wild_seqs) == len(uni_ids)
  print("Num. wild-seqs:", len(uni_ids))
  gen_fasta_file(fpath, wild_seqs, uni_ids)
  return len(uni_ids)


def get_mut_seq(seq: str, position: int, mut_res: str, wild_res: str):
  if position >= len(seq) or seq[position] != wild_res:
    raise ValueError("Invalid, given position was %d", position)
  return seq[:position] + mut_res + seq[position+1:]

def save_mutants(df: pd.DataFrame, fpath: str):
  mutant_seqs = df.apply(lambda x: get_mut_seq(x['sequence'], x['position'] - 1, x['mutation'], x['wild_type']), axis=1)
  mutant_seqs = list(mutant_seqs)
  labels = df.apply(lambda x: "%s_%s%d%s" % (x['uniprot_id'], x['wild_type'], x['position'], x['mutation']), axis=1)
  print("Num. mutant seqs:", len(mutant_seqs))
  gen_fasta_file(fpath, mutant_seqs, labels)
  return len(mutant_seqs)

In [None]:
seq_lens = prot_df['sequence_length']
max_len = max(seq_lens)
print(max_len)

# chunk data into different lengths, 
# so prots with 0-50 residues go into one file, 50-100 another, etc
skip = 50
bounds = np.arange(0, max_len+skip, skip)
mut_tot = 0
wild_tot = 0
for i in range(1, len(bounds)):
  lb = bounds[i-1]
  ub = bounds[i]
  print("Range: [%d, %d)" % (lb, ub))
  s_df = get_length_n_seqs(prot_df, seq_lens, lb, ub)
  fname = "length_%d_%d.fasta" % (lb, ub - 1)
  wild_fpath = os.path.join(WILDTYPE_DIR, fname)
  mut_fpath = os.path.join(MUT_DUR, fname)
  wild_tot += save_wildtypes(s_df, wild_fpath)
  mut_tot += save_mutants(s_df, mut_fpath)
  print("-"*80)

In [None]:
expected_wild_tot = len(set(prot_df['uniprot_id']))
expected_mut_tot = len(prot_df)
print(expected_wild_tot)
print(expected_mut_tot)
print(wild_tot)
print(mut_tot)

In [None]:
uids = list(set(prot_df['uniprot_id']))
for i in tqdm(range(len(uids))):
  x = uids[i]
  for j in range(i+1, len(uids)):
    y = uids[j]
    x_seq = set(prot_df[prot_df['uniprot_id'] == x]['sequence'])
    y_seq = set(prot_df[prot_df['uniprot_id'] == y]['sequence'])
    both = x_seq & y_seq
    if len(both) > 0:
      print(x, y)

In [None]:
# see where mut-position is > sequence
p_df = prot_df[prot_df['sequence_length'] > prot_df['position']]
w_types = p_df['wild_type']
positions = p_df['position']

seqs = p_df['sequence']
s_types = [x[p-1] for x, p in zip(seqs, positions)]

bad_locs = p_df[w_types != s_types]
good_locs = p_df[w_types == s_types]

print(len(bad_locs))
print(np.unique(good_locs['db_origin'], return_counts=True))
print(np.unique(bad_locs['db_origin'], return_counts=True))

# AlphaFold
Download wild-type predictions from AlphaFold database

In [None]:
small_comb_save_pth = os.path.join(DATA_DIR, "small_combined_cleaned.csv")
small_prot_df = pd.read_csv(small_comb_save_pth)

In [None]:
# adapted from:
# https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests

def download_alphafold_pdb(url: str, uniprot_id: str, save_pth: str):
  try:
    with requests.get(url, stream=True) as r:
      r.raise_for_status()
      with open(save_pth, 'wb') as f:
        for chunk in r.iter_content(chunk_size=8192): 
          f.write(chunk)
    return True
  except requests.exceptions.HTTPError:
    # protein not found in database
    return False

In [None]:
# setup directory to save predictions to
ALPHAFOLD_DIR = os.path.join(DATA_DIR, "alphafold_preds", "wildtypes")
os.makedirs(ALPHAFOLD_DIR, exist_ok=True)

# uniprot ids not found in database
not_found = [] 

# save all wild-type predictions from alphafold db
uniprot_ids = set(prot_df['uniprot_id'])
for uid in tqdm(uniprot_ids):
  url_target = "https://alphafold.ebi.ac.uk/files/AF-%s-F1-model_v4.pdb" % uid
  fname = "%s.pdb" % uid
  save_pth = os.path.join(ALPHAFOLD_DIR, fname)
  found = download_alphafold_pdb(url_target, uid, save_pth)
  if not(found):
    not_found.append(uid)

In [None]:
print(sorted(not_found))
print(len(not_found))

In [None]:
filt_prot_df = prot_df.loc[prot_df['uniprot_id'].isin(not_found), :]
print(len(filt_prot_df))
print(sorted(set(filt_prot_df['uniprot_id'])))

# Generate fastas for HH-Suite
Heavy-lifting done on CARC, here just prepping FASTA files for input

In [None]:
def gen_hh_fasta_file(fpath: str, seq: str, label: str):
  """
  Generate a FASTA file with the given params. Each sequence labeled
  by uniprot id.
  :param str fpath: path to save fasta file to
  :param str seqs: sequence to write
  :param str label: label for sequence
  """
  with open(fpath, "w") as ofile:
    ofile.write(">" + label + "\n" + seq + "\n")
  return True # success

In [None]:
def get_length_n_seqs(df: pd.DataFrame, seq_lens: pd.Series, lower_b: int, upper_b: int):
  """
  Get all protein sequences from df of desired length (lower_b <= length < upper_b)
  """
  ret_df = df.loc[seq_lens.between(lower_b, upper_b, inclusive='left')]
  return ret_df

In [None]:
def save_wildtypes(df: pd.DataFrame, save_dir: str):
  uni_ids = list(set(df['uniprot_id']))
  wild_seqs = []
  for ui in uni_ids:
    seq = df[df['uniprot_id'] == ui]['sequence'].iloc[0]
    wild_seqs.append(seq)
  assert len(wild_seqs) == len(uni_ids)
  print("Num. wild-seqs:", len(uni_ids))
  for seq, uid in zip(wild_seqs, uni_ids):
    fpath = os.path.join(save_dir, uid + ".fasta")
    gen_hh_fasta_file(fpath, seq, uid)
  return len(uni_ids)


def get_mut_seq(seq: str, position: int, mut_res: str, wild_res: str):
  if position >= len(seq) or seq[position] != wild_res:
    raise ValueError("Invalid, given position was %d", position)
  return seq[:position] + mut_res + seq[position+1:]

def save_mutants(df: pd.DataFrame, save_dir: str):
  mutant_seqs = df.apply(lambda x: get_mut_seq(x['sequence'], x['position'] - 1, x['mutation'], x['wild_type']), axis=1)
  mutant_seqs = list(mutant_seqs)
  labels = df.apply(lambda x: "%s_%s%d%s" % (x['uniprot_id'], x['wild_type'], x['position'], x['mutation']), axis=1)
  print("Num. mutant seqs:", len(mutant_seqs))
  for seq, label in zip(mutant_seqs, labels):
    fpath = os.path.join(save_dir, label + ".fasta")
    gen_hh_fasta_file(fpath, seq, label)
  return len(mutant_seqs)

In [None]:
# create dir for fasta files
FASTA_DIR = os.path.join(DATA_DIR, "hhsuite_fastas")
WILDTYPE_DIR = os.path.join(FASTA_DIR, "wildtypes")
MUT_DUR = os.path.join(FASTA_DIR, "mutants")
dirs = [FASTA_DIR, WILDTYPE_DIR, MUT_DUR]
for d in dirs:
  os.makedirs(d, exist_ok=True)

In [None]:
comb_save_pth = os.path.join(DATA_DIR, "small_combined_cleaned.csv")
prot_df = pd.read_csv(comb_save_pth)

In [None]:
seq_lens = prot_df['sequence_length']
max_len = max(seq_lens)
print(max_len)

# chunk data into different lengths, 
# so prots with 0-50 residues go into one file, 50-100 another, etc
skip = 50
bounds = np.arange(0, max_len+skip, skip)
mut_tot = 0
wild_tot = 0
for i in range(1, len(bounds)):
  lb = bounds[i-1]
  ub = bounds[i]
  print("Range: [%d, %d)" % (lb, ub))
  s_df = get_length_n_seqs(prot_df, seq_lens, lb, ub)
  subdir = "length_%d_%d" % (lb, ub - 1)
  wild_tot += save_wildtypes(s_df, WILDTYPE_DIR)
  mut_tot += save_mutants(s_df, MUT_DUR)
  print("-"*80)

# PDB Prediction Data
Gather plDDt scores and other relevant info for predictions from ESMFold and AlphaFold

### Get average pLDDT values from predictions

In [None]:
ALPHAFOLD_DIR = os.path.join(DATA_DIR, "alphafold")
ALPHA_WILDTYPE_DIR = os.path.join(ALPHAFOLD_DIR, "alphafold_preds", "wildtypes")

ESMFOLD_DIR = os.path.join(DATA_DIR, "esmfold")
ESM_WILDTYPE_DIR = os.path.join(ESMFOLD_DIR, "esm_preds", "wildtypes")
ESM_MUTANT_DIR = os.path.join(ESMFOLD_DIR, "esm_preds", "mutants")

In [None]:
def get_struct(pdb_path: str) -> Bio.PDB.Structure:
  """
  Load a pdb file from the given path into a Bio.PDB.Structure object
  :return: Bio.PDB.Structure object
  """
  builder = Bio.PDB.Polypeptide.PPBuilder()
  parser = Bio.PDB.PDBParser(QUIET=True)
  struct = parser.get_structure('Structure', pdb_path)
  return struct

In [None]:
# credit to ChatGPT for writing this function
def get_plddts(struct: Bio.PDB.Structure):
    """
    Extracts the plddt values from an AlphaFold PDB structure.
    :param struct: A Bio.PDB.Structure object.
    :return: A list of floats representing the plddt values in `struct`.
    """
    # get the first model
    model = struct[0]


    # create an empty list to store the pLDDT values
    plddt_values = []

    # iterate over each residue in the model
    for residue in model.get_residues():
        # get the B-factor value for the residue
        bfactor = residue["CA"].get_bfactor()

        # append the B-factor value to the pLDDT values list
        plddt_values.append(bfactor)

    return plddt_values

In [None]:
def get_avg_plddts(inp_dir: str) -> list:
  """
  Get a list of average plddt values for each pdb structure in a given directory
  :param str inp_dir: input directory containing pdbs
  :return: list[float], average plddt values for each pdb in the directory
  """
  avg_plddts = []
  pdb_files = []
  for pdb_file in tqdm(os.listdir(inp_dir)):
    if not(pdb_file.endswith(".pdb")):
      continue # don't want to try to get plddt of log files and such
    pdb_path = os.path.join(inp_dir, pdb_file)
    struct = get_struct(pdb_path)
    avg_plddt = np.mean(get_plddts(struct))
    avg_plddts.append(avg_plddt)
    pdb_files.append(pdb_file)
  return avg_plddts, pdb_files

In [None]:
alpha_wild_plddts, alpha_wild_pdb_files = get_avg_plddts(ALPHA_WILDTYPE_DIR)

In [None]:
# had stored pdbs in subdirs based on seq length, hence the loops
def iterate_subdirs(dir: str):
  plddts = []
  pdb_files = []
  for subdir in os.listdir(dir):
    sub_plddts, sub_pdb_files = get_avg_plddts(os.path.join(dir, subdir))
    plddts += sub_plddts
    pdb_files += sub_pdb_files
  return plddts, pdb_files

esm_wild_plddts, esm_wild_pdb_files = iterate_subdirs(ESM_WILDTYPE_DIR)
esm_mutant_plddts, esm_mutant_pdb_files = iterate_subdirs(ESM_MUTANT_DIR)

### Saving into csvs for later use

In [None]:
def save_plddt_wild_data(plddt_vals: list, pdb_files: list, save_path: str):
  uni_ids = [s[:s.index(".pdb")] for s in pdb_files]
  df = pd.DataFrame(list(zip(plddt_vals, uni_ids)), columns=["avg_plddt", "uniprot_id"])
  df.to_csv(save_path, index=False)

def save_plddt_mut_data(plddt_vals: list, pdb_files: list, save_path: str):
  uni_ids = [s[:s.index(".pdb")].split("_")[0] for s in pdb_files]
  mut_codes = [s[:s.index(".pdb")].split("_")[1] for s in pdb_files]
  df = pd.DataFrame(list(zip(plddt_vals, uni_ids, mut_codes)), columns=["avg_plddt", "uniprot_id", "mut_code"])
  df.to_csv(save_path, index=False)

In [None]:
save_plddt_wild_data(alpha_wild_plddts, alpha_wild_pdb_files, os.path.join(ALPHAFOLD_DIR, "plddts.csv"))
save_plddt_wild_data(esm_wild_plddts, esm_wild_pdb_files, os.path.join(ESM_WILDTYPE_DIR, "plddts.csv"))
save_plddt_mut_data(esm_mutant_plddts, esm_mutant_pdb_files, os.path.join(ESM_MUTANT_DIR, "plddts.csv"))

# DDGun Setup for CARC

### Create file formats expected by ddgun

In [None]:
DDGUN_DIR = os.path.join(DATA_DIR, "ddgun")
MUTFILE_DIR = os.path.join(DDGUN_DIR, "mut_files")
os.makedirs(MUTFILE_DIR, exist_ok=True)

In [None]:
comb_save_pth = os.path.join(DATA_DIR, "small_combined_cleaned.csv")
prot_df = pd.read_csv(comb_save_pth)

In [None]:
def write_mut_file(file_path: str, muts: list):
  with open(file_path, 'w') as outfile:
    outfile.write('\n'.join(muts))

In [None]:
uni_ids = set(prot_df['uniprot_id'])
for uni_id in tqdm(uni_ids):
  sub_df = prot_df.loc[prot_df['uniprot_id'] == uni_id]
  mut_codes = list(sub_df.apply(lambda x: "%s%d%s" % (x['wild_type'], x['position'], x['mutation']), axis=1))
  fname = "%s.muts" % uni_id
  save_path = os.path.join(MUTFILE_DIR, fname)
  write_mut_file(save_path, mut_codes)

# DDGun / ACDC-NN Predictions


### Glob together ddgun predictions

In [None]:
DDGUN_DIR = os.path.join(DATA_DIR, "ddgun")
DDGUN_OUT_DIR = os.path.join(DDGUN_DIR, "out_files")

In [None]:
def generate_output_predictions(inp_dir: str):
  """
  Concat together all of the .out files in a directory into a single DataFrame
  :param str inp_dir: path to input directory
  :return: pd.DataFrame containing all of the predictions w/ uniprot_id, chain, and mutation code
  """
  df = pd.DataFrame(columns=['uniprot_id', 'CHAIN', 'VARIANT', 'T_DDG[3D]'])
  dfs = []
  for file_name in tqdm(os.listdir(inp_dir)):
      file_path = os.path.join(inp_dir, file_name)
      temp_df = pd.read_csv(file_path, sep='\t', skiprows=1, header=None, names=['PDBFILE', 'CHAIN', 'VARIANT', 'S_DDG[3D]', 'T_DDG[3D]', 'STABILITY[3D]'])
      uniprot_id = os.path.splitext(file_name)[0]
      temp_df['uniprot_id'] = uniprot_id
      dfs.append(temp_df[['uniprot_id', 'CHAIN', 'VARIANT', 'T_DDG[3D]']])
  df = pd.concat(dfs)
  df = df.rename(columns={'CHAIN': 'chain', 'VARIANT': 'mut_code', 'T_DDG[3D]':'predicted_ddg'})
  df = df.reset_index(drop=True)
  return df

In [None]:
def generate_output_predictions_seq(inp_dir: str):
  """
  Concat together all of the .out files in a directory into a single DataFrame
  :param str inp_dir: path to input directory
  :return: pd.DataFrame containing all of the predictions w/ uniprot_id, chain, and mutation code
  """
  df = pd.DataFrame(columns=['uniprot_id', 'VARIANT', 'T_DDG[SEQ]'])
  dfs = []
  for file_name in tqdm(os.listdir(inp_dir)):
      file_path = os.path.join(inp_dir, file_name)
      temp_df = pd.read_csv(file_path, sep='\t', skiprows=1, header=None, names=['SEQFILE', 'VARIANT', 'S_DDG[SEQ]', 'T_DDG[SEQ]', 'STABILITY[SEQ]'])
      uniprot_id = os.path.splitext(file_name)[0]
      temp_df['uniprot_id'] = uniprot_id
      dfs.append(temp_df[['uniprot_id', 'VARIANT', 'T_DDG[SEQ]']])
  df = pd.concat(dfs)
  df = df.rename(columns={'VARIANT': 'mut_code', 'T_DDG[SEQ]':'predicted_ddg'})
  df = df.reset_index(drop=True)
  return df

In [None]:
alphafold_dir = os.path.join(DDGUN_OUT_DIR, "alphafold")
esmfold_dir = os.path.join(DDGUN_OUT_DIR, "esmfold")
ddgun_alpha = generate_output_predictions(alphafold_dir) # alphafold
ddgun_esm = generate_output_predictions(esmfold_dir) # esmfold

In [None]:
seq_dir = os.path.join(DDGUN_OUT_DIR, "sequence")
ddgun_seq = generate_output_predictions_seq(seq_dir)

In [None]:
comb_save_pth = os.path.join(DATA_DIR, "small_combined_cleaned.csv")
prot_df = pd.read_csv(comb_save_pth)

In [None]:
# add in experimental ddg values
prot_df['mut_code'] = prot_df.apply(lambda x: "%s%d%s" % (x['wild_type'], x['position'], x['mutation']), axis=1)

ddgun_merged_alpha = pd.merge(ddgun_alpha, prot_df, how='left', on=['uniprot_id', 'mut_code'])
# alphafold and esmfold only use chain 'A', hence chain_x being used here
ddgun_alpha = ddgun_merged_alpha[['uniprot_id', 'chain_x', 'mut_code', 'ddG', 'predicted_ddg']]

ddgun_merged_esm = pd.merge(ddgun_esm, prot_df, how='left', on=['uniprot_id', 'mut_code'])
ddgun_esm = ddgun_merged_esm[['uniprot_id', 'chain_x', 'mut_code', 'ddG', 'predicted_ddg']]

In [None]:
ddgun_merged_seq = pd.merge(ddgun_seq, prot_df, how='left', on=['uniprot_id', 'mut_code'])
ddgun_seq = ddgun_merged_seq[['uniprot_id', 'mut_code', 'ddG', 'predicted_ddg']]

In [None]:
# make naming consistent with acdcnn
ddgun_alpha = ddgun_alpha.rename(columns={"ddG":"experimental_ddg"})
ddgun_esm = ddgun_esm.rename(columns={"ddG":"experimental_ddg"})
ddgun_seq = ddgun_seq.rename(columns={"ddG":"experimental_ddg"})

### Make data consistent
This is really important!! Want to compare datapoints that use the same set of mutations and wild-type proteins

In [None]:
# read in acdcnn data
ACDCNN_DIR = os.path.join(DATA_DIR, "acdcnn")
ALPHAFOLD_DIR = os.path.join(ACDCNN_DIR, "alphafold", "wildtypes")
ESMFOLD_DIR = os.path.join(ACDCNN_DIR, "esmfold", "wildtypes")
ESMFOLD_MUT_DIR = os.path.join(ACDCNN_DIR, "esmfold", "mutants") # using mutant structure + wildtype
alphafold_pred_csv = os.path.join(ALPHAFOLD_DIR, "ddg_predictions.csv")
esmfold_wild_pred_csv = os.path.join(ESMFOLD_DIR, "ddg_predictions.csv")
esmfold_mut_pred_csv = os.path.join(ESMFOLD_MUT_DIR, "ddg_predictions.csv")

# alphafold wildtypes
acdcnn_alpha = pd.read_csv(alphafold_pred_csv)
acdcnn_alpha = acdcnn_alpha.rename(columns={'id':'uniprot_id'})

# esmfold wildtypes
acdcnn_esm_wild = pd.read_csv(esmfold_wild_pred_csv)
acdcnn_esm_wild = acdcnn_esm_wild.rename(columns={'id':'uniprot_id'})

# esmfold wildtype + mut structure
acdcnn_esm_mut = pd.read_csv(esmfold_mut_pred_csv)
acdcnn_esm_mut = acdcnn_esm_mut.rename(columns={'id':'uniprot_id'})

# sequence-based
acdcnn_seq_wild = pd.read_csv(os.path.join(ACDCNN_DIR, "sequence", "ddg_predictions.csv"))
acdcnn_seq_wild =  acdcnn_seq_wild.rename(columns={'id':'uniprot_id'})

In [None]:
# read-in protein data
comb_save_pth = os.path.join(DATA_DIR, "small_combined_cleaned.csv")
prot_df = pd.read_csv(comb_save_pth)

In [None]:
# read in plddt data
ALPHAFOLD_DIR = os.path.join(DATA_DIR, "alphafold")
ALPHA_WILDTYPE_DIR = os.path.join(ALPHAFOLD_DIR, "alphafold_preds", "wildtypes")
ESMFOLD_DIR = os.path.join(DATA_DIR, "esmfold")
ESM_WILDTYPE_DIR = os.path.join(ESMFOLD_DIR, "esm_preds", "wildtypes")
ESM_MUTANT_DIR = os.path.join(ESMFOLD_DIR, "esm_preds", "mutants")

alpha_plddts = pd.read_csv(os.path.join(ALPHAFOLD_DIR, "plddts.csv"))
esm_wild_plddts = pd.read_csv(os.path.join(ESM_WILDTYPE_DIR, "plddts.csv"))
esm_mut_plddts = pd.read_csv(os.path.join(ESM_MUTANT_DIR, "plddts.csv"))

In [None]:
# ensure both dfs contain the predictions from the same set of ids/mutations
def make_dfs_consistent(df1: pd.DataFrame, df2: pd.DataFrame, match_cols=['uniprot_id', 'mut_code']):
  df1_index = df1.set_index(match_cols).index
  df2_index = df2.set_index(match_cols).index
  df1_mask = df1_index.isin(df2_index)
  df2_mask = df2_index.isin(df1_index)
  consistent_df1 = df1.loc[df1_mask]
  consistent_df2 = df2.loc[df2_mask]
  return consistent_df1, consistent_df2

In [None]:
# go through several rounds to ensure all dfs are consistent
# (there's probably a better way to do this)

# r1
ddgun_alpha_c, ddgun_esm_c = make_dfs_consistent(ddgun_alpha, ddgun_esm)
acdcnn_alpha_c, acdcnn_esm_c = make_dfs_consistent(acdcnn_alpha, acdcnn_esm_wild)
acdcnn_esm_mut_c, acdcnn_esm_c = make_dfs_consistent(acdcnn_esm_mut, acdcnn_esm_c)

# r2
acdcnn_alpha_c, acdcnn_esm_mut_c = make_dfs_consistent(acdcnn_alpha_c, acdcnn_esm_mut_c)
ddgun_alpha_c, acdcnn_esm_mut_c = make_dfs_consistent(ddgun_alpha_c, acdcnn_esm_mut_c)
ddgun_esm_c, acdcnn_esm_mut_c = make_dfs_consistent(ddgun_esm_c, acdcnn_esm_mut_c)

# make prot_df consistent
prot_df_c, acdcnn_esm_mut_c = make_dfs_consistent(prot_df, acdcnn_esm_mut_c)

# make sequence-based consistent (seq-based contained same muts as original prot_df)
acdcnn_seq_wild_c, acdcnn_esm_mut_c = make_dfs_consistent(acdcnn_seq_wild, acdcnn_esm_mut_c)
ddgun_seq_c, acdcnn_esm_mut_c = make_dfs_consistent(ddgun_seq, acdcnn_esm_mut_c)

# make plddt data consistent
alpha_plddts_c, acdcnn_esm_mut_c = make_dfs_consistent(alpha_plddts, acdcnn_esm_mut_c, match_cols="uniprot_id")
esm_wild_plddts_c, acdcnn_esm_mut_c = make_dfs_consistent(esm_wild_plddts, acdcnn_esm_mut_c, match_cols="uniprot_id")
esm_mut_plddts, acdcnn_esm_mut_c = make_dfs_consistent(esm_mut_plddts, acdcnn_esm_mut_c)

In [None]:
# make sure all dfs have same num. uniprot ids and mutations
dfs = [ddgun_alpha_c, ddgun_esm_c, ddgun_seq_c, acdcnn_alpha_c, acdcnn_esm_c, acdcnn_esm_mut_c, acdcnn_seq_wild_c, 
       prot_df_c]
for df in dfs:
  print(len(df))
  print(len(set(df['uniprot_id'])))
  print("-"*60)

plddt_dfs = [alpha_plddts_c, esm_wild_plddts_c, esm_mut_plddts]
for df in plddt_dfs:
  print(len(df)) # should be 184 for wild-type plddt data
  print(len(set(df['uniprot_id'])))
  print("-"*60)

In [None]:
# save final dfs
SAVE_DIR = os.path.join(DATA_DIR, "cleaned_final_data")
os.makedirs(SAVE_DIR, exist_ok=True)

dfs = [ddgun_alpha_c, ddgun_esm_c, acdcnn_alpha_c, acdcnn_esm_c, acdcnn_esm_mut_c, ddgun_seq_c, acdcnn_seq_wild_c, 
       alpha_plddts_c, esm_wild_plddts_c, esm_mut_plddts, prot_df_c]
save_paths = [os.path.join(SAVE_DIR, "ddgun_alphafold.csv"), os.path.join(SAVE_DIR, "ddgun_esmfold.csv"),  
              os.path.join(SAVE_DIR, "acdcnn_alphafold.csv"), os.path.join(SAVE_DIR, "acdcnn_esmfold.csv"), 
              os.path.join(SAVE_DIR, "acdcnn_esmfold_muts.csv"), os.path.join(SAVE_DIR, "ddgun_seq.csv"),
              os.path.join(SAVE_DIR, "acdcnn_seq.csv"), os.path.join(SAVE_DIR, "alphafold_wild_plddts.csv"),
              os.path.join(SAVE_DIR, "esmfold_wild_plddts.csv"), os.path.join(SAVE_DIR, "esmfold_mutant_plddts.csv"), 
              os.path.join(SAVE_DIR, "final_prot_df.csv")]
for df, save_path in zip(dfs, save_paths):
  df.to_csv(save_path, index=False)

### Prediction evaluation

In [None]:
from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error, mean_absolute_error
import scipy.stats as stats

In [None]:
# load predictions in 
SAVE_DIR = os.path.join(DATA_DIR, "cleaned_final_data")

save_paths = [os.path.join(SAVE_DIR, "ddgun_alphafold.csv"), os.path.join(SAVE_DIR, "ddgun_esmfold.csv"),  
              os.path.join(SAVE_DIR, "acdcnn_alphafold.csv"), os.path.join(SAVE_DIR, "acdcnn_esmfold.csv"), 
              os.path.join(SAVE_DIR, "acdcnn_esmfold_muts.csv"), os.path.join(SAVE_DIR, "ddgun_seq.csv"), 
              os.path.join(SAVE_DIR, "acdcnn_seq.csv")]

ddgun_alpha = pd.read_csv(save_paths[0])
ddgun_esm = pd.read_csv(save_paths[1])
acdcnn_alpha = pd.read_csv(save_paths[2])
acdcnn_esm = pd.read_csv(save_paths[3])
acdcnn_esm_mut = pd.read_csv(save_paths[4])
ddgun_seq = pd.read_csv(save_paths[5])
acdcnn_seq = pd.read_csv(save_paths[6])

In [None]:
FIG_SAVE_DIR = os.path.join(DATA_DIR, "figures", "parity_plots")
os.makedirs(FIG_SAVE_DIR, exist_ok=True)

In [None]:
def create_parity_plot(df: pd.DataFrame, save_pth: str=None, title: str=None):
  sns.set_style("whitegrid")
  sns.scatterplot(data=df, x="experimental_ddg", y="predicted_ddg", alpha=0.5)
  sns.lineplot(x=df["experimental_ddg"], y=df["experimental_ddg"], color="black")
  ax = plt.gca()
  ax.set_ylabel("Predicted \u0394\u0394G")
  ax.set_xlabel("Experimental \u0394\u0394G")
  plt.tight_layout()
  if title is not None:
    ax.set_title(title)
  if save_pth is not None:
    plt.savefig(save_pth, dpi=300)
    print("Successfully saved figure to:", save_pth)

In [None]:
def get_boostrap_ci(y_true, y_pred, n:int, mode:str="rmse"):
  """
  Calculate 95% confidence interval using bootstrapping. Intended for use with RMSE and MAE
  """
  np.random.seed(42)
  values = np.empty(n)
  for i in range(n):
      indices = np.random.choice(len(y_true), len(y_true), replace=True)
      y_true_sample = y_true[indices]
      y_pred_sample = y_pred[indices]
      if mode == "rmse":
        values[i] = np.sqrt(mean_squared_error(y_true_sample, y_pred_sample))
      elif mode == "mae":
        values[i] = np.mean(np.abs(y_true_sample - y_pred_sample))
      else:
        raise NotImplementedError("mode expected to be 'rmse' or 'mae'")
  
  # compute the 95% confidence interval
  lower_b = np.percentile(values, 2.5)
  upper_b = np.percentile(values, 97.5)
  return lower_b, upper_b

In [None]:
def print_metrics(df: pd.DataFrame, calculate_ci: bool=False):
  """
  Print metrics from a given df
  :param df: pd.DataFrame, expected to have cols 'experimental_ddg' and 'predicted_ddg'
  :param calculate_ci: bool, whether to calculate confidence intervals
  """
  corr_res = pearsonr(df["experimental_ddg"], df["predicted_ddg"])
  mae = mean_absolute_error(df["experimental_ddg"], df["predicted_ddg"])
  rmse =  mean_squared_error(df["experimental_ddg"], df["predicted_ddg"], squared=False)
  print("Pearson correlation: %.4f" % corr_res[0])
  print("MAE: %.4f" % mae)
  print("RMSE: %.4f" % rmse)
  if calculate_ci:
      corr_low, corr_high = corr_res.confidence_interval(confidence_level=0.95)
      mae_ci = get_boostrap_ci(df["experimental_ddg"],  df["predicted_ddg"], len(df), mode="mae")
      rmse_ci = get_boostrap_ci(df["experimental_ddg"], df["predicted_ddg"], len(df), mode="rmse")
      print("Correlation interval: (%.4f, %.4f)" % (corr_low, corr_high))
      print("MAE interval: (%.4f, %.4f)" % (mae_ci[0], mae_ci[1]))
      print("RMSE interval: (%.4f, %.4f)" % (rmse_ci[0], rmse_ci[1]))

In [None]:
print_metrics(ddgun_seq)
create_parity_plot(ddgun_seq, os.path.join(FIG_SAVE_DIR, "ddgun_seq.png"))

In [None]:
print_metrics(ddgun_alpha)
create_parity_plot(ddgun_alpha, os.path.join(FIG_SAVE_DIR, "ddgun_alpha_wild.png"))

In [None]:
print_metrics(ddgun_esm)
create_parity_plot(ddgun_esm, os.path.join(FIG_SAVE_DIR, "ddgun_esm_wild.png"))

In [None]:
print_metrics(acdcnn_seq)
create_parity_plot(acdcnn_seq, os.path.join(FIG_SAVE_DIR, "acdcnn_seq.png"))

In [None]:
print_metrics(acdcnn_alpha)
create_parity_plot(acdcnn_alpha, os.path.join(FIG_SAVE_DIR, "acdcnn_alpha_wild.png"))

In [None]:
print_metrics(acdcnn_esm)
create_parity_plot(acdcnn_esm, os.path.join(FIG_SAVE_DIR, "acdcnn_esm_wild.png"))

In [None]:
print_metrics(acdcnn_esm_mut)
create_parity_plot(acdcnn_esm_mut, os.path.join(FIG_SAVE_DIR, "acdcnn_esm_mut.png"))

### Correlation between ddg prediction errors and plddt values

In [None]:
SAVE_DIR = os.path.join(DATA_DIR, "cleaned_final_data")

save_paths = [os.path.join(SAVE_DIR, "ddgun_alphafold.csv"), os.path.join(SAVE_DIR, "ddgun_esmfold.csv"),  
              os.path.join(SAVE_DIR, "acdcnn_alphafold.csv"), os.path.join(SAVE_DIR, "acdcnn_esmfold.csv"), 
              os.path.join(SAVE_DIR, "acdcnn_esmfold_muts.csv"), os.path.join(SAVE_DIR, "alphafold_wild_plddts.csv"),
              os.path.join(SAVE_DIR, "esmfold_wild_plddts.csv"), os.path.join(SAVE_DIR, "esmfold_mutant_plddts.csv"),]

ddgun_alpha = pd.read_csv(save_paths[0])
ddgun_esm = pd.read_csv(save_paths[1])
acdcnn_alpha = pd.read_csv(save_paths[2])
acdcnn_esm = pd.read_csv(save_paths[3])
acdcnn_esm_mut = pd.read_csv(save_paths[4])
alpha_plddts = pd.read_csv(save_paths[5])
esm_wild_plddts = pd.read_csv(save_paths[6])
esm_mut_plddts = pd.read_csv(save_paths[7])

In [None]:
from sklearn.metrics import mean_squared_error

def calc_avg_rmses(df: pd.DataFrame):
  """
  Calculate the average RMSE (root-mean square error) in predictions for each uniprot id.
  :param df: pd.DataFrame containing uniprot ids, experimental ddgs, and predicted ddgs
  :return: pd.DataFrame mapping uniprot id to average RMSE
  """
  def calc_error(group):
    return mean_squared_error(group["experimental_ddg"], group["predicted_ddg"], squared=False)
  errors = df.groupby("uniprot_id").apply(calc_error)
  error_df = pd.DataFrame({"uniprot_id": errors.index, "error": errors.values})
  return error_df

In [None]:
def calc_rmses_mut(df: pd.DataFrame):
  """
  Calculate the average RMSE (root-mean square error) in predictions for each mutation.
  In this case RMSE is being calculated for each mutation so RMSE = |experimental - predicted| in this case
  :param df: pd.DataFrame containing uniprot ids, experimental ddgs, and predicted ddgs
  :return: pd.DataFrame mapping uniprot id and mutation to average RMSE
  """
  error_df = df.copy()
  def calc_error(entry):
    return mean_squared_error([entry["experimental_ddg"]], [entry["predicted_ddg"]], squared=False)
  errors = df.apply(calc_error, axis=1)
  error_df["error"] = errors
  return error_df

In [None]:
from scipy.stats import pearsonr

def get_corr(error_df: pd.DataFrame, plddt_df: pd.DataFrame, merge_on):
  """
  Get correlation (and correlation p-value) between prediction errors and structure plddts
  """
  merged_df = pd.merge(error_df, plddt_df, on=merge_on)
  correlation, p_value = pearsonr(merged_df["error"], merged_df["avg_plddt"])
  return correlation, p_value

In [None]:
ddgun_alpha_corr, ddgun_alpha_pval = get_corr(calc_avg_rmses(ddgun_alpha), alpha_plddts, merge_on="uniprot_id")
ddgun_esm_corr, ddgun_esm_pval = get_corr(calc_avg_rmses(ddgun_esm), esm_wild_plddts, merge_on="uniprot_id")
acdcnn_alpha_corr, acdcnn_alpha_pval = get_corr(calc_avg_rmses(acdcnn_alpha), alpha_plddts, merge_on="uniprot_id")
acdcnn_esm_corr, acdcnn_esm_pval = get_corr(calc_avg_rmses(acdcnn_esm), esm_wild_plddts, merge_on="uniprot_id")
acdcnn_esm_mut_coor, acdcnn_esm_mut_pval = get_corr(calc_rmses_mut(acdcnn_esm_mut), esm_mut_plddts, merge_on=["uniprot_id", "mut_code"])

In [None]:
corrs = [ddgun_alpha_corr, ddgun_esm_corr, acdcnn_alpha_corr, acdcnn_esm_corr, acdcnn_esm_mut_coor]
pvals = [ddgun_alpha_pval, ddgun_esm_pval, acdcnn_alpha_pval, acdcnn_esm_pval, acdcnn_esm_mut_pval]
df = pd.DataFrame(zip(corrs, pvals), columns=["Correlation", "P-Value"])
df

# Visualization

### Data plotting

In [None]:
SAVE_DIR = os.path.join(DATA_DIR, "cleaned_final_data")
comb_save_pth = os.path.join(SAVE_DIR, "final_prot_df.csv")
prot_df = pd.read_csv(comb_save_pth)

In [None]:
print("Number wild-type proteins:", len(set(prot_df["uniprot_id"])))
print("Number mutations:", len(prot_df))

In [None]:
FIG_DIR = os.path.join(DATA_DIR, "figures")
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
# create histogram of ddG vals
print('min ddG:', min(prot_df['ddG']))
print('max ddG:', max(prot_df['ddG']))
fig = sns.histplot(data=prot_df, x='ddG')
fig.set_xlabel("\u0394\u0394G Value")
fig.set_ylabel("Frequency")
fig.set_title("Experiments by \u0394\u0394G")
plt.savefig(os.path.join(FIG_DIR, "ddg_experimental_hist.png"), dpi=300)
plt.show()

In [None]:
# show names of most frequent protein entries
prot_cnts = prot_df['protein_name'].value_counts().rename_axis('protein_name').reset_index(name='counts')
n = 10
top_n = prot_cnts[0:n]
fig = sns.barplot(data=top_n, x='counts', y='protein_name', orient='h', color='lightseagreen', alpha=0.5)
fig.set_xlabel("Frequency")
fig.set_ylabel("")
fig.set_title("Top %d Proteins by Entry" % n)
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, "top_n_proteins.png"), dpi=300)
plt.show()

In [None]:
# map of amino acid substitions
from sklearn.metrics import confusion_matrix
from_acids = prot_df['wild_type']
to_acids = prot_df['mutation']
# data not actually a confusion matrix, but this works for visualization purposes
categories = sorted(list(set(from_acids)))
cf_matrix = confusion_matrix(from_acids, to_acids)
fig = sns.heatmap(cf_matrix, annot=False, xticklabels=categories, yticklabels=categories)
plt.ylabel('From acid')
plt.xlabel('To acid')
fig.set_title("Amino Acid Substitutions")
sz_factor = 2
sns.set(rc={'figure.figsize':(11.7 * sz_factor,8.27 * sz_factor)})
plt.savefig(os.path.join(FIG_DIR, "substitution_heatmap.png"), dpi=300)
plt.show()

In [None]:
# bar plot of amino acid substitutions
from_acids = prot_df['wild_type']
to_acids = prot_df['mutation']
from_acid_cnts = from_acids.value_counts().rename_axis('acid').reset_index(name='From amino acid')
to_acid_cnts = to_acids.value_counts().rename_axis('acid').reset_index(name='To amino acid')
cnt_df = from_acid_cnts.merge(to_acid_cnts, on='acid').sort_values(by='acid')
cnt_df = pd.melt(cnt_df, id_vars=["acid"])
cnt_df = cnt_df.rename(columns={"variable": "source", "value":"entries"})
fig = sns.barplot(data=cnt_df, x='acid', y="entries", hue='source')
fig.set_title("Number of Entries by Amino Acid Substitution")
fig.set_xlabel("Amino Acid")
fig.set_ylabel("Frequency")
plt.savefig(os.path.join(FIG_DIR, "amino_acid_entries.png"), dpi=300)
plt.show()

In [None]:
# plots sequence lengths
seq_lens = prot_df['sequence'].map(lambda x : len(x))
# ignore large proteins (~500 are larger than 1000 residues)
print('min length:', min(seq_lens))
print('max length:', max(seq_lens))
fig = sns.histplot(data=seq_lens)
fig.set_title("Sequence Length of Entries")
fig.set_xlabel("Sequence Length")
fig.set_ylabel("Frequency")
plt.savefig(os.path.join(FIG_DIR, "seq_len_histogram.png"), dpi=300)
plt.show()

In [None]:
# thanks to ChatGPT for helping write this 
# Create a new DataFrame with a column of unique strings and their counts
from collections import Counter
import ast

# Convert the string entries to lists of strings
families = prot_df['interpro_families'].apply(ast.literal_eval)

counts = Counter([item for sublist in families for item in set(sublist)])

# Create a DataFrame with the unique strings and their counts
unique_values = pd.DataFrame(list(counts.items()), columns=['family', 'count'])

# Create the histogram plot using seaborn
sns.set_style('whitegrid')
plt.figure(figsize=(12, 6))
ax = sns.barplot(x='family', y='count', data=unique_values)
ax.set_title('Histogram of Unique Interpro Families')
ax.set_xlabel('Family')
ax.set_ylabel('Count')
ax.set(xticklabels=[])
plt.show()

### pLDDT Histograms

In [None]:
FIG_DIR = os.path.join(DATA_DIR, "figures")
os.makedirs(FIG_DIR, exist_ok=True)

In [None]:
SAVE_DIR = os.path.join(DATA_DIR, "cleaned_final_data")

save_paths = [os.path.join(SAVE_DIR, "alphafold_wild_plddts.csv"), os.path.join(SAVE_DIR, "esmfold_wild_plddts.csv"), 
              os.path.join(SAVE_DIR, "esmfold_mutant_plddts.csv"),]

alpha_plddts = pd.read_csv(save_paths[0])
esm_wild_plddts = pd.read_csv(save_paths[1])
esm_mut_plddts = pd.read_csv(save_paths[2])

In [None]:
print("AlphaFold Wild (Mean, Median): %.2f, %.2f" % (alpha_plddts["avg_plddt"].mean(), alpha_plddts["avg_plddt"].median()))
print("ESMFold Wild (Mean, Median):  %.2f, %.2f" % (esm_wild_plddts["avg_plddt"].mean(), esm_wild_plddts["avg_plddt"].median()))
print("ESMFold Mutants (Mean, Median):  %.2f, %.2f" % (esm_mut_plddts["avg_plddt"].mean(), esm_mut_plddts["avg_plddt"].median()))

In [None]:
min_plddt = np.min([min(esm_wild_plddts["avg_plddt"]), min(esm_mut_plddts["avg_plddt"]), min(alpha_plddts["avg_plddt"])])
min_plddt = np.floor(min_plddt / 5) * 5 # floor to nearest multiple of 5
bins = np.arange(min_plddt, 105, 5)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(16, 4))

# Plot the histograms on the subplots
sns.histplot(ax=axs[0], data=alpha_plddts["avg_plddt"], bins=bins)
sns.histplot(ax=axs[1], data=esm_wild_plddts["avg_plddt"], bins=bins)
sns.histplot(ax=axs[2], data=esm_mut_plddts["avg_plddt"], bins=bins)

axs[0].set_ylabel("Frequency")
axs[1].set_ylabel("")
axs[2].set_ylabel("")

# Set the x-axis label for each subplot
axs[0].set_xlabel("pLDDT Value")
axs[1].set_xlabel("pLDDT Value")
axs[2].set_xlabel("pLDDT Value")

# set title for each subplot
axs[0].set_title("AlphaFold WildTypes")
axs[1].set_title("ESMFold WildTypes")
axs[2].set_title("ESMFold Mutants")

# show/save figure
fig.tight_layout()
plt.savefig(os.path.join(FIG_DIR, "plddt_histograms.png"), dpi=300)
plt.show()

### Structure visualization

In [None]:
!pip install py3DMol

In [None]:
STRUCT_FIG_DIR = os.path.join(DATA_DIR, "figures", "structures")
os.makedirs(STRUCT_FIG_DIR, exist_ok=True)

In [None]:
def load_pdb(pdb_path: str) -> str:
  system = ""
  with open(pdb_path) as ifile:
    system = "".join([x for x in ifile])
  return system

In [None]:
def visualize_pdb_with_confidence(pdb_path: str, save_path: str=None):
  """
  Visualize a given pdb file, using plddt scores to color each atom.
  """
  system = load_pdb(pdb_path)
  view = py3Dmol.view(width=400, height=400)
  view.addModelsAsFrames(system)

  i = 0
  for line in system.split("\n"):
    split = line.split()
    if len(split) == 0 or split[0] != "ATOM":
      continue
    plddt = float(split[-2])
    if plddt >= 90:
      color = "#3434eb"
    elif plddt < 90 and plddt >= 70:
      color = "#34deeb"
    elif plddt < 70 and plddt >= 50:
      color = "#ebe534"
    else:
      color = "#eb7134"
    idx = int(split[1])
    view.setStyle({'model': -1, 'serial': i+1}, {"cartoon": {'color': color}})
    i += 1
  view.zoomTo()
  view.show()
  view.render()
  view.png()

In [None]:
import py3Dmol
# some py3DMol documentation available at:
# https://william-dawson.github.io/using-py3dmol.html 
ALPHA_DIR = os.path.join(DATA_DIR, "alphafold", "alphafold_preds", "wildtypes")
ESM_DIR = os.path.join(DATA_DIR, "esmfold", "esm_preds", "wildtypes")

In [None]:
uniprot_id = "P04156"

In [None]:
pdb_path = os.path.join(ALPHA_DIR, uniprot_id + ".pdb")
save_path = os.path.join(STRUCT_FIG_DIR, uniprot_id +".png")
visualize_pdb_with_confidence(pdb_path, save_path)

In [None]:
def find_pdb_path(dir: str, uni_id: str):
  for subdir in os.listdir(dir):
    if os.path.exists(os.path.join(dir, subdir, uni_id + ".pdb")):
      return os.path.join(dir, subdir, uni_id + ".pdb")
  return None
pdb_path = find_pdb_path(ESM_DIR, uniprot_id)
save_path = os.path.join(STRUCT_FIG_DIR, uniprot_id +".png")
visualize_pdb_with_confidence(pdb_path, save_path)

In [None]:
# taken from:
# https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb


def plot_plddt_legend(save_pth: str):
  """Plots the legend for pLDDT."""
  thresh = ['Very high (pLDDT > 90)', 
            'Confident (90 > pLDDT > 70)',
            'Low (70 > pLDDT > 50)',
            'Very low (pLDDT < 50)']

  colors = ["#3434eb", "#34deeb", "#ebe534", "#eb7134"]

  plt.figure(figsize=(2, 2))
  for c in colors:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False, loc='center', fontsize=20)
  plt.xticks([])
  plt.yticks([])
  ax = plt.gca()
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  ax.spines['left'].set_visible(False)
  ax.spines['bottom'].set_visible(False)
  plt.title('Model Confidence', fontsize=20, pad=20)
  plt.tight_layout()
  plt.savefig(save_pth, dpi=300)
  return plt

In [None]:
plot_plddt_legend(os.path.join(STRUCT_FIG_DIR, "legend.png"))

# Misc/Unused

### Gathering Experimental PDB Structures

In [None]:
SAVE_DIR = os.path.join(DATA_DIR, "cleaned_final_data")
comb_save_pth = os.path.join(SAVE_DIR, "final_prot_df.csv")
prot_df = pd.read_csv(comb_save_pth)

In [None]:
# adapted from:
# https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests

def download_pdb(url: str, uniprot_id: str, save_pth: str):
  if url is None:
    return False
  try:
    with requests.get(url, stream=True) as r:
      r.raise_for_status()
      with open(save_pth, 'wb') as f:
        for chunk in r.iter_content(chunk_size=8192): 
          f.write(chunk)
    return True
  except requests.exceptions.HTTPError:
    # protein not found in database
    return False

In [None]:
# thanks to chatgpt for helping write this function
def get_target_url(uniprot_id: str):
  # Make a GET request to retrieve the mapping data for the UniProt ID
  response = requests.get(f"https://www.ebi.ac.uk/pdbe/api/mappings/best_structures/{uniprot_id}")
  if not response.ok:
    return None, ""
  # Parse the JSON response
  data = json.loads(response.text)

  # Extract the PDB IDs and coverage/resolution data
  pdb_coverage_data = []
  for pdb_entry in data[uniprot_id]:
      pdb_id = pdb_entry["pdb_id"]
      coverage = pdb_entry["coverage"]
      resolution = pdb_entry["resolution"]
      chain_id = pdb_entry['chain_id']
      if resolution is None:
        resolution = 10.0 # NMR methods don't provide resolution
      if len(chain_id) > 1:
        resolution = 100.0 # de-prioritize methods which have many chains
      pdb_coverage_data.append((pdb_id, coverage, resolution, chain_id))

  # Sort the PDB data by coverage and resolution
  sorted_pdb_coverage_data = sorted(pdb_coverage_data, key=lambda x: (x[1], -x[2]), reverse=True)

  # Get the PDB ID of the top-ranked structure
  top_pdb_id = sorted_pdb_coverage_data[0][0]
  chain_info = sorted_pdb_coverage_data[0][3]
  return f"https://files.rcsb.org/download/{top_pdb_id}.pdb", chain_info

In [None]:
# setup directory to save predictions to
EXPERIMENTAL_DIR = os.path.join(DATA_DIR, "experimental_structures", "wildtypes")
os.makedirs(EXPERIMENTAL_DIR, exist_ok=True)

# uniprot ids not found in database
not_found = [] 

# save all wild-type experimental structures
uniprot_ids = set(prot_df['uniprot_id'])
investigate = []
for uid in tqdm(uniprot_ids):
  url_target, chain_info = get_target_url(uid)
  if len(chain_info) > 1:
    investigate.append((uid, url_target, chain_info))
  fname = "%s.pdb" % uid
  save_pth = os.path.join(EXPERIMENTAL_DIR, fname)
  found = download_pdb(url_target, uid, save_pth)
  if not(found):
    not_found.append(uid)

In [None]:
print(len(investigate))
print(investigate)

In [None]:
print(not_found) # no experimental structures for these proteins

### Change to 'A' chain
To facilate comparison with ESMFold/AlphaFold. Even though many entries only have one chain for the protein structure additional information is often included that makes external tools difficult to use

In [None]:
# Open the input PDB file
def to_a_chain(pdb_path: str, save_path: str):
  with open(pdb_path, 'r') as f:
    lines = f.readlines()

  # Modify the header to set the chain identifier to "A"
  for i, line in enumerate(lines):
    if line.startswith('ATOM'):
      lines[i] = line[:21] + 'A' + line[22:]

  # Write the modified PDB file to disk
  with open(save_path, 'w') as f:
    f.writelines(lines) 

In [None]:
EXPERIMENTAL_DIR = os.path.join(DATA_DIR, "experimental_structures", "wildtypes")
EXPERIMENTAL_A_DIR = os.path.join(DATA_DIR, "experimental_structures", "wildtypes_A_chain")
os.makedirs(EXPERIMENTAL_A_DIR, exist_ok=True)

for pdb_file in tqdm(os.listdir(EXPERIMENTAL_DIR)):
  uid = pdb_file[:pdb_file.index(".pdb")]
  pdb_path = os.path.join(EXPERIMENTAL_DIR, pdb_file)
  save_path = os.path.join(EXPERIMENTAL_A_DIR, f"{uid}.pdb")
  to_a_chain(pdb_path, save_path)

In [None]:
# print chains from downloaded pdbs
from Bio import PDB

EXPERIMENTAL_DIR = os.path.join(DATA_DIR, "experimental_structures", "wildtypes_A_chain")

ls = []
for pdb_file in os.listdir(EXPERIMENTAL_DIR)[1:2]:
  # Create a PDB parser object
  parser = PDB.PDBParser()
  pdb_path = os.path.join(EXPERIMENTAL_DIR, pdb_file)
  print(pdb_path)
  # Parse the PDB file
  structure = parser.get_structure('my_structure', pdb_path)

  # Get the chain(s) in the structure
  chains = [chain.get_id() for chain in structure.get_chains()]
  for ch in chains:
    ls.append(ch)

In [None]:
print(set(ls))

### Calculate RMSD between two pdb files
The root-mean-square deviation (RMSD), is the measure of the average distance between the atoms (usually the backbone atoms) of superimposed protein.
See: https://en.wikipedia.org/wiki/Root-mean-square_deviation_of_atomic_positions 

In [None]:
!pip install biopython

In [None]:
# adapted from: https://github.com/sarisabban/RMSD/blob/main/RMSD.py 
import Bio.PDB
# Bio.PDB citation: 
# Hamelryck, T., Manderick, B. (2003) PDB parser and structure class implemented in Python. Bioinformatics 19: 2308–2310


def get_rmsd(pdb_path1: str, pdb_path2: str):
  '''
  Calculate the RMSD between two protein structures using Biopython
  The Biopython algorithm is poorly designed and only aligns local motifs
  rather than full protein structures/complexes.
  '''
  builder = Bio.PDB.Polypeptide.PPBuilder()
  parser = Bio.PDB.PDBParser(QUIET=False)
  struct1 = parser.get_structure('Structure 1', pdb_path1)
  struct2 = parser.get_structure('Structure 2', pdb_path2)
  fixed = [atom for atom in struct1.get_atoms()]
  moving = [atom for atom in struct2.get_atoms()]
  lengths = [len(fixed), len(moving)]
  smallest = min(lengths)
  sup = Bio.PDB.Superimposer()
  sup.set_atoms(fixed[:smallest], moving[:smallest])
  return sup.rms

In [None]:
uniprot_id = 'P37957'
pdb_esm_dir = os.path.join(DATA_DIR, "pdbs/wild_type/esm")
pdb_esm_path = os.path.join(pdb_esm_dir, "%s.pdb" % uniprot_id)
pdb_alpha_dir = os.path.join(DATA_DIR, "pdbs/wild_type/alphafold")
pdb_alpha_path = os.path.join(pdb_alpha_dir, '%s.pdb' % uniprot_id)

rmsd = get_rmsd(pdb_esm_path, pdb_alpha_path)
print(rmsd)