<a href="https://colab.research.google.com/github/PinknMatter/CART498-GenAI/blob/main/A2/P%2B7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

def p_plus_x(lines, model_name='gpt2', p=7):
    """
    Removes the last word of each line, then uses GPT-2 to predict the
    next word. Picks the p-th most likely word from GPT-2's distribution
    and appends it to the line.

    Args:
        lines (list): List of strings, each treated as one line.
        model_name (str): The GPT-2 model to use.
        p (int): Which probability rank to choose (1 = most likely, etc.).

    Returns:
        list: List of new lines with the replaced word.
    """
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)
    model.eval()

    new_lines = []
    for line in lines:
        words = line.strip().split()
        if not words:
            new_lines.append(line)
            continue

        partial_line = " ".join(words[:-1])
        input_ids = tokenizer.encode(partial_line, return_tensors='pt')

        with torch.no_grad():
            outputs = model(input_ids)
            next_token_logits = outputs.logits[0, -1, :]

        probs = torch.softmax(next_token_logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)

        if (p - 1) < len(sorted_indices):
            chosen_token_id = sorted_indices[p - 1].item()
        else:
            chosen_token_id = sorted_indices[-1].item()

        chosen_token = tokenizer.decode([chosen_token_id]).strip()
        new_line = partial_line + " " + chosen_token
        new_lines.append(new_line)

    return new_lines

In [20]:
lines = [
    "One must have a mind of winter",
    "To regard the frost and the boughs",
    "Of the pine-trees crusted with snow;",
    "And have been cold a long time",
    "To behold the junipers shagged with ice,",
    "The spruces rough in the distant glitter",
    "Of the January sun; and not to think",
    "Of any misery in the sound of the wind,",
    "In the sound of a few leaves,",
    "Which is the sound of the land",
    "Full of the same wind",
    "That is blowing in the same bare place",
    "For the listener, who listens in the snow,",
    "And, nothing himself, beholds",
    "Nothing that is not there and the nothing that is."
]

print("\n--- Final Lines ---")
results_p1 = p_plus_x(lines, p=20)
for i, r in enumerate(results_p1, 1):
    print(r)


--- Final Lines ---
One must have a mind of steel
To regard the frost and the darkness
Of the pine-trees crusted with his
And have been cold a long hard
To behold the junipers shagged with two
The spruces rough in the distant sun
Of the January sun; and not to give
Of any misery in the sound of the car
In the sound of a few ,
Which is the sound of the phone
Full of the same weight
That is blowing in the same bare skin
For the listener, who listens in the usual
And, nothing himself, only
Nothing that is not there and the nothing that remains
