In [1]:
import os
import openai

os.environ["OPENAI_API_KEY"] = ""
openai.api_key = os.getenv("OPENAI_API_KEY")

In [2]:
def generate_dialogue(text, summary, num_turns=3):
    messages_bot = [
        {"role": "system", "content": "You should briefly answer the questions on the following text. If there is no answer in the given text, then you must answer that there is not enough information. Your answers should be brief. \n" + text},
    ]
    
    messages_person = [
        {"role": "system", "content": "You should be asking short questions about an article you can't see. You only see the following summary. Your task is to ask clarifying dependent questions in order to understand the source text. You can ask only single short question at each turn. \n" + summary},
    ]
    
    dialogue = []
    for turn in range(num_turns):
        
        try:
    
            question = openai.ChatCompletion.create(
              model="gpt-3.5-turbo",
              messages = messages_person,
              temperature=0.7,
              max_tokens=512,
              top_p=1,
              frequency_penalty=0,
              presence_penalty=0
            )
            question = question['choices'][0]['message']

            messages_person.append(question)
            messages_bot.append({'role': 'user', 'content': question['content'].strip()})


            response = openai.ChatCompletion.create(
              model="gpt-3.5-turbo",
              messages = messages_bot,
              temperature=0.7,
              max_tokens=512,
              top_p=1,
              frequency_penalty=0,
              presence_penalty=0
            )

            response = response['choices'][0]['message']

            messages_bot.append(response)
            messages_person.append({'role': 'user', 'content': response['content'].strip()})

            dialogue.append((question['content'].strip(), response['content'].strip()))
        
        except Exception as e:
            if len(dialogue) >= 2:
                return dialogue
            
            raise Exception('short dialogue') from e
        
    return dialogue

In [3]:
import pickle
import random
from tqdm import tqdm

In [4]:
from chat_scripts.summary_generation_inference import BartGenerator

In [13]:
summarizer = BartGenerator("/home/jovyan/chatbot/_common/bart_summarization/distilbart-cnn-12-6_1e-5",
                            device='cuda')

In [14]:
def join_segments(raw_segments, max_len=4000):
    '''
    Join paper segments
    '''
    
    indeces_sections = [[0,1]]
    
    prev_title = raw_segments[2][-1]
    i = 2
    cur_section = []
    while i < len(raw_segments):
        cur_title = raw_segments[i][-1]
        if cur_title == prev_title:
            cur_section.append(i)
        else:
            indeces_sections.append(cur_section)
            cur_section = [i]
            prev_title = cur_title
        i += 1
    
    if len(cur_section) > 0:
        indeces_sections.append(cur_section)
    
    joined_segments = []
    for sec in indeces_sections:
        cur_text = ''
        cur_split = []
        
        for idx in sec:
            if len(cur_text + '\n' + raw_segments[idx][1]) < max_len or len(cur_text) == 0:
                cur_split.append({'id': raw_segments[idx][0],
                                  'title': raw_segments[idx][-2], 'section_type': raw_segments[idx][-1]})
                cur_text = cur_text + '\n' + raw_segments[idx][1]
                cur_text = cur_text.strip()
            
            else:
                joined_segments.append((cur_text, cur_split))
                cur_text = ''
                cur_split = []
        
        if len(cur_text) > 0:
            joined_segments.append((cur_text, cur_split))
                
    return joined_segments

In [15]:
with open('../_common/papers_segmented_data/segmented_papers-2.pkl', 'rb') as f:
    data = pickle.load(f)

In [16]:
len(data)

24875

### Postprocessing cycle over all papers

In [17]:
with open('../_common/papers_segmented_data/davinci_dialogues_full_v2.pkl', 'rb') as f:
    davinci_dialogues = pickle.load(f)

In [18]:
with open('../_common/papers_segmented_data/chatgpt_dialogues_full.pkl', 'rb') as f:
    chatgpt_dialogues = pickle.load(f)

In [19]:
processed_papers = set([d['meta_paper']['paper_id'] for d in davinci_dialogues]) | \
                     set([d['meta_paper']['paper_id'] for d in chatgpt_dialogues])

In [None]:
for i in tqdm(range(15000, 20000)):
    if data[i]['paper_id'] in processed_papers:
        continue
        
    processed_papers.add(data[i]['paper_id'])
    
    segmented_paper = join_segments(data[i]['segments'])
    
    #  select random segment
    j = random.randint(0, len(segmented_paper) - 1)
    random_segment = segmented_paper[j]
    
    # skip acknowledgements
    if random_segment[1][-1]['section_type'].startswith('acknowledgement'):
        continue
    
    # filter short
    if len(random_segment[0]) < 1500:
        continue
        
    with open('logs/processed.txt', 'a') as f:
        f.write(f'Processing {i}\n')
        
    try:
        text = random_segment[0]
        summary = summarizer(text)[0]
        with open('logs/processed.txt', 'a') as f:
            f.write(f'Generating dialog for {i}...\n')
        result = generate_dialogue(text, summary, num_turns=4)
    except Exception as e:
        with open('logs/processed.txt', 'a') as f:
            f.write(f'Error for {i}: {str(e)}\n')
        continue
    
    chatgpt_dialogues.append({
        'text': random_segment[0],
        'dialogue': result,
        'meta_segments': random_segment[1],
        'meta_paper': {'title': data[i]['title'], 'paper_id': data[i]['paper_id']},
        'used_summary': summary,
    })
    
    with open('../_common/papers_segmented_data/chatgpt_dialogues_full_upd.pkl', 'wb') as f:
        pickle.dump(chatgpt_dialogues, f)
        
    with open('../_common/papers_segmented_data/chatgpt_dialogues_full_copy_upd.pkl', 'wb') as f:
        pickle.dump(chatgpt_dialogues, f)
        
    with open('logs/processed.txt', 'a') as f:
        f.write(f'Saved for {i}, segments {j}\n')

 45%|████▌     | 2254/5000 [9:26:14<10:44:02, 14.07s/it]

In [None]:
len(chatgpt_dialogues)

In [None]:
chatgpt_dialogues[-1]

In [28]:
for dialogue in chatgpt_dialogues:
    turns = []
    for turn in dialogue['dialogue']:
        turns.append({'speaker': 'person', 'text': turn[0]})
        turns.append({'speaker': 'bot', 'text': turn[1]})
    dialogue['parsed_dialogue'] = {
        'summary': dialogue['used_summary'],
        'turns': turns
    }

In [30]:
with open('../_common/papers_segmented_data/chatgpt_dialogues_full_postproc_upd.pkl', 'wb') as f:
    pickle.dump(chatgpt_dialogues, f)