In [11]:
import os
import json
import torch
import argparse
import time
from data_utlis.dataset import FastConvert
from models.model import SentenceVAE
from utils import to_var, idx2word, interpolate
from torch.utils.data import DataLoader
from screening_process.screen import screen
def main(args):
    
    seqs = ['KKRPKPGGWNTGGSRYPGQGSPGGNRYPPQGGGGWGQPHGGGWGQPHGGGWGQPHGGGWGQPHGGGWGQGGGTHSQWNKPSKPKTNMKHMAGAAAAGAVVGGLGGYMLGSAMSRPIIHFGSD',
       'SKGPGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGY'
       ]
    with open(args.data_dir+'/vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    model = SentenceVAE(
        vocab_size=len(w2i),
        sos_idx=w2i['<sos>'],
        eos_idx=w2i['<eos>'],
        pad_idx=w2i['<pad>'],
        unk_idx=w2i['<unk>'],
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
        )

    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    model.load_state_dict(torch.load(args.load_checkpoint))
    print("Model loaded from %s" % args.load_checkpoint)

    if torch.cuda.is_available():
        model = model.cuda()
    
    model.eval()

    # samples, z = model.inference(n=args.num_samples)
    # print('----------SAMPLES----------')
    # print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

    # z1 = torch.randn([args.latent_size]).numpy()
    # z2 = torch.randn([args.latent_size]).numpy()
    # z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
    
    batch = next(iter(DataLoader(FastConvert(seqs),batch_size=2,pin_memory=torch.cuda.is_available())))
    for k, v in batch.items():
        if torch.is_tensor(v):
            batch[k] = to_var(v)


    
    print('-------------Screening procedures based on ESM and Autogluon---------')
    num_screen = 0
    
    screened_seq = []
    while(num_screen<5):
        # samples_sa, z = model.inference(n=args.num_samples)
        logp, mean, logv, z = model(batch['input'], batch['length'])
        z = z.cpu().detach().numpy()
        z_i = to_var(torch.from_numpy(interpolate(start=z[0], end=z[1], steps=32)).float())
    
        samples, _ = model.inference(z=z_i)
        # samples = torch.cat((samples,samples_sa),0)
        
        print('-------INTERPOLATION GENERATION-------')
        print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
        # temp_data_path = os.path.join(args.data_dir,'results/','temp_generate_'+str(args.num_samples)+'.fasta')
        # temp_seqlist = []
        # # with open(temp_data_path,'w') as rf:
        # #     for i,sp in enumerate(idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>'])):
        # #         sp = sp.replace(" ", "").replace("<eos>", "")
        # #         temp_seqlist.append(sp)
        # #         rf.write('>'+str(i)+'\n'+sp+'\n')
                
        # print('-------SCREENing-------')
        # predicter = list(screen(temp_data_path))
        # selected_seq= [string for flag, string in zip(predicter, temp_seqlist) if flag == 1]
        # print(selected_seq)
        # screened_seq.extend(selected_seq)
        # num_screen = num_screen+predicter.count(1)

        break
    print(screened_seq)
    # with open(os.path.join(args.data_dir,'results/','screen_generate_1000'+'.fasta'),'w') as rf:
    #     for i,sp in enumerate(screened_seq):
    #         temp_seqlist.append(sp)
    #         rf.write('>'+str(i)+'\n'+sp+'\n')

    




In [12]:
class Bunch(dict):
    def __init__(self, *args, **kwds):
        super(Bunch, self).__init__(*args, **kwds)
        self.__dict__ = self
inference_cfg = Bunch(
    
    
    load_checkpoint = './checkpoints/2023-Dec-05-03:19:06/E499.pytorch',    #The path to the directory where PTB data is stored, and auxiliary data files will be stored.
    num_samples = 100,         #生成的新数据量
    
    data_dir = './data',
    max_sequence_length = 150, #Specifies the cut off of long sentences.
    embedding_size = 256,
    rnn_type = 'gru', #rnn_type Either 'rnn' or 'gru'.
    hidden_size = 256, #hidden_size
    word_dropout = 0, #word_dropout Word dropout applied to the input of the Decoder which means words will be replaced by <unk> with a probability of word_dropout.
    embedding_dropout = 0.3, #embedding_dropout Word embedding dropout applied to the input of the Decoder.

    latent_size = 64, #latent_size
    num_layers = 1, #num_layers
    bidirectional = False, #bidirectional
)


In [13]:
main(inference_cfg)

Model loaded from ./checkpoints/2023-Dec-05-03:19:06/E499.pytorch
-------------Screening procedures based on ESM and Autogluon---------


-------INTERPOLATION GENERATION-------
L P Y L D L E G L Q E Q M E Q L S L M E Q S Q M Y L Q S M L E M Q M F S Y L M E M Q D M M Y E F S Q Y E F L M E Q S Q M E N L S E V F M E E G M Q D L G H S S V G K F M Y D G T Q S V Y F D G E L H S E G V F V E G R T M G K Y S G R T H K G F S E R V K L S T G N G R Y K F T G R G G E K V V T G L N H S
L P Y L D L E G L Q E Q M E Q L S L M E Q S Q M Y L Q S M L E M Q M F S Y L M E M Q D M M Y E F S Q Y E F L M E Q S Q M E N L S V N H S V G G R Y E F D L Q Y F S N V M Y D G R T F Y H F G Y F D G R T H S G R G F K F M G D N G S Y H P F D E R F H F E G T R K G S I Y K N E I K N K H R W K M S Q K Y E
L P Y L D L E G L Q E Q M E Q L S L M E Q S Q M Y L Q S M L E M Q M F S Y L M E M Q D M M Y E F Y L M E D M G M Y F D G T H M S E Q S Q H F S V M Y D G F Q H F S V Y H F D G E L H S E G F V M T G M Q K Y W D G L P F E F H T D G L H S E F V M I E G F K Q T N S Y H F E G T R K G S I Y K N E I K N K H
L P Y L D E G L Q E Q M E Q L S L M E Q S Q M Y L Q S M L E M

In [14]:
import torch
z1 = torch.randn([100]).numpy()
z2 = torch.randn([100]).numpy()
z1.shape

(100,)

In [15]:
from data_utlis.dataset import FastConvert

In [16]:
seq = ['KKRPKPGGWNTGGSRYPGQGSPGGNRYPPQGGGGWGQPHGGGWGQPHGGGWGQPHGGGWGQPHGGGWGQGGGTHSQWNKPSKPKTNMKHMAGAAAAGAVVGGLGGYMLGSAMSRPIIHFGSD',
       'SKGPGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGRGDSPYSGY'
       ]
test_data = FastConvert(sequence_strs=seq)

In [17]:
from torch.utils.data import DataLoader
data_loader = DataLoader(
    dataset=test_data,
    batch_size=2,
    pin_memory=torch.cuda.is_available()
)
for batch in data_loader:
    print(batch['input'])

tensor([[ 2,  7,  7, 20, 14,  7, 14, 11, 11, 22, 19, 18, 11, 11, 17, 20, 16, 14,
         11, 13, 11, 17, 14, 11, 11, 19, 20, 16, 14, 14, 13, 11, 11, 11, 11, 22,
         11, 13, 14,  8, 11, 11, 11, 22, 11, 13, 14,  8, 11, 11, 11, 22, 11, 13,
         14,  8, 11, 11, 11, 22, 11, 13, 14,  8, 11, 11, 11, 22, 11, 13, 11, 11,
         11, 18,  8, 17, 13, 22, 19,  7, 14, 17,  7, 14,  7, 18, 19,  4,  7,  8,
          4,  5, 11,  5,  5,  5,  5, 11,  5, 15, 15, 11, 11,  9, 11, 11, 16,  4,
          9, 11, 17,  5,  4, 17, 20, 14, 21, 21,  8, 10, 11, 17, 12,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,