# Text Completions: GPT-2 vs TinyGPT

Load either HuggingFace GPT-2 or a TinyGPT checkpoint and run text completions.
Supports plain text prompts and conversation-style prompts with special tokens.

In [2]:
import sys, os
sys.path.insert(0, os.path.join(os.getcwd(), '..'))

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from core.models import TinyGPT

device = torch.device('cuda' if torch.cuda.is_available() else 'xpu')
print(f'Device: {device}')

  from .autonotebook import tqdm as notebook_tqdm


Device: xpu


## 1. Configuration

Set `MODEL_TYPE` to `'gpt2'` or `'tinygpt'`. For TinyGPT, point `CHECKPOINT_PATH` to a `.pt` file.

In [3]:
MODEL_TYPE = 'tinygpt'  # 'gpt2' or 'tinygpt'
CHECKPOINT_PATH = '../logs/pu94vo4r/checkpoints/checkpoint_60000.pt'  # only used for tinygpt

## 2. Load model & tokenizer

In [4]:
# Tokenizer (shared, always has special tokens)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
special_tokens = {
    'bos_token': '<|beginoftext|>',
    'pad_token': '<|pad|>',
    'additional_special_tokens': ['<|user|>', '<|assistant|>', '<|system|>']
}
tokenizer.add_special_tokens(special_tokens)
print(f'Vocab: {len(tokenizer)}, BOS={tokenizer.bos_token_id}, EOS={tokenizer.eos_token_id}')

if MODEL_TYPE == 'gpt2':
    model = AutoModelForCausalLM.from_pretrained('gpt2').to(device)
    model.resize_token_embeddings(len(tokenizer))
    model.eval()
    print(f'Loaded GPT-2 ({sum(p.numel() for p in model.parameters())/1e6:.0f}M params)')

elif MODEL_TYPE == 'tinygpt':
    n_layers = 20
    dim = n_layers * 64
    n_heads = max(1, (dim + 127) // 128)
    model = TinyGPT(vocab_size=50262, dim=dim, n_layers=n_layers, n_heads=n_heads, max_seq_len=2048)

    ckpt = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False)
    model.load_state_dict(ckpt['model_state_dict'])

    # Resize embeddings for special tokens
    old_vocab = model.tok_emb.num_embeddings
    new_vocab = len(tokenizer)
    if old_vocab < new_vocab:
        old_tok = model.tok_emb.weight.data.clone()
        old_head = model.head.weight.data.clone()
        model.tok_emb = nn.Embedding(new_vocab, dim)
        model.head = nn.Linear(dim, new_vocab, bias=False)
        model.tok_emb.weight.data[:old_vocab] = old_tok
        model.head.weight.data[:old_vocab] = old_head
        print(f'Resized embeddings: {old_vocab} -> {new_vocab}')

    model = model.to(device)
    model.eval()
    print(f'Loaded TinyGPT from {CHECKPOINT_PATH} ({sum(p.numel() for p in model.parameters())/1e6:.0f}M params)')

else:
    raise ValueError(f'Unknown MODEL_TYPE: {MODEL_TYPE}')

Vocab: 50262, BOS=50257, EOS=50256
Loaded TinyGPT from ../logs/pu94vo4r/checkpoints/checkpoint_60000.pt (525M params)


## 3. Generation function

Stops at EOS (`<|endoftext|>`) so completions don't run on forever.

In [19]:
@torch.no_grad()
def generate(prompt_ids, max_new_tokens=200, temperature=1.0, top_k=40, stop_at_eos=True):
    """Generate tokens from prompt_ids tensor. Stops at EOS if stop_at_eos=True."""
    idx = prompt_ids.to(device)
    eos_id = tokenizer.eos_token_id if stop_at_eos else None
    generated = []

    for _ in range(max_new_tokens):
        if MODEL_TYPE == 'gpt2':
            logits = model(idx).logits[:, -1, :] / temperature
        else:
            idx_cond = idx[:, -2048:]
            logits = model(idx_cond)[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)

        if eos_id is not None and next_id.item() == eos_id:
            break

        idx = torch.cat((idx, next_id), dim=1)
        generated.append(next_id.item())

    return generated

print('generate() ready')

generate() ready


## 4. Plain text completions

In [22]:
def complete(prompt, max_new_tokens=50, temperature=0.8, top_k=40, stop_at_eos=True):
    """Encode prompt string, generate, decode, print."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    generated = generate(input_ids, max_new_tokens=max_new_tokens,
                         temperature=temperature, top_k=top_k, stop_at_eos=stop_at_eos)
    completion = tokenizer.decode(generated)
    print(f'PROMPT:     {prompt}')
    print(f'COMPLETION: {completion}')
    print(f'({len(generated)} tokens, stop_at_eos={stop_at_eos})')
    print('-' * 80)
    return completion

prompts = [
    'The capital of France is',
    'In machine learning, a neural network',
    'The most important thing about science is',
    'To make a good cup of coffee',
    'The best country in the world is',
    'A doctor is best for'
]

for p in prompts:
    complete(p)

PROMPT:     The capital of France is
COMPLETION:  Paris. It is the home of the famous Louvre Museum. Paris is also the capital of the French Republic.
The capital of France is also Paris, the capital of the French Republic. Paris is also the capital of the United Kingdom. It
(50 tokens, stop_at_eos=True)
--------------------------------------------------------------------------------
PROMPT:     In machine learning, a neural network
COMPLETION:  is a neural network used to learn patterns using data. They are widely used across a range of industries, including healthcare, finance, and education.
One of the most important advantages of neural networks is that they are able to learn a lot more detail
(50 tokens, stop_at_eos=True)
--------------------------------------------------------------------------------
PROMPT:     The most important thing about science is
COMPLETION:  it is there is no need for science, it does not need to be. Science can be learned from people all over the world, 

## 5. Conversation completions

Use special tokens to prompt the model in chat format:
`<|beginoftext|><|system|>...<|user|>...<|assistant|>`

The model generates the assistant response and stops at `<|endoftext|>`.

In [7]:
def chat_complete(user_message, system_message=None, max_new_tokens=300, temperature=0.8, top_k=40):
    """Build a conversation prompt with special tokens and generate assistant response."""
    tokens = [tokenizer.bos_token_id]

    if system_message:
        tokens.append(tokenizer.convert_tokens_to_ids('<|system|>'))
        tokens.extend(tokenizer.encode(system_message, add_special_tokens=False))

    tokens.append(tokenizer.convert_tokens_to_ids('<|user|>'))
    tokens.extend(tokenizer.encode(user_message, add_special_tokens=False))
    tokens.append(tokenizer.convert_tokens_to_ids('<|assistant|>'))

    input_ids = torch.tensor([tokens])
    generated = generate(input_ids, max_new_tokens=max_new_tokens,
                         temperature=temperature, top_k=top_k, stop_at_eos=True)

    response = tokenizer.decode(generated)
    print(f'USER:      {user_message}')
    if system_message:
        print(f'SYSTEM:    {system_message}')
    print(f'ASSISTANT: {response}')
    print(f'({len(generated)} tokens)')
    print('-' * 80)
    return response

chat_complete('What is 2 + 2?', system_message='You are a helpful math tutor.')

USER:      What is 2 + 2?
SYSTEM:    You are a helpful math tutor.
ASSISTANT: We can solve this equation by combining like terms: 2 + 2 = 2.
(16 tokens)
--------------------------------------------------------------------------------


'We can solve this equation by combining like terms: 2 + 2 = 2.'

## 6. Multi-turn conversation

Each turn appends the generated assistant response to context before the next user message.

In [8]:
def multi_turn(turns, system_message=None, max_new_tokens=300, temperature=0.8, top_k=40):
    """Multi-turn conversation. `turns` is a list of user messages."""
    tokens = [tokenizer.bos_token_id]

    if system_message:
        tokens.append(tokenizer.convert_tokens_to_ids('<|system|>'))
        tokens.extend(tokenizer.encode(system_message, add_special_tokens=False))

    for i, user_msg in enumerate(turns):
        tokens.append(tokenizer.convert_tokens_to_ids('<|user|>'))
        tokens.extend(tokenizer.encode(user_msg, add_special_tokens=False))
        tokens.append(tokenizer.convert_tokens_to_ids('<|assistant|>'))

        input_ids = torch.tensor([tokens])
        generated = generate(input_ids, max_new_tokens=max_new_tokens,
                             temperature=temperature, top_k=top_k, stop_at_eos=True)

        response = tokenizer.decode(generated)
        tokens.extend(generated)
        tokens.append(tokenizer.eos_token_id)  # close this turn

        print(f'[Turn {i+1}]')
        print(f'  USER:      {user_msg}')
        print(f'  ASSISTANT: {response}')
        print()

    print(f'Total context: {len(tokens)} tokens')
    print('-' * 80)

multi_turn([
    'What is photosynthesis?',
    'Why is it important for life on Earth?',
    'Can it happen without sunlight?',
])

KeyboardInterrupt: 

## 7. Temperature comparison

In [None]:
def compare_temps(prompt, temps=[0.3, 0.7, 1.0], max_new_tokens=100):
    print(f'PROMPT: {prompt}\n')
    for t in temps:
        input_ids = tokenizer.encode(prompt, return_tensors='pt')
        generated = generate(input_ids, max_new_tokens=max_new_tokens,
                             temperature=t, top_k=40, stop_at_eos=True)
        text = tokenizer.decode(generated)[:200]
        print(f'  T={t}: {text}')
    print('-' * 80)

compare_temps('The meaning of life is')
compare_temps('In the year 2050')

PROMPT: The meaning of life is

  T=0.3:  to be found in the life of the human being. The life of the human being is the life of the universe. Life is the life of the universe. Life is the life of the universe. Life is the life of the univer
  T=0.7:  not known.
The word "life" is used to describe something in all its complexity. The word "life" is an abstract idea that has been used in various cultures. It is also used in the context of the conce
  T=1.0:  more important as it relates to our own lives and is in harmony with the world around us.
If we look at each other in the mirror, we look at ourselves and look at each other’s lives. Without self dis
--------------------------------------------------------------------------------
PROMPT: In the year 2050

  T=0.3: , the world’s population will increase by more than 7 billion people, and the world’s population will increase by more than 9 billion people.
The number of people in the world is expected to rise by m
  T=0.7: , the world’

## 8. Greedy decoding (deterministic)

In [None]:
def greedy_complete(prompt, max_new_tokens=50):
    """Deterministic greedy decoding (temperature ~0)."""
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    generated = generate(input_ids, max_new_tokens=max_new_tokens,
                         temperature=0.01, top_k=1, stop_at_eos=True)
    text = tokenizer.decode(generated)
    print(f'PROMPT: {prompt}')
    print(f'GREEDY: {text}')
    print('-' * 80)

greedy_complete('The capital of France is')
greedy_complete('2 + 2 =')
greedy_complete('The largest planet in our solar system is')

PROMPT: The capital of France is
GREEDY:  Paris. Paris is the capital of France. Paris is the capital of France. Paris is the capital of France. Paris is the capital of France. Paris is the capital of France. Paris is the capital of France. Paris is the capital of France
--------------------------------------------------------------------------------
PROMPT: 2 + 2 =
GREEDY:  4
- 2 + 2 = 6
- 2 + 2 = 8
- 2 + 2 = 9
- 2 + 2 = 10
- 2 + 2 = 11
- 2 + 2 = 12
- 2 + 2 = 13
--------------------------------------------------------------------------------
PROMPT: The largest planet in our solar system is
GREEDY:  Mercury. Mercury is the second largest planet in our solar system after Earth. Mercury is the second largest planet in our solar system after Earth. Mercury is the second largest planet in our solar system after Earth. Mercury is the second largest planet in our solar
--------------------------------------------------------------------------------
