In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# @title Setup

# @markdown [Get your API key here](https://chroma-weights.generatebiomedicines.com) and enter it below before running.

from google.colab import output

output.enable_custom_widget_manager()

import os

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import contextlib

api_key = ""

!pip install generate-chroma > /dev/null 2>&1

import torch

torch.use_deterministic_algorithms(True, warn_only=True)

import warnings
from tqdm import tqdm, TqdmExperimentalWarning

warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
from functools import partialmethod

tqdm.__init__ = partialmethod(tqdm.__init__, leave=False)

from google.colab import files
import ipywidgets as widgets


def create_button(filename, description=""):
    button = widgets.Button(description=description)
    display(button)

    def on_button_click(b):
        files.download(filename)

    button.on_click(on_button_click)


def render(protein, trajectories, output="protein.cif"):
    display(protein)
    print(protein)
    protein.to_CIF(output)
    traj_output = output.replace(".cif", "_trajectory.cif")
    trajectories["trajectory"].to_CIF(traj_output)
    create_button(output, description="Download sample")
    create_button(traj_output, description="Download trajectory")


import locale

locale.getpreferredencoding = lambda: "UTF-8"

from chroma import Chroma, Protein, conditioners
from chroma.models import graph_classifier, procap
from chroma.utility.api import register_key
from chroma.utility.chroma import letter_to_point_cloud, plane_split_protein

register_key(api_key)
with contextlib.redirect_stdout(None):
    chroma = Chroma()

device = "cuda"



In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import uuid
from datetime import datetime
import re
import torch
from time import time

meta_data_filepath = "/content/drive/MyDrive/Generative_Models/conditional_generation/chroma_tev/generation_metadata_chroma_tev.csv"

if os.path.exists(meta_data_filepath):
  all_metadata_df = pd.read_csv(meta_data_filepath)
  print("Existing generation metadata read in.")
else:
  all_metadata_df = pd.DataFrame()
  #all_metadata_df.to_csv(meta_data_filepath, index=False)
  print("Created generation metadata dataframe")


Created generation metadata dataframe


In [None]:
from chroma.models.graph_design import GraphDesign
is_iterable = lambda obj: hasattr(obj, '__iter__') or hasattr(obj, '__getitem__') and not isinstance(obj, str)

batch_size = 1

meta_data = {}
meta_data['batch_id'] = str(uuid.uuid4())
meta_data['batch_size'] = batch_size
meta_data['Timestamp'] = str(datetime.now())
meta_data['model'] = 'chroma'
meta_data['task'] = 'all_atom_pdb_generation'
meta_data['conditions'] = 'tev (complex) scaffolding [selection_string = "((chain A and (resid 28-33 or resid 47-51 or resid 140-152 or resid 168-179 or resid 212-221)) or (chain B))"]'
meta_data['gpu'] = 'T4 GPU'

#protein = Protein.from_PDB("/content/drive/MyDrive/Generative_Models/conditional_generation/tev_monomer.pdb", device=device)
protein = Protein.from_PDB("/content/drive/MyDrive/Generative_Models/conditional_generation/tev_complex.pdb", device=device)
X, C, _ = protein.to_XCS()

#selection_string = "resid 28-33 or resid 47-51 or resid 140-152 or resid 168-179 or resid 212-221"
selection_string = "((chain A and (resid 28-33 or resid 47-51 or resid 140-152 or resid 168-179 or resid 212-221)) or (chain B))"

subseq_conditioner = conditioners.SubsequenceConditioner(design_model = GraphDesign(), protein=protein, selection=selection_string, weight = 1.0).to(device)
substruct_conditioner = conditioners.SubstructureConditioner(protein, backbone_model=chroma.backbone_network, selection=selection_string,weight = 1.0).to(device)
composed_conditioner = conditioners.ComposedConditioner([substruct_conditioner])

for i in range(100):
  print(i)
  start_time = time()

  infilled_protein, trajectories = chroma.sample(
      protein_init=protein,
      conditioner=composed_conditioner,
      samples = batch_size,
      full_output=True
  )

  end_time = time()
  total_job_time = end_time - start_time
  meta_data['wall_time_batch'] = str(total_job_time) + " Seconds"
  meta_data['wall_time_task'] = str(total_job_time/batch_size) + " Seconds (inferred)"

  if is_iterable(infilled_protein):
    for i, protein in enumerate(infilled_protein):
      meta_data['entity_id'] = str(uuid.uuid4())
      #new_name = "chroma_tev_mono_" + meta_data['entity_id'] + ".cif"
      new_name = "chroma_tev_comp_" + meta_data['entity_id'] + ".cif"
      protein.to_CIF(new_name)
      meta_data['output_file_name'] = new_name
      metadata_entry = pd.Series(meta_data)
      all_metadata_df = pd.concat([all_metadata_df,pd.DataFrame(metadata_entry).T], ignore_index=True)
      cleanup_command = f"""mv {new_name} /content/drive/MyDrive/Generative_Models/conditional_generation/chroma_tev/{new_name}"""
      !{cleanup_command}
  else:
    meta_data['entity_id'] = str(uuid.uuid4())
    #new_name = "chroma_tev_mono_" + meta_data['entity_id'] + ".cif"
    new_name = "chroma_tev_comp_" + meta_data['entity_id'] + ".cif"
    infilled_protein.to_CIF(new_name)
    meta_data['output_file_name'] = new_name
    metadata_entry = pd.Series(meta_data)
    all_metadata_df = pd.concat([all_metadata_df,pd.DataFrame(metadata_entry).T], ignore_index=True)
    cleanup_command = f"""mv {new_name} /content/drive/MyDrive/Generative_Models/conditional_generation/chroma_tev/{new_name}"""
    !{cleanup_command}

  all_metadata_df.to_csv(meta_data_filepath, index=False)

  print("Metadata saved. Cleaning up....")
  !rm -rf
  torch.cuda.empty_cache()