In [1]:
%load_ext autoreload
%autoreload 2
%cd '/home/hew/python/genhance/'
%ls

/home/hew/python/genhance
[0m[01;34mACE2[0m/  [01;34mdebug[0m/  [01;34moutput[0m/  [01;34mtensorboard[0m/
[01;34mdata[0m/  [01;34mfoldx[0m/  [01;34mtemp[0m/    [01;34mtransformers_custom[0m/


In [2]:
!nvidia-smi

Mon Apr 10 03:27:51 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:05.0 Off |                    0 |
| N/A   36C    P0    68W / 400W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  Off  | 00000000:00:06.0 Off |                    0 |
| N/A   34C    P0    67W / 400W |      0MiB / 81920MiB |      0%      Defaul

In [6]:
import re
import numpy as np
import pandas as pd
import torch
import time
import os
import pickle
from tqdm import tqdm
from pathlib import Path
from transformers import T5Tokenizer
from transformers_custom import MT5ForConditionalGenerationWithLatentSpace

In [43]:
input_data_dir = '/home/hew/storage/storage/genhance/data/'
gen_pretrained_dir = '/home/hew/storage/storage/genhance/ckpts/congen2/results/step_42000/'
generation_output_dir = '/home/hew/storage/storage/genhance/ckpts/congen2/generations/'
prepend_output_name = 'step42000'

topk_as_input = 12500
num_generations = 250000
num_gen_samples_per_input = 20
gen_batch_size = 200
unique_gen = True
temperature_init = 1.0
temperature_multiple = 1.2
gen_token_len = 83 + 2 + 1  # 83 ACE2 domain length + <cls> + </s> + <pad> token (decoder input_ids), other special tokens has been pruned
gen_save_interval = 100
z_tar_edit_before_dec = -1.0
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [16]:
ckpt_args = torch.load(gen_pretrained_dir + '/training_args.bin')
latent_space_args = {
    'latent_pooler'                 : ckpt_args.latent_pooler,
    'pool_enc_hidden_states_for_dec': ckpt_args.pool_enc_hidden_states_for_dec,
    'mask_non_target_z_vector'      : ckpt_args.mask_non_target_z_vector,
    'separate_targetattr_head'      : ckpt_args.separate_targetattr_head,
    'z_tar_vector_dim'              : ckpt_args.z_tar_vector_dim,
    'do_mi'                         : ckpt_args.do_mi,
    'latent_space_type'             : ckpt_args.latent_space_type,
    'latent_size'                   : ckpt_args.latent_size,
    'separate_latent_enc'           : ckpt_args.separate_latent_enc,
    'separate_latent_dec'           : ckpt_args.separate_latent_dec,
    'wae_z_enc_type'                : ckpt_args.wae_z_enc_type,
    }
ckpt_args

Namespace(beta=1.0, beta_ratio_increase=0.25, beta_ratio_zero=0.25, beta_start_step=10000, cache_dir='/home/hew/storage/storage/genhance/pretrained/', contrastive_cyc_start_step=10000, contrastive_perturb_cyc_start_step=-1, data_dir='/home/hew/storage/storage/genhance/data/', dim_target_kl=0.5, do_mi=False, eval_split_name='valid', eval_steps=200, lambda_contrastive=1.0, lambda_contrastive_cyc=1.0, lambda_contrastive_perturb_cyc=0.0, lambda_logvar_KL=0.0, lambda_logvar_L1=0.0, lambda_mi_head_loss=1.0, latent_pooler='cls', latent_size=1024, latent_space_type='wae', logging_dir='/home/hew/python/genhance/tensorboard/congen2', logging_steps=20, lr=0.0001, mask_non_target_z_vector=False, mmd_method='rf', num_decode_layers=6, num_layers=6, num_train_epochs=30, num_warmup_steps=0, output_dir='/home/hew/storage/storage/genhance/ckpts/congen2/results/', pc_perturb=-0.25, pc_perturb_type='std', per_device_eval_batch_size=256, per_device_train_batch_size=80, pool_enc_hidden_states_for_dec=True, 

In [17]:
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", cache_dir=ckpt_args.cache_dir)
tokenizer.add_special_tokens({"cls_token": "<cls>"})
assert tokenizer.cls_token == "<cls>"

In [18]:
tokenizer.get_vocab()

{'<pad>': 0,
 '</s>': 1,
 '<unk>': 2,
 '▁A': 3,
 '▁L': 4,
 '▁G': 5,
 '▁V': 6,
 '▁S': 7,
 '▁R': 8,
 '▁E': 9,
 '▁D': 10,
 '▁T': 11,
 '▁I': 12,
 '▁P': 13,
 '▁K': 14,
 '▁F': 15,
 '▁Q': 16,
 '▁N': 17,
 '▁Y': 18,
 '▁M': 19,
 '▁H': 20,
 '▁W': 21,
 '▁C': 22,
 '▁X': 23,
 '▁B': 24,
 '▁O': 25,
 '▁U': 26,
 '▁Z': 27,
 '<extra_id_99>': 28,
 '<extra_id_98>': 29,
 '<extra_id_97>': 30,
 '<extra_id_96>': 31,
 '<extra_id_95>': 32,
 '<extra_id_94>': 33,
 '<extra_id_93>': 34,
 '<extra_id_92>': 35,
 '<extra_id_91>': 36,
 '<extra_id_90>': 37,
 '<extra_id_89>': 38,
 '<extra_id_88>': 39,
 '<extra_id_87>': 40,
 '<extra_id_86>': 41,
 '<extra_id_85>': 42,
 '<extra_id_84>': 43,
 '<extra_id_83>': 44,
 '<extra_id_82>': 45,
 '<extra_id_81>': 46,
 '<extra_id_80>': 47,
 '<extra_id_79>': 48,
 '<extra_id_78>': 49,
 '<extra_id_77>': 50,
 '<extra_id_76>': 51,
 '<extra_id_75>': 52,
 '<extra_id_74>': 53,
 '<extra_id_73>': 54,
 '<extra_id_72>': 55,
 '<extra_id_71>': 56,
 '<extra_id_70>': 57,
 '<extra_id_69>': 58,
 '<extra_id_

In [20]:
input_data_path = Path(input_data_dir)
input_data_file = f'train_tophalf_ddG_solubility.pkl'
input_data_file = input_data_path/input_data_file
input_data_df = pd.read_pickle(input_data_file)
train_seq_list = input_data_df['MT_seq'].tolist()
ddG_sorted_input_df = input_data_df.sort_values(by='ddG', ascending=True)
gen_input_df = ddG_sorted_input_df.iloc[:topk_as_input]

In [21]:
len(train_seq_list), len(ddG_sorted_input_df), len(gen_input_df)

(125000, 125000, 12500)

In [22]:
if num_generations is None:
    num_generations = topk_as_input*num_gen_samples_per_input
num_unique_seqs_per_batch = gen_batch_size//num_gen_samples_per_input
num_batch = len(gen_input_df)//num_unique_seqs_per_batch
if len(gen_input_df)%num_unique_seqs_per_batch != 0:
    num_batch += 1

In [23]:
print("="*100)
print('generation conifg:')
print('num_generations: ', num_generations)
print('topk_as_input: ', topk_as_input)
print('num_gen_samples_per_input: ', num_gen_samples_per_input)
print('num_unique_seqs_per_batch: ', num_unique_seqs_per_batch)
print('num_batch: ', num_batch)
print("="*100)

generation conifg:
num_generations:  250000
topk_as_input:  12500
num_gen_samples_per_input:  20
num_unique_seqs_per_batch:  10
num_batch:  1250


## from_pretrained

In [24]:
%%time
gen_model = MT5ForConditionalGenerationWithLatentSpace.from_pretrained(gen_pretrained_dir,
                                                                       cache_dir=ckpt_args.cache_dir,
                                                                       num_layers=ckpt_args.num_layers,
                                                                       num_decoder_layers=ckpt_args.num_decoder_layers if
                                                                       'num_decoder_layers' in ckpt_args else ckpt_args.num_decode_layers,
                                                                       **latent_space_args)
gen_model.parallelize()
gen_model.resize_token_embeddings(len(tokenizer))

latent_space_type:  wae
wae_z_enc_type:  deterministic
separate_latent_enc:  False
separate_latent_dec:  False
mmd_method:  rf
sigma_mmd:  None
rf_dim_mmd:  None
dim_target_kl:  0.5
latent_size:  1024
latent_pooler:  cls
pool_enc_hidden_states_for_dec:  True
mask_non_target_z_vector:  False
separate_targetattr_head:  False
do_mi:  False
CPU times: user 55.3 s, sys: 13.4 s, total: 1min 8s
Wall time: 55.2 s


Embedding(129, 1024)

In [25]:
gen_model.device

device(type='cuda', index=0)

In [26]:
output_seq_list = []
input_seq_list = []
output_tensor_list = []
repeat_list = []
in_train_data_list = []
unique_n_notrain_list = []
start_time = time.time()
prev_save_path = None
repeat_seq_count = 0
in_train_count = 0
temperature = temperature_init
generation_rounds_done = 0
num_cls_seq = 0
num_fail_seq = 0

In [27]:
wt_seq = 'STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLQEIQNLTVKLQLQALQ'
len(wt_seq)

83

In [60]:
bad_words_ids = [0, 2] + list(range(23, 127 + 1))
bad_words_ids = [[x] for x in bad_words_ids]
bad_words_ids, len(bad_words_ids)

([[0],
  [2],
  [23],
  [24],
  [25],
  [26],
  [27],
  [28],
  [29],
  [30],
  [31],
  [32],
  [33],
  [34],
  [35],
  [36],
  [37],
  [38],
  [39],
  [40],
  [41],
  [42],
  [43],
  [44],
  [45],
  [46],
  [47],
  [48],
  [49],
  [50],
  [51],
  [52],
  [53],
  [54],
  [55],
  [56],
  [57],
  [58],
  [59],
  [60],
  [61],
  [62],
  [63],
  [64],
  [65],
  [66],
  [67],
  [68],
  [69],
  [70],
  [71],
  [72],
  [73],
  [74],
  [75],
  [76],
  [77],
  [78],
  [79],
  [80],
  [81],
  [82],
  [83],
  [84],
  [85],
  [86],
  [87],
  [88],
  [89],
  [90],
  [91],
  [92],
  [93],
  [94],
  [95],
  [96],
  [97],
  [98],
  [99],
  [100],
  [101],
  [102],
  [103],
  [104],
  [105],
  [106],
  [107],
  [108],
  [109],
  [110],
  [111],
  [112],
  [113],
  [114],
  [115],
  [116],
  [117],
  [118],
  [119],
  [120],
  [121],
  [122],
  [123],
  [124],
  [125],
  [126],
  [127]],
 107)

In [69]:
gen_model.eval()
while unique_gen and np.sum(unique_n_notrain_list) < num_generations:
    if generation_rounds_done > 0:
        temperature = temperature*temperature_multiple
        print("New generation round, temperature: ", temperature,
              "num_unique_n_notrain_list: ", np.sum(unique_n_notrain_list))

    for batch_ind in tqdm(range(num_batch)):
        batch_seqs = gen_input_df[batch_ind*num_unique_seqs_per_batch: (batch_ind + 1)*num_unique_seqs_per_batch]['MT_seq']

        batch_input_ids = []
        batch_input_seqs = []
        for seq in batch_seqs:
            batch_input_seqs = batch_input_seqs + [seq]*num_gen_samples_per_input
            seq = '<cls> ' + " ".join(list(re.sub(r"[UZOB]", "X", seq)))
            input_ids = tokenizer.encode(seq, return_tensors='pt').to(gen_model.device)
            repeated_input_ids = input_ids.repeat((num_gen_samples_per_input, 1))
            batch_input_ids.append(repeated_input_ids)

        batch_input_ids = torch.cat(batch_input_ids, dim=0)
        print("batch_input_ids.shape: ", batch_input_ids.shape)
        print("batch_input_ids[0]: ", len(batch_input_ids[0]), batch_input_ids[0])

        gen_output = gen_model.generate(batch_input_ids,
                                        max_length=83 + 2 + 1,
                                        min_length=83 + 2 + 1,
                                        do_sample=True,
                                        temperature=temperature,
                                        bad_words_ids=bad_words_ids,
                                        z_tar_edit_before_dec=z_tar_edit_before_dec)
        print("gen_output.shape: ", gen_output.shape)  # torch.Size([200, 86])
        print('gen_output: ', gen_output[0])
        for seq_ind, gen_seq in enumerate(gen_output.cpu().numpy()):
            unique_n_notrain = True
            repeat = False
            in_train_data = False

            tokens = tokenizer.convert_ids_to_tokens(gen_seq.tolist())
            # print("len(tokens): ", len(tokens)) # 86
            if tokens == None or len(tokens) != gen_token_len:
                continue

            # print("tokens[:2]: ", tokens[:2])  # ['<pad>', '<cls>']
            # print("gen_seq[:2]: ", gen_seq[:2])  # [0 128]
            str_token_seq = "".join(tokens[2:-1]).replace('▁', '')
            print(f"[{seq_ind}] str_token_seq: ", str_token_seq)

            if num_fail_seq%1000 == 0 or np.sum(unique_n_notrain_list)%1000 == 0:
                num_cls_seq += 1
                if num_cls_seq%100 == 0:
                    print("num failed gen: {},".format(num_fail_seq),
                          ", num valid gen: {},".format(np.sum(unique_n_notrain_list)),
                          ", gen/total: {:.2f}%".format(np.sum(unique_n_notrain_list)/num_generations*100))

            raise RuntimeError()  # debug

            if str_token_seq in output_seq_list:
                repeat_seq_count += 1
                repeat = True
                unique_n_notrain = False

            if str_token_seq in train_seq_list:
                in_train_count += 1
                in_train_data = True
                unique_n_notrain = False

            if unique_gen and not unique_n_notrain:
                continue

            unique_n_notrain_list.append(unique_n_notrain)
            repeat_list.append(repeat)
            in_train_data_list.append(in_train_data)

            input_seq_str = batch_input_seqs[seq_ind]
            input_seq_list.append(input_seq_str)
            output_seq_list.append(str_token_seq)

            seq_tensor = gen_output[seq_ind].detach().cpu()
            output_tensor_list.append(seq_tensor)

        if batch_ind%gen_save_interval == 0 and batch_ind != 0:
            save_path = os.path.join(generation_output_dir,
                                     "{}-gens-{}-{}.pkl".format(prepend_output_name,
                                                                len(output_seq_list),
                                                                num_generations))
            saved_dict = {
                'output_seq_list': output_seq_list, "input_seq_list": input_seq_list, "output_tensor_list": output_tensor_list,
                'repeat_list'    : repeat_list, 'in_train_data_list': in_train_data_list, 'temperature': temperature
                }
            with open(save_path, 'wb') as f:
                pickle.dump(saved_dict, f)
            cur_time = time.time()

            print('='*50, 'interval save', '='*50)
            print("generated #", len(output_seq_list))
            print("Time taken so far: {} hours".format((cur_time - start_time)/3600))
            print('='*50, 'interval save', '='*50)

            if prev_save_path is not None:
                os.remove(prev_save_path)
            prev_save_path = save_path

        if unique_gen and np.sum(unique_n_notrain_list) > num_generations:
            break
    generation_rounds_done += 1

'''Save Final Data'''
save_path = os.path.join(generation_output_dir, "{}-gens-{}.pkl".format(prepend_output_name,
                                                                        num_generations))
saved_dict = {
    'output_seq_list': output_seq_list, "input_seq_list": input_seq_list, "output_tensor_list": output_tensor_list,
    'repeat_list'    : repeat_list, 'in_train_data_list': in_train_data_list, 'temperature': temperature
    }
with open(save_path, 'wb') as f:
    pickle.dump(saved_dict, f)

if prev_save_path is not None:
    os.remove(prev_save_path)

  0%|          | 0/1250 [00:00<?, ?it/s]

batch_input_ids.shape:  torch.Size([200, 85])
batch_input_ids[0]:  85 tensor([128,   7,  11,  12,   9,   9,  16,   3,  14,  11,  15,   4,  10,  14,
         15,  17,  20,   9,  14,   9,  10,   4,  15,  18,  16,   7,   7,   4,
          3,  12,  21,  17,  18,  17,  11,  17,  12,  11,   9,   9,  17,   6,
         16,  17,  19,  17,  15,   3,  16,  10,  14,  21,   7,   3,  15,   4,
         14,   9,  16,   7,  11,   4,   3,  16,  19,  18,  13,   4,  16,   9,
         12,  16,  17,   4,  11,   6,  14,   4,  15,   4,  16,   3,   4,  19,
          1], device='cuda:0')


  0%|          | 0/1250 [00:10<?, ?it/s]

gen_output.shape:  torch.Size([200, 86])
gen_output:  tensor([  0, 128,   7,  11,  12,   9,   9,  16,   3,  14,  11,  15,   4,  10,
         14,  15,  17,  20,   9,  14,   9,  10,   4,  15,  18,  16,   7,   7,
          4,   3,  12,  21,  17,  18,  17,  11,  17,  12,  11,   9,   9,  17,
          6,  16,  17,  19,  17,  15,   3,  16,  10,  14,  21,   7,   3,  15,
          4,  14,   9,  16,   7,  11,   4,   3,  16,  19,  18,  13,   4,  16,
          9,  12,  16,  17,   4,  11,   6,  14,   4,  15,   4,  16,   3,   4,
         19,  15], device='cuda:0')
[0] str_token_seq:  STIEEQAKTFLDKFNHEKEDLFYQSSLAIWNYNTNITEENVQNMNFAQDKWSAFLKEQSTLAQMYPLQEIQNLTVKLFLQALM





RuntimeError: 

In [72]:
batch_input_ids.shape, gen_output.shape

(torch.Size([200, 85]), torch.Size([200, 86]))

In [71]:
batch_input_ids[0]

tensor([128,   7,  11,  12,   9,   9,  16,   3,  14,  11,  15,   4,  10,  14,
         15,  17,  20,   9,  14,   9,  10,   4,  15,  18,  16,   7,   7,   4,
          3,  12,  21,  17,  18,  17,  11,  17,  12,  11,   9,   9,  17,   6,
         16,  17,  19,  17,  15,   3,  16,  10,  14,  21,   7,   3,  15,   4,
         14,   9,  16,   7,  11,   4,   3,  16,  19,  18,  13,   4,  16,   9,
         12,  16,  17,   4,  11,   6,  14,   4,  15,   4,  16,   3,   4,  19,
          1], device='cuda:0')

In [78]:
gen_output[0]

tensor([  0, 128,   7,  11,  12,   9,   9,  16,   3,  14,  11,  15,   4,  10,
         14,  15,  17,  20,   9,  14,   9,  10,   4,  15,  18,  16,   7,   7,
          4,   3,  12,  21,  17,  18,  17,  11,  17,  12,  11,   9,   9,  17,
          6,  16,  17,  19,  17,  15,   3,  16,  10,  14,  21,   7,   3,  15,
          4,  14,   9,  16,   7,  11,   4,   3,  16,  19,  18,  13,   4,  16,
          9,  12,  16,  17,   4,  11,   6,  14,   4,  15,   4,  16,   3,   4,
         19,  15], device='cuda:0')

In [79]:
gen_output[0][2:-1]

tensor([ 7, 11, 12,  9,  9, 16,  3, 14, 11, 15,  4, 10, 14, 15, 17, 20,  9, 14,
         9, 10,  4, 15, 18, 16,  7,  7,  4,  3, 12, 21, 17, 18, 17, 11, 17, 12,
        11,  9,  9, 17,  6, 16, 17, 19, 17, 15,  3, 16, 10, 14, 21,  7,  3, 15,
         4, 14,  9, 16,  7, 11,  4,  3, 16, 19, 18, 13,  4, 16,  9, 12, 16, 17,
         4, 11,  6, 14,  4, 15,  4, 16,  3,  4, 19], device='cuda:0')

In [80]:
gen_seq = str_token_seq
gen_seq, len(gen_seq)

('STIEEQAKTFLDKFNHEKEDLFYQSSLAIWNYNTNITEENVQNMNFAQDKWSAFLKEQSTLAQMYPLQEIQNLTVKLFLQALM',
 83)

In [81]:
wt_seq, len(wt_seq)

('STIEEQAKTFLDKFNHEAEDLFYQSSLASWNYNTNITEENVQNMNNAGDKWSAFLKEQSTLAQMYPLQEIQNLTVKLQLQALQ',
 83)

In [82]:
dist = 0
for i in range(len(gen_seq)):
    if gen_seq[i] != wt_seq[i]:
        dist += 1
dist

6