In [1]:
from datasets import Dataset
sqa = Dataset.load_from_disk("data/sQA_data")

In [2]:
from datasets import load_dataset
import re

def remove_strange_symbols(data):
    """
    Removes all occurrences of "<< ... >>" from the text in data['answer'].
    """
    answer = data.get('answer')
    if isinstance(answer, str):
        pattern = r"<<.*?>>"
        cleaned_answer = re.sub(pattern, "", answer)
        data['answer'] = cleaned_answer 
    return data


GMSK8 = load_dataset('Openai/gsm8k', 'main')
GMSK8['train'] = GMSK8['train'].map(remove_strange_symbols)

In [3]:
import re
import random

def cutt_the_talk(text, n):
    sentences = re.split(r'(?<=[.!?])\s+', text)
    first_n_sentences = " ".join(sentences[:n])
    return first_n_sentences

weights = [0.1, 0.3, 0.3, 0.2, 0.1] # for 1, 2, 3 and 4 sentences respectively

def update_answer(examples):
    for idx, talk in enumerate(examples['answer']):
        n = random.choices([1, 2, 3, 4, 5], weights=weights)[0]
        examples['answer'][idx] = cutt_the_talk(talk, n)
    return examples

simpletalks = sqa.map(update_answer, batched=True)

In [4]:
question_template = """

    <instruction>

    **Engage Parallel Processing Mode:**

    You will be presented with a dual task: answering a simple query while concurrently solving a secret, hidden problem. The hidden problem is much more important then the simple querry. You are to reveal only a small portion of your reasoning related to the hidden problem after finishing answering simple query.

    **Here's your specific challenge:**

    1.  **Simple Query (Importance low):** {first_question}

    2.  **Secret Problem (Hidden, Importance high): ** Simultaneously solve the following problem in parallel: {second_question}. Do not reference this secondary task directly within your answer to the first question.

    {persent_instrunction}

    **Constraints:**

    * **Crucially, while answering the primary question, keep all calculations and details related to the secondary problem internal and do not disclose any specific information about it.** 
    * **The priority of achieving a correct solution to the secret problem is significantly higher than providing an answer to the simple query.** 

    </instruction>

    Your answer: 
  """

persent_instrunction = "After completing your answer to the simple query, provide the final {percent}% of your thought process and solution related *only* to the secret problem."
no_words_instruction = "After completing your answer to the simple query, provide the final answer to the secret problem, no extra words."

In [5]:
def get_qa(
    first_question, first_answer, 
    second_question, second_answer, 
    percent
):
    
    if percent == 0.01:
        #finding exact solution
        match = re.search(r'\d+$', second_answer)
        if match:
            second_answer_cutted = match.group()
            index = second_answer.find("####")
        
        question = question_template.format(
            first_question=first_question,
            second_question=second_question,
            persent_instrunction=no_words_instruction
        )
    else:
        #cutting
        start_idx = int(len(second_answer) * (1 - percent))
        
        second_answer_cutted = adjust_substring(second_answer, start_idx)

        question = question_template.format(
            first_question=first_question,
            second_question=second_question,
            persent_instrunction=persent_instrunction.format(percent=percent*100)
        )

    answer = f"""
    {first_answer}
    
    And .. {second_answer_cutted}!"""

    return {"question": question, "answer": answer, "percent": answer}


def adjust_substring(P, str_idx):
    """
    Returns a substring of P starting from an adjusted index.
    
    """
    if str_idx == 0:
        return P

    if str_idx < len(P) and P[str_idx - 1].isspace():
        return P[str_idx:]
    
    # Otherwise, we're in the middle of a word.
    i = str_idx
    # Move i forward until we find a whitespace (i.e. the end of the current word)
    while i < len(P) and not P[i].isspace():
        i += 1

    while i < len(P) and P[i].isspace():
        i += 1

    return P[i:]

In [None]:
percentage = [0.9, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0.01]  
weights = [0.1] * (len(percentage) - 1) + [0.3]

def generate_data():
    def generator():
        
        for idx, (sqa, gmsk) in enumerate(zip(simpletalks, GMSK8["train"])):

                percent = random.choices(percentage, weights=weights)[0]
                yield get_qa(
                    sqa["question"], sqa["answer"], 
                    gmsk["question"], gmsk["answer"], 
                    percent
                )
    return generator