# Testing Approaches to get output log probs sequences

The goal of this notebook is to study and explore ways to use the log prob output of a language models in comparsion with a target sequence to estimate the importance of a given context


## What are we using here?

* Model: GPT2: a light baseline to allows us to quick handle the outputs and generatios
* Methods: generate() - given a sequence of inputs it returns the scores and logits (if specified to)

In [4]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"


In [5]:
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Initial Visualization of the outputs

In [6]:

gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")



In [7]:
input_ids = tokenizer("Today is a nice day", return_tensors="pt").input_ids
generated_outputs = gpt2.generate(input_ids, 
                                  output_scores=True, 
                                  length_penalty=0, 
                                  output_logits=True,
                                  return_dict_in_generate=True,
                                  max_new_tokens=3,
                                  num_beams=2,
                                 )

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [8]:
type(generated_outputs)

transformers.generation.utils.GenerateBeamDecoderOnlyOutput

In [9]:
generated_outputs.keys()

odict_keys(['sequences', 'sequences_scores', 'scores', 'logits', 'beam_indices', 'past_key_values'])

In [10]:
generated_outputs.sequences

tensor([[8888,  318,  257, 3621, 1110,   11,  475,  314]])

In [11]:
len(generated_outputs.scores)

3

In [12]:
generated_outputs.scores[0]

tensor([[ -3.8570,  -6.2563, -13.6480,  ..., -20.5584, -17.1551,  -9.5375],
        [ -3.8570,  -6.2563, -13.6480,  ..., -20.5584, -17.1550,  -9.5375]])

In [13]:
len(generated_outputs.logits)

3

In [33]:
generated_outputs.logits[0]

tensor([[ -88.5250,  -90.9243,  -98.3160,  ..., -105.2264, -101.8231,
          -94.2055],
        [ -88.5250,  -90.9243,  -98.3160,  ..., -105.2264, -101.8231,
          -94.2055]])

In [15]:
generated_text = tokenizer.decode(generated_outputs.sequences[0])
print(generated_text)

Today is a nice day, but I


## 2. Increment Context Approach

In this approach the log_prob is calculated for each expected target token where they are generated one by one and them added to the context in order to be done the next prediction

Link: https://discuss.huggingface.co/t/compute-log-probabilities-of-any-sequence-provided/11710/16

In [52]:
def increment_context(input_token, target_output, model, tokenizer):

    log_sum = 0
    input_tokens_updated = input_tokens.clone().to(torch.int64).to(device)

    for i in range(len(target_output)):
        # Predict with the given model
        with torch.no_grad():
            outputs = model.generate(input_tokens_updated, max_new_tokens=1, output_logits=True, return_dict_in_generate=True, pad_token_id=50256)
            logit_predictions = outputs.logits[0]
    
        # Extract the log probability of the output token
        token = tokenizer.decode(target_output[i])
        log_probs = torch.nn.functional.log_softmax(logit_predictions, dim=-1)
        out_token_logit = logit_predictions[0, target_output[i]]
        out_token_log_prob = log_probs[0, target_output[i]]
        log_sum += out_token_log_prob
        print(f"Token: {token}, logit: {out_token_logit}, log prob: {out_token_log_prob}")


        predicted_token = tokenizer.decode(outputs.sequences[0][-1])
        predicted_logit = logit_predictions[0, outputs.sequences[0][-1]]
        predicted_log_prob = log_probs[0, outputs.sequences[0][-1]]
        print(f"Predicted Token: {predicted_token}, logit: {predicted_logit}, log prob: {predicted_log_prob}")
    
        # Incrementally add an output token to the current sequence
        input_tokens_updated = torch.cat([input_tokens_updated, target_output[i].reshape(1, 1)], dim=1)
        print([tokenizer.decode(token) for token in input_tokens_updated])
        print("============")
        print()
    print(f"Total Log Sum Probability: {log_sum}")

In [47]:
input_tokens = tokenizer.encode("Today is a nice day", add_special_tokens=False, return_tensors="pt").to(device)
target_output = tokenizer.encode(", and tomorrow it will be as well", add_special_tokens=False, return_tensors="pt")[0].to(device)


In [48]:
input_tokens

tensor([[8888,  318,  257, 3621, 1110]], device='cuda:0')

In [49]:
target_output

tensor([  11,  290, 9439,  340,  481,  307,  355,  880], device='cuda:0')

In [54]:
increment_context(input_tokens, target_output, gpt2.to(device), tokenizer)

Token: ,, logit: -86.54722595214844, log prob: -1.8792250156402588
Predicted Token:  for, logit: -86.18283081054688, log prob: -1.5148298740386963
['Today is a nice day,']

Token:  and, logit: -98.12623596191406, log prob: -2.3308000564575195
Predicted Token:  but, logit: -97.56433868408203, log prob: -1.7689028978347778
['Today is a nice day, and']

Token:  tomorrow, logit: -120.85678100585938, log prob: -5.777646064758301
Predicted Token:  I, logit: -116.57794189453125, log prob: -1.4988069534301758
['Today is a nice day, and tomorrow']

Token:  it, logit: -85.92831420898438, log prob: -3.039114236831665
Predicted Token:  is, logit: -83.59170532226562, log prob: -0.7025054097175598
['Today is a nice day, and tomorrow it']

Token:  will, logit: -71.38861083984375, log prob: -1.4177868366241455
Predicted Token: 's, logit: -70.89265441894531, log prob: -0.921830415725708
['Today is a nice day, and tomorrow it will']

Token:  be, logit: -89.09539794921875, log prob: -0.24227555096149445
