# Gaze-controlled text generation

This notebook demonstrates how to generate texts with the language model / gaze model ensemble.

In [None]:
# Login to huggingface to run models. Requires the environment variable "HUGGINGFACE_TOKEN" to be set with a valid access token.
# Linux: export HUGGINGFACE_TOKEN=your_huggingface_token_here
# Windows: $env:HUGGINGFACE_TOKEN="your_huggingface_token_here"

import os
from huggingface_hub import login

token = os.getenv("HUGGINGFACE_TOKEN")
login(token)

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svâ€¦

In [None]:
import json
from pprint import pprint

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from modeling.gaze_models import CausalTransformerGazeModel
from modeling.generation import GazeControlledBeamSearch

## Loading the models

The texts in our experiment were generated using the off-the-shelf instruction-tuned Llama-3.2 language model with 3B parameters. The gaze model is a GPT-2 model fine-tuned to predict first-pass gaze duration.

> **NOTE:** Expect the ensemble to be quite slow on CPU (up to a minute per token), so you should consider either using a GPU (e.g., on Google Colab) or choosing smaller model(s).

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

language_model_name = "meta-llama/Llama-3.2-3B-Instruct"
gaze_model_name = "openai-community/gpt2"

language_model_tokenizer = AutoTokenizer.from_pretrained(language_model_name)
language_model = AutoModelForCausalLM.from_pretrained(language_model_name).to(device)

gaze_model = CausalTransformerGazeModel.from_pretrained(gaze_model_name).to(device)
gaze_model.load_state_dict(torch.load("models/trf_gaze_model.pt", map_location=device))

## Loading the story prompts

The titles and prompts for the stories were generated using GPT-4 and manually curated.

In [None]:
with open("stories/prompts.jsonl") as f:
    prompts = [json.loads(line) for line in f]
pprint(prompts, sort_dicts=False)

## Generating the texts

Before starting the generation, we need to build an instruction prompt according to the template that is specific to Llama-3.2. We then generate text using [beam search](https://en.wikipedia.org/wiki/Beam_search) until one of two conditions applies:

- the language model has predicted an end-of-message token in the best beam, or
- the number of generated tokens has reached 800.

![Visualization of beam search with beam size 3](https://upload.wikimedia.org/wikipedia/commons/2/23/Beam_search.gif)

Refer to [`generation.py`](modeling/generation.py) for details.

In [None]:
gaze_weight = 2
beam_size = 8

beam_search = GazeControlledBeamSearch(
    language_model,
    language_model_tokenizer,
    gaze_model,
)

outputs = []
for prompt in prompts:
    input_text = language_model_tokenizer.apply_chat_template(
        [
            {
                "role": "user",
                "content": (
                    "Write a short story based on the following title and prompt.\n"
                    f"Title: {prompt['title']}\n"
                    f"Prompt: {prompt['prompt']}\n\n"
                    "The story should not be longer than 500 words. "
                    "Keep in mind that the reader will not see the prompt, only the story itself. "
                    "Do not include the title."
                ),
            },
        ],
        add_generation_prompt=True,
        tokenize=False,
    )
    output_text, token_score, gaze_score = beam_search.generate(
        input_text,
        gaze_weight=gaze_weight,
        max_length=800,
        beam_size=beam_size,
        ignore_prompt=True,
        verbose=True,
    )
    outputs.append(
        {
            **prompt,
            "gaze_weight": gaze_weight,
            "input_text": input_text,
            "output_text": output_text,
            "token_score": token_score,
            "gaze_score": gaze_score,
        }
    )

## Saving the outputs

This includes the final text from the best beam as well as the total token score from the language model and the gaze score from the gaze model.

In [None]:
with open("stories/outputs.jsonl", "w") as f:
    for output in outputs:
        f.write(json.dumps(output) + "\n")