In [None]:
# Install necessary libraries
!pip install datasets transformers torch streamlit

# Importing necessary libraries
from datasets import load_dataset
from transformers import AutoTokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, pipeline
import torch
import math
from torch.utils.data import DataLoader
import streamlit as st

# Step 1: Data Loading & Tokenization

# Load the WikiText-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Initialize the GPT-2 tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Tokenization function to prepare text for GPT-2
def tokenize_function(examples):
    return tokenizer(examples["text"], return_tensors="pt", truncation=True, padding="max_length", max_length=512)

# Apply the tokenizer to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Step 2: Model Selection

# Load pre-trained GPT-2 model
model = GPT2LMHeadModel.from_pretrained("gpt2")

# Step 3: Fine-Tuning the Model

# Define the training arguments for fine-tuning
training_args = TrainingArguments(
    output_dir="./results",            # Directory to store results
    num_train_epochs=3,                # Number of training epochs
    per_device_train_batch_size=8,     # Batch size for training
    per_device_eval_batch_size=8,      # Batch size for evaluation
    logging_dir="./logs",              # Directory for storing logs
    evaluation_strategy="epoch",       # Evaluate model at the end of each epoch
    save_strategy="epoch",             # Save model every epoch
)

# Define the Trainer object which handles training
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"]
)

# Train the model
trainer.train()

# Step 4: Evaluation - Perplexity and Top-K Accuracy

# Perplexity calculation function
def compute_perplexity(model, eval_dataset):
    model.eval()  # Set the model to evaluation mode
    eval_dataloader = DataLoader(eval_dataset, batch_size=8)
    total_loss = 0.0

    for batch in eval_dataloader:
        outputs = model(**batch)  # Forward pass through the model
        loss = outputs.loss
        total_loss += loss.item()  # Accumulate the loss

    perplexity = math.exp(total_loss / len(eval_dataloader))  # Calculate perplexity
    return perplexity

# Compute and print perplexity on the validation dataset
perplexity = compute_perplexity(model, tokenized_datasets["validation"])
print(f"Perplexity: {perplexity}")

# Top-K Accuracy calculation function
def top_k_accuracy(model, eval_dataset, k=5):
    correct_predictions = 0
    total_predictions = 0

    # Initialize the text-generation pipeline
    next_word_predictor = pipeline("text-generation", model=model, tokenizer=tokenizer)

    # Evaluate top-k accuracy
    for example in eval_dataset:
        input_text = example["text"]
        predicted_words = next_word_predictor(input_text, max_length=50, num_return_sequences=k)
        top_k_predictions = [prediction["generated_text"].split()[-1] for prediction in predicted_words]

        if input_text.split()[-1] in top_k_predictions:
            correct_predictions += 1
        total_predictions += 1

    accuracy = correct_predictions / total_predictions  # Calculate top-k accuracy
    return accuracy

# Compute and print top-5 accuracy on the validation dataset
accuracy = top_k_accuracy(model, tokenized_datasets["validation"], k=5)
print(f"Top-5 Accuracy: {accuracy}")

# Step 5: Optional Extension - Streamlit for Deployment

# Streamlit app for next-word prediction
def run_streamlit():
    st.title("Next-Word Prediction Model")

    # Input field to enter text
    input_text = st.text_input("Enter text:", "")

    if input_text:
        # Use the model to generate the next word prediction
        next_word_predictor = pipeline("text-generation", model=model, tokenizer=tokenizer)
        prediction = next_word_predictor(input_text, max_length=50, num_return_sequences=1)

        # Display the generated text prediction
        st.write(f"Predicted text: {prediction[0]['generated_text']}")

# This will not run directly in Colab, but is the Streamlit code for deployment.
