<a href="https://colab.research.google.com/github/Karthick47v2/question-generator/blob/main/data_extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Download dataset

In [None]:
# SQuAD dataset
!wget https://data.deepai.org/squad1.1.zip
!unzip squad1.1.zip

# SciQ dataset
!wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip
!unzip SciQ.zip

### Install third party libraries

In [None]:
!pip3 install transformers==4.1.1
!pip3 install tokenizers==0.9.4
!pip3 install sentencepiece==0.1.94

### Import libraries

In [3]:
import json
import pandas as pd
from transformers import T5Tokenizer
import seaborn as sns
import matplotlib.pyplot as plt

### Extract data from json

In [4]:
def parse_json(filepath):
  """Load json file from storage.

  Args:
    filepath (str): Path of json file.

  Returns:
    list(dict(obj)): List of nested dictionaries.
  """
  data = {}

  with open(filepath) as file:
    data = json.load(file)

  return data

***SQuAD***

- SQuAD dataset doesn't contain null values, so, no need to check.
- We are only interested in generating questions from simple answers. So answers with more than 5 words will be filtered out.

In [5]:
def extract_from_squad(data):
  """Extract data from SQuAD dataset.

  Args:
    data (list(dict(obj))): List of nested dictionaries.

  Returns:
    tuple(list(str), list(str)): tuple of lists of model input and output. 
  """
  contexts = []
  questions = []
  answers = []

  for topic in data['data']:
    for dict_set in topic['paragraphs']:
      for qna_set in dict_set['qas']:
        if is_short_answer(qna_set['answers'][0]['text'], 5):
          contexts.append(f"context: {dict_set['context']}")
          questions.append(f"question: {qna_set['question']}")
          answers.append(f"answer: {qna_set['answers'][0]['text']}")

  return contexts, questions, answers

***SciQ***

- SCiQ dataset contains empty string for some values of `support` (mentioned in dataset readme.txt). So, that will be filtered out.
- We are only interested in generating questions from simple answers. So answers with more than 5 words will be filtered out.

In [6]:
def extract_from_sciq(data):
  """Extract data from SciQ dataset.

  Args:
    data (list(dict(obj))): List of nested dictionaries.

  Returns:
    tuple(list(str), list(str)): tuple of lists of model input and output. 
  """
  contexts = []
  questions = []
  answers = []

  for dict_set in data:
    if dict_set['support'] == "":
      continue
    if is_short_answer(dict_set['correct_answer'], 5):
      contexts.append(f"context: {dict_set['support']}")
      questions.append(f"question: {dict_set['question']}")
      answers.append(f"answer: {dict_set['correct_answer']}")

  return contexts, questions, answers

In [7]:
def is_short_answer(ans, threshold):
  return len(ans.split()) <= threshold

In [9]:
data = parse_json('train-v1.1.json')
squad_contexts, squad_questions, squad_answers = extract_from_squad(data)

sciq_contexts = []
sciq_questions = []
sciq_answers = []

for filename in ['train', 'test', 'valid']:
  data = parse_json(f"SciQ dataset-2 3/{filename}.json")
  contexts, questions, answers = extract_from_sciq(data)

  sciq_contexts.extend(contexts)
  sciq_questions.extend(questions)
  sciq_answers.extend(answers)

***SQuAD***
- Total data: 87,599
- Filtered data: 76,135


***SciQ***
- Total data: 13,679
- Filtered data: 12,214

### Data visualization and reduction

Filter out any duplicate questions.

In [10]:
squad_df = pd.DataFrame({'context': squad_contexts, 'question': squad_questions,
                         'answer': squad_answers})
sciq_df = pd.DataFrame({'context': sciq_contexts, 'question': sciq_questions, 
                        'answer': sciq_answers})

squad_df.drop_duplicates(subset=['question'], ignore_index=True, inplace=True)
sciq_df.drop_duplicates(subset=['question'], ignore_index=True, inplace=True)

***SQuAD***
- Before: 76,135
- After filtering out duplicates: 75,937


***SciQ***
- Before: 12,214
- After filtering out duplicates: 12,133

Filter out data with has exceeding tokens (than model input token size)

In [None]:
t5_tokenizer = T5Tokenizer.from_pretrained('t5-base')

In [None]:
### SAMPLE tokenization with padding ...
encoding = t5_tokenizer.encode(sciq_df.loc[0,'context'], 
                               sciq_df.loc[0,'answer'], max_length=512, 
                               padding='max_length', truncation='only_first',
                        add_special_tokens=True, return_attention_mask = True,
                        return_tensors='pt')
preds = [
    t5_tokenizer.decode(
        input_id, skip_special_tokens=False, clean_up_tokenization_spaces=False) 
    for input_id in encoding
]

" ".join(preds)

In [13]:
def plot_token_count(df, dataset):
  """Plot token count againts no of data.
  
  Args:
    df (DataFrame): DataFrame of dataset that needs to plot.
    dataset (str): Dataset name.
  """
  source_token_count = []
  target_token_count = []

  for _, row in df.iterrows():
    source_token_count.append(get_token_len(row['context'],
                                                text_pair=row['answer']))
    target_token_count.append(get_token_len(row['question']))

  fig, (ax1, ax2) = plt.subplots(1,2, figsize=(20,10))
  
  sns.histplot(source_token_count, ax=ax1).set(title=f"{dataset}-source-tokens")
  sns.histplot(target_token_count, ax=ax2).set(title=f"{dataset}-target-tokens")

In [14]:
def get_token_len(text, text_pair=None):
  """Get length of tokens

  Args:
    text (str): 1st input sequence.
    text_pair (str): 2nd input sequence. Defaults to None.

  Returns:
    (int): Length of tokens.
  """

  return len(t5_tokenizer.encode(text,text_pair))

In [15]:
def filter_out_tokens(df, max_token_len, text, text_pair):
  """Delete rows which has data with exceeding token length.

  Args:
    df (DataFrame): Dataset that needs to be processed.
    max_token_len (int): Maximum token length (For model input = 512).
    text (str): 1st input sequence.
    text_pair (str): 2nd input sequence.

  Returns:
    (DataFrame): Filtered dataset by 'filter_by' attribute
  """
  df['not_exceeded'] = df.apply(
      lambda x: len(t5_tokenizer.encode(x[text])) <= max_token_len, axis=1)\
   if text_pair == None else df.apply(lambda x: len(t5_tokenizer.encode(
       x[text], x[text_pair])) <= max_token_len, axis=1)
  df = df[df['not_exceeded']]
  return df.drop(columns=['not_exceeded'])

In [16]:
def filter_df_by_token_len(df, dataset, max_token_len, text, text_pair=None):
  """Filter dataset against model requirements.

  Args:
    df (DataFrame): Dataset that needs to be processed.
    dataset (str): Dataset name.
    max_token_len (int): Maximum token length (For model input = 512).
    text (str): 1st input sequence.
    text_pair (str): 2nd input sequence. Defaults to None.

  Returns:
    (DataFrame): Filtered dataset by 'filter_by' attribute
  """
  print(f"filter by {text}")
  print('Before filtering...')
  plot_token_count(df, dataset)
  df = filter_out_tokens(df, max_token_len, text, text_pair)
  print('After filtering...')
  plot_token_count(df, dataset)
  return df

***SciQ***

In [None]:
sciq_df = filter_df_by_token_len(sciq_df, 'SciQ', 512, 'context', 'answer')

Since most of 'target_text' token lengths are between 0-6x, filter out the outliers. (Used to set max out token length in Model training)

Let's filtered out by 72.


In [None]:
sciq_df = filter_df_by_token_len(sciq_df, 'SciQ', 72, 'question')

***SQuAD***

In [None]:
squad_df = filter_df_by_token_len(squad_df, 'SQuAD', 512, 'context', 'answer')

Since most of 'target_text' token lengths are between 0-4x, filter out the outliers. (Used to set max out token length in Model training)

Let's filtered out by 48.

> We aren't combining both dataset for training. Each one will be trained separately for different purpose. So, differ in max output length doesn't matter.

In [None]:
squad_df = filter_df_by_token_len(squad_df, 'SQuAD', 48, 'question')

***SQuAD***
- Before: 75,937
- After filtering out data w exceeding token lens: 75,683


***SciQ***
- Before: 12,133
- After filtering out data w exceeding token lens: 11,964

### Export as *.csv and upload to GDrive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [23]:
sciq_df.to_csv('SciQ-processed.csv', index=False)
squad_df.to_csv('SQuAD-processed.csv', index=False)

!mv SciQ-processed.csv SQuAD-processed.csv gdrive/MyDrive/mcq-gen