In [2]:
import os
import random
import shutil
import sys
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from tqdm import tqdm
import torch.distributed as dist

from datasets import load_dataset
import argparse
import torch.nn.functional as F
from torch.utils.data import Dataset
import time
from datetime import datetime
import sys


@torch.no_grad()
def evaluate(model, dataset, n_samples=None):
    model.eval()
    nlls = []
    length = 2048
    n_samples = n_samples if n_samples else dataset.size(1) // length
    for i in tqdm(range(n_samples), desc="Evaluating..."):
        batch = dataset[:, (i * length) : ((i + 1) * length)].to(model.device)
        with torch.no_grad():
            lm_logits = model(batch).logits
        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_labels = dataset[:, (i * length) : ((i + 1) * length)][:, 1:].to(model.device)
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
        )
        neg_log_likelihood = loss.float() * length
        nlls.append(neg_log_likelihood)

    return torch.exp(torch.stack(nlls).sum() / (n_samples * length))


model_path = '/localssd/lbxj/llama-2-7b-hf/'
tokenizer = AutoTokenizer.from_pretrained(model_path)


In [4]:

def get_wikitext2(nsamples, seqlen, tokenizer, eval_mode=False):
    if not eval_mode:
        traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
        trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
        trainloader = []
        for _ in range(nsamples):
            i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = trainenc.input_ids[:, i:j]
            trainloader.append(inp)
        return trainloader
    else:
        testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
        return testenc

def get_c4(nsamples, seqlen, tokenizer, eval_mode=False):
    if not eval_mode:
        traindata = load_dataset(
            "allenai/c4",
            "default",
            data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
            split="train",
            revision="607bd4c8450a42878aa9ddc051a65a055450ef87",
        )
        trainloader = []
        for _ in range(nsamples):
            while True:
                i = random.randint(0, len(traindata) - 1)
                trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
                if trainenc.input_ids.shape[1] >= seqlen:
                    break
            i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = trainenc.input_ids[:, i:j]
            trainloader.append(inp)
        return trainloader

    else:
        valdata = load_dataset(
            "allenai/c4",
            "default",
            data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
            split="validation",
            revision="607bd4c8450a42878aa9ddc051a65a055450ef87",
        )
        random.seed(0)
        valenc = []
        for _ in range(256):
            while True:
                i = random.randint(0, len(valdata) - 1)
                tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
                if tmp.input_ids.shape[1] >= seqlen:
                    break
            if tmp.input_ids.shape[1] == seqlen:
                # rare case, discovered with Yi tokenizer
                valenc.append(tmp.input_ids)
            else:
                i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
                j = i + seqlen
                valenc.append(tmp.input_ids[:, i:j])
        valenc = torch.hstack(valenc)
        return valenc    
test_data_wiki = get_wikitext2(40, 2048, tokenizer, True).input_ids
# test_data_c4 = get_c4(40, 2048, tokenizer, True)

ConnectionError: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/datasets/wikitext/revision/b08601e04326c79dfdd32d625aee71d232d685c3 (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7fc20946fbb0>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 8d551d75-20f1-46bb-9ccf-7eed3e51d87c)')