In [3]:
import json
from collections import defaultdict

data = json.load(open('datasets/retacred/train.json'))

re_stats = defaultdict(int)
for sent in data:
    re_stats[sent['relation']] += 1
ignore_keys = set(['org:website', 'per:city_of_birth'])
aug_keys = set()
for key, value in sorted(re_stats.items(), key=lambda x: x[1]):
    if value < 300 and key not in ignore_keys:
        # print(f'---{key}: {value}')
        aug_keys.add(key)
print(aug_keys)
aug_sents = [sent for sent in data if sent['relation'] in aug_keys]
print(len(aug_sents))

{'per:other_family', 'org:dissolved', 'per:stateorprovince_of_birth', 'per:country_of_birth', 'per:city_of_death', 'per:stateorprovinces_of_residence', 'org:political/religious_affiliation', 'per:cause_of_death', 'per:date_of_birth', 'per:countries_of_residence', 'org:founded_by', 'per:parents', 'per:cities_of_residence', 'per:siblings', 'per:origin', 'per:schools_attended', 'per:children', 'org:number_of_employees/members', 'per:religion', 'per:spouse', 'per:stateorprovince_of_death', 'per:date_of_death', 'per:country_of_death', 'per:charges', 'org:founded', 'org:shareholders'}
3447


In [4]:
def encode_sent(sent):
    tokens = list(sent['token'])
    pairs = [(sent['subj_start'], sent['subj_end']+1), (sent['obj_start'], sent['obj_end']+1)]
    pairs.sort()
    poss = set()
    for start, end in pairs:
        if set(range(start, end)) & poss:
            print('------> overlapping entities')
            return None
        poss.update(range(start, end))
    for idx, (start, end) in enumerate(pairs):
        tokens.insert(start+2*idx, '[[')
        tokens.insert(end+2*idx+1, ']]')
    return ' '.join(tokens)

encoded_sents = defaultdict(list)
for sent in aug_sents:
    # print(' '.join(sent['tokens']))
    encoded_sent = encode_sent(sent)
    if encoded_sent:
        encoded_sents[sent['relation']].append(encoded_sent)

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer


class LLM:
    def __init__(self, prompt, device = "cuda:4"):
        LLM_path = "mistralai/Mistral-7B-Instruct-v0.2"
        self.prompt = prompt
        self.device = device
        self.model = AutoModelForCausalLM.from_pretrained(LLM_path).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(LLM_path)

    def chat(self, message):
        content = f'{self.prompt}\n{message}'
        messages = [{"role": "user", "content": content}]
        encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
        model_inputs = encodeds.to(self.device)
        generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=self.tokenizer.eos_token_id)
        decoded = self.tokenizer.batch_decode(generated_ids)
        return self.get_output(decoded[0])
    
    def get_output(self, response):
        response = response.replace('<s>', '').replace('</s>', '')
        if '[/INST]' not in response:
            return response.replace('[INST]', '')
        reponse = response.split('[/INST]')[1]
        return reponse.strip()

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
prompt = """

"""

llm = LLM(prompt)

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.08s/it]


In [7]:
def decode(content):
    pre_words = content.split()
    while True:
        changed = False
        words = []
        for word in pre_words:
            if len(word) == 0:
                continue
            if word in ['[[', ']]']:
                words.append(word)
            elif '[[' in word:
                idx = word.index('[[')
                if len(word[:idx]):
                    words.append(word[:idx])
                words.append('[[')
                if len(word[idx+2:]):
                    words.append(word[idx+2:])
                changed = True
            elif ']]' in word:
                idx = word.index(']]')
                if len(word[:idx]):
                    words.append(word[:idx])
                words.append(']]')
                if len(word[idx+2:]):
                    words.append(word[idx+2:])
                changed = True
            elif not word[-1].isalpha():
                words.append(word[:-1])
                words.append(word[-1])
            else:
                words.append(word)
        if not changed:
            break
        pre_words = words
    return words


def check_valid(words):
    if '\n' in words:
        return False, 'multiple lines'
    if words.count('[[') != 2:
        return False, 'incorrect num of [[: %s' % words.count('[[')
    if words.count(']]') != 2:
        return False, 'incorrect num of ]]: %s' % words.count(']]')
    head_start = words.index('[[')
    head_end = words.index(']]')
    tail_start = words.index('[[', head_start+1)
    tail_end = words.index(']]', head_end+1)
    if set(range(head_start, head_end+1)) & set(range(tail_start, tail_end+1)):
        return False, 'overlap entities'
    return True, 'Good content'

In [9]:
import os
from tqdm import tqdm

template = """
You are an editor who is very good at reading sentence. Your task is rewrite a given sentence well keeping the original entities.

In a sentence, each entity is nested in the sentence in the format of [[ entity ]].
Rewrite the given sentence using each given entity exactly once and do not introduce other entities.
Nest the original entities in the same format in the rewrited sentence.
You change the content inside the entity.

%s
"""

def rewrite_sent(sent, max_iter = 5):
    message = template % sent
    is_valid = False
    cur_iter = 0
    while not is_valid and cur_iter < max_iter:
        response = llm.chat(message)
        new_tokens = decode(response)
        is_valid, error = check_valid(new_tokens)
    if not is_valid:
        print('=====not valid afater %s tries' % max_iter)
        return None
    return new_tokens

for relation, cur_encoded_sents in sorted(encoded_sents.items(), key=lambda x: len(x[1])):
    print('--------> processing', relation)
    file_name = f'augs/{relation.replace("/", "--")}.json'
    if os.path.exists(file_name):
        continue
    rewrited_sents = []
    for encoded_sent in tqdm(cur_encoded_sents):
        rewrited_sent = rewrite_sent(encoded_sent)
        if rewrited_sent:
            # print('*' * 30)
            # print(encoded_sent)
            # print(rewrited_sent)
            rewrited_sents.append(rewrited_sent)
    with open(file_name, 'w') as af:
        json.dump(rewrited_sents, af)

--------> processing per:country_of_death
--------> processing org:dissolved
--------> processing per:country_of_birth
--------> processing per:stateorprovince_of_birth
--------> processing org:number_of_employees/members


 39%|███▉      | 21/54 [02:49<04:17,  7.82s/it]