<a href="https://colab.research.google.com/github/GoldPapaya/info256-applied-nlp/blob/main/11.nlp/HW11_LLM_Coref.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dbamman/anlp25/blob/main/11.nlp/HW11_LLM_Coref.ipynb)

# HW11: Coreference with LLMs

In this homework, you will experiment with using LLMs for zero-shot or few-shot coreference resolution.

In [1]:
import torch

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
!pip install stanza==1.10.1 peft

Collecting stanza==1.10.1
  Downloading stanza-1.10.1-py3-none-any.whl.metadata (13 kB)
Collecting emoji (from stanza==1.10.1)
  Downloading emoji-2.15.0-py3-none-any.whl.metadata (5.7 kB)
Downloading stanza-1.10.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading emoji-2.15.0-py3-none-any.whl (608 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m608.4/608.4 kB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: emoji, stanza
Successfully installed emoji-2.15.0 stanza-1.10.1


In [3]:
!wget https://raw.githubusercontent.com/dbamman/anlp25/refs/heads/main/data/1342_pride_and_prejudice_brat.conll -O 1342_pride_and_prejudice_brat.conll
!wget https://raw.githubusercontent.com/dbamman/anlp25/refs/heads/main/data/1342_pride_and_prejudice_sample.txt -O 1342_pride_and_prejudice_sample.txt
!wget https://raw.githubusercontent.com/dbamman/anlp25/refs/heads/main/11.nlp/coref_utils.py

--2025-11-14 15:39:21--  https://raw.githubusercontent.com/dbamman/anlp25/refs/heads/main/data/1342_pride_and_prejudice_brat.conll
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 114884 (112K) [text/plain]
Saving to: ‘1342_pride_and_prejudice_brat.conll’


2025-11-14 15:39:22 (13.2 MB/s) - ‘1342_pride_and_prejudice_brat.conll’ saved [114884/114884]

--2025-11-14 15:39:22--  https://raw.githubusercontent.com/dbamman/anlp25/refs/heads/main/data/1342_pride_and_prejudice_sample.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 8815 

## The coreference resolution task

We formulate the coreference resolution task as follows: given an input text, output a sequence of coref chains $(C_1, \ldots, C_i)$, each of which contains a sequence of coref mentions $C_i = (m_{i1}, \ldots, m_{ij})$ ordered by start index. Each coref mention is a tuple of the start and end indices $m_{ij} = (\text{start\_index}, \text{end\_index})$ denoting the span of the mention in the text.

To formalize this in the code, we set up the following classes:

```python
@dataclass
class CorefMention:
    start_idx: int
    end_idx: int


@dataclass
class CorefChain:
    mentions: list[CorefMention] = field(default_factory=list)


CorefOutput = list[CorefChain]
```

To avoid cluttering up the notebook, we provide a utility file with these types and other useful functions.

In [4]:
from coref_utils import CorefMention, CorefChain, CorefOutput

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


## Evaluation data

We will use data from [LitBank](https://github.com/dbamman/litbank), which contains coreference annotations for novels in the public domain. In particular, we will be evaluate the coreference output on approximately 2,000 words of _Bleak House_ by Charles Dickens.

Because most systems can't handle coreference on such long texts (and it would also strain the memory usage of the LLM), we do inference on chunks of sentences instead. The `load_conll_data` function takes care of this for us.

In [5]:
from coref_utils import load_conll_data

In [6]:
text, gold_coref_chains = load_conll_data("./1342_pride_and_prejudice_sample.txt", "./1342_pride_and_prejudice_brat.conll")

Let's examine the first text chunk as an example.

In [7]:
print(text[0])

Chapter 1


It is a truth universally acknowledged, that a single man in possession
of a good fortune, must be in want of a wife.

However little known the feelings or views of such a man may be on his
first entering a neighbourhood, this truth is so well fixed in the minds
of the surrounding families, that he is considered the rightful property
of some one or other of their daughters.

“My dear Mr. Bennet,” said his lady to him one day, “have you heard that
Netherfield Park is let at last?”

Mr. Bennet replied that he had not.


In [8]:
for chain in gold_coref_chains[0][:10]:
    print("===")
    for mention in chain.mentions:
        print(f'"{text[0][mention.start_idx:mention.end_idx]}"')


===
"a single man in possession
of a good fortune"
===
"a wife"
===
"such a man"
"his"
"he"
===
"a neighbourhood"
===
"the surrounding families"
"their"
===
"My dear"
"Mr. Bennet"
"his"
"him"
"you"
"Mr. Bennet"
"he"
===
"My"
"his lady"
===
"Netherfield Park"
===
"some one or other of their daughters"


## Baseline with `stanza`

Here, we will evaluate a baseline with the `stanza` coreference resolution system.

The $B^3$ precision and recall metrics are defined at the entity mention level. We follow previous works in evaluating mention detection and coreference separately:
- We calculate span F1 to measure the performance of mention detection
- We calculate the $B^3$ metrics on only the mentions that are shared between the gold and system outputs.

In [9]:
from coref_utils import evaluate

# Usage:
# evaluate(gold, pred)
# Returns an EvaluationOutput object that contains B3 precision and recall, as well as span precision/recall/f1

Now let's implement and run the Stanza baseline.

In [10]:
## STANZA BASELINE
import stanza

def stanza_baseline(text_chunks: list[str]) -> CorefOutput:
    """
    Run the stanza baseline on the text chunks.
    """
    pipe = stanza.Pipeline("en", processors="tokenize,lemma,pos,depparse,coref")
    results = []
    for text_chunk in tqdm(text_chunks):
        chains = []
        doc = pipe(text_chunk)
        for coref_chain in doc.coref:
            mentions = []
            for mention in coref_chain.mentions:
                span = doc.sentences[mention.sentence].words[mention.start_word:mention.end_word]
                mentions.append(CorefMention(start_idx=span[0].start_char, end_idx=span[-1].end_char))
            chains.append(CorefChain(mentions=mentions))
        results.append(chains)
    return results

baseline = stanza_baseline(text)

INFO:stanza:Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES


Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.10.0.json:   0%|  …

INFO:stanza:Downloaded file to /root/stanza_resources/resources.json


Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.10.0/models/tokenize/combined.pt:   0%|   …

Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.10.0/models/mwt/combined.pt:   0%|        …

Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.10.0/models/pos/combined_charlm.pt:   0%| …

Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.10.0/models/lemma/combined_nocharlm.pt:   …

Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.10.0/models/coref/udcoref_xlm-roberta-lora…

Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.10.0/models/depparse/combined_charlm.pt:  …

Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.10.0/models/forward_charlm/1billion.pt:   …

Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.10.0/models/backward_charlm/1billion.pt:  …

Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.10.0/models/pretrain/conll17.pt:   0%|    …

INFO:stanza:Loading these models for language: en (English):
| Processor | Package                  |
----------------------------------------
| tokenize  | combined                 |
| mwt       | combined                 |
| pos       | combined_charlm          |
| lemma     | combined_nocharlm        |
| coref     | udcoref_xlm-roberta-lora |
| depparse  | combined_charlm          |

INFO:stanza:Using device: cuda
INFO:stanza:Loading: tokenize
INFO:stanza:Loading: mwt
INFO:stanza:Loading: pos
INFO:stanza:Loading: lemma
INFO:stanza:Loading: coref
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

INFO:stanza:Loading: depparse
INFO:stanza:Done loading processors!
100%|██████████| 22/22 [00:07<00:00,  2.90it/s]


In [11]:
evaluate(gold_coref_chains, baseline)

EvaluationOutput(b3_recall=0.8611619636557228, b3_precision=0.777449162261211, span_recall=0.9080464253519174, span_precision=0.940840839850143, span_f1=0.9207456155960159)

## LLM prompting

How can we prompt an LLM to do this task? How can we post-process the output into the structure that we desire?

In [12]:
# use the 4B model

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="cuda", dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

We use the same code to call the LLM as in previous assignments. Feel free to modify this if you wish.

In [13]:
def call_llm(prompt, system_prompt="You are a helpful assistant.", generation_config=None):
    if generation_config is None:
        generation_config = {
            "max_new_tokens": 10,
            "temperature": 0.01
        }
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )

    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # conduct text completion
    generated = model.generate(
        **model_inputs,
        **generation_config
    )

    # let's break this down:
    #                      | we take the element of the batch (our batch size is 1)
    #                      |  |-----------------------------| skip our original input
    output_ids = generated[0][len(model_inputs.input_ids[0]):].tolist()

    # decode into token space
    return tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")

### Question 1

What is the simplest solution you can think of to do this task with an LLM? **In a few sentences**, explain your approach. Then, **implement this method** and evaluate its performance.

I'm thinking of doing a simple zero-shot prompt (but with guidence for how to format output, which is not an example) for LLM to identify coreference chains and output them in a simple format that can be parsed like C1: he, him, his. This will involve some post processing of the output by performing string matching on texts to recover correct start and end spans in order of appearance.



In [14]:
def llm_baseline(text_chunks) -> CorefOutput:
    """For you to implement! Return a list of CorefChain objects"""
    return [[CorefChain([])] for _ in text_chunks]

output = llm_baseline(text)
evaluate(gold_coref_chains, output)

EvaluationOutput(b3_recall=0.0, b3_precision=0.0, span_recall=0.0, span_precision=0.0, span_f1=0.0)

In [62]:
def find_exact_mentions(text, mention_texts):
    """
    Helper function that when given a list of mention strings
    (ex ['he', 'Mr. Bennet']), finds their spans in the text by
    occurrence order
    """
    mentions = []
    current_pos = 0
    used_spans = set()
    for ment_text in mention_texts:
        start = text.find(ment_text, current_pos)
        if start == -1:
            continue
        end = start + len(ment_text)
        span = (start, end)
        if span in used_spans:
            continue
        used_spans.add(span)
        mentions.append(CorefMention(start_idx=start, end_idx=end))
        current_pos = end
    return mentions

def llm_baseline(text_chunks) -> CorefOutput:
    res = []
    for chunk in tqdm(text_chunks):
        prompt = f"""
        Identify all referring expressions (nouns, pronouns, names) that refer to the same entity.
        Group them into coreference chains.

        Output format:
        C1: he, him, his
        C2: she, her, hers

        Only output the chains. Use exact text from the passage. No indices. No quotes.

        Text:
        \"\"\"{chunk}\"\"\"
        """
        raw_output = call_llm(prompt, generation_config={
            "max_new_tokens": 512,
            "temperature": 0.0, # greedy but reproducible output, same for next question
            "do_sample": False
        })
        chains = []
        for line in raw_output.split('\n'):
            line = line.strip()
            if not line.startswith('C'):
                continue
            if ':' not in line:
                continue
            mention_strs = [m.strip().strip('"\'') for m in line.split(':', 1)[1].split(',')]
            mention_strs = [m for m in mention_strs if m]
            mentions = find_exact_mentions(chunk, mention_strs)
            if mentions:
                # sort mentions by the starting index
                mentions.sort(key=lambda m: m.start_idx)
                chains.append(CorefChain(mentions=mentions))

        if not chains:
            chains = [CorefChain()]
        res.append(chains)
    return res

In [63]:
output = llm_baseline(text)
eval_result = evaluate(gold_coref_chains, output)

100%|██████████| 22/22 [01:52<00:00,  5.13s/it]


In [54]:
print(output)
print(eval_result)

[[CorefChain(mentions=[CorefMention(start_idx=153, end_idx=155), CorefMention(start_idx=399, end_idx=409)]), CorefChain(mentions=[CorefMention(start_idx=198, end_idx=201), CorefMention(start_idx=417, end_idx=425)]), CorefChain(mentions=[CorefMention(start_idx=122, end_idx=128), CorefMention(start_idx=365, end_idx=368)]), CorefChain(mentions=[CorefMention(start_idx=463, end_idx=479)])], [CorefChain(mentions=[CorefMention(start_idx=22, end_idx=25), CorefMention(start_idx=56, end_idx=59)]), CorefChain(mentions=[CorefMention(start_idx=23, end_idx=25), CorefMention(start_idx=94, end_idx=104)]), CorefChain(mentions=[CorefMention(start_idx=32, end_idx=41)]), CorefChain(mentions=[CorefMention(start_idx=5, end_idx=7)])], [CorefChain(mentions=[CorefMention(start_idx=53, end_idx=55), CorefMention(start_idx=252, end_idx=262)]), CorefChain(mentions=[CorefMention(start_idx=74, end_idx=85), CorefMention(start_idx=109, end_idx=111)]), CorefChain(mentions=[CorefMention(start_idx=53, end_idx=55), CorefM

### Question 2

What are some mistakes that you notice the system making? What are some improvements you can make? Name at least **two** improvements you want to test. Then, **implement these** and report how they affect the performance of the system. Feel free to experiment with more!

I want to try implementing a two-stage approach where all mentions are extracted first in one process, then 'clustered' together in a second process. Also, I want to try giving the prompt itself a golden one-shot example to hopefully improve the result. Both these improvements should fix some of the low span scores in the previous version (and therefore also the F1), because of the exhaustive candidate extraction, which should in theory find more and/or better real mentions.

In [64]:
from tqdm import tqdm
import re

EXAMPLE = """
Text:
Chapter 1


It is a truth universally acknowledged, that a single man in possession
of a good fortune, must be in want of a wife.

However little known the feelings or views of such a man may be on his
first entering a neighbourhood, this truth is so well fixed in the minds
of the surrounding families, that he is considered the rightful property
of some one or other of their daughters.

“My dear Mr. Bennet,” said his lady to him one day, “have you heard that
Netherfield Park is let at last?”

Mr. Bennet replied that he had not.

Output:
C1: a single man in possession of a good fortune
C2: a wife
C3: such a man, his, he
C4: a neighbourhood
C5: the surrounding families, their
C6: My dear, Mr. Bennet, his, him, you, Mr. Bennet, he
C7: My, his lady
C8: Netherfield Park
C9: some one or other of their daughters
"""

def find_mentions_ordered(text, mention_texts):
    """
    Helper function that when given a list of mention strings
    (ex ['he', 'Mr. Bennet']), finds their spans in the text by
    occurrence order
    """
    mentions = []
    current_pos = 0
    used_spans = set()
    for ment_text in mention_texts:
        start = text.find(ment_text, current_pos)
        if start == -1:
            continue
        end = start + len(ment_text)
        span = (start, end)
        if span in used_spans:
            continue
        used_spans.add(span)
        mentions.append(CorefMention(start_idx=start, end_idx=end))
        current_pos = end
    return sorted(mentions, key=lambda m: m.start_idx)


def extract_candidates(chunk: str) -> list[str]:
    prompt = f"""Extract EVERY noun phrase, pronoun, proper name, and definite description that could refer to an entity.
    Be exhaustive. List one per line. No duplicates.

    Text:
    \"\"\"{chunk}\"\"\"
    """
    raw_output = call_llm(prompt, generation_config={
        "max_new_tokens": 512,
        "temperature": 0.0,
        "do_sample": False
    })


    candidates = [line.strip().strip('"\'') for line in raw_output.split('\n') if line.strip()]
    return [c for c in candidates if len(c) > 1]  # filter junk

def cluster_mentions(chunk: str, candidates: list[str]) -> list[list[str]]:
    cand_str = "\n".join([f"- {c}" for c in candidates])
    prompt = f"""{EXAMPLE}
    Now do the same for this text.

    First, here are all candidate mentions:
    {cand_str}

    Group them into coreference chains. Only group expressions that refer to the EXACT same entity.
    Output format:
    C1: mention1, mention2, ...
    C2: ...

    Text:
    \"\"\"{chunk}\"\"\"
    """
    raw_output = call_llm(prompt, generation_config={
        "max_new_tokens": 512,
        "temperature": 0.0, # greedy but reproducible output, same for next question
        "do_sample": False
    })
    chains = []
    for line in raw_output.split('\n'):
        line = line.strip()
        if not line.startswith('C'):
            continue
        if ':' not in line:
            continue
        chain_str = line.split(':', 1)[1]
        mentions = [m.strip().strip('"\'') for m in chain_str.split(',')]
        mentions = [m for m in mentions if m in candidates]  # sanity filter
        if mentions:
            chains.append(mentions)
    return chains

def llm_improved(text_chunks) -> CorefOutput:
    res = []
    for chunk in tqdm(text_chunks):
        # get candidates
        candidates = extract_candidates(chunk)
        if not candidates:
            res.append([CorefChain()])
            continue
        # cluster
        raw_chains = cluster_mentions(chunk, candidates)
        # convert to corefchain objects
        chains = []
        for chain_mentions in raw_chains:
            mentions = find_mentions_ordered(chunk, chain_mentions)
            if mentions:
                chains.append(CorefChain(mentions=mentions))

        if not chains:
            chains = [CorefChain()]
        res.append(chains)
    return res

In [65]:
output = llm_improved(text)
evaluate(gold_coref_chains, output)

100%|██████████| 22/22 [06:23<00:00, 17.44s/it]


EvaluationOutput(b3_recall=0.7670683020683021, b3_precision=0.7511171236171236, span_recall=0.4892763281722092, span_precision=0.5594482520817604, span_f1=0.4830351689259331)

In [None]:
def llm_improved(text_chunks) -> CorefOutput:
    """For you to implement! Return a list of CorefChain objects"""
    return [[CorefChain([])] for _ in text_chunks]

output = llm_improved(text)
evaluate(gold_coref_chains, output)