In [None]:
from datasets import load_dataset  # type: ignore
import random
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch  # type: ignore
import torch.nn as nn  # type: ignore

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def load_prompts_and_references(train_pct: int = 10, test_pct: str = 2):
    """
    Load prompts and references from OpenWebText.
    Split into training and test sets.
    """
    tot_pct = train_pct + test_pct
    assert tot_pct <= 100, "Train + test percent > 100"

    print("Loading datasets...")
    # openwebtext = load_dataset(
    #     "openwebtext", split=f"train[:{tot_pct}%]", trust_remote_code=True
    # )
    openwebtext = load_dataset("stas/openwebtext-10k")

    texts = [sample["text"] for sample in openwebtext["train"] if "text" in sample]

    print(f"Total samples: {len(texts)}")
    random.shuffle(texts)

    prompts = []
    references = []
    for text in texts:
        sentences = text.split(".")
        if len(sentences) >= 2:
            prompts.append(sentences[0].strip() + ".")
            references.append(". ".join(sentences[:2]).strip() + ".")

    # Split.
    train_size = int((train_pct / tot_pct) * len(texts))

    train_prompts = prompts[:train_size]
    train_references = references[:train_size]
    test_prompts = prompts[train_size:]
    test_references = references[train_size:]

    print(
        f"Loaded {len(train_prompts)} training samples and {len(test_prompts)} test samples."
    )
    return (train_prompts, train_references), (test_prompts, test_references)

In [4]:
train, test = load_prompts_and_references()
train_prompts, train_references = train

Loading datasets...
Total samples: 10000
Loaded 8333 training samples and 1657 test samples.


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [None]:
def generate_text(
    prompt,
    top_p: float = 0.9,
    temperature: float = 1.0,
):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=50,
            do_sample=True,
            top_p=top_p,
            temperature=temperature,
            top_k=50,
            pad_token_id=tokenizer.eos_token_id
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [8]:
def sample_with_policy(prompt, temperature):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        max_length=50,
        do_sample=True,
        top_p=0.9,
        temperature=temperature.item(),
        top_k=50,
        pad_token_id=tokenizer.eos_token_id,
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
class SamplingPolicy(nn.Module):
    """Learn a temperature adjustment based on prompt."""

    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(768, 1)  # 768 = GPT2 hidden size.

    def forward(self, prompt_embedding):
        # Temperature in (0, 2).
        return torch.sigmoid(self.linear(prompt_embedding)) * 2.0

policy = SamplingPolicy().to(device)

In [11]:
prompt = train_prompts[0]

# Get prompt embedding.
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    # Mean pooling.
    prompt_emb = model.transformer.wte(inputs.input_ids).mean(dim=1)

# Get temperature.
temperature = policy(prompt_emb)

# Sample.
sample = sample_with_policy(prompt, temperature)

KeyboardInterrupt: 

In [None]:
print(sample)