In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from safetensors import save_file
from datasets import load_dataset
from dotenv import load_dotenv
import os
from typing import List
load_dotenv()


device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-405B-FP8", token=os.getenv("HUGGINGFACE"))
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-405B-FP8", token=os.getenv("HUGGINGFACE")).to(device)
model.eval()

In [None]:
ds = load_dataset("openai/gsm8k", "main", split="train")
questions = ds["question"]

In [None]:
def get_logits(inputs : str) -> torch.tensor:
    inputs = tokenizer(inputs, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        return outputs.logits

In [None]:
def save_logits(logits : List[torch.tensor], ) -> None:
    data = {}
    for i, logit in enumerate(logits):
        data[f"Question {i+1}"] = logit.cpu()

    save_file(data, "llama-3.1-405b-gsm8k-base-tensors.safetensors")

In [None]:
# to be used later on to load back the logits

from safetensors import safe_open

def load_list_of_logits_safetensor(file_path):
    # Open the safetensor file
    with safe_open(file_path, framework="pt") as f:
        logits_list = []
        for key in f.keys():
            logits_list.append(f.get_tensor(key))
    
    return logits_list


In [None]:
logits = [get_logits(question) for question in questions]
save_logits(logits)