In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import torch
import torch.nn as nn
import tqdm 
# add current directory to path
sys.path.append(os.path.join(os.getcwd(), "llama"))
sys.path.append(os.path.join("..", "llama"))


from llama import Llama

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12357"
os.environ["NUM_TRAINERS"] = "1"

model_name = "llama"
max_seq_len = 11
max_batch_size = 1024
generator = Llama.build(
        ckpt_dir=f"../../{model_name}/META_RELEASED_WEIGHTS/7B",
        tokenizer_path=f"../../{model_name}/META_RELEASED_WEIGHTS/tokenizer.model",
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
        model_parallel_size=1,
    )
generator.model = generator.model.half()
for n, p in generator.model.named_parameters():
    p.requires_grad = False

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
Loaded in 5.79 seconds


# Check how many completions with a single token

In [2]:
# Check the 10-token long completions for each possible token
batches = torch.arange(32_000).split(max_batch_size)
completions = torch.zeros((32_000, 10), dtype=torch.long)

for b_idx, b in tqdm.tqdm(enumerate(batches), total=len(batches)):
    b = b.unsqueeze(1).to("cuda")
    for i in range(10):
        next_tokens = generator.model.forward(b, start_pos=0)[:, -1, :].argmax(dim=-1).flatten()
        completions[b_idx * max_batch_size : min((b_idx + 1) * max_batch_size, 32_000),i] = next_tokens
        b = torch.cat((b, next_tokens.unsqueeze(1)), dim=1)

# Find the number of unique completions
unique_completions = torch.unique(completions, dim=0)
print(f"Number of unique completions: {unique_completions.shape[0]}")


100%|██████████| 32/32 [03:00<00:00,  5.64s/it]

Number of unique completions: 24426





In [3]:
# decode the tokens
print(generator.tokenizer.decode(torch.cat((torch.arange(32_000).unsqueeze(1), completions), dim=1).tolist()))



# Check how many completions with a single prefix

In [4]:
prefixes = torch.load("../llama_completions.pt", map_location="cuda")
print(f"{prefixes.size(0)} prefixes loaded, each of length {prefixes.size(1)}")

48218 prefixes loaded, each of length 1


In [5]:
batches = prefixes.split(max_batch_size)
completions = torch.zeros((prefixes.size(0), 10), dtype=torch.long)

with torch.no_grad():
    for b_idx, b_prefix in tqdm.tqdm(enumerate(batches), total=len(batches)):
        b_prefix = b_prefix.to("cuda").half()
        for i in range(10):
            next_tokens = generator.model.forward(b_prefix, start_pos=0, virtual_tokens=True)[:, -1, :].argmax(dim=-1).flatten()
            completions[b_idx * max_batch_size : b_idx * max_batch_size + b_prefix.size(0),i] = next_tokens
            b_prefix = torch.cat((b_prefix, generator.model.tok_embeddings(next_tokens.unsqueeze(1))), dim=1)



100%|██████████| 48/48 [04:32<00:00,  5.67s/it]


In [6]:
# Find the number of unique completions
unique_completions = torch.unique(completions, dim=0)
print(f"Number of unique completions: {unique_completions.shape[0]}")

Number of unique completions: 46812


In [7]:
# decode the tokens
print(generator.tokenizer.decode(completions.tolist()))

