In [None]:
import time, datetime, json, os
from tqdm.notebook import tqdm
from collections import defaultdict, Counter
import numpy as np
from nltk import sent_tokenize
import matplotlib.pyplot as plt
import torch
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
from transformers import AutoTokenizer

from index_files import LongDoc, write_json, QualityDataset, NarrativeQADataset, ReadingAgent, read_json, read_jsonline, LLM, Retriever, RetrieverOutput, GritLM, ChunkInfo

In [None]:
# gritlm = GritLM("GritLM/GritLM-7B", device_map="auto", torch_dtype="auto")
retriever = Retriever()
llm = LLM()
# llm = 'mistralai/Mistral-7B-Instruct-v0.2'
# llm = None

# Experiment

In [None]:
dataset = NarrativeQADataset(llm)

In [None]:
dataset = QualityDataset(llm, split='dev')

In [None]:
reading_agent = ReadingAgent(dataset, llm)

## Index passages

In [None]:
from typing import List, Tuple, Set, Dict
import itertools
import networkx as nx
from prompt import *

def match_entities(target_ents:List[str], refer_ents:List[str]):
    target_ents_emb = retriever.embed_paragraphs(target_ents, True)
    refer_ents_emb = retriever.embed_paragraphs(refer_ents, True)
    sim_mat:np.ndarray = np.matmul(target_ents_emb, refer_ents_emb.T)
    ent_map:Dict[str, str] = {}
    for eid, ent in enumerate(target_ents):
        max_idx = sim_mat[eid].argmax()
        if sim_mat[eid, max_idx] > 0.8:
            ent_map[ent] = refer_ents[max_idx]
    return ent_map
    
def parse_entities(responses:List[str]):
    ent_lists:List[str] = []
    ent_cnt = Counter()
    ent_cnt_threshold = len(responses) // 2 + 1
    for response in responses:
        i = 1
        temp_ents = []
        for line in response.splitlines():
            if line.startswith(f'{i}. '):
                temp_ents.append(line.split(' ', 1)[1].strip().strip('.'))
                i += 1
        ent_lists.append(temp_ents)
        ent_cnt.update(temp_ents)
    g = nx.Graph()
    for list1, list2 in itertools.combinations(ent_lists, 2):
        g.add_edges_from(match_entities(list1, list2).items())
    ent_cluster:Set[str]
    rep_cnt:Dict[str, int] = {}
    for ent_cluster in nx.connected_components(g):
        cnts = [(ent_cnt[ent], ent) for ent in ent_cluster]
        cnts.sort(key=lambda x: x[0], reverse=True)
        rep_cnt[cnts[0][1]] = sum([cnt for cnt, _ in cnts])
    return [rep for rep, cnt in rep_cnt.items() if cnt >= ent_cnt_threshold]
    
def parse_ent_description(response:str, important_ents:List[str]):
    description_dict:Dict[str, str] = {}
    for line in response.splitlines():
        if line:
            ent, description = line.split(': ', 1)
            ent = ent.strip()
            description = description.strip()
            if ent in important_ents:
                description_dict[ent] = description
    return description_dict

def index_text(paragraphs:List[str], w_note:bool=True, r_num:int=1):
    results:List[ChunkInfo] = []
    relation_graph = nx.Graph()
    for paragraph in tqdm(paragraphs):
        if results and w_note:
            summary_recap = {pid-r_num: ci.summary for pid, ci in enumerate(results[-r_num:])}
            summary_recap_str = '\n\n'.join([f'Passage {pid}\n{summary}' for pid, summary in summary_recap.items()])
            list_entity_prompt = LongDocPrompt.list_entity_w_note(summary_recap_str, paragraph)
        else:
            list_entity_prompt = LongDocPrompt.list_entity(paragraph)
        # Extract important entities
        chat_response = llm(list_entity_prompt, 5, 0.7)[0]
        important_ents = parse_entities(chat_response)
        
        # Generate entity description, summary, relation description
        important_ents_str = '\n'.join(important_ents)
        if results and w_note:
            # ent_description_recap:Dict[str, Dict[int, str]] = {}
            # relation_description_recap:Dict[int, List[Tuple[List[str], str]]] = {}
            # recaps = []
            # for rid, result in enumerate(results[-r_num:]):
            #     prev_description_dict:Dict[str, str] = result['description_dict']
            #     match_dict = match_entities(important_ents, list(prev_description_dict.keys()))
            #     prev_description = '\n'.join([f'{ent}: {prev_description_dict[ent]}' for _, ent in match_dict.items()])
            #     recap = f'Passage {rid - r_num}:\nEntity descriptions:\n{prev_description}\nSummary:\n{result["shorten"]}'
            #     recaps.append(recap)
            # recap_str = '\n\n'.join(recaps)
            pass
        else:
            ent_description_prompt = LongDocPrompt.ent_description(paragraph, important_ents_str, important_ents[0], important_ents[1])
            summary_prompt = LongDocPrompt.shorten(paragraph)
            relation_description_prompt = LongDocPrompt.relation_description(paragraph, important_ents_str)
        
        ent_description, relation_description, summary = llm([ent_description_prompt, summary_prompt, relation_description_prompt])
        ent_description, relation_description, summary = ent_description[0], relation_description[0], summary[0]
        ent_description_dict = parse_ent_description(ent_description, important_ents)
        results.append(ChunkInfo(paragraph, summary, important_ents, ent_description_dict, relation_description))
        
        # if len(results):
        #     prompt_ent_description = prompt_ent_description_w_note_template.format(recap=recap_str, paragraph=paragraph, context_type=context_type, important_ents_str=important_ents_str, important_ents_0=important_ents[0], important_ents_1=important_ents[1])
        #     prompt_shorten = prompt_shorten_w_note_template.format(recap_str, paragraph)
        #     prompt_relation_description = prompt_relation_description_w_note_template.format(recap=recap_str, paragraph=paragraph, context_type=context_type, important_ents_str=important_ents_str)
            
        #     ent_description, relation_description, shorten = llm([prompt_ent_description, prompt_relation_description, prompt_shorten])
        #     ent_description, relation_description, shorten = ent_description[0], relation_description[0], shorten[0]
        #     description_dict = {}
        #     for line in ent_description.splitlines():
        #         if line:
        #             ent, description = line.split(': ', 1)
        #             ent = ent.strip()
        #             description = description.strip()
        #             if ent in important_ents:
        #                 description_dict[ent] = description
        
            
    return results

In [None]:
paragraphs = ['\n'.join(p) for p in read_json(os.path.join(dataset.data_dir, f'pages_{1}.json'))]

In [None]:
test_results, results = index_text(paragraphs)
write_json('results.json', results)
write_json('test_results.json', test_results)

In [None]:
results = read_json('results.json')
test_results = read_json('test_results.json')

In [None]:
results[0].keys()

In [None]:
test_pid = 12
# sent_tokenize(results[test_pid]['shorten'])
# results[test_pid]['description_dict']
print(results[test_pid]['prompt_shorten'])


In [None]:
sent_tokenize(test_results[test_pid]['shorten'])
# test_results[test_pid]['description_dict']

In [None]:
print(results[test_pid]['paragraph'])

## DPR

In [None]:
def split_sents(retriever_tokenizer:AutoTokenizer, p_input_ids:np.ndarray, is_contriever:bool):
    sents = sent_tokenize(retriever_tokenizer.decode(p_input_ids[1:-1] if is_contriever else p_input_ids))
    sent_lens = [len(retriever_tokenizer.encode(sent)) - 2 for sent in sents]
    sent_start = 1 if is_contriever else 0
    sent_spans = []
    for sid in range(len(sents)):
        sent_end = sent_start + sent_lens[sid]
        while len(retriever_tokenizer.decode(p_input_ids[sent_start:sent_end]).strip()) < len(sents[sid]):
            sent_end += 1
        sent_spans.append((sent_start, sent_end))
        sent_start = sent_end
    return sent_spans

def important_page_tokens(retriever_tokenizer:AutoTokenizer, question:str, pages, q_lhs:np.ndarray, q_input_ids:np.ndarray, q_emb, p_lhs, p_input_ids, pids, scores):
    print(question)
    for i in range(q_lhs.shape[0]):
        print(retriever_tokenizer.decode(q_input_ids[i]), np.linalg.norm(q_lhs[i]))
    print('\n')
    for rank, (pid, score) in enumerate(zip(pids, scores)):
        print(f'Rank {rank}\nPassage {pid}:\n{score}\n')
        print(pages[pid])
        token_scores = p_lhs[pid].dot(q_emb)
        max_indices = np.argsort(token_scores)[::-1][:(token_scores>2).sum()].tolist()
        print('\n\nHigh scored spans:\n')
        for idx in max_indices:
            print(token_scores[idx], f'<{retriever_tokenizer.decode(p_input_ids[pid][max(0, idx - 1): idx + 1])}>', retriever_tokenizer.decode(p_input_ids[pid][max(0, idx - 5): idx + 5]))
        print('\n\n')
        
def query_indicatiors(retriever_tokenizer:AutoTokenizer, question:str, pages, q_lhs:np.ndarray, q_input_ids:np.ndarray, p_lhs, p_input_ids, pids, scores):
    print(question)
    for i in range(q_lhs.shape[0]):
        print(retriever_tokenizer.decode(q_input_ids[i]), np.linalg.norm(q_lhs[i]))
    print('\n')
    for rank, (pid, score) in enumerate(zip(pids, scores)):
        print(f'Rank {rank}\nPassage {pid}:\n{score}\n')
        print(pages[pid])
        print('\n\nHigh scored spans:\n')
        q_token_scores = np.matmul(q_lhs, p_lhs[pid].T)
        x = [retriever_tokenizer.decode(q_token) for q_token in q_input_ids]
        y = [token_scores.mean() for token_scores in q_token_scores]
        x.reverse()
        y.reverse()
        plt.barh(x, y)
        plt.show()
        for q_token, token_scores in zip(q_input_ids, q_token_scores):
            max_indices = np.argsort(token_scores)[::-1].tolist()[:10]
            print(token_scores.mean(), f'<{retriever_tokenizer.decode(q_token)}>', *[(token_scores[idx], f'<{retriever_tokenizer.decode(p_input_ids[pid][max(0, idx - 5): idx + 1])}>') for idx in max_indices])
        print('\n\n')
        
def query_indicator_sents(retriever_tokenizer:AutoTokenizer, pages, q_lhs:np.ndarray, q_input_ids:np.ndarray, p_lhs, p_input_ids, test_pid, test_q_token_id:int, is_contriever:bool):
    print(retriever_tokenizer.decode(q_input_ids[test_q_token_id]), '\n')
    print(pages[test_pid], '\n')
    sent_spans = split_sents(retriever_tokenizer, p_input_ids[test_pid], is_contriever)
    scores = p_lhs[test_pid].dot(q_lhs[test_q_token_id])
    sent_scores = [(scores[sent_span[0]:sent_span[1]].mean(), retriever_tokenizer.decode(p_input_ids[test_pid][sent_span[0]:sent_span[1]])) for sent_span in sent_spans]
    sent_scores.sort(key=lambda x: x[0], reverse=True)
    for score, sent in sent_scores:
        print(score, sent)

In [None]:
test_i = 2
pages = ['\n'.join(p) for p in read_json(os.path.join(dataset.data_dir, f'pages_{test_i}.json'))]
questions, answers = dataset.get_questions_and_answers(dataset.data[test_i])
questions = [q.splitlines()[0] for q in questions]
questions

In [None]:
qid = 8

### Contriever

In [None]:
q_embedding = retriever.embed_paragraphs([questions[qid]], normalize=True, complete_return=True)
page_embeddings = retriever.embed_paragraphs(pages, normalize=True, complete_return=True)
c_q_emb, c_q_input_ids, c_q_lhs = q_embedding.embeddings, q_embedding.input_ids, q_embedding.last_hidden_states
c_p_emb, c_p_input_ids, c_p_lhs = page_embeddings.embeddings, page_embeddings.input_ids, page_embeddings.last_hidden_states
c_retriever_tokenizer = retriever.retriever_tokenizer
c_pids, c_scores = retriever.dense_retrieval(c_q_emb, c_p_emb, 10, normalize=False, return_score=True)

In [None]:
important_page_tokens(c_retriever_tokenizer, questions[qid], pages, c_q_lhs[0], c_q_input_ids[0], c_q_emb[0], c_p_lhs, c_p_input_ids, c_pids, c_scores)

In [None]:
query_indicatiors(c_retriever_tokenizer, questions[qid], pages, c_q_lhs[0], c_q_input_ids[0], c_p_lhs, c_p_input_ids, c_pids, c_scores)

In [None]:
test_q_token_id = 15
test_pid = 4
query_indicator_sents(c_retriever_tokenizer, pages, c_q_lhs[0], c_q_input_ids[0], c_p_lhs, c_p_input_ids, test_pid, test_q_token_id, True)

### GritLM

In [None]:
def gritlm_instruction(instruction):
    return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"

instruction = "Retrieve relevant passages from a story to answer a given question."
# No need to add instruction for retrieval documents
g_p_emb, g_p_input_ids, g_p_lhs = gritlm.encode(pages, instruction=gritlm_instruction(""), max_length=2048)
g_q_emb, g_q_input_ids, g_q_lhs = gritlm.encode([questions[qid]], instruction=gritlm_instruction(instruction))
g_retriever_tokenizer = gritlm.tokenizer
g_pids, g_scores = retriever.dense_retrieval(g_q_emb, g_p_emb, 10, normalize=False, return_score=True)

In [None]:
important_page_tokens(g_retriever_tokenizer, questions[qid], pages, g_q_lhs[0], g_q_input_ids[0], g_q_emb[0], g_p_lhs, g_p_input_ids, g_pids, g_scores)

In [None]:
query_indicatiors(g_retriever_tokenizer, questions[qid], pages, g_q_lhs[0], g_q_input_ids[0], g_p_lhs, g_p_input_ids, g_pids, g_scores)

In [None]:
test_q_token_id = 15
test_pid = 4
query_indicator_sents(g_retriever_tokenizer, pages, g_q_lhs[0], g_q_input_ids[0], g_p_lhs, g_p_input_ids, test_pid, test_q_token_id, False)

In [None]:
print(pages[3])