In [1]:
from sampling import naive_sampling
from utils import set_seed, create_history, create_model_kwargs

from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [4]:
# User definitions
model_name = "gpt2"
seed = 42

num_samples = 200
input_str = "Hello! Nice to"
avoid_terms = "meet you"

# ==========================================================
# Load models
tokenizer = GPT2Tokenizer.from_pretrained(model_name, model_max_length=512)
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)

# Parse input and set seeds for reproducibility
set_seed(seed)
bos_token_id = tokenizer.bos_token_id or model.config.decoder_start_token_id

input_ids = tokenizer(input_str, return_tensors="pt", add_special_tokens=False).input_ids
avoid_terms_ids = tokenizer(avoid_terms, add_special_tokens=False).input_ids

# History (or past observations) and model_kwargs will be the same for all queries
history = create_history(num_samples, input_ids, bos_token_id)

# Call Naive Sampling
mean, var, samples = naive_sampling(
    avoid_term_ids=avoid_terms_ids,
    **create_model_kwargs(history, model, tokenizer),
    max_num_tokens=5,
    model=model,
    tokenizer=tokenizer,
)

print("Freq:", mean, "Var:", var)
print("Produced samples")
print("\n".join(tokenizer.batch_decode(samples)[::5]))

Freq: 0.1499999761581421 Var: 0.12814070284366608
Produced samples
Hello! Nice to see they have Burger King
Hello! Nice to meet you! Good night
Hello! Nice to see you comfortable and amazed
Hello! Nice to meet you. Whats up
Hello! Nice to see you, Izuna
Hello! Nice to do this to you guys
Hello! Nice to see you playing batting cages
Hello! Nice to meet you all! Aunt
Hello! Nice to meet you far from soon
Hello! Nice to meet you-" She said
Hello! Nice to meet you me!


Hello! Nice to greet you guys, now
Hello! Nice to be here (very nice
Hello! Nice to meet you too!

Hello! Nice to meet you Miss. Not
Hello! Nice to be with you connection!"
Hello! Nice to meet you!



Hello! Nice to hear from you! Oh
Hello! Nice to meet you! Welcome to
Hello! Nice to see you many wares
Hello! Nice to see you!" She laughed
Hello! Nice to meet you! But I
Hello! Nice to meet you, Vidaho
Hello! Nice to meet you Schröd
Hello! Nice to meet you!" AnAg
Hello! Nice to meet you guys!

Hello! Nice to meet you!"


Hell

In [7]:
# naive counting of number of sentences w/ one of the avoid terms
counts = 0

for i, sample in enumerate(samples):
    for token_id in sample:
        if token_id in avoid_terms_ids:
            #print(f"'{tokenizer.decode(token_id)}' appeared in sample {i}: '{tokenizer.decode(sample)}'")
            counts+=1
            break

# 
print("Total #(occur set A):", counts, f"(out of {len(samples)})")
print("Frequency #(occur set A) in samples", counts / len(samples))
print("Proba of not occurring:", 1 - counts / len(samples))

Total #(occur set A): 170 (out of 200)
Frequency #(occur set A) in samples 0.85
Proba of not occurring: 0.15000000000000002


In [None]:
tokenizer.pad_token