In [None]:
!pip install datasets
!pip install transformers

Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

def evaluate_model_with_per_example_progress(dataset, model, tokenizer, context_length=128, print_every=100):
    model.eval()

    total_loss = 0.0
    total_predictions = 0
    correct_predictions = 0

    # Tokenize and process dataset
    def tokenize_function(example):
        return tokenizer(example['text'], truncation=True, padding='max_length', max_length=context_length)

    tokenized_dataset = dataset.map(tokenize_function, batched=False)

    # Convert the dataset to PyTorch tensors
    tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

    # Track total number of examples
    total_examples = len(tokenized_dataset)

    with torch.no_grad():
        for i, example in enumerate(tokenized_dataset):
            input_ids = example['input_ids'].unsqueeze(0)  # Add batch dimension
            attention_mask = example['attention_mask'].unsqueeze(0)

            # Move tensors to the same device as the model
            input_ids = input_ids.to(model.device)
            attention_mask = attention_mask.to(model.device)

            # Get model output
            outputs = model(input_ids, attention_mask=attention_mask)
            predictions = outputs.logits

            # Shift input_ids and labels by one to predict next token
            shift_logits = predictions[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()

            # Compute the loss
            loss_fct = torch.nn.CrossEntropyLoss(reduction='sum')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            total_loss += loss.item()
            total_predictions += shift_labels.numel()

            # Calculate accuracy: compare the predicted token with the actual next token
            predicted_tokens = torch.argmax(shift_logits, dim=-1)
            correct_predictions += (predicted_tokens == shift_labels).sum().item()

            # Print progress every N examples
            if (i + 1) % print_every == 0:
                current_perplexity = torch.exp(torch.tensor(total_loss / total_predictions)).item()
                current_accuracy = correct_predictions / total_predictions
                print(f"Processed {i + 1}/{total_examples} examples. Current Perplexity: {current_perplexity:.2f}, Accuracy: {current_accuracy * 100:.2f}%")

    # Final metrics calculation
    perplexity = torch.exp(torch.tensor(total_loss / total_predictions)).item()
    accuracy = correct_predictions / total_predictions

    return accuracy, perplexity

# Load the model and tokenizer
model_name = "microsoft/MiniLM-L12-H384-uncased"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load the PTB dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split='test')

# Evaluate the model
context_length = 32  # Set the desired fixed context length
accuracy, perplexity = evaluate_model_with_per_example_progress(dataset, model, tokenizer, context_length, print_every=100)
print(f"MiniLM Perplexity on PTB dataset with context length {context_length}: {perplexity:.2f}")
print(f"MiniLM Accuracy on PTB dataset with context length {context_length}: {accuracy * 100:.2f}%")

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of BertLMHeadModel were not initialized from the model checkpoint at microsoft/MiniLM-L12-H384-uncased and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Processed 100/4358 examples. Current Perplexity: 41358.08, Accuracy: 0.03%
Processed 200/4358 examples. Current Perplexity: 40871.80, Accuracy: 0.02%
Processed 300/4358 examples. Current Perplexity: 40906.27, Accuracy: 0.01%
Processed 400/4358 examples. Current Perplexity: 41447.79, Accuracy: 0.01%
Processed 500/4358 examples. Current Perplexity: 41369.64, Accuracy: 0.01%
Processed 600/4358 examples. Current Perplexity: 41215.15, Accuracy: 0.01%
Processed 700/4358 examples. Current Perplexity: 41491.37, Accuracy: 0.00%
Processed 800/4358 examples. Current Perplexity: 41539.16, Accuracy: 0.00%
Processed 900/4358 examples. Current Perplexity: 41165.61, Accuracy: 0.00%
Processed 1000/4358 examples. Current Perplexity: 41250.86, Accuracy: 0.00%
Processed 1100/4358 examples. Current Perplexity: 41112.80, Accuracy: 0.00%
Processed 1200/4358 examples. Current Perplexity: 41022.26, Accuracy: 0.00%
Processed 1300/4358 examples. Current Perplexity: 41051.53, Accuracy: 0.00%
Processed 1400/4358 e

In [None]:
#wiki-text 103 raw test set
#2-> 58385.12
#4-> 43948.05
#8-> 57331.96
#16->73945.55
#32->33567.81
#64->68098.13