In [1]:
import torch
import torch.nn.functional as F

from numpy import random

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

Network

In [109]:
model_name = "meta-llama/Llama-3.2-1B"

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    # bnb_4bit_compute_dtype=torch.float16,  # Ensure computation type matches input type
    bnb_4bit_compute_dtype=torch.bfloat16,  # Use bfloat16 for better performance
    bnb_4bit_use_double_quant=True,  # Double quantization for memory efficiency
)
llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    # quantization_config=quant_config,
    # device_map=torch.device("cuda"),
    device_map="auto"
)
llm.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)

Configuration

In [3]:
CONTEXT_SIZE = 8

In [None]:
def pad(tokens, padding_val):
    pad_len = CONTEXT_SIZE - tokens.shape[0] % CONTEXT_SIZE

    pads = torch.full([pad_len], padding_val, device=tokens.device)
    padded_tokens = torch.cat([tokens, pads])

    return padded_tokens, pad_len

def text_to_tokens(text):
    tokens = tokenizer(text, return_tensors="pt")
    return tokens["input_ids"].squeeze()

def get_rank(logits, indices):
   selected_logits = logits.gather(-1, indices[..., None]).squeeze(-1)
   return (logits > selected_logits[..., None]).long().sum(-1)

def get_token_by_rank(logits, ranks): ...


In [102]:
s = "The quick brown fox jumps over the lazy dog."

# s = ":".join(
#     str(x)
#     for x in random.randint(0, 5000, (50,)).tolist()
# )

print("String length:", len(s))

String length: 44


Encoding

In [58]:
tokens.shape

torch.Size([2, 9])

In [110]:
tokens = text_to_tokens(s) 

tokens, pad_len = pad(tokens[1:], tokenizer.eos_token_id)
tokens = tokens.view(-1, CONTEXT_SIZE)

bos = torch.full([tokens.shape[0]], tokenizer.bos_token_id, device=tokens.device).unsqueeze(1)
tokens = torch.cat((bos, tokens), 1)

ranks = torch.empty_like(tokens[:, :-1])
past_key_values = None
for idx in range(CONTEXT_SIZE):
    next_tokens = llm(tokens[:, :idx+1].cuda(), past_key_values=past_key_values)
    past_key_values = next_tokens.past_key_values
    
    rank = get_rank(next_tokens.logits[:, -1, :], tokens[:, idx+1].cuda())
    ranks[:, idx] = rank

# next_tokens = llm(tokens.cuda())

torch.cuda.empty_cache()
tokens.shape, ranks.shape

(torch.Size([2, 9]), torch.Size([2, 8]))

In [111]:
print(tokens)
print(ranks)

tensor([[128000,    791,   4062,  14198,  39935,  35308,    927,    279,  16053],
        [128000,   5679,     13, 128001, 128001, 128001, 128001, 128001, 128001]])
tensor([[    3,  1629,    25,     0,    12,     0,     0,    12],
        [37690,     9, 14633,   231,   266,   296,   299,   252]])


Decoding

In [112]:
input_ids = torch.tensor([[tokenizer.bos_token_id]]*ranks.shape[0], device=tokens.device)

with torch.no_grad():
    past_key_values = None
    for idx in range(CONTEXT_SIZE):
        output = llm(input_ids.cuda(), past_key_values=past_key_values, top_k=1)
        past_key_values = output.past_key_values

        logits = output.logits[:, -1, :]  # shape: (n_chunks, vocab)
        logits, sorted_tokens = torch.sort(logits, descending=True)

        next_token_id = sorted_tokens.gather(-1, ranks.cuda()[:, idx].unsqueeze(-1))

        input_ids = torch.cat([input_ids.cuda(), next_token_id], dim=1)
input_ids

tensor([[128000,    791,   4062,  14198,  39935,  35308,    927,    279,  16053],
        [128000,   5679,     13, 128001, 128001, 128001, 128001, 128001, 128001]],
       device='cuda:0')

In [113]:
output = input_ids[:, 1:].flatten()
generated_text = tokenizer.decode(output[:-pad_len], skip_special_tokens=True)
print("\nFinal generated sequence:\n", generated_text)


Final generated sequence:
 The quick brown fox jumps over the lazy dog.


In [None]:
s