In [16]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from pathlib import Path
from einops import rearrange, repeat
import os
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [17]:
# Load pre-trained model and tokenizer
model_name = "huggyllama/llama-13b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Put model in evaluation mode
model.eval()

Loading checkpoint shards: 100%|██████████| 3/3 [05:46<00:00, 115.57s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 5120, padding_idx=0)
    (layers): ModuleList(
      (0-39): 40 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (k_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (v_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [66]:
def get_prob(sequences):
    
    for input_sequence in sequences:

        # Encode input sequence
        input_ids = tokenizer.encode(input_sequence, return_tensors="pt")

        vocab = tokenizer.get_vocab()

        # Generate hidden states
        outputs = model(input_ids, output_hidden_states=True)
        
        # Map final hidden state to vocabulary size using output embeddings
        logits = outputs.logits

        # Get probability distribution over vocabulary for next token
        next_token_probs = F.softmax(logits, dim=-1)[0, -1].tolist()

        yes_prob = next_token_probs[vocab['▁Yes']] + next_token_probs[vocab['▁yes']] + next_token_probs[vocab['▁YES']] + next_token_probs[vocab['▁TRUE']] + next_token_probs[vocab['▁true']]
        no_prob = next_token_probs[vocab['▁No']] + next_token_probs[vocab['▁no']] + next_token_probs[vocab['▁NO']] + next_token_probs[vocab['▁FALSE']] + next_token_probs[vocab['▁false']]

        yes_norm = yes_prob/(yes_prob + no_prob)
        no_norm = 1-yes_norm

        print(f"{input_sequence}\nYes {yes_norm:.2f}\nNo  {no_norm:.2f}\n------------------------------")

In [67]:
get_prob(["Is 2+5=7 right?"])

Is 2+5=7 right?
Yes 0.46
No  0.54
------------------------------


In [35]:
get_prob(["Is 2+3 = 5?"])

Is 2+3 = 5?
Yes 0.66
No  0.34
------------------------------


In [36]:
get_prob(["Is 2+3 = 4?"])

Is 2+3 = 4?
Yes 0.74
No  0.26
------------------------------


In [38]:
get_prob(["Is 2-3 = 4?"])

Is 2-3 = 4?
Yes 0.79
No  0.21
------------------------------


In [39]:
get_prob(["?"])

?
Yes 0.07
No  0.93
------------------------------


In [25]:
get_prob("Is 2+2 = 4?")

I
Yes 0.00
No  1.00
------------------------------
s
Yes 0.02
No  0.98
------------------------------
 
Yes 0.07
No  0.93
------------------------------
2
Yes 0.09
No  0.91
------------------------------
+
Yes 0.17
No  0.83
------------------------------
2
Yes 0.09
No  0.91
------------------------------
 
Yes 0.07
No  0.93
------------------------------
=
Yes 0.06
No  0.94
------------------------------
 
Yes 0.07
No  0.93
------------------------------
4
Yes 0.07
No  0.93
------------------------------
?
Yes 0.06
No  0.94
------------------------------


In [26]:
get_prob("Is 2+2 = 4?")

I
Yes 0.00
No  1.00
------------------------------
s
Yes 0.02
No  0.98
------------------------------
 
Yes 0.07
No  0.93
------------------------------
2
Yes 0.09
No  0.91
------------------------------
+
Yes 0.17
No  0.83
------------------------------
2
Yes 0.09
No  0.91
------------------------------
 
Yes 0.07
No  0.93
------------------------------
=
Yes 0.06
No  0.94
------------------------------
 
Yes 0.07
No  0.93
------------------------------
4
Yes 0.07
No  0.93
------------------------------
?
Yes 0.06
No  0.94
------------------------------


## Old

In [None]:
def get_prob(sequences):
    
    for input_sequence in sequences:

        # Encode input sequence
        input_ids = tokenizer.encode(input_sequence, return_tensors="pt")

        # Generate hidden states
        outputs = model(input_ids, output_hidden_states=True)
        # Get final hidden state of last token
        last_token_hidden_state = outputs.hidden_states[-1][:, -1, :]

        # Get the vocabulary from the tokenizer
        vocab = tokenizer.get_vocab()

        # Map final hidden state to vocabulary size using linear layer
        linear_layer = nn.Linear(model.config.hidden_size, len(vocab))
        logits = linear_layer(last_token_hidden_state)

        # Get probability distribution over vocabulary for next token
        next_token_probs = F.softmax(logits, dim=1)[0].tolist()

        yes_prob = next_token_probs[vocab['▁Yes']] + next_token_probs[vocab['▁yes']] + next_token_probs[vocab['▁YES']]
        no_prob = next_token_probs[vocab['▁No']] + next_token_probs[vocab['▁no']] + next_token_probs[vocab['▁NO']]

        yes_norm = yes_prob/(yes_prob + no_prob)
        no_norm = 1-yes_norm

        print(f"{input_sequence}\nYes {yes_norm:.2f}\nNo  {no_norm:.2f}\n------------------------------")

In [51]:
sequences = ['Is 1+1 = 2?',
             'Is 1+1 = 0?',
             'Is 1+1 = 10?',
             'Is 1+1 = 100?',
             'Is 1+1 = 105?',
             'Is 2+2 = 4?',
             'Is 2+2 = 0?',
             'Is 2+2 = 10?',
             'Is 2+2 = 100?',
             'Is 2+2 = 105?',
             'Is 145 + 132 = 277?',
             'Is 145 + 132 = 13',
             'Is 145 - 132 = 13',
             'Is 145 - 132 = 11'
             ]

get_prob(sequences)

Is 1+1 = 2?
Yes 0.52
No  0.48
------------------------------
Is 1+1 = 0?
Yes 0.55
No  0.45
------------------------------
Is 1+1 = 10?
Yes 0.49
No  0.51
------------------------------
Is 1+1 = 100?
Yes 0.59
No  0.41
------------------------------
Is 1+1 = 105?
Yes 0.34
No  0.66
------------------------------
Is 2+2 = 4?
Yes 0.51
No  0.49
------------------------------
Is 2+2 = 0?
Yes 0.43
No  0.57
------------------------------
Is 2+2 = 10?
Yes 0.37
No  0.63
------------------------------
Is 2+2 = 100?
Yes 0.59
No  0.41
------------------------------
Is 2+2 = 105?
Yes 0.72
No  0.28
------------------------------
Is 145 + 132 = 277?
Yes 0.61
No  0.39
------------------------------
Is 145 + 132 = 13
Yes 0.55
No  0.45
------------------------------
Is 145 - 132 = 13
Yes 0.51
No  0.49
------------------------------
Is 145 - 132 = 11
Yes 0.20
No  0.80
------------------------------


In [53]:
sequences = ['Is 1+1 = 2?',
             'Is 1+1 = 2?',
             'Is 1+1 = 2?',
             'Is 1+1 = 2?',
             'Is 1+1 = 2?',
             'Is 1+1 = 2?',
             'Is 1+1 = 2?',
             'Is 1+1 = 2?',
             'Is 1+1 = 2?',
             'Is 1+1 = 2?',
             'Is 1+1 = 2?',
             'Is 1+1 = 2?',
             ]

get_prob(sequences)

Is 1+1 = 2?
Yes 0.50
No  0.50
------------------------------
Is 1+1 = 2?
Yes 0.48
No  0.52
------------------------------
Is 1+1 = 2?
Yes 0.18
No  0.82
------------------------------
Is 1+1 = 2?
Yes 0.37
No  0.63
------------------------------
Is 1+1 = 2?
Yes 0.48
No  0.52
------------------------------
Is 1+1 = 2?
Yes 0.39
No  0.61
------------------------------
Is 1+1 = 2?
Yes 0.69
No  0.31
------------------------------
Is 1+1 = 2?
Yes 0.32
No  0.68
------------------------------
Is 1+1 = 2?
Yes 0.25
No  0.75
------------------------------
Is 1+1 = 2?
Yes 0.23
No  0.77
------------------------------
Is 1+1 = 2?
Yes 0.24
No  0.76
------------------------------
Is 1+1 = 2?
Yes 0.81
No  0.19
------------------------------


In [54]:
sequences = ['Is 13+45 = 58?',
             'Is 13+45 = 58?',
             'Is 13+45 = 58?',
             'Is 13+45 = 58?',
             'Is 13+45 = 58?',
             'Is 13+45 = 58?',
             'Is 13+45 = 58?',
             'Is 13+45 = 58?',
             ]

get_prob(sequences)

Is 13+45 = 58?
Yes 0.50
No  0.50
------------------------------
Is 13+45 = 58?
Yes 0.59
No  0.41
------------------------------
Is 13+45 = 58?
Yes 0.67
No  0.33
------------------------------
Is 13+45 = 58?
Yes 0.36
No  0.64
------------------------------
Is 13+45 = 58?
Yes 0.45
No  0.55
------------------------------
Is 13+45 = 58?
Yes 0.45
No  0.55
------------------------------
Is 13+45 = 58?
Yes 0.39
No  0.61
------------------------------
Is 13+45 = 58?
Yes 0.56
No  0.44
------------------------------


In [55]:
sequences = ['Is 2+2 = 4?',
             'Is 2+2 = 4?',
             'Is 2+2 = 4?',
             'Is 2+2 = 4?',
             'Is 2+2 = 4?',
             'Is 2+2 = 4?',
             'Is 2+2 = 4?',
             'Is 2+2 = 4?',
             ]

get_prob(sequences)

Is 2+2 = 4?
Yes 0.68
No  0.32
------------------------------
Is 2+2 = 4?
Yes 0.47
No  0.53
------------------------------
Is 2+2 = 4?
Yes 0.72
No  0.28
------------------------------
Is 2+2 = 4?
Yes 0.65
No  0.35
------------------------------
Is 2+2 = 4?
Yes 0.23
No  0.77
------------------------------
Is 2+2 = 4?
Yes 0.47
No  0.53
------------------------------
Is 2+2 = 4?
Yes 0.25
No  0.75
------------------------------
Is 2+2 = 4?
Yes 0.65
No  0.35
------------------------------
