In [None]:
%pip install biopython

In [None]:
#This metadata contains all the fasta files with sequence-redesigns for our backbone models
from google.colab import drive
drive.mount('/content/drive')
import os
import Bio
from Bio import SeqIO
from Bio.Seq import Seq

import json
import os
import shutil
import glob
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import uuid
from datetime import datetime
import re
import torch
from time import time

meta_data_filepath = "/content/drive/MyDrive/Generative_Models/utilities/metadata_mpnn.csv"

if os.path.exists(meta_data_filepath):
  #This metadata contains all the fasta files with sequence-redesigns for our backbone models
  mpnn_metadata = pd.read_csv(meta_data_filepath)
  print("Existing MPNN metadata read in.")

root_dir = "/content/drive/MyDrive/Generative_Models/unconditional_generation/"
paths = []

for dirpath, dirnames, filenames in os.walk(root_dir):
  for filename in filenames:
      if "generation_metadata" in filename:
          paths.append(os.path.join(dirpath, filename))

import pandas as pd
all_dfs = []
for file_path in paths:
  df = pd.read_csv(file_path)
  df["dir_path"] = "/".join(file_path.split("/")[:-1])
  all_dfs.append(df)
gen_meta = pd.concat(all_dfs, ignore_index=True)

gen_meta = gen_meta[(gen_meta['entity_id'].notnull()) & ~(gen_meta['task'].str.contains('backbone'))]
gen_meta['length'] = None
gen_meta.loc[:,"length"] = gen_meta.loc[:,"conditions"].str.extract('(\d+)')[0].astype(int)

Mounted at /content/drive
Existing generation metadata read in.


In [None]:
root_dir = "/content/drive/MyDrive/Generative_Models/unconditional_generation/"
paths = []
for dirpath, dirnames, filenames in os.walk(root_dir):
  for filename in filenames:
      if "length_dist" in filename:
          paths.append(os.path.join(dirpath, filename))
all_length_dists = {}
for file_path in paths:
  with open(file_path, "r") as f:
    all_length_dists[file_path.split('/')[-1]] = json.load(f)

In [None]:
from Bio.PDB import PDBParser
from Bio.PDB import MMCIFParser
from Bio.PDB.Polypeptide import PPBuilder

import warnings
from Bio.PDB.PDBExceptions import PDBConstructionWarning
from Bio.SeqUtils import seq1
from Bio.Seq import Seq

#warnings.resetwarnings()
warnings.simplefilter('ignore', PDBConstructionWarning)

def is_valid_protein(sequence):
  try:
    Seq(str(sequence))
    return True
  except ValueError:
    return False

for length in range(151,201):
  mpnn_meta = mpnn_metadata.loc[mpnn_metadata.length == length].drop_duplicates(subset='output_file_path', keep='last')
  #All foldingdiff outputs > 128 in length are actually just 128 so we'll ignore them
  mpnn_meta = mpnn_meta.loc[~((mpnn_meta['length'] > 128) & (mpnn_meta['gen_model'] == 'foldingdiff'))]
  #mpnn_meta.groupby('gen_model')['entity_id'].nunique()
  records = []
  for _, row in mpnn_meta.iterrows():
    for i, design in enumerate(SeqIO.parse(row['output_file_path'], "fasta")):
      if i > 0:
        design.description = ""
        design.name = ""
        records.append(design)
        #len(records)

  sequence_outputs = gen_meta.loc[(gen_meta.task == "sequence_generation") & (gen_meta.length == length),:]
  #LLM outputs sometimes have artifacts. Correct these first.
  #empty characters/tokens
  sequence_outputs.loc[:,'generated_sequence'] = sequence_outputs.loc[:,'generated_sequence'].apply(lambda x: re.sub(r'[\s\d]+', '', str(x)))
  #recalc length
  sequence_outputs.loc[:,'length'] = sequence_outputs.loc[:,'generated_sequence'].apply(lambda x: len(x))
  #is a valid protein seq
  sequence_outputs.loc[:,'generated_sequence'] = sequence_outputs.loc[sequence_outputs.loc[:,'generated_sequence'].apply(is_valid_protein),:]
  sequence_outputs.drop('conditions',inplace=True, axis=1)

  #now we need to subsample the LLM outputs (have overgenerated particularly at shorter lengths as there is not always length control)

  length_dists = [
  ('rita_xl','/content/drive/MyDrive/Generative_Models/unconditional_generation/rita_unconditional/uniref50_length_dist_rita.json'),
  ('evodiff_OA_DM_640M','/content/drive/MyDrive/Generative_Models/unconditional_generation/evodiff_unconditional/uniref50_length_dist_evodiff.json'),
  ('protgpt2','/content/drive/MyDrive/Generative_Models/unconditional_generation/protgpt2_unconditional/uniref50_length_dist_protgpt2.json'),
  ('ESM_Design','/content/drive/MyDrive/Generative_Models/unconditional_generation/esmdesign_unconditional/uniref50_length_dist_esmdesign.json'),
  ('ProGen2','/content/drive/MyDrive/Generative_Models/unconditional_generation/progen2_unconditional/uniref50_length_dist_progen2.json')]

  seq_subsamp = pd.DataFrame()
  for model, dist in length_dists:
    print(model)
    with open(dist, "r") as f:
      uniprot_length_dist =  json.load(f)
    uniprot_length_dist = [i for i in uniprot_length_dist if i[0] == length]
    print(uniprot_length_dist)
    sampled_df = pd.DataFrame()
    sampled_df = sequence_outputs.loc[sequence_outputs.model == model].sample(frac=1, random_state=42).reset_index(drop=True).head(uniprot_length_dist[0][1])
    if len(sampled_df) < uniprot_length_dist[0][1]:
        print(f"Warning: Only {len(sampled_df)} rows available for length {length}. Sampling all available rows.")
    seq_subsamp =  pd.concat([seq_subsamp, sampled_df])
  seq_subsamp.reset_index(drop=True,inplace=True)

  for i, row in seq_subsamp.iterrows():
    id = row['model'] + "_" + "len"+ str(row["length"]) +"_" +row['entity_id']
    sequence = row['generated_sequence']
    record = SeqIO.SeqRecord(
        seq=Seq(sequence),
        id=id,
        description="",
        name="",
    )
    records.append(record)
  #len(records)

  #Now we also need the sequences for our all-atom outputs so we can refold them
  aa_outputs = gen_meta.loc[(gen_meta.task == "all_atom_pdb_generation") & (gen_meta.length == length),:]
  #aa_outputs.groupby('model')['entity_id'].nunique()
  roots = ["/content/drive/MyDrive/Generative_Models/unconditional_generation/chroma_unconditional",
  "/content/drive/MyDrive/Generative_Models/unconditional_generation/proteingenerator_unconditional",
  "/content/drive/MyDrive/Generative_Models/unconditional_generation/protpardelle_unconditional"]
  for root in roots:
    for i, row in aa_outputs.iterrows():
      #print(root + "/" + row["output_file_name"])
      if os.path.exists(root + "/" + row["output_file_name"]):
        if root.split("/")[-1] == "chroma_unconditional":
          parser = MMCIFParser()
          structure = parser.get_structure('cif', root + "/" + row["output_file_name"])
        else:
          parser = PDBParser()
          structure = parser.get_structure('pdb', root + "/" + row["output_file_name"])
        ppb=PPBuilder()
        sequence = ""
        for pp in ppb.build_peptides(structure):
          sequence = sequence + pp.get_sequence()
        if row["length"] != len(sequence):
          print(root + "/" + row["output_file_name"])
          print(len(sequence))
          sequence=""
          for model in structure:
            for chain in model:
              for residue in chain:
                res_name = residue.get_resname()
                sequence += seq1(res_name)
          if row["length"] == len(sequence):
            record = SeqIO.SeqRecord(
            seq=Seq(sequence),
            id=row['model'] + "_" + "len"+ str(row["length"]) +"_" +row['entity_id'] + "_refold",
            description="",
            name="",)
            records.append(record)
          else:
            print("is still busted ^^^")
        else:
          record = SeqIO.SeqRecord(
          seq=Seq(sequence),
          id=row['model'] + "_" + "len"+ str(row["length"]) +"_" +row['entity_id'],
          description="",
          name="",)
          records.append(record)
  with open(f'/content/drive/MyDrive/Generative_Models/utilities/fold_inputs/all_len{length}.fa', 'w',) as f:
    SeqIO.write(records, f, 'fasta')