In [None]:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt

transformers.set_seed(42)

In [None]:
# apply arial font, great font sizes, set dpi to 300
mpl.rc('font', size=16)
mpl.rc('axes', titlesize=16)
mpl.rc('axes', labelsize=16)
mpl.rc('xtick', labelsize=16)
mpl.rc('ytick', labelsize=16)
mpl.rc('legend', fontsize=14)
mpl.rc('figure', dpi=300)

In [None]:
model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
num_samples = 800
sequence_length = 512
batch_size = 1

In [None]:
dataset = load_dataset(
    "wikitext",
    "wikitext-2-v1",
    split="test",
    revision="b08601e04326c79dfdd32d625aee71d232d685c3",
)

In [None]:
texts = dataset["text"][: num_samples]
# Tokenize and chunk the texts into sequences of the specified length
encodings = tokenizer("\n\n".join(texts), return_tensors="pt")
input_ids = encodings["input_ids"][0]
total_length = input_ids.size(0)
num_sequences = total_length // sequence_length
input_ids = input_ids[: num_sequences * sequence_length]
input_ids = input_ids.view(num_sequences, sequence_length)
input_ids = input_ids

dataset = TensorDataset(input_ids)
dataloader = DataLoader(dataset, batch_size=batch_size)
averages = []
correct = 0
total = 0
with torch.no_grad():
    for batch in tqdm(dataloader):
        try:
            input_ids = batch[0].to(model.device)
        except:
            input_ids = batch[0]
        # Prepare inputs and labels by shifting the input_ids
        inputs = input_ids[:, :-1]
        labels = input_ids[:, 1:]
        outputs = model(inputs)
        logits = (
            outputs.logits
        )  # shape: (batch_size, seq_length - 1, vocab_size)
        predictions = torch.argmax(logits, dim=-1)
        # Compare predictions with labels
        correct += (predictions == labels).sum().item()
        total += labels.numel()
        averages.append(correct / total)

In [None]:
print(len(averages))

In [None]:

fig = plt.figure()
plt.plot(averages)
# add overall accuracy
plt.axhline(correct / total, color="red", linestyle="--")
# add vertical line at x=45 to showcase the 300 sample mark
plt.axvline(45, color="green", linestyle="--")
plt.grid()
plt.xlabel("Batch")
plt.ylabel("Accuracy")
plt.legend(["Token Prediction Accuracy",
           "Averaged Accuracies", "300 Sample Mark"])
plt.title("Token Prediction Accuracy on WikiText with SmolLM-135M-Instruct")
plt.tight_layout()
plt.savefig("visualizations/300_sample_convergence_wikitext.png",
            bbox_inches="tight")