In [None]:
import dspy
from dspy import Example
from dspy.teleprompt import BootstrapFewShot
from dspy.predict.retry import Retry
from dspy.datasets import QuoteSum
from dsp.utils import EM, normalize_text
from dsp.templates.utils import passages2text
from dspy.primitives.assertions import assert_transform_module, suggest_backtrack_handler

import re
import os
import nltk
nltk.download('punkt')
from sklearn.metrics import precision_recall_fscore_support
from nltk.tokenize import word_tokenize
from rouge import Rouge
import numpy as np
from collections import Counter

In [None]:
turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=1000)
dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)

In [None]:
import openai
openai.api_base = os.getenv('OPENAI_API_BASE')
openai.api_key = os.getenv('OPENAI_API_KEY')

In [None]:
dataset = QuoteSum(path='../../QuoteSum/v1/', train_seed=1, train_size=300, eval_seed=2023, dev_size=300, test_size=0)
trainset = [x.with_inputs('question', 'entries') for x in dataset.train]
devset = [x.with_inputs('question', 'entries') for x in dataset.dev]

In [None]:
#SEMQA formatting helper functions
def build_context_string(data_entry):
    passages = []
    i = 1  
    while True:  
        title_key = f"title{i}"
        source_key = f"source{i}"
        if title_key in data_entry and source_key in data_entry and data_entry[title_key] and data_entry[source_key]:
            passage = f"{data_entry[title_key]}: {data_entry[source_key]}"
            passages.append(passage)
        else:
            break  
        i += 1  
    return passages2text(passages)

def process_into_passages(text):
    passages = re.split(r'\[\d+\]', text)
    passages = [passage.strip() for passage in passages if passage.strip()]
    modified_passages = []
    for i, passage in enumerate(passages, start=1):
        match = re.search(r'«.*?:\s*(.*)»', passage)
        if match:
            content = match.group(1)
            modified_passage = f"[{i}] {content}"
        else:
            modified_passage = f"[{i}] {passage}"
        modified_passages.append(modified_passage)

    return modified_passages

def format_quoted_bullets(quoted_bullets_dict):
    formatted_bullets = []
    for key, bullets in quoted_bullets_dict.items():
        for bullet in bullets:
            bullet = bullet.lstrip('- ').strip()
            formatted_bullet = f"[ {key} {bullet} ]"
            formatted_bullets.append(formatted_bullet)
    return ' '.join(formatted_bullets)

def count_unique_sources(data):
    count = 0
    while True:
        title_key = f"title{count+1}"
        source_key = f"source{count+1}"
        if title_key in data and source_key in data and data[title_key] and data[source_key]:
            count += 1
        else:
            break
    return count

def contains_marked_tokens(answer):
    pattern = r'\[\s*\d+\s+[^\]]+\]'
    match = re.search(pattern, answer)
    return bool(match)

def tokenize(text):
    return word_tokenize(text.lower())

def extract_marked_tokens(answer):
    pattern = r'\[\s*\d+\s*([^\]]+)\]'
    clean_answer = re.sub(pattern, r'\1', answer)
    return clean_answer

def extract_marked_tokens_for_source(answer, source_number):
    pattern = r'\[\s*{}\s*([^\]]+)\]'.format(source_number)
    matches = re.findall(pattern, answer)
    return ' '.join(match for match in matches)

In [None]:
#Metrics helper functions
def token_f1_score(ref_tokens, gen_tokens):
    ref_token_count = Counter(ref_tokens)
    gen_token_count = Counter(gen_tokens)
    true_positives = ref_token_count & gen_token_count
    true_positive_count = sum(true_positives.values())
    if true_positive_count == 0:
        return 0, 0, 0
    precision = true_positive_count / sum(gen_token_count.values())
    recall = true_positive_count / sum(ref_token_count.values())
    f1 = (2 * precision * recall) / (precision + recall)    
    return precision, recall, f1


def token_recall(ref_tokens, gen_tokens):
    ref_token_count = Counter(ref_tokens)
    gen_token_count = Counter(gen_tokens)
    true_positives = ref_token_count & gen_token_count
    true_positive_count = sum(true_positives.values())
    recall = true_positive_count / sum(ref_token_count.values()) if ref_token_count else 0
    return recall

In [None]:
#Metrics functions
def fluency_metric(references, model_summary):
    rouge = Rouge()
    max_f_measure = 0
    max_precision = 0
    max_recall = 0
    for i in range(len(references)):
        ref_summary = references[i]['summary']
        gen_summary = model_summary.quoted_summary
        extracted_from_ref = extract_marked_tokens(ref_summary)
        extracted_from_gen = extract_marked_tokens(gen_summary)
        tokens_ref = tokenize(extracted_from_ref)
        tokens_gen = tokenize(extracted_from_gen)
        ref_summary_str = ' '.join(tokens_ref)
        gen_summary_str = ' '.join(tokens_gen)
        scores = rouge.get_scores(gen_summary_str, ref_summary_str)
        rouge_l_score = scores[0]['rouge-l']
        max_f_measure = max(max_f_measure, rouge_l_score['f'])
        max_precision = max(max_precision, rouge_l_score['p'])
        max_recall = max(max_recall, rouge_l_score['r'])
    return max_f_measure

def preciseness(references, model_summary):
    gen_summary = model_summary.quoted_summary
    source_f1_scores = []
    num_sources = count_unique_sources(references[0])
    for source_number in range(1, num_sources):
        max_f1_for_source = 0
        for reference in references:
            ref_summary = reference['summary']
            extracted_ref_tokens = tokenize(extract_marked_tokens_for_source(ref_summary, source_number))
            extracted_gen_tokens = tokenize(extract_marked_tokens_for_source(gen_summary, source_number))
            _, _, token_f1 = token_f1_score(extracted_ref_tokens, extracted_gen_tokens)
            max_f1_for_source = max(max_f1_for_source, token_f1)
        source_f1_scores.append(max_f1_for_source)
    average_max_f1_across_sources = np.mean(source_f1_scores)
    return average_max_f1_across_sources

def comprehensiveness(references, model_summary):
    source_recall_scores = []
    num_sources = count_unique_sources(references[0])
    for source_number in range(1, num_sources + 1):
        max_recall_for_source = 0
        gen_tokens_for_source = tokenize(extract_marked_tokens_for_source(model_summary.quoted_summary, source_number))
        for reference in references:
            short_answer_key = f"short_ans_{source_number}"
            if short_answer_key in reference:
                ref_short_answer_tokens = tokenize(reference[short_answer_key])
                recall = token_recall(ref_short_answer_tokens, gen_tokens_for_source)
                max_recall_for_source = max(max_recall_for_source, recall)
        source_recall_scores.append(max_recall_for_source)
    average_max_recall_across_sources = np.mean(source_recall_scores)
    return average_max_recall_across_sources

In [None]:
class GenerateQuotedText(dspy.Signature):
    """Extracts bullets of quoted text from passages to answer question."""

    passage = dspy.InputField(desc="structured passage text for analysis")
    question = dspy.InputField()
    extracted_quotes = dspy.OutputField(desc="bullets of relevant quoted text")


class FormatQuotedSummary(dspy.Signature):
    """Follow the exact formatting of the Quoted Bullets when using them while creating a cohesive summary paragraph to answer question. 
    Here is an example:
    Quoted Bullets: [ 1 The sky is blue ]. [ 2 The grass is green ].
    Question: What are some basic observations about the natural colors in our environment?
    Quoted Summary: One source states the [ 1 The sky is blue ] . Another source states [ 2 The grass is green ] .
    """

    # """Follow the exact formatting of the Quoted Bullets when using them while creating a cohesive summary paragraph to answer question. """

    quoted_bullets = dspy.InputField(desc="quoted bullets by source number")
    question = dspy.InputField()
    quoted_summary = dspy.OutputField(desc="quoted summary paragraph that is formatted")

In [None]:
class SEMQA(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_quoted_text = dspy.ChainOfThought(GenerateQuotedText)
        self.generate_formatted_quoted_summary = dspy.ChainOfThought(FormatQuotedSummary)
    
    def forward(self, question, entries):
        references = entries
        context_string = build_context_string(references[0])
        passages = process_into_passages(context_string)
        quoted_bullets_dict = {}
        for passage_index in range(len(passages)):
            quoted_bullet = self.generate_quoted_text(passage=passages[passage_index], question=question).extracted_quotes
            bullet_points = quoted_bullet.split('\n')
            quoted_bullets_dict[passage_index + 1] = bullet_points 
        formatted_string = format_quoted_bullets(quoted_bullets_dict)
        model_summary = self.generate_formatted_quoted_summary(quoted_bullets=formatted_string, question=question)
        pred = dspy.Prediction(quoted_summary=model_summary.quoted_summary)
        return pred, references

In [None]:
class SEMQA_Assertions(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_quoted_text = dspy.ChainOfThought(GenerateQuotedText)
        self.generate_formatted_quoted_summary = dspy.ChainOfThought(FormatQuotedSummary)
    
    def forward(self, question, entries):
        references = entries
        context_string = build_context_string(references[0])
        passages = process_into_passages(context_string)
        quoted_bullets_dict = {}
        for passage_index in range(len(passages)):
            quoted_bullet = self.generate_quoted_text(passage=passages[passage_index], question=question).extracted_quotes
            bullet_points = quoted_bullet.split('\n')
            quoted_bullets_dict[passage_index + 1] = bullet_points 
        formatted_string = format_quoted_bullets(quoted_bullets_dict)
        model_summary = self.generate_formatted_quoted_summary(quoted_bullets=formatted_string, question=question)
        pred = dspy.Prediction(quoted_summary=model_summary.quoted_summary)
        dspy.Suggest(contains_marked_tokens(pred.quoted_summary), f"Make the citation formatting is exactly in this format: '[ source_number text ].'.", target_module=FormatQuotedSummary)
        # num_sources = count_unique_sources(references[0])
        # for source_number in range(1, num_sources):
        #     passage = passages[source_number - 1]
        #     extracted_segments = extract_marked_tokens_for_source(pred.quoted_summary, source_number).split('  ')
        #     print('extracted_segments')
        #     print(extracted_segments)
        #     if extracted_segments == ['']:
        #         continue

        #     all_segments_matched = True
        #     for segment in extracted_segments:
        #         if segment not in passage:
        #             all_segments_matched = False
        #             break
        #     dspy.Suggest(all_segments_matched, f"Make sure the quoted summary cites exactly from the passage: '{passage}'.", target_module=FormatQuotedSummary)
        return pred, references

In [None]:
def evaluate(module):
    total_fluency_score = 0
    total_preciseness_score = 0
    total_comprehensiveness_score = 0
    total_semqa_score = 0
    num_examples = len(devset)
    for i in range(len(devset)):
        output, references = module(question = devset[i].question, entries = devset[i].entries)
        fluency_score = fluency_metric(references, output)
        preciseness_score = preciseness(references, output)
        comprehensiveness_score = comprehensiveness(references, output)
        semqa_score = (preciseness_score * fluency_score) ** 0.5
        total_fluency_score += fluency_score
        total_preciseness_score += preciseness_score
        total_comprehensiveness_score += comprehensiveness_score
        total_semqa_score += semqa_score
        print(f"Example {i}: Fluency: {fluency_score}, Preciseness: {preciseness_score}, "
            f"Comprehensiveness: {comprehensiveness_score}, SEM-QA: {semqa_score}")

    average_fluency_score = total_fluency_score / num_examples
    average_preciseness_score = total_preciseness_score / num_examples
    average_comprehensiveness_score = total_comprehensiveness_score / num_examples
    average_semqa_score = total_semqa_score / num_examples
    print(f"Average Fluency: {average_fluency_score}")
    print(f"Average Preciseness: {average_preciseness_score}")
    print(f"Average Comprehensiveness: {average_comprehensiveness_score}")
    print(f"Average SEM-QA: {average_semqa_score}")

In [None]:
#No Compilation + No Assertion
semqa = SEMQA()
evaluate(semqa)

In [None]:
asserted_SEMQA = assert_transform_module(SEMQA_Assertions().map_named_predictors(Retry), suggest_backtrack_handler)
evaluate(asserted_SEMQA)