In [None]:
pip install transformers datasets tensorflow tqdm


In [None]:
import numpy as np
import tensorflow as tf
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel, DataCollatorForLanguageModeling
from datasets import load_dataset
from tensorflow.keras.optimizers.schedules import PolynomialDecay
from tensorflow.keras.optimizers import Adam

# Load the Wiki-2 dataset
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

# Initialize the GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Add a padding token to the tokenizer
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Preprocess the dataset
def preprocess_function(examples):
    return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)

# Apply preprocessing
encoded_dataset = dataset.map(preprocess_function, batched=True)

# Define a TensorFlow data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Convert to TensorFlow datasets
def convert_to_tf_dataset(encoded_dataset, batch_size):
    return encoded_dataset.to_tf_dataset(
        columns=['input_ids', 'attention_mask'],
        label_cols=['input_ids'],
        shuffle=True,
        batch_size=batch_size,
        collate_fn=data_collator
    )

train_dataset = convert_to_tf_dataset(encoded_dataset['train'], batch_size=4)
eval_dataset = convert_to_tf_dataset(encoded_dataset['validation'], batch_size=4)

# Define the model
model = TFGPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))  # Resize model embeddings to accommodate new tokens

# Define training parameters
epochs = 1
learning_rate = 2e-5
batch_size = 4

# Create a learning rate schedule and optimizer
num_train_steps = len(train_dataset) * epochs
lr_schedule = PolynomialDecay(initial_learning_rate=learning_rate, decay_steps=num_train_steps, end_learning_rate=0)
optimizer = Adam(learning_rate=lr_schedule)

# Custom training loop
@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        outputs = model(input_ids=inputs, attention_mask=inputs['attention_mask'], labels=labels, return_dict=True)
        loss = outputs.loss
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

# Training loop
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    for batch in train_dataset:
        inputs, labels = batch
        loss = train_step(inputs, labels)
        print(f"Loss: {loss.numpy()}")

    # Evaluation (optional)
    eval_loss = 0
    num_batches = 0
    for batch in eval_dataset:
        inputs, labels = batch
        outputs = model(input_ids=inputs, attention_mask=inputs['attention_mask'], labels=labels, return_dict=True)
        eval_loss += outputs.loss
        num_batches += 1
    avg_eval_loss = eval_loss / num_batches
    print(f"Validation Loss: {avg_eval_loss.numpy()}")

# Save the fine-tuned model and tokenizer
model.save_pretrained('./custom_gpt2')
tokenizer.save_pretrained('./custom_gpt2')

# Load the fine-tuned model
model = TFGPT2LMHeadModel.from_pretrained('./custom_gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('./custom_gpt2')

# Example function to generate text
def generate_text(prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors='tf')
    outputs = model.generate(inputs['input_ids'], max_length=max_length, num_return_sequences=1)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

# Test text generation
prompt = "The future of AI"
print("Generated text:", generate_text(prompt))

# Optional: Create a FastAPI app for inference
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

app = FastAPI()

class RequestBody(BaseModel):
    text: str

@app.post("/generate")
def generate_text_api(request_body: RequestBody):
    inputs = tokenizer.encode(request_body.text, return_tensors='tf')
    outputs = model.generate(inputs, max_length=50)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return {"generated_text": generated_text}

if __name__ == '__main__':
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
