### === Purpose ===

The goal of this lab is to perform relation classification on a text where NER and Disambiguation were performed. For example, given a Wikipedia article:

    <Elvis_Presley>
    <Elvis_Presley> was an <United_States_of_America> singer and actor, married to <Priscilla_Presley>.

the goal is to predict the relation between the title entity and the others:

    <Elvis_Presley><nationality><United_States_of_America>
    <Elvis_Presley><spouse><Priscilla_Presley>

You will use a Language Model for this task, and make use of Constrained Decoding in order to make the predictions.

### === Provided Data ===

We provide
1. A preprocessed version of Wikipedia, wikipedia-ner.txt, which contains articles about disambiguated entities, whose content also went through NERC and Disambiguation.
2. A gold standard for the task, student-gold-standard.tsv, which contains triples <subject_entity> <object_entity> <relation>, that you will use to evaluate your method
3. a template for your code, relation_classification.py

### === Task ===

You will have two tasks in this lab.
The first will be to complete the function construct_trie, so that it constructs a trie for the (tokenized) list of relations given as input.
Your second task is to complete the function classify_relation in this file.
It receives as input (1) the title entity (subject), (2) the article content, (3) a trie.
It outputs a list of relations between the title entity and all the other disambiguated entities in the article content. It uses Language Models and Constrained Decoding.

### === Working with Colab ===
You need to save a local copy of the notebook to your own google drive.
Connect to an execution environment using a GPU (this should be automatic, but be aware of this !). Upload the local files directly to the colab, and you can run everything !

Don't forget to download the results file at the end.


### === Submission ===

1. Download your code (this notebook, in .ipynb or .py format) and the output of your code on the dataset (results.tsv)
2. ZIP these files in a file called firstName_lastName.zip
3. submit it here before the deadline announced during the lab:

https://www.dropbox.com/request/Bgb7txXDLK92Rg273nKe


### === Contact ===

If you have any additional questions, you can send an email to: zacchary.sadeddine@telecom-paris.fr

In [None]:
"""
Install necessary modules, run only once !
"""
!pip install -q transformers
!pip install -q sentencepiece
!pip install -q accelerate

In [4]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from collections import defaultdict
import time
from tqdm.notebook import tqdm
import re
from typing import Dict, List

In [5]:
"""
Loads a T5 LLM
"""
torch.cuda.empty_cache()
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [6]:
class WikipediaArticle:
  """ Represents a Wikipedia article. Do not modify. """
  def __init__(self, title, content):
    self.title_entity=title
    self.content=content

def wikipediaArticles(file):
  """ Yields the Wikipedia articles from a file. Do not modify. """
  article=[]
  title=None
  with open(file, "rt", encoding="utf=8") as inputFile:
    for line in inputFile:
      line=line.rstrip()
      if not title:
        title=line
        continue
      if not len(line) and title and len(article):
        yield WikipediaArticle(title, article[0])
        title=None
        article=[]
        continue
      article+=[line]

def clean(yagoEntity):
    """ Removes prefixes etc."""
    if yagoEntity.startswith('"'):
      return yagoEntity[1:-1]
    yagoEntity=yagoEntity[yagoEntity.find(':')+1:]
    return '<'+yagoEntity+'>'

In [24]:
def run_evaluation(root=''):
  """Evaluation script, do not modify (unless you want to remove some prints).
  We use the f-05 measure, which gives more importance to precision: classifying entities correctly is more valued than finding all entities.
  """
  with open(f"{root}student-gold-standard.tsv", "r", encoding="utf-8") as f:
    lines = f.readlines()
  gold_standard_dict = defaultdict(dict)
  for line in lines:
    title_entity, entity_id, relation = tuple(line.replace("\n","").split("\t"))
    gold_standard_dict[title_entity][entity_id] = relation
  gold_standard_dict = dict(gold_standard_dict)

  with open(f"{root}results.tsv", "r", encoding="utf-8") as f:
    lines = f.readlines()
  predictions_dict = defaultdict(dict)
  for line in lines:
    title_entity, entity_id, relation = tuple(line.replace("\n","").split("\t"))
    predictions_dict[title_entity][entity_id] = relation

  true_pos = 0
  false_pos = 0
  false_neg = 0

  for title_entity in predictions_dict:
    for entity_id in predictions_dict[title_entity]:
      try:
        gold_yago_relation = gold_standard_dict[title_entity][entity_id]
      except KeyError: #should not happen
        continue
      if predictions_dict[title_entity][entity_id] == gold_yago_relation:
        true_pos += 1
      else:
        false_pos += 1
        if false_pos < 100:
          print("You classified the relation between", title_entity + " and " + entity_id, "wrong.", "Expected output: ", gold_yago_relation, ",given:", predictions_dict[title_entity][entity_id])

  for gold_title in gold_standard_dict: #do we really want to measure this? There are some entities that don't have a wikipedia article, so they count. Should they be removed from the gold standard?
    for entity_id in gold_standard_dict[gold_title]:
      try:
        predict_relation = predictions_dict[gold_title][entity_id]
      except KeyError:
        false_neg += 1
        if false_neg < 100:
          print("You did not classify the relation between", gold_title + " and " + entity_id +".")

  if true_pos + false_pos != 0:
    precision = float(true_pos) / (true_pos + false_pos)
  else:
    precision = 0.0

  if true_pos + false_neg != 0:
    recall = float(true_pos) / (true_pos + false_neg)
  else:
    recall = 0.0

  beta = 0.5

  if precision + recall != 0.0:
    f05 = (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)
  else:
    f05 = 0.0

  print()
  print("Scores")
  print(f"- Precision: {precision:.3%}")
  print(f"- Recall: {recall:.3%}")
  print(f"- F-0.5 Score: {f05:.3%}")

def get_all_relations(file):
  with open(file, "r", encoding="utf-8") as f:
    lines = f.readlines()
  relations = set()
  for line in lines:
    title_entity, entity_id, relation = tuple(line.replace("\n","").split("\t"))
    relations.add(relation)
  return relations

In [8]:
def prefix_allowed_tokens_fn(input_ids, trie):
  '''
  The function that handles constrained decoding.
  For the current generated text, returns the following allowed tokens. If nothing is allowed, return EOS token (ends the generation).
  This function is called at every generation step (every time a token is generated)
  DO NOT MODIFY (unless you're brave)
  '''
  model_output = input_ids.tolist()
  allowed_tokens = trie.get(model_output)
  if not allowed_tokens:
    return [tokenizer.eos_token_id]
  return allowed_tokens

In [9]:
class Trie(object):
    def __init__(self, sequences: List[List[int]] = []):
        self.trie_dict = {}
        if sequences:
            for sequence in sequences:
                Trie._add_to_trie(sequence, self.trie_dict)

    def add(self, sequence: List[int]):
        Trie._add_to_trie(sequence, self.trie_dict)

    def get(self, prefix_sequence: List[int]):
        return Trie._get_from_trie(prefix_sequence, self.trie_dict)

    @staticmethod
    def _add_to_trie(sequence: List[int], trie_dict: Dict):
        if sequence:
            if sequence[0] not in trie_dict:
                trie_dict[sequence[0]] = {}
            Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])

    @staticmethod
    def _get_from_trie(prefix_sequence: List[int], trie_dict: Dict):
        if len(prefix_sequence) == 0:
            output = list(trie_dict.keys())
            return output
        elif prefix_sequence[0] in trie_dict:
            return Trie._get_from_trie(prefix_sequence[1:],trie_dict[prefix_sequence[0]])
        else:
            return []

    def __iter__(self):
        def _traverse(prefix_sequence, trie_dict):
            if trie_dict:
                for next_token in trie_dict:
                    yield from _traverse(prefix_sequence + [next_token], trie_dict[next_token])
            else:
                yield prefix_sequence

        return _traverse([], self.trie_dict)

    def __getitem__(self, value):
        return self.get(value)

In [10]:
def construct_trie(relations: List[str], tokenizer):
    '''
    This function builds a Trie for the list of relations and the tokenizer given in input
    For each relation, add the list of relevant token ids to the trie
    Be careful: with the model you're using, every generation starts with the token <pad> (token id 0)
    '''
    #YOUR CODE GOES HERE
    trie = Trie()
    for relation in relations:
        # Remove '<' and '>' before encoding
        relation_text = relation[1:-1]  # Extract the relation text
        relation_tokens = tokenizer.encode(relation_text, add_special_tokens=True)
        trie.add([0] + relation_tokens)  # Add to the Trie

    return trie

In [11]:
def test_tokenizer(word, tokenizer):
    '''
    A small test function for you to play with
    Use it to understand how the tokenizer works, and build your Trie accordingly
    '''
    tokens = tokenizer.encode(word)
    print(tokens)
    print([tokenizer.decode([t]) for t in tokens])

test_tokenizer("friendship", tokenizer)

[9888, 1]
['friendship', '</s>']


In [12]:
def run_model(prompt, trie):
  """
  Runs the language model using constrained decoding
  Input:
  - Prompt (str)
  - prefix_allowed_tokens_fn
  """
  device = "cuda" if torch.cuda.is_available() else "cpu"
  input_text = prompt
  inputs = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

  outputs = model.generate(inputs, max_new_tokens=20, do_sample=False, num_beams=5, temperature=None, top_p=None, pad_token_id=tokenizer.eos_token_id,
        prefix_allowed_tokens_fn=lambda _, input_ids: prefix_allowed_tokens_fn(input_ids, trie))
  return(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [13]:
def classify_relations(title_entity, page_content, trie, prompt_template):
    """ Disambiguates the entity name based on the Wikipedia article
    Returns a list of triples (title_entity, object_entity, relation) or (title_entity, object_entity, None)
    In the Wikipedia article, the title entity and the object entities are surrounded by "<" and ">"
    """
    #YOUR CODE GOES HERE
    results = []
    object_entities = re.findall(r'<([^>]+)>', page_content)

    for object_entity in object_entities[1:]:
      if object_entity != title_entity:
        prompt = prompt_template.format(
            page_content=page_content,
            title_entity=title_entity,
            object_entity=object_entity
        )
        relation = run_model(prompt, trie)
        formatted_relation = relation.replace(' ', '_')
        results.append((title_entity, object_entity, formatted_relation))
    return results

In [21]:
prompt_template = """
In the Wikipedia article: '{page_content}', what is the relationship between {title_entity} and {object_entity}?
Be carefull to the context and the type of the entity designed (people, institution, date, object, ...)
"""

In [26]:
prompt_template = """
  In the Wikipedia article: '{page_content}', what is the relationship between {title_entity} and {object_entity}?

  Some examples (not exhaustive list) of relationships are:
  - birthPlace: Where a person was born.
  - deathPlace: Where a person died.
  - birthDate: When a person was born. DO NOT confuse it with dateCreated, like for an object or non-living entity.
  - memberOf: An organization a person is a member of.
  - nationality: The country a person is a citizen of.

  Be careful to the context and the type of the entity designed (people, institution, date, object, ...).
  """

In [27]:
def run(prompt_template, root=''):
  relations = get_all_relations(f"{root}student-gold-standard.tsv")
  trie = construct_trie(relations, tokenizer) #function to complete
  with open(f"{root}results.tsv", 'a', encoding="utf-8") as output:
    start = time.time()
    for i, page in tqdm(enumerate(wikipediaArticles(f"{root}wikipedia-ner.txt")), total=len(list(wikipediaArticles("Lab2/wikipedia-ner.txt"))), desc='Processing pages'):
      #print("  Processing",page.title_entity, i)
      result = classify_relations(page.title_entity, page.content, trie) #function to complete
      if result is not None:
        #print(result)
        for subj, obj, rel in result:
          output.write(f"{subj}\t<{obj}>\t{clean(rel)}\n")
  end = time.time()
  print("done")
  print("execution time: ", end - start)
  print("number of articles: ", i)
run(prompt_template=prompt_template, root="Lab2/")

Processing pages:   0%|          | 0/400 [00:00<?, ?it/s]

done
execution time:  215.71366095542908
number of articles:  399


In [28]:
run_evaluation(root='Lab2/')

You classified the relation between <Ashok_Kumar__u0028_Indian_politician_u0029_> and <1954-10-27T00:00:00Z> wrong. Expected output:  <birthDate> ,given: <dateCreated>
You classified the relation between <Ashok_Kumar__u0028_Indian_politician_u0029_> and <Indian_National_Congress> wrong. Expected output:  <memberOf> ,given: <owns>
You classified the relation between <Ashok_Kumar__u0028_Indian_politician_u0029_> and <Patna_Medical_College_and_Hospital> wrong. Expected output:  <alumniOf> ,given: <owns>
You classified the relation between <Ashok_Kumar__u0028_golfer_u0029_> and <1983-07-20T00:00:00Z> wrong. Expected output:  <birthDate> ,given: <dateCreated>
You classified the relation between <Ashok_Kumar__u0028_golfer_u0029_> and <Bihar> wrong. Expected output:  <birthPlace> ,given: <location>
You classified the relation between <Ashok_Kumar_Dogra> and <1958-11-24T00:00:00Z> wrong. Expected output:  <birthDate> ,given: <dateCreated>
You classified the relation between <Ashok_Kumar_Dogra>