In [1]:
!pip install openai langchain tiktoken -q

In [2]:
import os
import pandas as pd

In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
os.environ["OPENAI_API_KEY"] = user_secrets.get_secret("openai")

In [4]:
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain

In [5]:
from langchain.output_parsers import ResponseSchema
from langchain.output_parsers import StructuredOutputParser

In [6]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

In [7]:
import tiktoken

In [8]:
from tqdm.auto import tqdm

In [9]:
def num_tokens_from_string(string: str, encoding_name: str="gpt-3.5-turbo") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.encoding_for_model(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

In [10]:
def truncate_text_by_token(string: str, limit: int=3000, encoding_name: str="gpt-3.5-turbo") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.encoding_for_model(encoding_name)
    tokens = encoding.encode(string)[:limit]
    text = encoding.decode(tokens)
    return text

In [11]:
# output token max to 500
chat = ChatOpenAI(temperature=0.75, request_timeout=600, max_tokens=1200, max_retries=1)

In [12]:
debug=False

In [13]:
df = pd.read_parquet('/kaggle/input/wikipedia-stem-plaintext/parsed.parquet')
df['id'] = range(len(df))

In [14]:
df = df.groupby(['title']).sample(1,random_state=100).sample(3000, random_state=100).reset_index(drop=True)

In [15]:
#df['id'] = -1

In [16]:
df

Unnamed: 0,title,section,text,id
0,Log sum inequality,Log sum inequality,The log sum inequality is used for proving the...,1103246
1,Low-voltage detect,Low-voltage detect,A low-voltage detect (LVD) is a microcontrolle...,1113000
2,Gradshteyn and Ryzhik,History,The work was originally compiled by the Russia...,810758
3,Reactive planning,Reactive plan representation,Condition-action rules (productions) A conditi...,1569278
4,CDIO Initiative,Sources,Edward Crawley; Johan Malmqvist; Sören Östlund...,265010
...,...,...,...,...
2995,Jacobs Vehicle Systems,How A Jake Brake Works,A “Jake Brake” is an engine compression releas...,1004207
2996,Camshaft,Piston engines,Alternative drive systems used in the past inc...,281776
2997,Quantum microscopy,Photoionization,"In 2013, Aneta Stodolna and colleagues imaged ...",1538091
2998,Nelson Diversity Surveys,Methodology,"The NDS were funded by Nelson, the Sloan Found...",1275997


In [17]:
doc_ids = df.id.tolist()

In [18]:
texts = df.text.tolist()

In [19]:
response_keys_set = set(("prompt", "A", "B", "C", "D", "E", "answer"))

In [20]:
template_string = """
You will be provided with TEXT from wikipedia. \
The TEXT will be delimited with ### characters.
### TEXT begin
{text}
### TEXT end
Output a python list of {q_num} dict objects, where each object is \
a multiple choice question whom answers should be in \
the given TEXT and that has 5 choices each and has the following format:
    'prompt': <question on the TEXT>
    'A': <question answer option>
    'B': <question answer option>
    'C': <question answer option>
    'D': <question answer option>
    'E': <question answer option>
    'answer': <answer option key label>

You should tell me which one of your proposed options is right \
by assigning the corresponding option's key label in the 'answer' field.

The question and the answer options should be challenging, \
more about statements, description.

The answer do not require the given TEXT and should be LONGER enough!
"""

In [21]:
prompt_template = ChatPromptTemplate.from_template(template_string)

In [22]:
chain = LLMChain(llm=chat, prompt=prompt_template)

In [23]:
if debug:
    texts = texts[:20] 

In [24]:
# split each text to 3000 token chunks, each chunk generate 1 question
# change this if for your vary len texts
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 3000,
    chunk_overlap  = 0,
    length_function = num_tokens_from_string)

In [25]:
def is_valid_question(r):
    return (response_keys_set == set(list(r.keys()))) and all([isinstance(v, str) and len(v)>0 for k,v in r.items()]) and (r['answer'] in ["A", "B", "C", "D", "E"])

In [26]:
def select_correct_formatted_questions(r):
    res_q = []
    for rr in r:
        if is_valid_question(rr):
            res_q.append(rr)
    return res_q

In [27]:
def gather_multiple_choice_question_dataset(doc_ids, texts, text_splitter, q_num=1, max_attempts=1):
    multiple_choice_questions = []
    for doc_id, text in tqdm(zip(doc_ids, texts), total=len(texts)):
        k = 0
        for text_chunk in text_splitter.create_documents([text]):
            
            if k>=1:
                break
            
            attempts_cnt = 0
            while True:
                try:
                    response = chain.run(text=text_chunk, q_num=q_num)
                    mcq = eval(response)
                    if not isinstance(mcq, list) or len(mcq) == 0:
                        raise Exception
                    output = select_correct_formatted_questions(mcq)
                    if len(output) == 0:
                        raise Exception
                    for d in output:
                        d['doc_id'] = doc_id
                    multiple_choice_questions.extend(output)
                    k+=1
                    break
                except Exception:
                    attempts_cnt += 1
                    if attempts_cnt >= max_attempts:
                        break
                        
    return multiple_choice_questions

In [28]:
multi_choice_questions = gather_multiple_choice_question_dataset(doc_ids=doc_ids,
                                                                 texts=texts, 
                                                                 text_splitter=text_splitter, 
                                                                 q_num=1,
                                                                 max_attempts=1)

  0%|          | 0/3000 [00:00<?, ?it/s]

In [29]:
q_df = pd.DataFrame(multi_choice_questions)

In [30]:
q_df.head()

Unnamed: 0,prompt,A,B,C,D,E,answer,doc_id
0,What is the log sum inequality used for?,Proving theorems in mathematics,Analyzing data in information theory,Solving optimization problems in computer science,Predicting outcomes in statistical analysis,Exploring quantum mechanics principles,B,1103246
1,Which field of study benefits from the log sum...,Economics,Computer science,Biology,Physics,Psychology,B,1103246
2,How is the log sum inequality used in informat...,To calculate the entropy of a random variable,To analyze the capacity of a communication cha...,To measure the amount of information in a message,To quantify the error in a data transmission,To estimate the probability of an event,B,1103246
3,What does the log sum inequality help prove in...,The existence of error-correcting codes,The optimality of certain coding schemes,The convergence of iterative decoding algorithms,The speed of data transmission in wireless net...,The security of cryptographic protocols,B,1103246
4,In what context is the log sum inequality comm...,Signal processing,Machine learning,Game theory,Image recognition,Database management,A,1103246


In [31]:
q_df.shape

(5466, 8)

In [32]:
q_df.to_csv("dataset_stem.csv", index=False)