In [None]:
from google.colab import drive
drive.mount('/content/drive')
#RitaModelForCausalLM is not compatible with current versions of HuggingFace Transformers so we need to downgrade (~version at time of publishing)
!pip install transformers==4.19.0

In [None]:
import os
import shutil
import glob
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import uuid
from datetime import datetime
import re
import torch
import time

meta_data_filepath = "/content/drive/MyDrive/Generative_Models/unconditional_generation/rita_unconditional/generation_metadata_rita.csv"

if os.path.exists(meta_data_filepath):
  all_metadata_df = pd.read_csv(meta_data_filepath)
  print("Existing generation metadata read in.")
else:
  all_metadata_df = pd.DataFrame()
  #all_metadata_df.to_csv(meta_data_filepath, index=False)
  print("Created generation metadata dataframe")

len_dist_filepath = "/content/drive/MyDrive/Generative_Models/unconditional_generation/rita_unconditional/uniref50_length_dist_rita.json"

if os.path.exists(len_dist_filepath):
  with open(len_dist_filepath, "r") as f:
    uniprot_length_dist =  json.load(f)
  print("Loaded length distribution from drive")
else:
  #https://www.uniprot.org/uniprotkb/statistics#sequence-size
  bins = np.array([13,51,101,151,201,251,301,351,401,451,501,551,601,651,701,751,801,851,901,951,1001,1101,1201,1301,1401,1501,1601,1701,1801,1901,2001,2101,2201,2301,2401,2501,34350])
  swissprot_reviewed = np.array([0,9968,43534,59796,59574,58452,52413,52846,45901,37706,30572,22287,15830,13156,9403,7870,5700,4889,5301,4109,3007,4124,2897,2207,2070,1675,834,642,587,503,395,272,386,340,234,195,1462])
  TrEMBL_unreviewed = np.array([0,2668805,19825275,24705701,23838128,23462438,23225451,21389271,16814580,14287105,11501843,8283150,6266068,4715059,3755005,3186452,2687314,2166878,1843669,1457871,1153537,1975953,1398765,961048,664766,517536,390552,300984,236895,210921,180246,138808,122833,102865,82441,71548,527646])

  ecdf = np.cumsum(swissprot_reviewed) / np.sum(swissprot_reviewed)
  #shortest protein in uniprot is 14 res, longest is 34350 res.
  x = np.arange(14, 34350+1)
  ecdf = np.interp(x, bins, ecdf)

  # Sample from the empirical CDF
  num_samples = 11000
  random_values = np.random.rand(num_samples)
  sampled_lengths = np.round(np.interp(random_values, ecdf, x)).astype(int)
  #ten thousand sequences up to 1000 res in length
  sampled_lengths = sampled_lengths[sampled_lengths <= 1000][0:10000]

  # Plot the histogram of sampled values
  hist_values, bin_edges, patches = plt.hist(sampled_lengths, bins=x[0:1001-13], alpha=0.7, label='Sampled Values')
  plt.xlabel('X-axis label')
  plt.ylabel('Frequency')
  plt.legend()
  plt.show()

  uniprot_length_dist = list(zip([int(edge) for edge in bin_edges],[int(value) for value in hist_values]))
  with open(len_dist_filepath, "w") as f:
      json.dump(uniprot_length_dist, f)


Existing generation metadata read in.
Loaded length distribution from drive


In [None]:
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("lightonai/RITA_xl", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("lightonai/RITA_xl")
model.to('cuda')

from transformers import pipeline
rita_gen = pipeline('text-generation', model=model, tokenizer= tokenizer, trust_remote_code=True, device=0)

In [None]:
i = 0
meta_data = {}
while i < 10000:
  if all_metadata_df.empty:
    max_length = max([l[0] for l in uniprot_length_dist])
  else:
    sampling_lengths = {s: n for s, n in uniprot_length_dist if n > 0}
    for l in all_metadata_df["conditions"].str.extract(r'length = (\d+)', expand=False).astype(int):
      if l >= 14 and l in sampling_lengths.keys():
        sampling_lengths[l] = sampling_lengths[l] -1
    sampling_lengths = {s: n for s, n in sampling_lengths.items() if n > 0}
    max_length = max(sampling_lengths.keys())
    print("Max generation length: " + str(max_length))

  meta_data['entity_id'] = str(uuid.uuid4())
  meta_data["batch_id"] = None
  meta_data["batch_size"] = None
  meta_data['output_file_name'] = None
  meta_data["timestamp"] = str(datetime.now())
  meta_data['model'] = 'Rita'
  meta_data['task'] = 'sequence_generation'
  meta_data['wall_time_batch'] = None
  meta_data['conditions'] = 'max length = ' + str(max_length)
  meta_data['gpu'] = 'T4 GPU'

  start_time = time.time()
  sequences = rita_gen("M", max_length=max_length, do_sample=True, top_k=950, repetition_penalty=1.2,
                     num_return_sequences=1, eos_token_id=2)
  end_time = time.time()
  meta_data['wall_time_task'] = str(end_time-start_time) + " Seconds"
  sequence = sequences[0]['generated_text'].replace(' ', '')
  print(sequence)
  length = len(sequence)
  print("Generated length: " + str(length))
  meta_data['conditions'] = 'length = ' + str(length)
  meta_data['generated_sequence'] = sequence
  metadata_entry = pd.Series(meta_data)
  all_metadata_df = pd.concat([all_metadata_df,pd.DataFrame(metadata_entry).T], ignore_index=True)
  i = i + 1
  if i % 5 == 0:
    all_metadata_df.to_csv(meta_data_filepath, index=False)
    print("saved to metadata " + str(datetime.now()))