In [None]:
%pip install transformers datasets
%pip install https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.4.0-cp311-cp311-linux_x86_64.whl

In [None]:
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer, ExLlamaV2Cache_Q6
import torch
from safetensors.torch import save_file
from datasets import load_dataset
from typing import List
import gc
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
num_gpus = torch.cuda.device_count()
print(device)

In [None]:
# test gpu-gpu access
for i in range(num_gpus):
        for j in range(num_gpus):
            if i != j:
                can_access = torch.cuda.can_device_access_peer(i, j)
                print(f"GPU {i} can access GPU {j}: {can_access}")

In [None]:
# load and prepare model for generation
config = ExLlamaV2Config()
config.model_dir = "./llama-405b/models--ek826--Meta-Llama-3.1-405B-Instruct-6.0bpw-exl2/snapshots/c3aafe6600c3c081f514e33ea325da7a19cb1822"
config.prepare()
model = ExLlamaV2(config)
tokenizer = ExLlamaV2Tokenizer(config)

# 6-bit cache chosen due to only having 4 96gb H100s (376gb VRAM total)
cache = ExLlamaV2Cache_Q6(model, lazy = True)
model.load_autosplit(cache)

In [3]:
# load GSM8K and grab first 2638 questions only
ds = load_dataset("openai/gsm8k", "main", split="train")
questions = ds["question"]
questions = questions[:2638]

In [4]:
# generate logit for an input sequence
def get_logits(inputs : str) -> torch.tensor:
    inputs = tokenizer.encode(inputs)

    with torch.no_grad():
        outputs = model.forward(inputs)

    # mem management to ensure no VRAM overflow
    gc.collect()
    torch.cuda.empty_cache()
    
    return outputs

In [5]:
# save list of logits to disk
def save_logits(logits : List[torch.tensor], ) -> None:
    data = {}
    for i, logit in enumerate(logits):
        data[f"Question {i+1}"] = logit.to('cpu')

    save_file(data, "llama-3.1-405b-gsm8k-teacher-prompt-tensors.safetensors")

In [None]:
# get and save logits + move them off VRAM after generation
logits = [get_logits(question).to('cpu') for question in tqdm(questions)]
save_logits(logits)