# Train AT2

In this notebook, we'll walk through training an AT2 score estimator.
We'll consider the problem of context attribution (attributing a model's generation to in-context information).

As a training dataset, we'll be using a subset of [`databricks-dolly-15k`](https://huggingface.co/datasets/databricks/databricks-dolly-15k), an instruction following dataset which includes summarization, context-based question answering, and information extraction tasks.

In [1]:
from pathlib import Path
from datasets import load_dataset

from at2.utils import get_model_and_tokenizer
from at2.tasks import SimpleContextAttributionTask
from at2 import AT2Trainer, AT2Attributor

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /mnt/xfs/home/bencw/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


We'll start by loading 1,000 random examples from the dataset, filtering to only include examples with a context and to omit examples with very long contexts (to make training easier).

In [2]:
def filter_fn(example):
    valid_category = example["category"] in ["summarization", "closed_qa", "information_extraction"]
    valid_length = len(example["context"]) < 20_000
    return valid_category and valid_length

raw_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
dataset = raw_dataset.filter(filter_fn).shuffle(seed=42).select(range(1_000))

In [3]:
dataset[0]

{'instruction': 'In this reference text summarizing plot of the book The High King, how did the sword Dyrnwyn lose its power?',
 'context': 'The story begins only days after the conclusion of Taran Wanderer. With winter approaching, Taran and his companion Gurgi return from their wanderings to Caer Dallben after getting news from Kaw the crow that Princess Eilonwy has returned from the Isle of Mona. Indeed, they find her at home, along with her escort King Rhun of Mona and the former giant Glew, who had been magically restored to human size by a potion from Dallben.\n\nBefore Taran can propose to Eilonwy, the bard-king Fflewddur Fflam and his mount Llyan arrive with a gravely injured Gwydion, Prince of Don. Servants of Arawn had assaulted them and seized the magical black sword Dyrnwyn. Fflewddur also states that Taran was involved in the ambush, baffling everyone. With Achren\'s help, the truth is determined: Arawn himself has come from Annuvin to the verge of Caer Dallben in the guis

Next, we'll load the model.
We'll be working with [`microsoft/Phi-4-mini-instruct`](https://huggingface.co/microsoft/Phi-4-mini-instruct).

In [4]:
model_name = "microsoft/Phi-4-mini-instruct"
model, tokenizer = get_model_and_tokenizer(model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

To learn to attribute, we'll need to define an "attribution task."
An attribution task consists of an input sequence, a generated sequence, a model/tokenizer and a set of sources to which we would like to attribute the generated sequence.
In the case of context attribution, the input sequence is a context and query, the generated sequence is the model's response and the sources are pieces of the context, e.g., sentences.
We've defined a class, `SimpleContextAttributionTask` to be able to quickly create such a task from an example in the dataset.

In [5]:
def task_from_example(example, model, tokenizer, source_type="token"):
    return SimpleContextAttributionTask(
        context=example["context"],
        query=example["instruction"],
        model=model,
        tokenizer=tokenizer,
        source_type=source_type,
    )

Let's create a task to see what they look like.

In [6]:
example = dataset[0]
task = task_from_example(example, model, tokenizer, source_type="sentence")
print("### Context ###")
print(example["context"][:500] + "..." if len(example["context"]) > 500 else example["context"])
print()
print("### Instruction ###")
print(example["instruction"])
print()
# Generates a response and caches relevant information for attribution
print("### Generated response ###")
print(task.generation)

### Context ###
The story begins only days after the conclusion of Taran Wanderer. With winter approaching, Taran and his companion Gurgi return from their wanderings to Caer Dallben after getting news from Kaw the crow that Princess Eilonwy has returned from the Isle of Mona. Indeed, they find her at home, along with her escort King Rhun of Mona and the former giant Glew, who had been magically restored to human size by a potion from Dallben.

Before Taran can propose to Eilonwy, the bard-king Fflewddur Fflam ...

### Instruction ###
In this reference text summarizing plot of the book The High King, how did the sword Dyrnwyn lose its power?

### Generated response ###
In the reference text summarizing the plot of "The High King," the sword Dyrnwyn loses its power after the defeat of Arawn. The text states that with the death of Arawn, the stronghold of Annuvin bursts into flame and falls in ruins, destroying all of the magical implements inside, including Dyrnwyn. As a result, Dyrnwyn

In [7]:
print("Total sources:", task.num_sources)
# This is the first few sources (sentences from the context)
for i in range(3):
    print(f"Source #{i}:")
    print(task.sources[i].strip())
    print()

Total sources: 58
Source #0:
The story begins only days after the conclusion of Taran Wanderer.

Source #1:
With winter approaching, Taran and his companion Gurgi return from their wanderings to Caer Dallben after getting news from Kaw the crow that Princess Eilonwy has returned from the Isle of Mona.

Source #2:
Indeed, they find her at home, along with her escort King Rhun of Mona and the former giant Glew, who had been magically restored to human size by a potion from Dallben.



We're now ready to train AT2 on our dataset!
To do so, we'll first create an `AT2Trainer`.
From there, training involves three steps:
1. Generating a response for each task (from the context and query).
1. Computing features (attention weights) and outputs (logit probabilities for a few ablations of the sources).
1. Actually training a score estimator to predict the effects of ablations using the features.

In [8]:
save_path = Path("outputs") / "test_context_phi_4_mini_instruct"

trainer = AT2Trainer(
    save_path=save_path,
    dataset=dataset,
    model=model,
    tokenizer=tokenizer,
    task_from_example=task_from_example,
)

In [9]:
# To parallelize across multiple jobs, set `num_jobs` and `job_index`
trainer.generate()

Generating completions:   0%|          | 0/250 [00:00<?, ?it/s]

In [10]:
# To parallelize across multiple jobs, set `num_jobs` and `job_index`
trainer.compute_features_and_outputs()

Computing features and outputs:   0%|          | 0/1000 [00:00<?, ?it/s]

In [11]:
# A few examples without at least one valid sentence in the response are excluded
trainer.train(save_name="default")

Training on 994 examples of 1000


Training score estimator:   0%|          | 0/1000 [00:00<?, ?it/s]

Step 0: loss=-0.2556
Step 99: loss=-0.5373
Step 199: loss=-0.5693
Step 299: loss=-0.5703
Step 399: loss=-0.573
Step 499: loss=-0.5747
Step 599: loss=-0.5749
Step 699: loss=-0.5781
Step 799: loss=-0.5779
Step 899: loss=-0.5781
Step 999: loss=-0.5779
Saved estimator to outputs/test_context_phi_4_mini_instruct/estimators/default


LinearScoreEstimator(
  (linear): Linear(in_features=768, out_features=1, bias=False)
)

To wrap up, we'll use the score estimator we've just trained to attribute a response for a request to summarize an article from [CNN DailyMail](https://huggingface.co/datasets/abisee/cnn_dailymail).

In [12]:
dataset = load_dataset("cnn_dailymail", "3.0.0", split="validation")
example = dataset[0]

task = SimpleContextAttributionTask(
    context=example["article"],
    query="Summarize the article in up to three sentences.",
    model=model,
    tokenizer=tokenizer,
    source_type="sentence",
)

In [13]:
task.show_target_with_indices()

[36m[(0, 139)][0mZully Broussard's selfless act of donating a kidney to a stranger led to a chain reaction, resulting in six patients receiving transplants. [36m[(140, 312)][0mThe process, which took only three weeks, was made possible by a computer program called MatchGrid, created by David Jacobs, which quickly matches up donor pairs or chains. [36m[(313, 491)][0mThe chain of surgeries, which involved five surgeons, a team of physician assistants, nurses, anesthesiologists, and more than 40 support staff, is set to be completed by Friday.


In [14]:
attributor = AT2Attributor.from_path(
    task, trainer.save_path / "estimators" / "default" / "score_estimator.pt"
)
start, end = (140, 312)
attributor.show_attribution(start=start, end=end, verbose=True)

Computing attribution scores for:
 The process, which took only three weeks, was made possible by a computer program called MatchGrid, created by David Jacobs, which quickly matches up donor pairs or chains.


Unnamed: 0,Score,Source
0,0.004,"Jacobs paid it forward with his programming skills, creating MatchGrid, a program that genetically matches up donor pairs or chains quickly."
1,0.002,That changed when a computer programmer named David Jacobs received a kidney transplant.
2,0.001,"We did this in about three weeks,"" Jacobs said."
3,0.001,But the power that multiplied Broussard's gift was data processing of genetic profiles from donor-recipient pairs.
4,0.0,"It's been done before, California Pacific Medical Center said in a statement, but matching up the people in the chain has been laborious and taken a long time."
5,0.0,"It works on a simple swapping principle but takes it to a much higher level, according to California Pacific Medical Center in San Francisco."
6,0.0,"""When we did a five-way swap a few years ago, which was one of the largest, it took about three to four months."
7,0.0,"So high, that it is taking five surgeons, a covey of physician assistants, nurses and anesthesiologists, and more than 40 support staff to perform surgeries on 12 people."
