In [41]:
!pip install datasets

import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
from nltk.tokenize import word_tokenize
import numpy as np




[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


# Model loading and tokenizer / Data Loading and Preprocessing

In [42]:

random.seed(42)
torch.manual_seed(42)

def load_model_and_tokenizer():
    """Initialize the model and tokenizer with special tokens."""
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    model = AutoModelForCausalLM.from_pretrained('gpt2')

    # Add special tokens
    special_tokens = {'additional_special_tokens': ['<len>', '<text>']}
    num_added_tokens = tokenizer.add_special_tokens(special_tokens)

    # Important: resize model embeddings to account for new tokens
    model.resize_token_embeddings(len(tokenizer))

    # Set padding token
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

# Initialize the model and tokenizer
model, tokenizer = load_model_and_tokenizer()


def load_training_data(percentage):
    """Load a subset of OpenWebText dataset.
    Start with a small percentage (0.1%) for testing."""

    dataset = load_dataset("Bingsu/openwebtext_20p", split="train") # OpenWebText dataset
    num_samples = len(dataset)
    indices = random.sample(range(num_samples), int(percentage * num_samples))
    subset = dataset.select(indices)
    return subset['text']

# Example usage:
texts = load_training_data(0.003)
print(f"Loaded {len(texts)} training examples.")

def process_sentence(sentence):
    tokens = word_tokenize(sentence)
    num_words = len(tokens)
    sentence = sentence.strip().replace("\n", " ")
    return f"<len> {num_words} <text> {sentence}"

# process all sentences
processed_texts = [process_sentence(text) for text in texts]
print("Sample processed text:")
print(processed_texts[0])

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/17 [00:00<?, ?it/s]

Loaded 99503 training examples.
Sample processed text:
<len> 65 <text> Challenging the state forces, Anu said they were able to go to the people and talk to them, even in places where government has deployed 300 to 400 security personnel. "This definitely shows people's support for us,” she said, exuding confidence that they will succeed in broadening their base among the people," Anu said.


# Setup Training with HuggingFace Trainer / Fine-tuning the Model

In [43]:
class LengthControlledDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # tokenize the processed text
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        # squeeze to remove extra batch dimension
        item = {key: tensor.squeeze() for key, tensor in encoding.items()}
        return item

# create the dataset and DataLoader
train_dataset = LengthControlledDataset(processed_texts, tokenizer)
print(f"Dataset size: {len(train_dataset)} examples.")

# data collator which handles language modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# training arguments. small number of epochs for initial testing
training_args = TrainingArguments(
    output_dir="./gpt2_length_control",
    overwrite_output_dir=True,
    num_train_epochs=5, # epochs (5 for 100,000 samples)
    per_device_train_batch_size=8, # small batch size (4 to 8)
    gradient_accumulation_steps=4,  # simulate larger batch size
    save_steps=5000, # frequency of saving checkpoints
    save_total_limit=2, # keep only the latest 2 checkpoints
    logging_steps=100, # log training loss every 100 steps
    learning_rate=3e-5, # learning rate for stable training  LOWERED
    warmup_steps=500, # warmup steps for learning rate scheduler
    report_to="tensorboard", # tensorboard" for visualization
    max_grad_norm=1.0,  # Gradient clipping
    fp16=True  # mixed precision training

)

# initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset
)

# start training
trainer.train()

# save the final model
trainer.save_model("./gpt2_length_control_final")

Dataset size: 99503 examples.


Step,Training Loss
100,4.3707
200,3.9549
300,3.8227
400,3.7892
500,3.653
600,3.5788
700,3.5796
800,3.55
900,3.5363
1000,3.5416


# Length-Controlled Generation

In [54]:
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
nltk.download('punkt')

def generate_text(prompt, target_word_count, tokenizer, model):
    prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    # extra tokens for high-quality text
    max_new_tokens = target_word_count + 20

    # sampling parameters
    outputs = model.generate(
        prompt_ids,
        do_sample=True,
        temperature=0.7, # temperature for balance between creativity and coherence
        top_k=50, # limit to top-50 tokens
        top_p=0.95, # nucleus sampling to keep cumulative probability at 95%
        repetition_penalty=1.2, # Penalize repetition to improve diversity
        max_new_tokens=max_new_tokens,
        eos_token_id=None, # disable early EOS stopping to generate enough text
        pad_token_id=tokenizer.eos_token_id
    )

    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # remove the prompt if present
    if generated.startswith(prompt):
        generated = generated[len(prompt):].strip()

    words = word_tokenize(generated)
    # if generation is too short, regenerate (up to 3 tries)
    tries = 0
    while len(words) < target_word_count and tries < 3:
        outputs = model.generate(
            prompt_ids,
            do_sample=True,
            temperature=0.3,
            top_k=30,
            top_p=0.95,
            repetition_penalty=1.2,
            max_new_tokens=max_new_tokens,
            eos_token_id=None,
            pad_token_id=tokenizer.eos_token_id
        )
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if generated.startswith(prompt):
            generated = generated[len(prompt):].strip()
        words = word_tokenize(generated)
        tries += 1

    sentences = sent_tokenize(generated)
    # rebuild text sentence by sentence until we reach or exceed target length
    new_text = ""
    new_words = []
    for sentence in sentences:
        sentence_words = word_tokenize(sentence)
        if len(new_words) + len(sentence_words) <= target_word_count:
            new_words.extend(sentence_words)
            new_text += sentence + " "
        else:
            remaining = target_word_count - len(new_words)
            new_words.extend(sentence_words[:remaining])
            new_text += " ".join(sentence_words[:remaining])
            break
    if len(new_words) != target_word_count:
        new_text = " ".join(words[:target_word_count])

    return new_text.strip()


target_lengths = [5, 10, 15, 20, 25, 30]
test_prompts = [f"<len> {length} <text>" for length in target_lengths]

errors = []
for target, prompt in zip(target_lengths, test_prompts):
    generated_text = generate_text(prompt, target, tokenizer, model)
    generated_word_count = len(word_tokenize(generated_text))
    error = abs(target - generated_word_count)
    errors.append(error)
    print(f"\nPrompt: {prompt}")
    print(f"Generated: {generated_text}")
    print(f"Target words: {target} | Generated words: {generated_word_count} | Error: {error}")

avg_error = np.mean(errors)
print(f"\nAverage length control error: {avg_error:.2f}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!



Prompt: <len> 5 <text>
Generated: 5 How do you feel
Target words: 5 | Generated words: 5 | Error: 0

Prompt: <len> 10 <text>
Generated: 10 It should be noted that this is a separate
Target words: 10 | Generated words: 10 | Error: 0

Prompt: <len> 15 <text>
Generated: 15 The government is also working on a $ 2.4-billion overhaul of the criminal justice
Target words: 15 | Generated words: 15 | Error: 0

Prompt: <len> 20 <text>
Generated: 20  The other two were a total of $5 million each. They are the victims who have been
Target words: 20 | Generated words: 20 | Error: 0

Prompt: <len> 25 <text>
Generated: 25 The latest figures show that in April , the number of people using cannabis fell by 3.7 per cent from 2015 to 2016 —
Target words: 25 | Generated words: 25 | Error: 0

Prompt: <len> 30 <text>
Generated: 30 The company has an annual profit of $ 1.6 billion , and it expects to make $ 250 million in revenue this year from its operations at the North
Target words: 30 | Generated words: