# 🚀 Next Word Prediction using GPT-2
Fine-tuning GPT-2 on WikiText-2 using Hugging Face Transformers.

In [None]:
# ✅ Install necessary packages
!pip install -U transformers datasets fsspec==2023.6.0

In [None]:
import transformers
print("Transformers version:", transformers.__version__)
assert transformers.__version__ >= '4.3.0', "Update transformers to >= 4.3.0"

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

# Load and tokenize dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=True)

tokenized = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

In [None]:
# Group texts into blocks
block_size = 128

def group_texts(examples):
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_len = (len(concatenated["input_ids"]) // block_size) * block_size
    result = {
        k: [concatenated[k][i:i + block_size] for i in range(0, total_len, block_size)]
        for k in concatenated.keys()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_dataset = tokenized.map(group_texts, batched=True)

In [None]:
from transformers import GPT2LMHeadModel, TrainingArguments, Trainer, DataCollatorForLanguageModeling

# Load model and prepare trainer
model = GPT2LMHeadModel.from_pretrained("gpt2")

training_args = TrainingArguments(
    output_dir="./gpt2-wikitext2",
    num_train_epochs=2,
    per_device_train_batch_size=2,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_steps=10,
    save_strategy="no",
    report_to="none"
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_dataset["train"],
    eval_dataset=lm_dataset["validation"],
    data_collator=data_collator
)

# Train the model
trainer.train()

In [None]:
# Evaluate perplexity
import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")