# MuSiQue(answerable & unanswerable)

In [1]:
from transformers import pipeline
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
# import streamlit as st

model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
triplet_extractor = pipeline('text2text-generation', model= model, tokenizer= tokenizer)

  from .autonotebook import tqdm as notebook_tqdm
Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [2]:
from llama_index.core import ServiceContext
from openai import OpenAI
from llama_index.core import Settings,KnowledgeGraphIndex
from llama_index.core import Document
from llama_index.core.node_parser import TokenTextSplitter
import json
from tqdm.auto import tqdm
import hashlib
import random
from langchain.text_splitter import RecursiveCharacterTextSplitter

In [3]:
# Function to parse the generated text and extract the triplets
# Rebel outputs a specific format. This code is mostly copied from the model card!

def extract_triplets(input_text):
    triples = []
    text = triplet_extractor.tokenizer.batch_decode([triplet_extractor(input_text, return_tensors=True, return_text=False)[0]["generated_token_ids"]])[0]
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and subject in input_text and relation != '' and relation in input_text and object_ != '' and object_ in input_text:
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    for t in triplets:
        triples.append((t['head'], t['type'], t['tail']))
    return triplets

In [4]:
def bert_len(text):
    tokens = tokenizer.encode(text)
    return len(tokens)

#用text提示后面才是文本部分
def generate_data_for_extraction(doc_id,node_id,corpus_data):
    data = []

    max1 = -1
    ans = ""
  
    m = hashlib.md5()
    # for dt in build_data:
    dt = str(corpus_data)
    m.update(dt.encode('utf-8'))
    uid = m.hexdigest()[:12]
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size = 400,
        chunk_overlap  = 40,
        length_function = bert_len,
        separators=['\n\n', '\n', ' ', ''],
    )
    chunks = text_splitter.split_text(dt)
    for i, chunk in enumerate(chunks):
        data.append({
            'doc_id':doc_id,
            # 'doc_id': f'{uid}',
            'node_id':node_id,
            'subnode_id': f'{i}',
            'text': chunk
        })
    return data

In [5]:
def generate_triples(texts):
    triples = []
    gen_kwargs = {
    "max_length": 256,
    "length_penalty": 0,
    "num_beams": 3,
    "num_return_sequences": 1,
}
    model_inputs = tokenizer(texts, max_length=512, padding=True, truncation=True, return_tensors='pt')
    generated_tokens = model.generate(
        model_inputs["input_ids"].to(model.device),
        attention_mask=model_inputs["attention_mask"].to(model.device),
        **gen_kwargs
    )
    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
    for idx, sentence in enumerate(decoded_preds):
        et = extract_triplets(sentence)
        for t in et:
            triples.append((t['head'], t['type'], t['tail']))
    
    return triples

def generate_doc_triples(data):
    doc_triplets = []
    
    for i in tqdm(range(0, len(data), 2)):
        triplets = {}
     
        triplets['doc_id'] = data[i]['doc_id']
        triplets['node_id'] = data[i]['node_id']
        triplets['triplets'] = generate_triples(data[i]['text'])
        doc_triplets.append(triplets)
      
    return doc_triplets




In [17]:
import jsonlines
file_path = "musique_full_v1.0_dev.jsonl"
corpus_file = 'data/musique/corpus.json'

def prepare_corpus():
    corpus = str()
    print(file_path)
        # data = json.load(file)
    cnt = 0
    corpus_data = []
    with open(file_path) as file:
        for item in file:
            if cnt<500:
                k = json.loads(item)
                # print(k)
                text = {}
                text['id'] = k["id"]
                text['text'] = k["paragraphs"]
                corpus_data.append(text)
                corpus += str(text['text'])
                corpus += "\n"
                cnt += 1
            
    with open("data/musique/corpus.txt", 'w') as file:
        # file.truncate()
        file.write(corpus)
    with open(corpus_file, 'w') as file:
        file.truncate()
        json.dump(corpus_data,file)
    return corpus_data

In [18]:
corpus_data = prepare_corpus()

musique_full_v1.0_dev.jsonl


In [19]:
doc_nodes_500 = {}
for doc in corpus_data:
    doc_id = doc['id']
    documents = Document(text = str(doc['text']))
    text_splitter = TokenTextSplitter(chunk_size=256,chunk_overlap=200)
    doc_nodes_500[doc_id] = text_splitter.get_nodes_from_documents([documents])

In [20]:
from tqdm import tqdm
doc_triplets_500 = []
total = 100  # 总体进度
# 创建进度条
with tqdm(total=total) as pbar:
    for doc in doc_nodes_500.keys():
        # print("doc:",doc)
        nodes = doc_nodes_500[doc]
        for node in nodes:
            text = node.get_content()
            extraction_data = generate_data_for_extraction(doc,node.node_id,text)
            triplets = generate_doc_triples(extraction_data)
            doc_triplets_500.extend(triplets)
        # print(doc_triplets)
        pbar.update(1)

100%|██████████| 2/2 [00:58<00:00, 29.12s/it]
100%|██████████| 2/2 [00:53<00:00, 26.75s/it]
100%|██████████| 2/2 [00:36<00:00, 18.49s/it]
100%|██████████| 2/2 [00:52<00:00, 26.25s/it]
100%|██████████| 2/2 [00:49<00:00, 24.65s/it]
100%|██████████| 2/2 [00:47<00:00, 23.56s/it]
100%|██████████| 2/2 [00:44<00:00, 22.19s/it]
100%|██████████| 2/2 [00:52<00:00, 26.27s/it]
100%|██████████| 2/2 [00:50<00:00, 25.07s/it]
100%|██████████| 2/2 [00:39<00:00, 19.88s/it]
100%|██████████| 2/2 [01:05<00:00, 32.63s/it]
100%|██████████| 2/2 [00:44<00:00, 22.30s/it]
100%|██████████| 1/1 [00:23<00:00, 23.95s/it]
100%|██████████| 1/1 [00:23<00:00, 23.62s/it]
100%|██████████| 2/2 [00:34<00:00, 17.11s/it]
100%|██████████| 2/2 [00:45<00:00, 22.89s/it]
100%|██████████| 2/2 [01:55<00:00, 57.64s/it]
100%|██████████| 2/2 [01:05<00:00, 32.77s/it]
100%|██████████| 2/2 [00:37<00:00, 18.53s/it]
100%|██████████| 2/2 [00:54<00:00, 27.43s/it]
100%|██████████| 2/2 [17:23<00:00, 521.79s/it]
100%|██████████| 2/2 [00:27<00:00