In [3]:
print('Installing packages...')
! pip install torch transformers accelerate sentencepiece  datasets tqdm zstandard

Installing packages...
Collecting sentencepiece
  Using cached sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (10 kB)
Using cached sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (1.4 MB)
Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.2.1


In [1]:
import tqdm
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from functools import partial
import gc

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model_path = "facebook/opt-1.3b"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda", torch_dtype=torch.float16)

In [None]:
def evaluate(model, tokenizer, dataset=load_dataset('wikitext', 'wikitext-2-raw-v1', split='test'), batch_size=4, max_length=512):
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader):
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True, max_length=max_length)
            input_ids = inputs['input_ids'].to(model.device)
            attention_mask = inputs['attention_mask'].to(model.device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            total_loss += loss.item() * input_ids.size(0)
            total_tokens += input_ids.size(0)

    perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
    return perplexity.item()

In [None]:
model_perplexity = evaluate(model, tokenizer)