In [None]:
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch

# Load model
print("Loading AI model...")
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

Loading AI model...


Loading weights:   0%|          | 0/76 [00:00<?, ?it/s]

GPT2LMHeadModel LOAD REPORT from: distilgpt2
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
transformer.h.{0, 1, 2, 3, 4, 5}.attn.bias | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


In [None]:
# 2. DEFINE THE DYNAMIC CONTEXT
# Change 'user_topic' to anything you want!
user_topic = "The silver moon"

context = f"""Instruction: Write a short, creative 4-line poem.
The poem must begin with the provided topic.

Topic: A summer garden
Poem:
A summer garden blooms in light,
Dancing in the morning sun,
Colors blooming soft and bright,
Day has only just begun.

Topic: {user_topic}
Poem:
{user_topic}"""

# 3. SET PARAMETERS
current_text = context
line_count = 0
max_tokens = 250
temperature = 0.8
top_p = 0.9
repetition_penalty = 1.2
top_k = 50

# Find special character IDs
newline_token_id = tokenizer.encode("\n")[0]
punctuation_ids = tokenizer.encode(". ! ?", add_special_tokens=False)
eos_token_id = tokenizer.eos_token_id

In [None]:
print("\n--- Generating your 4-line poem ---\n")

# loop until 4 lines are completed
for _ in range(max_tokens):
    input_ids = tokenizer.encode(current_text, return_tensors="pt")

    with torch.no_grad():
        outputs = model(input_ids)
        next_token_logits = outputs.logits[0, -1, :]

        # Look at every token we've already used
        for token_id in set(input_ids[0].tolist()):
            # If the logit is positive, make it smaller; if negative, make it more negative
            if next_token_logits[token_id] > 0:
                next_token_logits[token_id] /= repetition_penalty
            else:
                next_token_logits[token_id] *= repetition_penalty

        # check the length of the current line to see if it's too long
        current_lines = current_text.split('\n')
        words_on_this_line = len(current_lines[-1].split())

        if words_on_this_line > 7:
            # add a 'bias' to the newline token's score
            # This makes the probability of a newline almost 100%
            next_token_logits[newline_token_id] += 20.0

        next_token_logits = next_token_logits / temperature

        # sort the logits in descending order
        sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)

        # convert sorted logits to probabilities
        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

        # identify tokens to remove (those that exceed the top_p threshold)
        sorted_indices_to_remove = cumulative_probs > top_p

        # shift the indices to keep the first token that crosses the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # set the logits for the "excluded" tokens to -infinity
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        next_token_logits[indices_to_remove] = -float('Inf')

        # final probabilities and sampling
        next_token_probs = torch.softmax(next_token_logits, dim=0)

    # top_k is now applied to the *already* top-p filtered distribution
    top_probs, top_indices = torch.topk(next_token_probs, top_k)

    # Randomly sample the next word from the top_k candidates
    sample_idx = torch.multinomial(top_probs, num_samples=1)
    next_token_id = top_indices[sample_idx].item()
    next_word = tokenizer.decode([next_token_id])

    # Don't allow a newline if the last character was already a newline
    if "\n" in next_word and current_text.endswith("\n"):
        continue

    # Don't allow a newline if the line is too short (e.g., less than 15 chars)
    # This forces the model to actually write words before ending the line.
    current_lines = current_text.split('\n')
    if "\n" in next_word and len(current_lines[-1]) < 15:
        continue

    # Check if the generated word contains a newline character
    if "\n" in next_word:
        line_count += 1

    # Append the word to our poem
    current_text += next_word

    # If we hit 4 lines, stop immediately
    if line_count >= 4:
        break

# ensure the poem displays nicely
final_poem = current_text.split("Poem:")[-1].strip()
# Clean up any potential double-spacing or end-of-text artifacts
final_poem = final_poem.replace("<|endoftext|>", "").strip()
print(final_poem)
print("\n--- End of Poem ---")


--- Generating your 4-line poem ---

The silver moon is falling out of its shadow
And it's back to earth!
The red star shines brightly from time immemorial
Through years... This beautiful new spring bloom will be

--- End of Poem ---


In [None]:
print(f"\nFinal raw output: {repr(current_text)}")


Final raw output: "Instruction: Write a short, creative 4-line poem. \nThe poem must begin with the provided topic.\n\nTopic: A summer garden\nPoem:\nA summer garden blooms in light,\nDancing in the morning sun,\nColors blooming soft and bright,\nDay has only just begun.\n\nTopic: The silver moon\nPoem:\nThe silver moon is falling out of its shadow\nAnd it's back to earth!\nThe red star shines brightly from time immemorial\nThrough years... This beautiful new spring bloom will be\n"
