In [None]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
# Commented out because we yet again find mps to be drastically slower
# elif torch.backends.mps.is_available():
#     torch._dynamo.disable()  # https://github.com/pytorch/pytorch/issues/149184
#     device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"{device=}")

In [None]:
from transformers import AutoTokenizer
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
dataset = load_dataset("wikitext", "wikitext-2-v1")

In [None]:
context_length = 4

def tokenize(batch):
    # TODO: Sequence packing
    outputs = tokenizer(
        batch["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    return {
        "input_ids": [
            input_ids
            for length, input_ids in zip(outputs["length"], outputs["input_ids"])
            if length == context_length
        ]
    }

tokenized_ds = dataset.map(
    tokenize, batched=True, remove_columns=dataset["train"].column_names
)
tokenized_ds

In [None]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig

config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

model = GPT2LMHeadModel(config).to(device)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 size: {model_size/1000**2:.1f}M parameters")

In [None]:
from transformers import TrainerCallback
import torch

class GenerateTextCallback(TrainerCallback):
    def __init__(self, tokenizer, prompt="Once upon a time", device=None, every_n_steps=5):
        self.tokenizer = tokenizer
        self.prompt = prompt
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.every_n_steps = every_n_steps

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % self.every_n_steps == 0:
            model = kwargs["model"]
            model.eval()
            tokenized = self.tokenizer(self.prompt, return_tensors="pt").to(self.device)
            with torch.no_grad():
                output_ids = model.generate(
                    **tokenized,
                    max_new_tokens=50,
                    do_sample=True,
                    temperature=0.7,
                    top_k=50,
                    top_p=0.95,
                    no_repeat_ngram_size=2,
                    # early_stopping=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
            print(f"[Step {state.global_step}] Generated text:\n{output_text}")

In [None]:
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="mygpt2",
    per_device_train_batch_size=512,
    per_device_eval_batch_size=512,
    eval_strategy="steps",
    eval_steps=250,
    logging_steps=50,
    gradient_accumulation_steps=1,
    max_steps=10000,
    # num_train_epochs=2,
    weight_decay=0.1,
    warmup_steps=100,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=100,
    push_to_hub=True,
)

tokenizer.pad_token = tokenizer.eos_token
generate_callback = GenerateTextCallback(
    tokenizer=tokenizer, prompt="Once upon a time", device=device, every_n_steps=100
)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    # tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["validation"],
    callbacks=[generate_callback],
)

In [None]:
trainer.train()