In [1]:
import torch
import torch.nn.functional as F
from modeling_llama import LlamaForCausalLM
from transformers import LlamaTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "meta-llama/Llama-2-7b-hf"  # Adjust based on your access
tokenizer = LlamaTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.19it/s]


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [27]:
import torch
import torch.nn.functional as F

def generate_next_tokens(model, tokenizer, input_ids, num_passes=100):
    device = model.device
    past_key_values = None
    outputs = []
    input_len = input_ids.shape[1]

    with torch.no_grad():
        # First pass
        first_outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
        next_token_logits = first_outputs.logits[:, -1, :]
        past_key_values = first_outputs.past_key_values
        next_token_scores = F.log_softmax(next_token_logits, dim=-1)
        next_token_scores, next_tokens = torch.topk(
            next_token_scores, 3, dim=1, largest=True, sorted=True
        )
        outputs.append(next_tokens[0])

        # Subsequent passes
        for i in range(1, num_passes):
            position_ids = torch.tensor([[i + input_len-1, i + input_len-1, i + input_len-1]], device=device)
            
            attention_mask_length = input_len + i * 3
            attention_mask = torch.zeros((1, 1, 3, attention_mask_length), device=device, dtype=torch.float16)
            attention_mask[0, 0, :, :input_len] = 0
            for j in range(i):
                start_idx = input_len + j * 3
                attention_mask[0, 0, :, start_idx:start_idx+3] = torch.tensor([[0, -65504, -65504],
                                                                               [-65504, 0, -65504],
                                                                               [-65504, -65504, 0]], device=device)

            #print(position_ids, attention_mask)

            pass_outputs = model(next_tokens, past_key_values=past_key_values, 
                                 position_ids=position_ids, attention_mask=attention_mask, use_cache=True)
            next_token_logits = pass_outputs.logits[:, -3:, :]
            past_key_values = pass_outputs.past_key_values
            
            next_tokens = torch.argmax(next_token_logits, dim=-1)
            outputs.append(next_tokens[0])

    return outputs

# Example usage
input_ids = tokenizer.encode("Once upon a time", return_tensors="pt").to(model.device)
print(f"Input shape: {input_ids.shape}")

generated_outputs = generate_next_tokens(model, tokenizer, input_ids)
stacked_tensor = torch.stack(generated_outputs)

# Group elements by position
grouped_tensors = [
    stacked_tensor[:, 0],  # First elements
    stacked_tensor[:, 1],  # Second elements
    stacked_tensor[:, 2]   # Third elements
]

# Print the result
for i, tensor in enumerate(grouped_tensors, 1):
    print(f"Sentence {i}: Once upon a time {tokenizer.decode(tensor)}")

Input shape: torch.Size([1, 5])
Sentence 1: Once upon a time , there was a little girl who loved to read. She loved to read so much that she would read anything she could get her hands on. She would read the cereal box, the back of the cereal box, the back of the cereal box again, and then she would read the cereal box again. She would read the cereal box so much that she would read the cereal box until she was sick of reading the cereal box.
Sentence 2: Once upon a time there was a little girl who loved to read. She loved to read so much that she would read anything she could get her hands on. She would read the back of cereal boxes, the ingredients on the side of the box, the instructions on the back of the box, the instructions on the side of the box, the instructions on the back of the box, the instructions on the side of the box, the instructions on the back of the box, the instructions on the side
Sentence 3: Once upon a time in a land far, far away, there was a little girl who l