### Importing libraries

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from datasets import load_dataset
import torch

### Check GPU

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


### Load 50% of WikiText-2

In [None]:
print("Loading WikiText-2...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:50%]")

### Load GPT-2 tokenizer and model

In [None]:

model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Fix pad token error

model = GPT2LMHeadModel.from_pretrained(model_name).to(device)

### Tokenize dataset

In [None]:

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)

print("🔄 Tokenizing dataset...")
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

### Data collator

In [None]:

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

### Training arguments for GPU

In [None]:

training_args = TrainingArguments(
    output_dir="./gpt2-finetuned-wikitext2",
    per_device_train_batch_size=8,
    num_train_epochs=1,
    save_steps=500,
    logging_steps=100,
    fp16=True, 
    overwrite_output_dir=True,
    report_to="none"
)


### Training model

In [None]:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)

print("Training GPT-2 (1 epoch, 50% WikiText-2, GPU)...")
trainer.train()

print("Saving model...")
model.save_pretrained("./gpt2-finetuned-wikitext2")
tokenizer.save_pretrained("./gpt2-finetuned-wikitext2")




### Taking prompt from the user

In [None]:
from transformers import pipeline
import torch

device = 0 if torch.cuda.is_available() else -1

generator = pipeline("text-generation", model="./gpt2-finetuned-wikitext2", tokenizer="./gpt2-finetuned-wikitext2", device=device)

while True:
    prompt = input("Enter your prompt (or type 'exit'): ")
    if prompt.strip().lower() == "exit":
        break
    output = generator(prompt, max_length=100, num_return_sequences=1)
    print("\n GPT-2 Output:\n")
    print(output[0]['generated_text'])
    print("=" * 80)
