In [1]:
# hide_output
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2-xl"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
text = "Adilbek sucks at "
text_tokenized = tokenizer(text)
text_tokenized

{'input_ids': [2782, 346, 47083, 22523, 379, 220], 'attention_mask': [1, 1, 1, 1, 1, 1]}

In [3]:
tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"])

['Ad', 'il', 'bek', 'Ġsucks', 'Ġat', 'Ġ']

In [6]:
tokenizer.decode(text_tokenized["input_ids"])

'Adilbek sucks at '

In [27]:
import pandas as pd

input_txt = "Adilbek sucks at"
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
iterations = []
n_steps = 3
choices_per_step = 3


In [28]:
len(input_ids[0])

5

In [29]:
iterations

[]

In [None]:
with torch.no_grad():
    for _ in range(n_steps):
        iteration = dict()
        iteration["Input"] = tokenizer.decode(input_ids[0])
        output = model(input_ids=input_ids) # (1, 5, 50257)
        print(f"{output.logits.shape=}")
        # Select logits of the first batch and the last token and apply softmax
        next_token_logits = output.logits[0, -1, :] # (50257)
        print(f"{next_token_logits.shape=}")
        next_token_probs = torch.softmax(next_token_logits, dim=-1) # (50257)
        print(f"{next_token_probs.shape=}")
        sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True) # (50257)
        
        print(f"{sorted_ids.shape=}")
        # Store tokens with highest probabilities
        print("-"*40)
        for choice_idx in range(choices_per_step):
            token_id = sorted_ids[choice_idx]
            print(f"{token_id=}")
            token_prob = next_token_probs[token_id].cpu().numpy()
            token_choice = (
                f"{tokenizer.decode(token_id)} ({100 * token_prob:.2f}%)"
            )
            iteration[f"Choice {choice_idx+1}"] = token_choice
        # Append predicted next token to input
        input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)
        iterations.append(iteration)
        

output.logits.shape=torch.Size([1, 5, 50257])
next_token_logits.shape=torch.Size([50257])
next_token_probs.shape=torch.Size([50257])
sorted_ids.shape=torch.Size([50257])
----------------------------------------
token_id=tensor(10688)
token_id=tensor(465)
token_id=tensor(262)
output.logits.shape=torch.Size([1, 6, 50257])
next_token_logits.shape=torch.Size([50257])
next_token_probs.shape=torch.Size([50257])
sorted_ids.shape=torch.Size([50257])
----------------------------------------
token_id=tensor(13)
token_id=tensor(11)
token_id=tensor(290)
output.logits.shape=torch.Size([1, 7, 50257])
next_token_logits.shape=torch.Size([50257])
next_token_probs.shape=torch.Size([50257])
sorted_ids.shape=torch.Size([50257])
----------------------------------------
token_id=tensor(198)
token_id=tensor(679)
token_id=tensor(314)


In [34]:
next_token_logits

tensor([ 0.1078,  1.4739, -0.4956,  ..., -6.9666, -6.6467,  6.6501])

In [35]:
next_token_probs

tensor([1.1343e-05, 4.4460e-05, 6.2035e-06,  ..., 9.6011e-09, 1.3220e-08,
        7.8702e-03])

In [33]:
sorted_ids

tensor([  198,   679,   314,  ..., 45544,   216,   182])

In [36]:
next_token_probs[198]

tensor(0.1821)

In [37]:
next_token_probs[182]

tensor(1.6364e-16)

In [31]:
pd.DataFrame(iterations)

Unnamed: 0,Input,Choice 1,Choice 2,Choice 3
0,Adilbek sucks at,math (3.49%),his (3.46%),the (3.36%)
1,Adilbek sucks at math,. (34.40%),", (23.77%)",and (10.39%)
2,Adilbek sucks at math.,\n (18.21%),He (16.77%),I (3.04%)


In [32]:
tokenizer.convert_ids_to_tokens(input_ids[0])

['Ad', 'il', 'bek', 'Ġsucks', 'Ġat', 'Ġmath', '.', 'Ċ']

In [38]:
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
output = model.generate(input_ids, max_new_tokens=n_steps, do_sample=False)
print(tokenizer.decode(output[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Adilbek sucks at math.



In [40]:
print(tokenizer.decode(output[0]))

Adilbek sucks at math.



In [72]:

input_txt = "Adilbek sucks at"
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
iterations = []
n_steps = 10
choices_per_step = 3

print(input_txt)
with torch.no_grad():
    for _ in range(n_steps):
        output = model(input_ids=input_ids) # (1, 5, 50257)
        # Select logits of the first batch and the last token and apply softmax
        next_token_logits = output.logits[0, -1, :] # (50257)
        next_token_probs = torch.softmax(next_token_logits, dim=-1) # (50257)
        sorted_id = torch.argmax(next_token_probs, dim=-1).unsqueeze(-1).unsqueeze(-1) # (50257)
        
        input_ids = torch.cat([input_ids, sorted_id], dim=-1)
        print(tokenizer.decode(input_ids[0]))

Adilbek sucks at
Adilbek sucks at math
Adilbek sucks at math.
Adilbek sucks at math.

Adilbek sucks at math.


Adilbek sucks at math.

"
Adilbek sucks at math.

"I
Adilbek sucks at math.

"I'm
Adilbek sucks at math.

"I'm not
Adilbek sucks at math.

"I'm not good
Adilbek sucks at math.

"I'm not good at


In [62]:
next_token_probs

tensor([2.6458e-05, 2.3018e-05, 3.1340e-07,  ..., 4.1718e-07, 3.0811e-08,
        1.8454e-05])

In [65]:
sorted_id = torch.argmax(next_token_probs, dim=-1) # (50257)
sorted_id

tensor(10688)

In [66]:
sorted_id = sorted_id.unsqueeze(-1)  # Add dimension at the end
sorted_id

tensor([10688])

In [67]:
sorted_id = sorted_id.unsqueeze(-1)  # Add dimension at the end
sorted_id

tensor([[10688]])

In [57]:
input_ids

tensor([[ 2782,   346, 47083, 22523,   379]])

In [71]:
tokenizer.decode(torch.cat([input_ids, sorted_id], dim=-1)[0])


'Adilbek sucks at math math'

In [73]:
input_txt = "5 + 5 = "
n_steps = 2
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
output = model.generate(input_ids, max_new_tokens=n_steps, do_sample=False)
print(tokenizer.decode(output[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


5 + 5 = ????

The first thing to note is that


In [74]:
input_txt = "5 + 8 => 13 \n 7 + 2 => 9 \n 1 + 0 =>"

n_steps = 3

input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
output = model.generate(input_ids, max_new_tokens=n_steps, do_sample=False)

print('-'*40)
print(tokenizer.decode(output[0]))


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


----------------------------------------
5 + 8 => 13 
 7 + 2 => 9 
 1 + 0 => 1 



In [75]:
import torch.nn.functional as F

def log_probs_from_logits(logits, labels):
    logp = F.log_softmax(logits, dim=-1)
    logp_label = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logp_label


In [None]:
def sequence_logprob(model, labels, input_len=0):
    with torch.no_grad():
        output = model(labels)
        log_probs = log_probs_from_logits(
            output.logits[:, :-1, :], labels[:, 1:]
        )
        seq_log_prob = torch.sum(log_probs[:, input_len:])
    return seq_log_prob.cpu().numpy()
     

In [77]:

output


tensor([[  20, 1343,  807, 5218, 1511,  220,  198,  767, 1343,  362, 5218,  860,
          220,  198,  352, 1343,  657, 5218,  352,  220,  198]])

In [78]:
logp = sequence_logprob(model, output, input_len=len(input_ids[0]))
print(tokenizer.decode(output[0]))
print(f"\nlog-prob: {logp:.2f}")

5 + 8 => 13 
 7 + 2 => 9 
 1 + 0 => 1 


log-prob: -0.42


In [80]:
input_txt = "Guide to find a girlfriend\n"
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)

In [84]:
max_tokens = 50

In [85]:
output_beam = model.generate(input_ids, max_length=max_tokens, num_beams=3, 
                             do_sample=False)
logp = sequence_logprob(model, output_beam, input_len=len(input_ids[0]))
print(tokenizer.decode(output_beam[0]))
print(f"\nlog-prob: {logp:.2f}")
     


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Guide to find a girlfriend

How to find a girlfriend

How to find a girlfriend

How to find a girlfriend

How to find a girlfriend

How to find a girlfriend

How to find a girlfriend

How

log-prob: -12.93


In [89]:
from transformers import TextStreamer

streamer = TextStreamer(tokenizer=tokenizer)

output_beam = model.generate(input_ids, max_length=max_tokens, num_beams=1, 
                             do_sample=False, no_repeat_ngram_size=2, streamer=streamer)
logp = sequence_logprob(model, output_beam, input_len=len(input_ids[0]))
print(tokenizer.decode(output_beam[0]))
print(f"\nlog-prob: {logp:.2f}")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Guide to find a girlfriend

How to get a girl to like you
 the best way to make a woman fall in love with you<|endoftext|>
Guide to find a girlfriend

How to get a girl to like you
 the best way to make a woman fall in love with you<|endoftext|>

log-prob: -38.28


In [87]:
output_beam

tensor([[47889,   284,  1064,   257, 11077,   198,   198,  2437,   284,  1064,
           257, 11077,   198,   198,  2437,   284,  1064,   257, 11077,   198,
           198,  2437,   284,  1064,   257, 11077,   198,   198,  2437,   284,
          1064,   257, 11077,   198,   198,  2437,   284,  1064,   257, 11077,
           198,   198,  2437,   284,  1064,   257, 11077,   198,   198,  2437]])