# Gaze modeling

This notebook demonstrates how to train a gaze model to predict token-level reading measures from text.

Currently, there are two gaze model implementations:
- a transformer-based model to be fine-tuned from a causal language model, and
- a linear regression model using word length and frequency as features.

Let's train one of each and evaluate them.

In [None]:
import matplotlib.pyplot as plt
import torch

from modeling import reading_measures
from modeling.dataset import GazeTextDataset
from modeling.gaze_models import CausalTransformerGazeModel, LinearRegressionGazeModel

## Loading and preprocessing the dataset

We are using the EMTeC dataset (eye tracking while reading generated texts) and preprocessing it in the following way:

1. Calculating the first-pass gaze duration for each word and participant.
2. Clamping the gaze durations within $\pm 3$ standard deviations of the mean.
3. Averaging the gaze durations across participants for each word.

Refer to [`dataset.py`](modeling/dataset.py) for details.

In [None]:
reading_measure = reading_measures.first_pass_gaze_duration
outlier_zscore = 3

emtec = GazeTextDataset.load(
    "data/emtec",
    reading_measure,
    outlier_zscore,
)

Next, let's split the dataset into a training set (80%) and a development set (20%):

In [None]:
train_dataset, dev_dataset = emtec.random_split(0.8)

print(f"Train: {len(train_dataset.texts)} texts, {len(train_dataset.gaze_data)} AOIs")
print(f"Dev: {len(dev_dataset.texts)} texts, {len(dev_dataset.gaze_data)} AOIs")

Finally, we're going to normalize the gaze durations to have mean 0 and standard deviation 1. This will make training easier, and it will also help us interpret the gaze scores in the text generation process later.

In [None]:
train_dataset.normalize_gaze_labels()
# Normalize dev_dataset based on the mean and std of train_dataset
dev_dataset.normalize_gaze_labels(
    train_dataset.gaze_label_mean, train_dataset.gaze_label_std
)

## Training a linear regression model

This model predicts first-pass gaze duration for each word based on the length and frequency of the word itself, as well as the length and frequency of the previous 2 words to account for spillover effects.

In [None]:
lr_gaze_model = LinearRegressionGazeModel(lang="en", max_spillover=2)
lr_gaze_model.fit(
    train_dataset,
    dev_dataset,
)

The `predict()` method returns the sum of the gaze durations for all words. You can use `predict_aois()` to reconstruct the gaze durations for each word separately (this is necessary due to tokenization differences between the dataset and the gaze models).

In [None]:
def predict_example(gaze_model):
    text = "The quick brown fox jumps over the lazy dog."
    aoi_ends = [3, 9, 15, 19, 25, 30, 34, 39, 44]
    total_pred = gaze_model.predict([text])[0]
    aoi_preds = gaze_model.predict_aois(text, aoi_ends)
    aoi_start = 0
    for aoi_end, aoi_pred in zip(aoi_ends, aoi_preds):
        word = text[aoi_start:aoi_end].strip()
        print(f"{word}\t{aoi_pred:.3f}")
        aoi_start = aoi_end
    print(f"\nTotal\t{total_pred:.3f}")

predict_example(lr_gaze_model)

## Training a transformer model

In addition to the linear regression model, let's fine-tune GPT-2 (124M parameters) on the dataset.

Since GPT-2 uses subword tokenization, the dataset will be reformatted behind the scenes and divide a word's gaze duration among its subwords. For example, if a word consists of three subwords and has a first-pass gaze duration of 600, the model will be trained to predict 200 for each of those subwords. Refer to the `TransformerGazeTextDataset` in [`gaze_models.py`](modeling/gaze_models.py) for details.

> **NOTE:** If you can't or don't want to run the training, the parameters of a trained gaze model have been included in this repository under [`models`](models). You can skip the cell calling `trf_gaze_model.fit()` and instead load it with the subsequent cell.


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the pretrained model
trf_gaze_model = CausalTransformerGazeModel.from_pretrained(
    "openai-community/gpt2"
).to(device)

In [None]:
# Train the model
trf_gaze_model.fit(
    train_dataset,
    dev_dataset,
    batch_size=10,
    patience=3,
    learning_rate=0.0001,
)
# Save the trained model
torch.save(trf_gaze_model.state_dict(), "models/trf_gaze_model.pt")

In [None]:
# Load the trained model
trf_gaze_model.load_state_dict(torch.load("models/trf_gaze_model.pt", map_location=device))

In [None]:
predict_example(trf_gaze_model)

## Evaluating the models

First, let's predict the gaze durations on the development set and plot the predictions against the observations for both models:

In [None]:
lr_results = lr_gaze_model.evaluate(dev_dataset)
plt.scatter(lr_results["preds"], lr_results["labels"])
plt.title("Linear regression")
plt.xlabel("Predicted")
plt.ylabel("Observed")

In [None]:
trf_results = trf_gaze_model.evaluate(dev_dataset)
plt.scatter(trf_results["preds"], trf_results["labels"])
plt.title("Linear regression")
plt.xlabel("Predicted")
plt.ylabel("Observed")

We can also look at some evaluation metrics:

- **MAE:** mean absolute error (lower is better)
- **R²:** proportion of variance explained (higher is better)
- **Pearson:** linear correlation coefficient (higher is better)

> **NOTE:** We are not using a separate held-out test set here, since our end goal is not evaluating the gaze models. We are only using these metrics to select the best model for our downstream task (gaze-controlled text generation). If you want to conduct a proper performance evaluation, a separate test set or cross-validation would be more appropriate.

In [None]:
print("Linear regression:")
print(f"MAE: {lr_results["mae"]}")
print(f"R2: {lr_results["r2"]}")
print(f"Pearson: {lr_results["pearson"]}")
print()
print("Transformer:")
print(f"MAE: {trf_results["mae"]}")
print(f"R2: {trf_results["r2"]}")
print(f"Pearson: {trf_results["pearson"]}")