In [1]:
from transformers import GemmaForCausalLM, GemmaTokenizer
from transformers import GPT2LMHeadModel
from transformers import LlamaForCausalLM
import torch
import gc
import numpy as np
import random
import torch.nn.functional as F

# Gemma-7b

In [2]:
gemma = GemmaForCausalLM.from_pretrained(
    "google/gemma-7b",
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    output_attentions=True,
)
tokenizer = GemmaTokenizer.from_pretrained("google/gemma-7b")
configuration = gemma.config

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

## add BOS

In [31]:
np.random.seed(2024)

T0 = 10
rep = 3

for _ in range(10):
    segment = np.random.randint(low=0, high=configuration.vocab_size, size=T0).tolist()

    input_ids = np.concatenate([segment for _ in range(3)])
    input_ids = np.concatenate([[2], input_ids])  ############ ADD BOS Here
    input_ids = torch.Tensor(input_ids).long().unsqueeze(0).cuda()
    with torch.no_grad():
        logits = gemma(input_ids).logits

    probs = F.softmax(logits.float(), dim=-1)
    top_prob, pred_next_token_ids = torch.topk(probs, dim=-1, k=1)

    correct_token_ids = input_ids[0, 1:]
    pred_token_ids = pred_next_token_ids[0, :-1, 0]
    T_range = range(T0 + T0 // 2, rep * T0)

    print(np.mean(((correct_token_ids == pred_token_ids).numpy(force=True))[T_range]))

0.8666666666666667
0.7333333333333333
1.0
1.0
1.0
0.8666666666666667
0.9333333333333333
0.9333333333333333
1.0
1.0


## no BOS

In [32]:
np.random.seed(2024)

T0 = 10
rep = 3

for _ in range(10):
    segment = np.random.randint(low=0, high=configuration.vocab_size, size=T0).tolist()

    input_ids = np.concatenate([segment for _ in range(rep)])
    # input_ids = np.concatenate([[2], input_ids]) ########## BOS Commented out
    input_ids = torch.Tensor(input_ids).long().unsqueeze(0).cuda()
    with torch.no_grad():
        logits = gemma(input_ids).logits

    probs = F.softmax(logits.float(), dim=-1)
    top_prob, pred_next_token_ids = torch.topk(probs, dim=-1, k=1)

    correct_token_ids = input_ids[0, 1:]
    pred_token_ids = pred_next_token_ids[0, :-1, 0]
    T_range = range(T0 + T0 // 2, rep * T0 - 1)

    print(np.mean(((correct_token_ids == pred_token_ids).numpy(force=True))[T_range]))

0.14285714285714285
0.35714285714285715
0.0
0.5
0.35714285714285715
0.2857142857142857
0.21428571428571427
0.07142857142857142
0.14285714285714285
0.2857142857142857


# Llama2-7b

In [2]:
llama = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    device_map="cuda",
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)
model = llama.model
configuration = llama.config

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
batch_size = 1
T0 = 10
T_cnt = T0 + T0 // 2
rep = 3
vocab_size = configuration.vocab_size

np.random.seed(2024)
sample_int = np.random.randint(low=0, high=vocab_size, size=batch_size * T0).reshape(
    batch_size, T0
)
sample_int = np.concatenate(tuple([sample_int] * rep), axis=1)
input_ids = torch.Tensor(sample_int).long().cuda()
correct_next_token_ids = (
    torch.Tensor(np.concatenate((sample_int[:, 1:], sample_int[:, :1]), axis=1))
    .long()
    .cuda()
)

with torch.no_grad():
    out = llama(input_ids)

logits = out.logits
attentions = out.attentions

probs = F.softmax(logits.float(), dim=-1)
top_prob, pred_next_token_ids = torch.topk(probs, dim=-1, k=1)

print(
    (pred_next_token_ids[0, T_cnt:-1, 0] == correct_next_token_ids[:, T_cnt:-1])
    .to(torch.float)
    .mean()
    .item(),
)

0.9285714626312256


In [26]:
# np.random.seed(2024)

T0 = 10
rep = 3

for _ in range(10):
    segment = np.random.randint(low=0, high=configuration.vocab_size, size=T0).tolist()

    input_ids = np.concatenate([segment for _ in range(rep)])
    # input_ids = np.concatenate([[1], input_ids])  ########## BOS Commented out
    input_ids = torch.Tensor(input_ids).long().unsqueeze(0).cuda()
    with torch.no_grad():
        logits = llama(input_ids).logits

    probs = F.softmax(logits.float(), dim=-1)
    top_prob, pred_next_token_ids = torch.topk(probs, dim=-1, k=1)

    correct_token_ids = input_ids[0, 1:]
    pred_token_ids = pred_next_token_ids[0, :-1, 0]
    T_range = range(T0 + T0 // 2, rep * T0 - 1)

    print(np.mean(((correct_token_ids == pred_token_ids).numpy(force=True))[T_range]))

0.8571428571428571
0.8571428571428571
0.7857142857142857
1.0
0.7142857142857143
1.0
1.0
1.0
0.8571428571428571
0.7857142857142857


: 

# GPT-2

In [2]:
gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
configuration = gpt2.config

In [3]:
batch_size = 1
T0 = 20
T_cnt = T0 + T0 // 2
rep = 3
vocab_size = configuration.vocab_size

np.random.seed(2024)
sample_int = np.random.randint(low=0, high=vocab_size, size=batch_size * T0).reshape(
    batch_size, T0
)
sample_int = np.concatenate(tuple([sample_int] * rep), axis=1)
input_ids = torch.Tensor(sample_int).long()
correct_next_token_ids = torch.Tensor(
    np.concatenate((sample_int[:, 1:], sample_int[:, :1]), axis=1)
).long()

with torch.no_grad():
    logits = gpt2(input_ids).logits

probs = F.softmax(logits.float(), dim=-1)
top_prob, pred_next_token_ids = torch.topk(probs, dim=-1, k=1)

print(
    (pred_next_token_ids[0, T_cnt:-1, 0] == correct_next_token_ids[:, T_cnt:-1])
    .to(torch.float)
    .mean()
    .item(),
)

1.0
