In [1]:
import argparse
import json, os
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import GenerationConfig
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
generation_config = GenerationConfig(
    temperature=0.2,
    top_k=40,
    top_p=0.9,
    do_sample=True,
    num_beams=1,
    repetition_penalty=1.2,
    max_new_tokens=400
)

In [3]:
model_path = '/ai/sharedisk/workspace/weil/ckpts/ProLLaMA'
load_type = torch.bfloat16
device = 'cuda'

In [4]:
tokenizer = LlamaTokenizer.from_pretrained(model_path)

In [5]:
model = LlamaForCausalLM.from_pretrained(
    model_path,
    # torch_dtype=load_type,
    low_cpu_mem_usage=True,
    device_map='auto',
    quantization_config=None
)

Loading checkpoint shards: 100%|██████████| 2/2 [01:32<00:00, 46.10s/it]


In [6]:
input_text = 'Superfamily=<Ankyrin repeat-containing domain superfamily>'

In [7]:
input_tokens = tokenizer(input_text, return_tensors="pt")  

In [8]:
input_tokens

{'input_ids': tensor([[    1,  5670, 11922, 29922, 29966,  2744,  3459, 17056, 12312, 29899,
          1285, 17225,  5354,  2428, 11922, 29958]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [9]:
model.eval()
with torch.no_grad():
    generation_output = model.generate(
            input_ids = input_tokens["input_ids"].to(device),
            attention_mask = input_tokens['attention_mask'].to(device),
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            generation_config = generation_config,
            output_attentions=False
        )

In [10]:
s = generation_output[0]
output = tokenizer.decode(s,skip_special_tokens=True)

In [11]:
output

'Superfamily=<Ankyrin repeat-containing domain superfamily> Seq=<MKLVLLALAATLAAAAPQPTPPPSPTPIRDGYEHFNADGRTVWERDPARGEAVIRALEAGDLDAIRRLVEEGVDPNLRDRNGRTALMLASANGHTEVAEFLIDSGANPDLPDKGGSTPLMAACFRGNLDVAEYLIEQGAEPNAMDDEGLSAMDIAKENGSPEIVALLREHRPRPAEEEDDDDSSESDSSSEEESD>'

In [37]:
seq = 'MKLVLLALAATLAAAAPQPTPPPSEEPTPIRDGY'

In [38]:
out = tokenizer.encode(seq)

In [39]:
temp = [[i] for i in out]

In [40]:
res = [tokenizer.decode(temp[i]) for i in range(len(temp))]

In [41]:
res

['<s>',
 'M',
 'K',
 'L',
 'V',
 'LL',
 'AL',
 'A',
 'AT',
 'LA',
 'AA',
 'AP',
 'Q',
 'P',
 'TP',
 'PP',
 'SEE',
 'PT',
 'PI',
 'R',
 'D',
 'G',
 'Y']

In [24]:
res

['<s>',
 'M',
 'K',
 'L',
 'V',
 'LL',
 'AL',
 'A',
 'AT',
 'LA',
 'AA',
 'AP',
 'Q',
 'P',
 'TP',
 'PP',
 'SP',
 'T',
 'PI',
 'R',
 'D',
 'G',
 'Y',
 'E',
 'H',
 'F',
 'N',
 'AD',
 'GR',
 'TV',
 'W',
 'ER',
 'DP',
 'AR',
 'GE',
 'AV',
 'I',
 'RA',
 'LE',
 'AG',
 'DL',
 'DA',
 'IR',
 'RL',
 'VE',
 'EG',
 'VD',
 'PN',
 'LR',
 'DR',
 'NG',
 'RT',
 'AL',
 'ML',
 'AS',
 'ANG',
 'HT',
 'EV',
 'AE',
 'FL',
 'ID',
 'SG',
 'AN',
 'PD',
 'LP',
 'DK',
 'GG',
 'ST',
 'PL',
 'MA',
 'AC',
 'FR',
 'GN',
 'LD',
 'VA',
 'EY',
 'LI',
 'EQ',
 'GA',
 'EP',
 'NA',
 'MD',
 'DE',
 'GL',
 'S',
 'AM',
 'DI',
 'AK',
 'EN',
 'GS',
 'PE',
 'IV',
 'ALL',
 'RE',
 'HR',
 'PR',
 'PA',
 'EE',
 'ED',
 'DD',
 'DS',
 'SE',
 'SD',
 'SS',
 'SEE',
 'ES',
 'D']

In [29]:
s[19:-2].shape

torch.Size([135])

In [30]:
seq = 'MKVLFLLSFLFFTGYAQNHWDEAKALLEEGCRVDPNHPDNDGTTPLILAIENGNVEIVEYLINNGADINARSKSGNTILMYAVESNNIEIVEMLIKAGANIDARDNEGRSPLMFAIRSNNSKEIIEFFIANGASPNAKDKIGWTPIMWAASSGSLDTVRYLTETGAEVDAKSNRGETSLMIATERGHEEMISILLSAGVDIEPEDDDGLTPLSMAAEEGHTELITALLRAGADPKLRDNYGETAREIAEEHRSE'

In [31]:
len(seq)

254

In [34]:
tokenizer.decode(torch.tensor([29966, 29958]),skip_special_tokens=True)

'<>'

In [26]:
import json

In [36]:
res_path = '/ai/sharedisk/workspace/weil/ZymCTRL/generated_seqs/ec_3.5.5.1/esmfold_results.json'
with open(res_path, 'r') as handle:
    res = json.load(handle)
plddts = [item['plddt'] for item in res]
np.median(plddts)

0.5593969291062462

In [31]:
import numpy as np

0.8306883116883117