In [None]:
from transformers.models.bert.tokenization_bert import BertTokenizer
decoder_model_path = 'xxx'
tokenizer = BertTokenizer.from_pretrained(decoder_model_path)
special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>'}
tokenizer.add_special_tokens(special_tokens_dict)
tokenizer.unk_token_id

In [None]:
import numpy as np
def get_decoder_beta_list( n_steps, start=0., stop=1.0, n_cycle=4):
        L = np.ones(n_steps)
        t_range = int(n_steps / n_cycle)
        for t_cur in range(n_steps):
            if t_cur > t_range:
                L[t_cur] = 0.
            else:    
                ratio = t_cur / t_range
                value = stop - ratio * (stop-start)
                L[t_cur] = value
        return L
get_decoder_beta_list(100, 0, 1, 5)

# Model Eval

In [None]:
import os 
import sys

import json
import torch
import argparse
from vae_pl_module import TDVAEModule
from load import ChineseSentenceSplitter
from torch.nn.utils.rnn import pad_sequence
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

def TDVAECollator(samples):
    # when len(samples) is larger than one, we need to save the sentence length info 
    paragraph_lengths = []
    encoder_inputs, decoder_labels = [], []
    for sp in samples:
        # NOTE: in TD-VAE, both encoder and decoder are gpt2, thus use decoder sent twice !
        sent_lengths = sp['decoder_sent_lengths']
        input_ids, decoder_target = sp['decoder_target'], sp['decoder_target']
        if len(sent_lengths) == 1 or sent_lengths[0] >= 512 or sent_lengths[0] + sent_lengths[1] >= 512:
            continue # we ignore paragraphs with only one sentence split
        encoder_inputs.append(torch.tensor(input_ids[:512], dtype=torch.long))
        decoder_labels.append(torch.tensor(decoder_target[:512], dtype=torch.long))
        paragraph_lengths.append(sent_lengths)
    if not encoder_inputs or not decoder_labels:
        return None, None, None  # if all the examples in the batch are single sentence
    encoder_inputs = pad_sequence(encoder_inputs, batch_first=True, padding_value=0)
    decoder_labels = pad_sequence(decoder_labels, batch_first=True, padding_value=0)
    return (encoder_inputs, decoder_labels, paragraph_lengths) 


# parse arguments
mlm_probability = 0.

checkpoint_path = 'xxx'
# marf_checkpoint_path = 'xxx'
encoder_model_path = 'xxx'
decoder_model_path = 'xxx'
max_split_num = 12

args_parser = argparse.ArgumentParser()
args_parser.add_argument("--checkpoint_path", type=str, default=checkpoint_path)
# args_parser.add_argument("--marf_checkpoint_path", type=str, default=marf_checkpoint_path)
args_parser.add_argument("--encoder_model_path", type=str, default=encoder_model_path)
args_parser.add_argument("--decoder_model_path", type=str, default=decoder_model_path)
args_parser.add_argument("--mlm_probability", type=float, default=mlm_probability)
args_parser.add_argument("--max_split_num", type=int, default=max_split_num)
args, unknown_args = args_parser.parse_known_args()

# load model
model, encoder_tokenizer, decoder_tokenizer, latent_size, \
    labels_dict, args =  TDVAEModule.load_model(args, labels_dict=None)

# load and process data 
sentence_splitter = ChineseSentenceSplitter()
data = []
inputs_dicts = []
bos_token, eos_token = decoder_tokenizer.bos_token_id, decoder_tokenizer.eos_token_id
with open("xxx/wudao/dev.json") as file:
    lines = file.readlines()
    for line in lines:
        data_dict = json.loads(line)    
        data.append(sentence_splitter.tokenize(data_dict['content']))
    for sentences in data:
        decoder_sent_lengths, decoder_target = [], []
        for sentence in sentences:
            # tokenize sentence with two tokenizer
            decoder_sent_target = decoder_tokenizer.convert_tokens_to_ids(decoder_tokenizer.tokenize(sentence))
            decoder_sent_lengths.append(len(decoder_sent_target))
            decoder_target.extend(decoder_sent_target)
        inputs_dicts.append({
            'decoder_target': decoder_target,
            "decoder_sent_lengths": decoder_sent_lengths
        })

tdvae_inputs = []
for input_dict in inputs_dicts:
    inputs_tuple = TDVAECollator([input_dict]) # batch_size = 1
    if not any([obj is None for obj in inputs_tuple]):
        tdvae_inputs.append(inputs_tuple)

In [None]:
max_length = 50
top_p = 0.3
temperature = 1.
repetition_penalty = 1.
device = 0
model = model.eval()
model.transition_net.model = model.transition_net.model.eval()
model = model.to(device)
inputs_idx = 4
inputs = tdvae_inputs[inputs_idx]
encoder_inputs, decoder_labels, paragraph_lengths = inputs
batch_belief_latent_z = model.get_belief_latent_z(encoder_inputs.to(device), paragraph_lengths, sample=False)
f_outputs, logdets = model.transition_net.model(batch_belief_latent_z.to(device), mode='direct')
b_outputs, logdets = model.transition_net.model(f_outputs, mode='inverse')
outputs = model.inference(encoder_inputs.to(device), paragraph_lengths, decoder_tokenizer, 
    max_length, top_p, temperature, repetition_penalty, mode="belief", sample_z=False, batch_belief_latent_z=b_outputs)

for gen_sentences, groung_truth_para, sent_lens in zip(outputs, encoder_inputs, paragraph_lengths):
    accm_len = sent_lens[0]
    print(decoder_tokenizer.decode(groung_truth_para[:sent_lens[0]]))
    print("-"*20)
    for gen_sent, sent_len in zip(gen_sentences, sent_lens[1:]):
        print(decoder_tokenizer.decode(groung_truth_para[accm_len:accm_len+sent_len]))
        accm_len=accm_len+sent_len
        print(decoder_tokenizer.decode(gen_sent))
        print("-"*20)

In [None]:
max_length = 50
top_p = 0.3
temperature = 1.
repetition_penalty = 1.
device = 0
model = model.eval()
model = model.to(device)
# predict and decode
inputs_idx = 4
inputs = tdvae_inputs[inputs_idx]
encoder_inputs, decoder_labels, paragraph_lengths = inputs
# add a replace to test if decoder knows how to predict different sentence based on the period 
# encoder_inputs[encoder_inputs == 511] = 100  # replace 。 with UNK 
        
outputs = model.inference(encoder_inputs.to(device), paragraph_lengths, decoder_tokenizer, 
    max_length, top_p, temperature, repetition_penalty, mode="predict", sample_z=False)

for gen_sentences, groung_truth_para, sent_lens in zip(outputs, encoder_inputs, paragraph_lengths):
    accm_len = sent_lens[0]
    print(decoder_tokenizer.decode(groung_truth_para[:sent_lens[0]]))
    print("-"*20)
    for gen_sent, sent_len in zip(gen_sentences, sent_lens[1:]):
        print(decoder_tokenizer.decode(groung_truth_para[accm_len:accm_len+sent_len]))
        accm_len=accm_len+sent_len
        print(decoder_tokenizer.decode(gen_sent))
        print("-"*20)



# eval 2

In [None]:
# encoder_inputs.shape, decoder_labels.shape, paragraph_lengths
max_length = 50
top_p = 0.5
temperature = 1.
repetition_penalty = 1
device = 0
model = model.eval()
model = model.to(device)
finished = False
decoder_sent_lengths, decoder_target = [], []
context = '我是一名土生土长的中国人。'
for i in range(5):
    tokenized_text1 = decoder_tokenizer.convert_tokens_to_ids(decoder_tokenizer.tokenize(context))
    decoder_sent_target = [bos_token] + tokenized_text1 + [eos_token]
    decoder_sent_lengths.append(len(decoder_sent_target))
    decoder_target.extend(decoder_sent_target)
    encoder_inputs = pad_sequence([torch.tensor(decoder_target, dtype=torch.long)], batch_first=True, padding_value=0)
    decoder_labels = pad_sequence([torch.tensor(decoder_target, dtype=torch.long)], batch_first=True, padding_value=0)
    paragraph_lengths = [decoder_sent_lengths]
    outputs = model.inference(encoder_inputs.to(device), paragraph_lengths, decoder_tokenizer, 
        max_length, top_p, temperature, repetition_penalty, mode="inference")
    context = decoder_tokenizer.decode(outputs[0][0][1:-1]).replace(' ', '')
    print(context)

    


# Num Active Units 

In [None]:
def get_sent_mus_vars(model, tdvae_inputs, device, num_para=20):
    belief_mus, belief_logvars = None, None
    for inputs_idx in range(num_para):
        inputs = tdvae_inputs[inputs_idx]
        encoder_inputs, decoder_labels, paragraph_lengths = inputs
        # each list in batch_belief_states contain a tensor with shape Num_sents * H. Num_sents vary across batch
        batch_belief_states, batch_actual_para_lengths = model.encoder(encoder_inputs.to(device), paragraph_lengths, 
            max_split_num=model.max_split_num, unk_token_id=model.unk_token_id)
        for belief_states in batch_belief_states:
            # Num_sents, Hidden_size => Num_sents, Latent_size
            belief_sent_mus, belief_sent_logvars, belief_latent_z = model.get_belief_network_output(belief_states, sample=False)
            
            belief_mus = belief_sent_mus.detach().clone().cpu() if belief_mus is None else torch.cat((belief_mus, belief_sent_mus.detach().clone().cpu()), dim=0)
            belief_logvars = belief_sent_logvars.detach().clone().cpu() if belief_logvars is None else torch.cat((belief_logvars, belief_sent_logvars.detach().clone().cpu()), dim=0)
            

    return belief_mus, belief_logvars


def cal_active_units(z_vector, threshold=0.1):
    return torch.sum(torch.var(z_vector, dim=0) > threshold)

device = 0
model = model.eval()
model = model.to(device)
belief_mus, belief_logvars = get_sent_mus_vars(model,
    tdvae_inputs, device, num_para=50)
cal_active_units(belief_mus, 0.1)


In [None]:
belief_mus[0]

In [None]:
for idx in range(5):
    print(belief_mus[idx].mean(), belief_mus[idx].median(), belief_mus[idx].max(), belief_mus[idx].min())
    print(belief_logvars[idx].mean(), belief_logvars[idx].median(), belief_logvars[idx].max(), belief_logvars[idx].min())
    print("-"*20)

In [None]:
belief_mus[:10, 1]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def get_cyclic_linear_beta_list( n_steps, start=0.5, stop=1.0, n_cycle=1):
    L = np.ones(n_steps)
    t_range = int(n_steps / n_cycle)
    for t_cur in range(n_steps):
        loc = t_cur % t_range
        split_range = int(t_range * (1/3))
        if loc < split_range:
            ratio = (loc % split_range) / split_range
            value = start + ratio * (stop-start)
        elif split_range <= loc < 2 * split_range:
            value = stop
        else:
            value = 0
        L[t_cur] = value
    return L 
x = get_cyclic_linear_beta_list(1000, 0.5, 1, 10)
plt.plot(x)
plt.show()

In [None]:
import torch
import numpy as np 
from torch.nn.utils.rnn import pad_sequence
# a = torch.randn(size=(3,64)).to(0)
# b = torch.randn(size=(5,64)).to(0)
# c = torch.randn(size=(4,32)).to(0)
# d = torch.cat((a[0],b[0],c[2]),dim=0)
# e = torch.cat((a[1],b[1],c[1]),dim=0)
# f torch.sum(torch.stack((d,e), dim=0),dim=1)/torch.tensor([12,24],device=0)
def word_drop(x, p, pad_token, vocab_size):     # drop words with probability p
    x_ = []
    words = x.tolist()
    keep = np.random.rand(len(words)) > p
    keep[0] = True  # do not drop the start sentence symbol
    for j, w in enumerate(words):
        if keep[j]:
            x_.append(w)
        else:
            if np.random.rand() > .5:
                x_.append(pad_token)
            else:
                x_.append(np.random.randint(0, vocab_size))
    return torch.LongTensor(x_).contiguous().to(x.device)
x = torch.randint(0, 1000, size=(1,30))
word_drop(x[0], 0.2, 10001, 10000)

In [None]:
from transformers import BertTokenizer
tok = BertTokenizer.from_pretrained("/cognitive_comp/wanghao/models/gpt2-base")