In [28]:
import os
import json
import torch
import argparse
import time

from models.model import SentenceVAE
from utils import to_var, idx2word, interpolate

def main(args):
    with open(args.data_dir+'/vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    model = SentenceVAE(
        vocab_size=len(w2i),
        sos_idx=w2i['<sos>'],
        eos_idx=w2i['<eos>'],
        pad_idx=w2i['<pad>'],
        unk_idx=w2i['<unk>'],
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
        )

    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    model.load_state_dict(torch.load(args.load_checkpoint))
    print("Model loaded from %s" % args.load_checkpoint)

    if torch.cuda.is_available():
        model = model.cuda()
    
    model.eval()

    samples, z = model.inference(n=args.num_samples)
    print('----------SAMPLES----------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

    z1 = torch.randn([args.latent_size]).numpy()
    z2 = torch.randn([args.latent_size]).numpy()
    z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
    samples, _ = model.inference(z=z)
    print('-------INTERPOLATION-------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
    
    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())
    with open(os.path.join(args.data_dir,'results',ts+'_generate_'+str(args.num_samples)+'.txt'),'w') as rf:
        for sp in idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']):
            rf.write(sp+'\n')
    rf.close()




In [29]:
class Bunch(dict):
    def __init__(self, *args, **kwds):
        super(Bunch, self).__init__(*args, **kwds)
        self.__dict__ = self
inference_cfg = Bunch(
    
    
    load_checkpoint = 'checkpoints/2023-May-05-12:53:31/E353.pytorch',    #The path to the directory where PTB data is stored, and auxiliary data files will be stored.
    num_samples = 100,         #生成的新数据量
    
    data_dir = './data',
    max_sequence_length = 150, #Specifies the cut off of long sentences.
    embedding_size = 256,
    rnn_type = 'gru', #rnn_type Either 'rnn' or 'gru'.
    hidden_size = 256, #hidden_size
    word_dropout = 0, #word_dropout Word dropout applied to the input of the Decoder which means words will be replaced by <unk> with a probability of word_dropout.
    embedding_dropout = 0.3, #embedding_dropout Word embedding dropout applied to the input of the Decoder.

    latent_size = 64, #latent_size
    num_layers = 1, #num_layers
    bidirectional = False, #bidirectional
)


In [30]:
main(inference_cfg)

Model loaded from checkpoints/2023-May-05-12:53:31/E353.pytorch
----------SAMPLES----------
G Q Q Q Q N Y Q Q Q P Q G Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q
M K K K K K K E Y <eos>
R G D Y S R G W M D K <eos>
M N D K P V <eos>
M A K K K K K K K K <eos>
P T G V P S G T G S L F G S F N S T P V G T S L A V <eos>
M Q A V G I N T P A E <eos>
M V A P G V I G <eos>
Y T P G V T G S L <eos>
M D V L K <eos>
M G L A I <eos>
M K P H <eos>
R P G Y G I G V P A G V <eos>
M K G F N V T P <eos>
R G D Y Q Q P Q Q G Y G Q Q N Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q N Y G Q Q N Y G Q P Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q N Y G Q Q N Q Q Y P G Q Q G Y G Q Q 