In [None]:
from google.colab import drive
drive.mount('/content/drive')

%env PYTHONPATH=
!wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local

!git clone https://github.com/facebookresearch/esm

"""
The base ESM environment is really big and takes ages to install so I've saved a copy of it
on my drive to just extract to build in the ESM_Embeddings notebook
Unzip the contents of the archive directly into /usr/local/envs/
"""
#!tar -xzvf /content/drive/MyDrive/Generative_Models/envs/esm_embeddings.tar.gz -C /usr/local/envs/ >/dev/null 2>&1

In [None]:
%%bash
source activate esmfold
pip install fair-esm
pip install nltk
pip install py3Dmol
pip install hydra

In [None]:
%cd ./esm/examples/lm-design

/content/esm/examples/lm-design


In [None]:
#We also need to change the number of iterations in config so generation doesn't time out on colab and is comparable to other LLMs
import yaml
with open("./conf/config.yaml", 'r') as file:
        config = yaml.safe_load(file)
        config['tasks']['free_generation']['num_iter'] = 250 #default is 170000
        with open("./conf/config.yaml", 'w') as file:
          yaml.dump(config, file, default_flow_style=False)

In [None]:
import os
import shutil
import glob
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import uuid
from datetime import datetime
import re
import random
import torch
from time import time

meta_data_filepath = "/content/drive/MyDrive/Generative_Models/unconditional_generation/esmdesign_unconditional/generation_metadata_esmdesign_higher_iter_test.csv"

if os.path.exists(meta_data_filepath):
  all_metadata_df = pd.read_csv(meta_data_filepath)
  print("Existing generation metadata read in.")
else:
  all_metadata_df = pd.DataFrame()
  #all_metadata_df.to_csv(meta_data_filepath, index=False)
  print("Created generation metadata dataframe")

len_dist_filepath = "/content/drive/MyDrive/Generative_Models/unconditional_generation/esmdesign_unconditional/uniref50_length_dist_esmdesign.json"

if os.path.exists(len_dist_filepath):
  with open(len_dist_filepath, "r") as f:
    uniprot_length_dist =  json.load(f)
  print("Loaded length distribution from drive")
else:
  #https://www.uniprot.org/uniprotkb/statistics#sequence-size
  bins = np.array([13,51,101,151,201,251,301,351,401,451,501,551,601,651,701,751,801,851,901,951,1001,1101,1201,1301,1401,1501,1601,1701,1801,1901,2001,2101,2201,2301,2401,2501,34350])
  swissprot_reviewed = np.array([0,9968,43534,59796,59574,58452,52413,52846,45901,37706,30572,22287,15830,13156,9403,7870,5700,4889,5301,4109,3007,4124,2897,2207,2070,1675,834,642,587,503,395,272,386,340,234,195,1462])
  TrEMBL_unreviewed = np.array([0,2668805,19825275,24705701,23838128,23462438,23225451,21389271,16814580,14287105,11501843,8283150,6266068,4715059,3755005,3186452,2687314,2166878,1843669,1457871,1153537,1975953,1398765,961048,664766,517536,390552,300984,236895,210921,180246,138808,122833,102865,82441,71548,527646])

  ecdf = np.cumsum(swissprot_reviewed) / np.sum(swissprot_reviewed)
  #shortest protein in uniprot is 14 res, longest is 34350 res.
  x = np.arange(14, 34350+1)
  ecdf = np.interp(x, bins, ecdf)

  # Sample from the empirical CDF
  num_samples = 11000
  random_values = np.random.rand(num_samples)
  sampled_lengths = np.round(np.interp(random_values, ecdf, x)).astype(int)
  #ten thousand sequences up to 1000 res in length
  sampled_lengths = sampled_lengths[sampled_lengths <= 1000][0:10000]

  # Plot the histogram of sampled values
  hist_values, bin_edges, patches = plt.hist(sampled_lengths, bins=x[0:1001-13], alpha=0.7, label='Sampled Values')
  plt.xlabel('X-axis label')
  plt.ylabel('Frequency')
  plt.legend()
  plt.show()

  uniprot_length_dist = list(zip([int(edge) for edge in bin_edges],[int(value) for value in hist_values]))
  with open(len_dist_filepath, "w") as f:
      json.dump(uniprot_length_dist, f)


Created generation metadata dataframe
Loaded length distribution from drive


In [None]:
pattern = r"finished after (\d+\.\d+) hours"
for length, batch_size in uniprot_length_dist:
  if all_metadata_df.loc[all_metadata_df.conditions == "length = " + str(length),:].shape[0] >= batch_size: continue
  torch.cuda.empty_cache()
  meta_data = {}
  meta_data["batch_id"] = None
  meta_data["batch_size"] = None
  meta_data["Timestamp"] = str(datetime.now())
  meta_data["model"] = "ESM-Design"
  meta_data["task"] = "sequence_generation"
  meta_data["conditions"] = "length = " + str(int(length))
  meta_data["gpu"] = "T4 GPU"
  meta_data['output_file_name'] = None
  meta_data["wall_time_batch"] = None

  #Can't achieve batching on olab ATM as we get GPU memory overflow problems
  generation_command = f"""
  source activate esmfold
  python -m lm_design task=free_generation free_generation_length={length} num_seqs=1
  """

  #batching leads to crashing so we'll just loop single gens
  for i in range(batch_size):
    #need to change the seed every time
    with open("./conf/config.yaml", 'r') as file:
        config = yaml.safe_load(file)
        config['seed'] = random.randint(0,99999999)
    with open("./conf/config.yaml", 'w') as file:
      yaml.dump(config, file, default_flow_style=False)

    output = !{generation_command}
    match = re.search(pattern, output[-1])
    hours = float(match.group(1))
    meta_data['wall_time_task'] = str(hours*60*60) + " Seconds"
    meta_data['entity_id'] = str(uuid.uuid4())
    meta_data['generated_sequence'] = output[-2].split(" ")[-1]
    metadata_entry = pd.Series(meta_data)
    all_metadata_df = pd.concat([all_metadata_df,pd.DataFrame(metadata_entry).T], ignore_index=True)

  all_metadata_df.to_csv(meta_data_filepath, index=False)
  print("Metadata Updated")
