# Import

In [1]:
import pandas as pd
import re
import ast
import json
import statistics
import openai
from multiprocessing import Pool
openai.api_key = "YOUR API KEY"

In [2]:
import os
import sys
sys.path.insert(1, '../')
from utils.prompting import *
from utils.postprocessing import *
from utils.g_eval import *

# Read Predicted Summary

In [3]:
qqsum_output_path = "../output/atlas-xl-seed2-lgret-lglm"

In [4]:
ground_truth_df = pd.read_pickle("../data/test/test.pkl")
ground_truth_df = ground_truth_df.drop(columns=['reviews', 'reviewText', 'title', 'fact_check_article'])
ground_truth_df['retrieved_relevant_sent_len'] = ground_truth_df['retrieved_relevant_sent'].str.len()

predicted_df = pd.read_json(qqsum_output_path + "/test-result.jsonl", lines=True)
predicted_df['passages'] = predicted_df['passages'].apply(lambda x: [pas['text'] for pas in x])
predicted_df['passages_len'] = predicted_df['passages'].str.len()
predicted_df = predicted_df[['query', 'generation', 'passages', 'passages_scores', 'passages_len', 'comment_clusters', 'id']]
df = ground_truth_df.merge(predicted_df, on=['query', 'id'])

In [5]:
# df = df[df['passages'].str.len() > 0]
# df['passages_scores_min_thres_ref_ground'] = df.apply(lambda row: row['passages_scores'][:len(row['retrieved_relevant_sent'])][-1], axis=1)
# df['passages_scores_min_thres_ref_ground'].mean()

In [6]:
df['summary'] = df['generation'].apply(lambda x: re.findall("\[\/INST\] *((.+\n*)+)$", x)[0][0].replace("</s>", ""))
df['summary'] = df['summary'].apply(lambda x: re.sub("(Therefore|Thus)(.+\n*)+$", "", x))
df['final_summary'] = df['summary'].apply(lambda x: re.findall("(While[^\n]+\n+(\+ *[0-9]+[^\n]+\n*)+)", x))
df['final_summary'] = df['final_summary'].apply(lambda x: [e[0] for e in x])
df['final_summary_text'] = df['final_summary'].apply(lambda x: "\n\n".join(x[:1]))

## Post-process summary into KPs

### Trial

In [7]:
row = df.iloc[0]
print(row['id'])
get_kp_from_summary(row['final_summary_text'])

3609


'```json\n[{"key_point": "The old key components need to be taken to a locksmith or dealership to be reprogrammed and cut to fit the new key head.", "prevalence": "7"}]\n```'

### Run

In [8]:
from utils.postprocessing import *

In [9]:
df['my_category'] = 1
num_workers = 1

In [10]:
# inputs = [(qqsum_output_path + "/post_processed_cache",
inputs = [(qqsum_output_path + "/post_processed_cache/rd",
           domain,
           df[df['my_category'] == domain].reset_index(drop=True)
           )
          for domain in df['my_category'].unique()]

In [11]:
start_time = time.time()
with Pool(num_workers) as processor:
    data = processor.starmap(prompted_claim_split_generation, inputs)
print("TIME ELAPSED", time.time() - start_time)

1 :  Loaded saved file. Done
TIME ELAPSED 0.1891341209411621


In [12]:
processed_df = pd.concat(data)
processed_df = processed_df[processed_df['comment_clusters'].str.len() > 0]
processed_df = processed_df[processed_df['final_summary'].str.len() > 0]

In [13]:
processed_df['claim_split_predicted'] = processed_df['claim_split_predicted'].apply(lambda x: re.sub(r"\n+ *", "", x.replace("json", "").replace("`", "")))
mask = processed_df['claim_split_predicted'].apply(lambda x: len(re.findall(r"(: *)\'((?:[^':]*\'+[^':,]*)+)\'( *)", x, re.DOTALL)) > 0)
processed_df.loc[mask, 'claim_split_predicted'] = processed_df.loc[mask, 'claim_split_predicted'].apply(
    lambda x: re.sub(r"(: *)\'((?:[^':]*\'+[^':]*)+)\'( *,)", r'\1"""\2"""\3', x))
processed_df['claim_split_predicted'] = processed_df['claim_split_predicted'].apply(extract_claims)

In [14]:
processed_df = processed_df.reset_index(drop=True)

In [15]:
processed_df = processed_df.apply(match_claim_with_cluster, axis=1)

In [16]:
kp_matching_df = pd.json_normalize(
    processed_df.to_dict(orient='records'), 
    "matching_comment_clusters", ["asin", "id", "query", "passages", 'passages_len', 'final_summary_text']
)

In [17]:
kp_matching_df = kp_matching_df[kp_matching_df['prevalence'] >= 3]

# Factual Alignment

In [18]:
evaluation_df = kp_matching_df

## AlignScore

In [19]:
from alignscore import AlignScore

  return torch.cuda.amp.custom_fwd(orig_func)  # type: ignore
  return torch.cuda.amp.custom_bwd(orig_func)  # type: ignore


In [20]:
align_scorer = AlignScore(
    model='roberta-base', 
    batch_size=8,
    device='cuda:0',
    ckpt_path='./AlignScore/checkpoints/AlignScore-base.ckpt', 
    evaluation_mode='nli_sp'
)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return torch.load(f, map_location=map_location)  # type: ignore[arg-type]
Lightning automatically upgraded your loaded checkpoint from v1.7.7 to v1.9.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file AlignScore/checkpoints/AlignScore-base.ckpt`
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  rank_zero_warn(


In [21]:
evaluation_df['context'] = evaluation_df['comments'].apply(lambda x: " ".join(x))

In [22]:
results = align_scorer.score(
    contexts=evaluation_df['context'].tolist(),
    claims=evaluation_df['key_point'].tolist(),
)

Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:02<00:00, 68.45it/s]


In [23]:
evaluation_df['align_score'] = results

In [24]:
eval_results = evaluation_df.groupby(['id']).apply(lambda grp: grp['align_score'].mean()).reset_index()
eval_results = eval_results.rename(columns={0: 'precision'})
eval_results['precision'].mean()

  eval_results = evaluation_df.groupby(['id']).apply(lambda grp: grp['align_score'].mean()).reset_index()


0.7487860343402906