In [1]:
print('Installing packages...')
! pip install torch transformers accelerate sentencepiece  datasets tqdm zstandard

In [1]:
import tqdm
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from functools import partial
import gc

In [3]:
model_path = "facebook/opt-1.3b"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda")
dataset=load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

In [45]:
n_samples=50
def evaluate_perplexity(model,tokenizer):
    tokenized_data=tokenizer("/n/n".join(dataset['text']),return_tensors='pt')
    tokens=tokenized_data.input_ids
    model.eval()
    total_loss=0.0
    tokens.to(model.device)
    for i in tqdm.tqdm(range(n_samples)):
        batch=tokens[:,i*2048:(i+1)*2048].to(model.device)
        with torch.no_grad():
            logits=model(batch).logits
            shift_logits=logits[:,:-1,:].contiguous()
            shift_labels=batch[:,1:].contiguous()
            loss_fct=nn.CrossEntropyLoss()
            loss=loss_fct(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1))
            total_loss+=loss.item()
        del batch
    return(torch.exp(torch.tensor(total_loss/(n_samples))))

def model_size(model,data_width=16,group_size=-1):
    if group_size!=-1:
        scale_width=16
        zero_point_width=4
        data_width+= (scale_width + zero_point_width)/group_size
    num_params=0
    for n,m in model.named_parameters():
        num_params+=m.numel()
    size_in_bits=num_params*data_width
    size_in_megabytes=size_in_bits/(8*1024*1024)
    return size_in_megabytes



In [None]:
### Model info 
print(f'Model size (in MB): {model_size(model,data_width=32,group_size=-1)}')
print(f'Model perplexity: {evaluate_perplexity(model,tokenizer)}')