In [22]:
from datasets import load_dataset,concatenate_datasets, Dataset,DatasetDict
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch

In [23]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [39]:
from models.my_transformers import *
from models.models import VAE, DDPM, MLPSkipNet, TransformerNet,VAE_DDPM
from train.reconstruction import *
from functions import weights_init_rondom
from train import *

In [25]:
def collate(examples):
    # Convert to Tensors and build dataset

    input_ids_bert = pad_sequence([torch.tensor(f['bert_token'], dtype=torch.long) for f in examples],
                                  batch_first=True, padding_value=bert_pad_token)
    input_ids_gpt = pad_sequence([torch.tensor(f['gpt2_token'], dtype=torch.long) for f in examples],
                                 batch_first=True, padding_value=gpt2_pad_token)
    try:
        token_lengths = torch.tensor([[len(f['bert_token']), len(f['gpt2_token'])] for f in examples],
                                     dtype=torch.long)
    except:
        token_lengths = torch.zeros((len(examples), 1091))
        for i in range(len(examples)):
            token_lengths[i, len(examples[i]['gpt2_token'])] = 1
    return (input_ids_bert, input_ids_gpt, token_lengths)

In [26]:
train_eval_dataset =load_dataset("guangyil/yelp_short_v2")
eval_dataloader =  DataLoader(train_eval_dataset['test'], num_workers=0, collate_fn=collate,batch_size=64)

In [27]:
encoder_model_class = MODEL_CLASS['BertForLatentConnectorAVG']
tokenizer_encoder = AutoTokenizer.from_pretrained("prajjwal1/bert-small")
latent_size = 64
model_encoder = encoder_model_class.from_pretrained("prajjwal1/bert-small", latent_size=latent_size,
                                                    pad_id=tokenizer_encoder.pad_token_id,local_files_only=False)


decoder_model_class = MODEL_CLASS['GPT2ForLatentConnectorNew']
tokenizer_decoder = AutoTokenizer.from_pretrained("gpt2-xl")
model_decoder = decoder_model_class.from_pretrained("gpt2-xl", latent_size=latent_size,
                                                        latent_as_gpt_emb=True,
                                                        latent_as_gpt_memory=True,local_files_only=False)
decoder_n_layer = model_decoder.transformer.config.n_layer
model_decoder.transformer.change_order()

special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>', }
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
model_decoder.resize_token_embeddings(len(tokenizer_decoder))
bert_pad_token = tokenizer_encoder.pad_token_id
gpt2_pad_token = tokenizer_decoder.pad_token_id

Some weights of BertForLatentConnectorAVG were not initialized from the model checkpoint at prajjwal1/bert-small and are newly initialized: ['bert.linear.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPT2ForLatentConnectorNew were not initialized from the model checkpoint at gpt2-xl and are newly initialized: ['h.48.mlp.c_proj.bias', 'h.48.attn.c_proj.bias', 'h.48.ln_2.bias', 'h.48.mlp.c_proj.weight', 'linear_emb.weight', 'h.48.attn.c_attn.bias', 'h.48.mlp.c_fc.weight', 'linear.weight', 'h.48.mlp.c_fc.bias', 'h.48.ln_1.bias', 'h.48.ln_2.weight', 'lm_head.bias', 'h.48.ln_1.weight', 'h.48.attn.c_proj.weight', 'h.48.attn.c_attn.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [28]:
output_dir = "test"
model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, latent_size, output_dir)

In [29]:
model_vae.apply(weights_init_rondom)
model_vae.to('cuda')

VAE(
  (encoder): BertForLatentConnectorAVG(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_af

In [30]:
calc_rec_lgy(model_vae, tokenizer_encoder, tokenizer_decoder,eval_dataloader, "cuda", True, ns=1)

{'bleu': 2.0236796874108411e-230}

In [31]:
ddpm = DDPM(eps_model=MLPSkipNet(latent_size), betas=(1e-4, 0.02), n_T=1000, criterion=nn.MSELoss(reduction='none'),)

In [32]:
ddpm.apply(weights_init_rondom)

DDPM(
  (eps_model): MLPSkipNet(
    (time_embed): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): SiLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
    )
    (layers): ModuleList(
      (0): MLPLNAct(
        (linear): Linear(in_features=64, out_features=2048, bias=True)
        (act): SiLU()
        (linear_emb): Linear(in_features=64, out_features=2048, bias=True)
        (cond_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=64, out_features=2048, bias=True)
        )
        (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (dropout): Identity()
      )
      (1): MLPLNAct(
        (linear): Linear(in_features=2112, out_features=2048, bias=True)
        (act): SiLU()
        (linear_emb): Linear(in_features=64, out_features=2048, bias=True)
        (cond_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=64, out_features=2048, bias=True)
        )
    

In [40]:
model = VAE_DDPM(model_vae, ddpm,1.0 )

In [None]:
out = train_vae_ddpm(model, train_dataloader, encoder_tokenizer, decoder_tokenizer, 
          table_name, eval_dataloader, output_dir, condition_f=lambda x: False,
          checkpoint=None, local_rank = -1, batch_size = 32, eval_batch_size = 32,
          train_epoch = 20, gradient_accumulation_steps = 1, device = 'cpu',
          fp16=False, fp16_opt_level=None, learning_rate=9e-5, adam_epsilon=1e-5,
          lr_end_multiplier= 0.01, power=3.0, warmup_steps=0, 
          disable_bar=True, model_ppl=None, tokenizer_ppl=None, max_grad_norm=1, evaluate_during_training=False,
          no_save=True):