In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!git clone https://github.com/salesforce/progen
%cd progen/progen2
!wget -P checkpoints/${model} https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-large.tar.gz
%cd ./checkpoints
!mkdir progen2-large
%cd ..
!tar -xvf ./checkpoints/progen2-large.tar.gz -C ./checkpoints/progen2-large/

Cloning into 'progen'...
remote: Enumerating objects: 248, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 248 (delta 9), reused 10 (delta 3), pack-reused 227[K
Receiving objects: 100% (248/248), 58.62 KiB | 1.67 MiB/s, done.
Resolving deltas: 100% (142/142), done.
/content/progen/progen2
--2024-05-24 00:02:01--  https://storage.googleapis.com/sfr-progen-research/checkpoints/progen2-large.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.197.207, 209.85.145.207, 142.250.125.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.197.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5137860178 (4.8G) [application/x-tar]
Saving to: ‘checkpoints/progen2-large.tar.gz’


2024-05-24 00:03:05 (76.4 MB/s) - ‘checkpoints/progen2-large.tar.gz’ saved [5137860178/5137860178]

/content/progen/progen2/checkpoints
/content/progen/progen2
./
./config

In [None]:
#For the sake of speed. We're going to rework the sample.py script from the package so we are not having to reload the model every generation. So we'll load the package in and redefine some code
setup_code = """
from setuptools import setup, find_packages

setup(
    name='progen2',
    version='1',
    packages=find_packages(),
    install_requires=[
    ],
)
"""

with open('setup.py', 'w') as setup_file:
    setup_file.write(setup_code)
print("setup.py created successfully.")
!pip install -e .
!pip install accelerate


In [None]:
#Non main() classes/methods from sample.py
import os
import time
import random
import argparse
import torch
from tokenizers import Tokenizer
from models.progen.modeling_progen import ProGenForCausalLM
########################################################################
# util
class print_time:
    def __init__(self, desc):
        self.desc = desc

    def __enter__(self):
        print(self.desc)
        self.t = time.time()

    def __exit__(self, type, value, traceback):
        print(f'{self.desc} took {time.time()-self.t:.02f}s')


def set_env():
    os.environ['TOKENIZERS_PARALLELISM'] = 'false'


def set_seed(seed, deterministic=True):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = deterministic
        torch.backends.cudnn.benchmark = not deterministic
########################################################################
# model

def create_model(ckpt, fp16=True):
    if fp16:
        return ProGenForCausalLM.from_pretrained(ckpt, revision='float16', torch_dtype=torch.float16, low_cpu_mem_usage=True)
    else:
        return ProGenForCausalLM.from_pretrained(ckpt)


def create_tokenizer_custom(file):
    with open(file, 'r') as f:
        return Tokenizer.from_str(f.read())

########################################################################
# sample

def sample(device, model, tokenizer, context, max_length, num_return_sequences, top_p, temp, pad_token_id):

    with torch.no_grad():
        input_ids = torch.tensor(tokenizer.encode(context).ids).view([1, -1]).to(device)
        tokens_batch = model.generate(input_ids, do_sample=True, temperature=temp, max_length=max_length, top_p=top_p, num_return_sequences=num_return_sequences, pad_token_id=pad_token_id)
        as_lists = lambda batch: [batch[i, ...].detach().cpu().numpy().tolist() for i in range(batch.shape[0])]
        return tokenizer.decode_batch(as_lists(tokens_batch))

def truncate(sample, terminals):
    pos = []
    for terminal in terminals:
        find_pos = sample.find(terminal, 1)
        if find_pos != -1:
            pos.append(find_pos)
    if len(pos) > 0:
        return sample[:(min(pos)+1)]
    else:
        return sample

def cross_entropy(logits, target, reduction='mean'):
    return torch.nn.functional.cross_entropy(input=logits, target=target, weight=None, size_average=None, reduce=None, reduction=reduction)


In [None]:
import random
#Pre-sampling part of Main method reworked
#####
# (0) constants
models_151M = [ 'progen2-small' ]
models_754M = [ 'progen2-medium', 'progen2-oas', 'progen2-base' ]
models_2B = [ 'progen2-large', 'progen2-BFD90' ]
models_6B = [ 'progen2-xlarge' ]
models = models_151M + models_754M + models_2B + models_6B

# (1) params
use_model = 'progen2-large'
use_device='cuda:0'
rng_seed = random.randint(1,9999999) #Need to run 1 sample at a time to record the time per length. This also means we must use a different seed each time for the script
rng_deterministic = False
fp16 = True
sanity=True

# (2) preamble
set_env()
set_seed(rng_seed, deterministic=rng_deterministic)
if not torch.cuda.is_available():
    print('falling back to cpu')
    args.device = 'cpu'
device = torch.device(use_device)
ckpt = f'./checkpoints/{use_model}'
if device.type == 'cpu':
    print('falling back to fp32')
    fp16 = False
# (3) load
with print_time('loading parameters'):
    model = create_model(ckpt=ckpt, fp16=fp16).to(device)
with print_time('loading tokenizer'):
    tokenizer = create_tokenizer_custom(file='tokenizer.json')
# (4) sanity
if sanity:
    with print_time('sanity cross-entropy'):
        def ce(tokens):
            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=fp16):
                    target = torch.tensor(tokenizer.encode(tokens).ids).to(device)
                    logits = model(target, labels=target).logits
                    # shift
                    logits = logits[:-1, ...]
                    target = target[1:]
                    return cross_entropy(logits=logits, target=target).item()
        x_uniref90bfd30 = '2GFLPFRGADEGLAAREAATLAARGTAARAYREDSWAVPVPRGLLGDLTARVAALGAASPPPADPLAVTLDLHHVTAEVALTTVLDAATLVHGQTRVLSAEDAAEAATAAAAATEAYLERLQDFVLFMSASVRVWRRGNAAGATGPEWDQWYTVADRDALGSAPTHLAVLGRQADALCHFVLDRVAWGTCGTPLWSGDEDLGNVVATFAGYADRLATAPRDLIM1'
        x_oas = '1EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMHWVRQAPWKGLEYVSAISSNGGSTYYANSVKGRFTISRDNSKNTLYLQMGSLRAEDMAVYYCARDESGYSYGWGYYFDYWGQGTLVTVSS2'
        x_bfd90 = '1TAPRSTRASGSEGSRPPGIPAKGRRCLPSRAGSVTPRFRHARQGTATVAKEQGRKLIASNRKARHDYHIEDTFEAGLVLTGTEVKSLRMGRASLIDGYAVFYGEELWLEGVHIPEYLNGNWTNHTPRRRRKLLLNRSELTKLAHKTSESGHTIVPLALYFKDGRAKVEIAVAKGKKAYDKRHALRERQDQREV2'
        checkpoint_x_ce = {
            'progen2-small': (x_uniref90bfd30, 2.4),
            'progen2-medium': (x_uniref90bfd30, 1.9),
            'progen2-base': (x_uniref90bfd30, 1.9),
            'progen2-large': (x_uniref90bfd30, 1.8),
            'progen2-xlarge': (x_uniref90bfd30, 1.0),
            'progen2-oas': (x_oas, 0.3),
            'progen2-BFD90': (x_bfd90, 1.3),
        }
        ce_eval = ce(checkpoint_x_ce[use_model][0])
        ce_target = checkpoint_x_ce[use_model][1]
        print(ce_target, ce_eval, abs(ce_eval - ce_target))
        assert abs(ce_eval - ce_target) < 0.1

loading parameters
loading parameters took 42.76s
loading tokenizer
loading tokenizer took 0.04s
sanity cross-entropy
1.8 1.8524295091629028 0.05242950916290279
sanity cross-entropy took 1.40s


In [None]:
import os
import shutil
import glob
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import uuid
from datetime import datetime
import re
import torch
import time

meta_data_filepath = "/content/drive/MyDrive/Generative_Models/unconditional_generation/progen2_unconditional/generation_metadata_progen2.csv"

if os.path.exists(meta_data_filepath):
  all_metadata_df = pd.read_csv(meta_data_filepath)
  print("Existing generation metadata read in.")
else:
  all_metadata_df = pd.DataFrame()
  #all_metadata_df.to_csv(meta_data_filepath, index=False)
  print("Created generation metadata dataframe")


len_dist_filepath = "/content/drive/MyDrive/Generative_Models/unconditional_generation/progen2_unconditional/uniref50_length_dist_progen2.json"

if os.path.exists(len_dist_filepath):
  with open(len_dist_filepath, "r") as f:
    uniprot_length_dist =  json.load(f)
  print("Loaded length distribution from drive")
else:

  #https://www.uniprot.org/uniprotkb/statistics#sequence-size
  bins = np.array([13,51,101,151,201,251,301,351,401,451,501,551,601,651,701,751,801,851,901,951,1001,1101,1201,1301,1401,1501,1601,1701,1801,1901,2001,2101,2201,2301,2401,2501,34350])
  swissprot_reviewed = np.array([0,9968,43534,59796,59574,58452,52413,52846,45901,37706,30572,22287,15830,13156,9403,7870,5700,4889,5301,4109,3007,4124,2897,2207,2070,1675,834,642,587,503,395,272,386,340,234,195,1462])
  TrEMBL_unreviewed = np.array([0,2668805,19825275,24705701,23838128,23462438,23225451,21389271,16814580,14287105,11501843,8283150,6266068,4715059,3755005,3186452,2687314,2166878,1843669,1457871,1153537,1975953,1398765,961048,664766,517536,390552,300984,236895,210921,180246,138808,122833,102865,82441,71548,527646])

  ecdf = np.cumsum(swissprot_reviewed) / np.sum(swissprot_reviewed)
  #shortest protein in uniprot is 14 res, longest is 34350 res.
  x = np.arange(14, 34350+1)
  ecdf = np.interp(x, bins, ecdf)

  # Sample from the empirical CDF
  num_samples = 11000
  random_values = np.random.rand(num_samples)
  sampled_lengths = np.round(np.interp(random_values, ecdf, x)).astype(int)
  #ten thousand sequences up to 1000 res in length
  sampled_lengths = sampled_lengths[sampled_lengths <= 1000][0:10000]

  # Plot the histogram of sampled values
  hist_values, bin_edges, patches = plt.hist(sampled_lengths, bins=x[0:1001-13], alpha=0.7, label='Sampled Values')
  plt.xlabel('X-axis label')
  plt.ylabel('Frequency')
  plt.legend()
  plt.show()

  uniprot_length_dist = list(zip([int(edge) for edge in bin_edges],[int(value) for value in hist_values]))
  with open(len_dist_filepath, "w") as f:
      json.dump(uniprot_length_dist, f)


Existing generation metadata read in.
Loaded length distribution from drive


In [None]:

#sampling
#########
context = 'M'
#temperature and top_p are taken from example on github page rather than command arg defaults as it seems to generate better results
temp = 0.8
top_p = 0.9
meta_data = {}
i = 0
while i < 10000:
  torch.cuda.empty_cache()
  if all_metadata_df.empty:
    max_length = max([l[0] for l in uniprot_length_dist])
  else:
    sampling_lengths = {s: n for s, n in uniprot_length_dist if n > 0}
    for l in all_metadata_df["conditions"].str.extract(r'length = (\d+)', expand=False).astype(int):
      if l < 14: continue
      if l > 1000: continue
      sampling_lengths[l] = sampling_lengths[l] -1
    sampling_lengths = {s: n for s, n in sampling_lengths.items() if n > 0}
    max_length = max(sampling_lengths.keys())
    #max_length = 100
  print("Max sampling length: " + str(max_length))
  meta_data['entity_id'] = str(uuid.uuid4())
  meta_data["batch_id"] = None
  meta_data["batch_size"] = None
  meta_data['output_file_name'] = None
  meta_data["Timestamp"] = str(datetime.now())
  meta_data['model'] = 'ProGen2'
  meta_data['task'] = 'sequence_generation'
  meta_data['wall_time_batch'] = None
  meta_data['gpu'] = 'T4 GPU'

  start_time = time.time()
  completions = sample(device=device, model=model, tokenizer=tokenizer, context=context, pad_token_id=tokenizer.encode('<|pad|>').ids[0], num_return_sequences=1, temp=temp, top_p=top_p, max_length=max_length)
  truncations = [truncate(completion, terminals=['1', '2']) for completion in completions]
  end_time = time.time()
  meta_data['wall_time_task'] = str(end_time-start_time) + " Seconds"
  #assuming only one sequence generated!!!
  sequence = truncations[0]
  print(sequence)

  length = len(sequence)
  print("Sampled length: " + str(length))
  meta_data['conditions'] = 'length = ' + str(length)
  meta_data['generated_sequence'] = sequence
  metadata_entry = pd.Series(meta_data)
  all_metadata_df = pd.concat([all_metadata_df,pd.DataFrame(metadata_entry).T], ignore_index=True)

  i = i + 1
  if i % 1 == 0:
    all_metadata_df.to_csv(meta_data_filepath, index=False)
    print("saved to metadata " + str(datetime.now()))

Max sampling length: 18
MPRLTRSVAFADFVVEGP
Sampled length: 18
saved to metadata 2024-05-24 00:11:03.571860
Max sampling length: 16
MTFDGRSAPTVVEPAV
Sampled length: 16
saved to metadata 2024-05-24 00:11:04.621323
Max sampling length: 15
MRDLDKRGMLERTAE
Sampled length: 15
saved to metadata 2024-05-24 00:11:05.627964
Max sampling length: 15
MRDAEELFLVANASE
Sampled length: 15
saved to metadata 2024-05-24 00:11:06.621454
Max sampling length: 15
MRGVVTNLTPYDAPR
Sampled length: 15
saved to metadata 2024-05-24 00:11:07.620490
Max sampling length: 15
MGLLSQAAAAAEGKQ
Sampled length: 15
saved to metadata 2024-05-24 00:11:08.607869
Max sampling length: 15
MGIRRSPTAIIAAVE
Sampled length: 15
saved to metadata 2024-05-24 00:11:09.614635
Max sampling length: 14
MDATQLGAAPEQFL
Sampled length: 14
saved to metadata 2024-05-24 00:11:10.546305
Max sampling length: 14
MALDIGSLNLAGTI
Sampled length: 14
saved to metadata 2024-05-24 00:11:11.796287


ValueError: max() arg is an empty sequence

In [None]:
!kill -9 -1