# **Surprisal**


Some reference:
- https://www.kaggle.com/code/smidgin/surprisal-with-gpt-2
- https://pypi.org/project/pysurprisal/
- https://github.com/aalok-sathe/surprisal
- https://github.com/byungdoh/slm_surprisal
- https://github.com/byungdoh/llm_surprisal
- https://github.com/tmalsburg/llm_surprisal
- https://github.com/benedict-krieger/llm-surprisal-rerps

Where I lerarned:
- https://huggingface.co/learn/llm-course/it/




Prima importiamo il modello.

- ***clean_up_tokenization_spaces*** = indica al 'tokenizer0 di 'ripulire' gli spazi quando decodifica i token in testo. Corregge artefatti di tokenizzazione come spazi prima della punteggiatura, doppi spazi. Può essere messo _False_ se si vuole analizzare il testo così com'è.

- ***from IPython.display import clear_output; clear_output()*** = serve a pulire l’output della cella in un notebook Jupyter/IPython. Lo si usa spesso dopo il caricamento del modello per togliere log, warning, o barre di progresso, lasciando il notebook più “pulito” e leggibile. In uno script Python normale (non notebook) non ha effetto e non è necessario.


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device= torch.device("cuda" if torch.cuda.is_available() else 'cpu')

model = AutoModelForCausalLM.from_pretrained("GroNLP/gpt2-medium-italian-embeddings").to(device)
toeknizer = AutoTokenizer.from_pretrained("GroNLP/gpt2-medium-italian-embeddings", clean_up_tokenization_space = True)


Calcoliamo le metriche

In [None]:
import torch
from torch import Tensor
import torch.nn.functional as D

def get_token_metrics (texts: list[str], model: AutoModelForCausalLM, tokenizer: Autotokenizer, truncation=False) -> dict:
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    B = len(texts)
    device = next(model.parameters()).device

    inputs = tokenizer(texts, return_tensors = "pt", padding=True, truncation = truncation, return_length =True).to(device)
    torch.cuda.empty_cache() #clear any leftover memory
    with torch.no_grad(), torch.amp.autocast(device.type):
        outputs = model(input_ids = inputs.input_ids, labels = inputs.input_ids)
        logits = outputs.logits[:, :-1].float()

        log_probs = torch.log_softmax(logits, dim=-1)
        next_tokens = inputs.input_idf[:, 1:]
        attention_mask = inputs.attention_mask[:, 1:]

        token_log_probs = log_probs.gather (-1, next_tokens.unsqueeze(-1)).squeeze(-1)

        surprisals = -token_log_probs * attention_mask
        null_first = torch.full((B, 1), float('nan'), device=device)
        surprisals = torch.cat((null_first, surprisals), 1)
                
        tokens = [
        tokenizer.convert_ids_to_tokens(seq[:mask.sum()].tolist())
        for seq, mask in zip(inputs.input_ids, inputs.attention_mask)]


        assert all(len(s) == n for s, n in zip(tokens, inputs.length))
        return {
        'tok_str':  tokens,  # list[list[str]] jagged shape [B, Tb<=T]
        'tok_surp': surprisals.cpu(),  # tensor[B, T]
        
        'tok_attn': attention_mask.cpu(),  # tensor[B, T]
    
        'seq_len':  inputs.length.cpu(),  # tensor[B]
        'vocab_size': len(tokenizer),  # int
    }
