In [1]:
from datasets import load_dataset,concatenate_datasets, Dataset,DatasetDict, load_from_disk
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch

from DMLP.models.my_transformers import MODEL_CLASS
from DMLP.models.models import VAE, DDPM, MLPSkipNet, TransformerNet,VAE_DDPM
from DMLP.train.reconstruction import *
from DMLP.utils.ddpm_schedule import ddpm_schedule
from DMLP.utils.random_init import weights_init_random
from DMLP.train.train_function import train_vae_ddpm
from DMLP.train import generation
import numpy as np
import torch.nn as nn



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MyCollator(object):
    def __init__(self, encoder_token, decoder_token):
        self.encoder_token = encoder_token
        self.decoder_token = decoder_token
    def __call__(self, batch):
        input_ids_bert = pad_sequence([torch.tensor(f['bert_token'], dtype=torch.long) for f in batch],
                                  batch_first=True, padding_value=self.encoder_token)
        input_ids_gpt = pad_sequence([torch.tensor(f['gpt2_token'], dtype=torch.long) for f in batch],
                                    batch_first=True, padding_value=self.decoder_token)
        try:
            token_lengths = torch.tensor([[len(f['bert_token']), len(f['gpt2_token'])] for f in batch],
                                        dtype=torch.long)
        except:
            token_lengths = torch.zeros((len(batch), 1091))
            for i in range(len(batch)):
                token_lengths[i, len(batch[i]['gpt2_token'])] = 1
        return (input_ids_bert, input_ids_gpt, token_lengths)
def condition_f(n):
        return ('linear' in n or 'wte' in n or 'decoder.transformer.h.0' in n or 'encoder' in n)

In [3]:
batch_size = 128
encoder_model_class = MODEL_CLASS['BertForLatentConnectorAVG']



    #initialize tokenizer and model
print("initialize models")
tokenizer_encoder = AutoTokenizer.from_pretrained("prajjwal1/bert-small")
latent_size = 128
model_encoder = encoder_model_class.from_pretrained("prajjwal1/bert-small", latent_size=latent_size,
                                                        pad_id=tokenizer_encoder.pad_token_id,local_files_only=False)


decoder_model_class = MODEL_CLASS['GPT2ForLatentConnectorNew']
tokenizer_decoder = AutoTokenizer.from_pretrained("gpt2-xl")
model_decoder = decoder_model_class.from_pretrained("gpt2-xl", latent_size=latent_size,
                                                            latent_as_gpt_emb=True,
                                                            latent_as_gpt_memory=True,local_files_only=False)
decoder_n_layer = model_decoder.transformer.config.n_layer
model_decoder.transformer.change_order()

special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>', }
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
model_decoder.resize_token_embeddings(len(tokenizer_decoder))
bert_pad_token = tokenizer_encoder.pad_token_id
gpt2_pad_token = tokenizer_decoder.pad_token_id

my_collator = MyCollator(bert_pad_token, gpt2_pad_token)
    #download data
print("download data")
train_eval_dataset =load_dataset("guangyil/yelp_short_v2")
eval_dataloader =  DataLoader(train_eval_dataset['test'], num_workers=0, collate_fn=my_collator,batch_size=batch_size)
train_dataloader = DataLoader(train_eval_dataset['train'], num_workers=0, collate_fn=my_collator, batch_size=batch_size)

output_dir = "../../out_temp"
model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, latent_size, output_dir)
checkpoint = torch.load('../../ckpts/checkpoints/checkpoint-full-2/training.bin',map_location=torch.device('cpu'))
model_vae.load_state_dict(checkpoint['model_state_dict'], strict=False) 

ddpm = DDPM(MLPSkipNet(latent_size), (1e-4, 0.02), 2000, nn.MSELoss(reduction='none'), ddpm_schedule)
checkpoint_ddpm = torch.load('../../ckpts/checkpoints/checkpoint-ddpm-2-1/training_ddpm.bin',map_location=torch.device('cpu'))
ddpm.load_state_dict(checkpoint_ddpm['model_state_dict'], strict=False) 
ddpm.to("cpu")
model = VAE_DDPM(model_vae, ddpm,10.0 )

initialize models


Some weights of BertForLatentConnectorAVG were not initialized from the model checkpoint at prajjwal1/bert-small and are newly initialized: ['bert.linear.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPT2ForLatentConnectorNew were not initialized from the model checkpoint at gpt2-xl and are newly initialized: ['h.48.mlp.c_fc.weight', 'h.48.ln_1.bias', 'linear.weight', 'lm_head.bias', 'h.48.ln_2.weight', 'h.48.mlp.c_proj.weight', 'h.48.mlp.c_proj.bias', 'h.48.attn.c_proj.weight', 'h.48.mlp.c_fc.bias', 'h.48.ln_1.weight', 'h.48.attn.c_proj.bias', 'h.48.attn.c_attn.weight', 'h.48.ln_2.bias', 'h.48.attn.c_attn.bias', 'linear_emb.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


download data


In [5]:
generate_txt = []
bz=2
context_tokens = tokenizer_decoder.encode(tokenizer_decoder.bos_token)

latent_z = ddpm.sample(bz,(128,),"cpu",fp16=False)


In [6]:
out = sample_sequence_conditional(
                model=model_vae.decoder,
                context=context_tokens,
                past=latent_z,
                length=32,
                num_samples=latent_z.size(0),
                device="cpu",
                decoder_tokenizer=tokenizer_decoder,
                eos_id=tokenizer_decoder.eos_token_id,
                loss=True
            )

In [7]:
generate_text = []
for i in range(latent_z.size(0)):
    text_x1 = tokenizer_decoder.decode(out[i, :].tolist(), clean_up_tokenization_spaces=False).split(tokenizer_decoder.eos_token)[
        0].replace(tokenizer_decoder.bos_token, '').strip()
    text_x1 = ' '.join(text_x1.split())
    generate_text.append(text_x1 + '\n')

TypeError: tuple indices must be integers or slices, not tuple

In [8]:
out

(tensor([[50258,  4753,   319,   257,  1107,  2089,  1110,   764,   198, 50259],
         [50258,  1312,  1842,   428,  1295,  5145,   198, 50259, 50259, 50259]]),
 [-3.9238951206207275, -0.05642591789364815])

In [14]:
tokenizer_decoder.decode(out[0][1])

'<BOS> i love this place!\n<EOS><EOS><EOS>'