In [None]:
import LLMUtils
import Datasets

import json
import textwrap
import nltk
import os

import pandas as pd
import more_itertools as mit
import matplotlib.pyplot as plt

from enum import Enum
from collections import defaultdict
from tqdm import tqdm

nltk.download('punkt')
nltk.download('punkt_tab')

In [None]:
HUGGINGFACE_TOKEN = "" # Needed for restricted LLMs with EULAs (Llama...)
EVALUATE_ON_PAPERS = True # False: DocRED

# Load the model

In [None]:
LLM = LLMUtils.LLM(LLMUtils.LLM.GEMMA_9B, hf_token=hf_token) # In the publication: Gemma-2-9B, Llama-3.1-8B, Phi-3-3B

# Load the papers dataset

In [None]:
with open('papers.json') as f:
    papers = json.load(f)

# Auxiliary functions

In [None]:
if EVALUATE_ON_PAPERS:
    webnlg_dataset = Datasets.WebNLGDataset() # Used for in-context samples
    main_dataset = Datasets.WebNLGDataset() # Used for generating triples, either WebNLG (papers) or DocRED
else:
    webnlg_dataset = Datasets.WebNLGDataset()
    main_dataset = Datasets.DocREDDataset()

In [None]:
# Possible prompt strategies to do
# PREVIOUS_SENTENCES: Split the text into chunks of n sentences,
#                     with an overlap of m sentences which will
#                     act as the context. n and m can be any
#                     value (e.g. n=1, m=0 traverses it sentence
#                     by sentence with no context). The iterator
#                     will adjust the overlap in the first sentence(s)
#                     where there may not be enough preceding ones
#
# SECTION_CONTENTS: Generate triples for the whole section at once
PromptStrategy = Enum('PromptStrategy', ['PREVIOUS_SENTENCES', 'SECTION_CONTENTS'])

# Auxiliary functions for post-processing and testing
The functions below will perform triple merging when applicable

### Triple cleanup

In [None]:
def is_valid_triple(t):
    return isinstance(t, tuple) and len(t) == 3


def clean_triple(t):
    s, p, o = t
    if not isinstance(s, str):
        s = str(s)
    if not isinstance(p, str):
        p = str(p)
    if not isinstance(o, str):
        o = str(o)

    return (s, p, o)


def clean_and_get_triples_stats(paper_triples):
    bad_triples = []
    total_triples = 0
    unique_s = defaultdict(int)
    unique_p = defaultdict(int)
    unique_o = defaultdict(int)

    for i, (sentence, sentence_triples) in enumerate(paper_triples):
        total_triples += len(sentence_triples)

        for triple in sentence_triples[:]:
            if is_valid_triple(triple):
                s, p, o = clean_triple(triple)
                unique_s[s] += 1
                unique_p[p] += 1
                unique_o[o] += 1
            else:
                bad_triples.append(triple)
                sentence_triples.remove(triple)

    return [triple for (_, sentence_triples) in paper_triples for triple in sentence_triples], total_triples, len(unique_s), len(unique_p), len(unique_o), len(bad_triples)

# Test the context strategies

In [None]:
def get_triples_for_sentences(context, sentences, n_samples, LLM, webnlg_dataset):
    """
    Prompts and returns the triples from the LLM
    """
    system_prompt, user_prompt = webnlg_dataset.get_prompts(context, sentences, n_samples, avoid_explanations=True)

    return LLM.get_triples(system_prompt, user_prompt, allow_bad_triples=True)



def get_triples_from_papers(papers,
                            n_samples,
                            sentences_per_prompt,
                            overlap,
                            prompt_strategy : PromptStrategy,
                            LLM,
                            main_dataset):
    """
    Given a JSON response containing paper abstracts from https://api.plos.org, 
    returns a dict of paper ID (DOI) -> list of triples
    """
    paper_triples = dict()
    
    for i, paper in enumerate(pbar := tqdm(papers["response"]["docs"], file=open("progress.log", "w"))):
        paper_id = paper["id"]
        paper_triples[paper_id] = []

        abstract_sentences = nltk.sent_tokenize(paper["abstract"][0])

        # Group the paper sentences into tuples of sentences_per_prompt sentences, with the
        # desired overlapping. It will also add a context depending on the strategy being used
        sentence_chunks = []
        if prompt_strategy == PromptStrategy.PREVIOUS_SENTENCES:
            # And prepare the next chunk to have a correct starting overlap
            if overlap > 0:
                # Add the first group of sentences manually, as they will have no context
                sentence_chunks.append(("There is no context for this sample", " ".join(abstract_sentences[:sentences_per_prompt])))
                # And prepare for the next iterations
                abstract_sentences = abstract_sentences[sentences_per_prompt - overlap:]

            for sentences_chunk in list(mit.windowed(abstract_sentences, n=sentences_per_prompt, step=sentences_per_prompt-overlap)):
                sentence_chunks.append((" ".join(filter(None, sentences_chunk[:overlap])), " ".join(filter(None, sentences_chunk[overlap:]))))

        elif prompt_strategy == PromptStrategy.SECTION_CONTENTS:
            sentence_chunks = [(None, paper["abstract"])]

        for j, (context, sentences_chunk) in enumerate(sentence_chunks):
            pbar.set_description(f"sentences_per_prompt: {sentences_per_prompt}, strategy: {prompt_strategy}, overlap: {overlap}. Generating triples for paper {i+1}/{len(papers["response"]["docs"])}, sentence chunk {j+1}/{len(sentence_chunks)}")

            sentence_triples = get_triples_for_sentences(context, sentences_chunk, n_samples, LLM, main_dataset)
            paper_triples[paper_id] += [(sentences_chunk, sentence_triples)]

    return paper_triples

def get_triples_from_docred(n_samples,
                            sentences_per_prompt,
                            overlap,
                            prompt_strategy : PromptStrategy,
                            LLM,
                            main_dataset, 
                            webnlg_dataset):
    """
    Given a docRED dataset parsed from https://huggingface.co/datasets/thunlp/docred, return a dict of text -> list of triples
    """
    docred_triples = dict()

    # First 100 validation samples, which contain both text and ground truth triples (to evaluate later on)
    for i, docred_sample in enumerate(pbar := tqdm(main_dataset.docred["validation"].select(range(100)), file=open("progress.log", "w"))):
        text, _ = main_dataset.get_text_and_triples(docred_sample)

        sample_id = text
        docred_triples[sample_id] = []

        sample_sentences = nltk.sent_tokenize(text)

        # Group the paper sentences into tuples of sentences_per_prompt sentences, with the
        # desired overlapping. It will also add a context depending on the strategy being used
        sentence_chunks = []
        if prompt_strategy == PromptStrategy.PREVIOUS_SENTENCES:
            # And prepare the next chunk to have a correct starting overlap
            if overlap > 0:
                # Add the first group of sentences manually, as they will have no context
                sentence_chunks.append(("There is no context for this sample", " ".join(sample_sentences[:sentences_per_prompt])))
                # And prepare for the next iterations
                abstract_sentences = sample_sentences[sentences_per_prompt - overlap:]

            for sentences_chunk in list(mit.windowed(sample_sentences, n=sentences_per_prompt, step=sentences_per_prompt-overlap)):
                sentence_chunks.append((" ".join(filter(None, sentences_chunk[:overlap])), " ".join(filter(None, sentences_chunk[overlap:]))))

        elif prompt_strategy == PromptStrategy.SECTION_CONTENTS:
            sentence_chunks = [(None, text)]

        for j, (context, sentences_chunk) in enumerate(sentence_chunks):
            pbar.set_description(f"sentences_per_prompt: {sentences_per_prompt}, strategy: {prompt_strategy}, overlap: {overlap}. Generating triples for paper {i+1}/100, sentence chunk {j+1}/{len(sentence_chunks)}")

            sentence_triples = get_triples_for_sentences(context, sentences_chunk, n_samples, LLM, webnlg_dataset)
            docred_triples[sample_id] += [(sentences_chunk, sentence_triples)]

    return docred_triples

In [None]:
def run_experiments(papers,
                    n_samples,
                    sentences_per_prompt,
                    overlap,
                    prompt_strategy,
                    LLM,
                    main_dataset,
                    webnlg_dataset):
    if isinstance(main_dataset, Datasets.WebNLGDataset):
        paper_triples = get_triples_from_papers(papers,
                                                n_samples,
                                                sentences_per_prompt,
                                                overlap,
                                                prompt_strategy,
                                                LLM,
                                                main_dataset)
        results = {
           "n_samples": [],
           "prompt_strategy": [],
           "sentences_per_prompt": [],
           "overlap": [],

           "paper_id": [],
           "clean_triples": [],
           "total_triples": [],
           "unique_s": [],
           "unique_p": [],
           "unique_o": [],
           "bad_triples": [],
        }

        for paper_id, triples in paper_triples.items():
            clean_triples, total_triples, unique_s, unique_p, unique_o, bad_triples = clean_and_get_triples_stats(triples)

            results["n_samples"].append(n_samples)
            results["prompt_strategy"].append(prompt_strategy)
            results["sentences_per_prompt"].append(sentences_per_prompt)
            results["overlap"].append(overlap)

            results["paper_id"].append(paper_id)
            results["clean_triples"].append(clean_triples)
            results["total_triples"].append(total_triples)
            results["unique_s"].append(unique_s)
            results["unique_p"].append(unique_p)
            results["unique_o"].append(unique_o)
            results["bad_triples"].append(bad_triples)
    else:
        docred_triples = get_triples_from_docred(n_samples,
                                                 sentences_per_prompt,
                                                 overlap,
                                                 prompt_strategy,
                                                 LLM,
                                                 main_dataset,
                                                 webnlg_dataset)

        results = {
           "n_samples": [],
           "prompt_strategy": [],
           "sentences_per_prompt": [],
           "overlap": [],

           "text": [],
           "clean_triples": [],
           "total_triples": [],
           "unique_s": [],
           "unique_p": [],
           "unique_o": [],
           "bad_triples": [],
        }

        for text, triples in docred_triples.items():
            clean_triples, total_triples, unique_s, unique_p, unique_o, bad_triples = clean_and_get_triples_stats(triples)

            results["n_samples"].append(n_samples)
            results["prompt_strategy"].append(prompt_strategy)
            results["sentences_per_prompt"].append(sentences_per_prompt)
            results["overlap"].append(overlap)

            results["text"].append(text)
            results["clean_triples"].append(clean_triples)
            results["total_triples"].append(total_triples)
            results["unique_s"].append(unique_s)
            results["unique_p"].append(unique_p)
            results["unique_o"].append(unique_o)
            results["bad_triples"].append(bad_triples)

    return results

## Run the experiments

### Check variability across strategies and parameters
Depending on the number of sentences we ask the triples for in a given prompt, the context length and the context strategies themselves, the amount of triples generated can vary a lot:

In [None]:
if os.path.exists("paper_triples_results.csv"): # Resume the experiments
    results = pd.read_csv("paper_triples_results.csv")
else:
    if isinstance(main_dataset, Datasets.WebNLGDataset):
        results = pd.DataFrame({
            "n_samples": [],
            "prompt_strategy": [],
            "sentences_per_prompt": [],
            "overlap": [],

            "paper_id": [],
            "clean_triples": [],
            "total_triples": [],
            "unique_s": [],
            "unique_p": [],
            "unique_o": [],
            "bad_triples": []
        })
    else:
        results = pd.DataFrame({
            "n_samples": [],
            "prompt_strategy": [],
            "sentences_per_prompt": [],
            "overlap": [],

            "text": [],
            "clean_triples": [],
            "total_triples": [],
            "unique_s": [],
            "unique_p": [],
            "unique_o": [],
            "bad_triples": []
        })

In [None]:
def exists_result_with_config(n_samples, prompt_strategy, sentences_per_prompt, overlap):
    return (
        (results["n_samples"] == float(n_samples)) &
        (results["prompt_strategy"] == str(prompt_strategy)) &
        (results["sentences_per_prompt"] == float(sentences_per_prompt)) &
        (results["overlap"] == float(overlap))
    ).any()

In [None]:
n_samples = 8 # WebNLG samples to use
max_sentences_per_prompt = 10

In [None]:
if not exists_result_with_config(n_samples, PromptStrategy.SECTION_CONTENTS, 0, 0):
    results_contents = run_experiments(papers,
                                       n_samples,
                                       0,
                                       0,  # We don't care about overlap in this case
                                       PromptStrategy.SECTION_CONTENTS,
                                       LLM,
                                       main_dataset,
                                       webnlg_dataset)
    results = pd.concat([results, pd.DataFrame(results_contents)], ignore_index = True)
    results.to_csv("paper_triples_results.csv", index=False)
else:
    print("Test Skipped:", n_samples, PromptStrategy.SECTION_CONTENTS, 0, 0)

for i in range(max_sentences_per_prompt): # From 1 to max_sentences_per_prompt sentences at once
    sentences_per_prompt = i+1

    if not exists_result_with_config(n_samples, PromptStrategy.PREVIOUS_SENTENCES, sentences_per_prompt, 0):
        results_sentences_no_overlap = run_experiments(papers,
                                                       n_samples,
                                                       sentences_per_prompt,
                                                       0, # We don't care about overlap in this case
                                                       PromptStrategy.PREVIOUS_SENTENCES,
                                                       LLM,
                                                       main_dataset,
                                                       webnlg_dataset)

        results = pd.concat([results, pd.DataFrame(results_sentences_no_overlap)], ignore_index = True)
        results.to_csv("paper_triples_results.csv", index=False)
    else:
        print("Test Skipped:", n_samples, PromptStrategy.PREVIOUS_SENTENCES, sentences_per_prompt, 0)

    for j in range(max_sentences_per_prompt): # From 1 to max_sentences_per_prompt context sentences
        overlap = j+1
        if overlap >= sentences_per_prompt:
            continue

        if not exists_result_with_config(n_samples, PromptStrategy.PREVIOUS_SENTENCES, sentences_per_prompt, overlap):
            results_sentences = run_experiments(papers,
                                                n_samples,
                                                sentences_per_prompt,
                                                overlap,
                                                PromptStrategy.PREVIOUS_SENTENCES,
                                                LLM,
                                                main_dataset,
                                                webnlg_dataset)
            results = pd.concat([results, pd.DataFrame(results_sentences)], ignore_index = True)
            results.to_csv("paper_triples_results.csv", index=False)
        else:
            print("Test Skipped:", n_samples, PromptStrategy.PREVIOUS_SENTENCES, sentences_per_prompt, overlap)

In [None]:
results = pd.read_csv("paper_triples_results.csv")

## Save the averaged results

In [None]:
if isinstance(main_dataset, Datasets.WebNLGDataset):
    results = results.drop(columns='paper_id')
else:
    results = results.drop(columns='text')
results = results.drop(columns='clean_triples')
results = results.drop(columns='n_samples')

results = results.groupby(['prompt_strategy',
                           'sentences_per_prompt',
                           'overlap']).mean().reset_index()

In [None]:
results.to_csv("paper_triples_results_clean.csv", index=False)