<a href="https://colab.research.google.com/github/Karthick47v2/question-generator/blob/main/data_extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Download dataset

In [None]:
# SQuAD dataset
!wget https://data.deepai.org/squad1.1.zip
!unzip squad1.1.zip

# SciQ dataset
!wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip
!unzip SciQ.zip

### Install third party libraries

In [None]:
!pip3 install transformers==4.1.1
!pip3 install tokenizers==0.9.4
!pip3 install sentencepiece==0.1.94

### Import libraries

In [3]:
import json
import pandas as pd
from transformers import T5Tokenizer
import seaborn as sns
import matplotlib.pyplot as plt

### Extract data from json

In [4]:
def parse_json(filepath):
  """Load json file from storage.

  Args:
    filepath (str): Path of json file.

  Returns:
    list(dict(obj)): List of nested dictionaries.
  """
  data = {}

  with open(filepath) as file:
    data = json.load(file)

  return data

***SQuAD***

- SQuAD dataset doesn't contain null values, so, no need to check.
- We are only interested in generating questions from simple answers. So answers with more than 5 words will be filtered out.

In [5]:
def extract_from_squad(data):
  """Extract data from SQuAD dataset.

  Args:
    data (list(dict(obj))): List of nested dictionaries.

  Returns:
    tuple(list(str), list(str)): tuple of lists of model input and output. 
  """
  source = []
  target = []

  for topic in data['data']:
    for dict_set in topic['paragraphs']:
      for qna_set in dict_set['qas']:
        if is_short_answer(qna_set['answers'][0]['text'], 5):
          source.append(f"context: {dict_set['context']} answer: {qna_set['answers'][0]['text']}")
          target.append(qna_set['question'])

  return source, target

***SciQ***

- SCiQ dataset contains empty string for some values of `support` (mentioned in dataset readme.txt). So, that will be filtered out.
- We are only interested in generating questions from simple answers. So answers with more than 5 words will be filtered out.

In [6]:
def extract_from_sciq(data):
  """Extract data from SciQ dataset.

  Args:
    data (list(dict(obj))): List of nested dictionaries.

  Returns:
    tuple(list(str), list(str)): tuple of lists of model input and output. 
  """
  source = []
  target = []

  for dict_set in data:
    if dict_set['support'] == "":
      continue
    if is_short_answer(dict_set['correct_answer'], 5):
      source.append(f"context: {dict_set['support']} answer: {dict_set['correct_answer']}")
      target.append(dict_set['question'])

  return source, target

In [7]:
def is_short_answer(ans, threshold):
  return len(ans.split()) <= threshold

In [8]:
data = parse_json('train-v1.1.json')
squad_source_text, squad_target_text = extract_from_squad(data)

sciq_source_text = []
sciq_target_text = []

for filename in ['train', 'test', 'valid']:
  data = parse_json(f"SciQ dataset-2 3/{filename}.json")
  source, target = extract_from_sciq(data)

  sciq_source_text.extend(source)
  sciq_target_text.extend(target)

***SQuAD***
- Total data: 87,599
- Filtered data: 76,135


***SciQ***
- Total data: 13,679
- Filtered data: 12,214

### Data visualization and reduction

Filter out any duplicate questions.

In [9]:
squad_df = pd.DataFrame({'source_text': squad_source_text, 'target_text': squad_target_text})
sciq_df = pd.DataFrame({'source_text': sciq_source_text, 'target_text': sciq_target_text})

squad_df.drop_duplicates(subset=['target_text'], ignore_index=True, inplace=True)
sciq_df.drop_duplicates(subset=['target_text'], ignore_index=True, inplace=True)

***SQuAD***
- Before: 76,135
- After filtering out duplicates: 75,937


***SciQ***
- Before: 12,214
- After filtering out duplicates: 12,133

Filter out data with has exceeding tokens (than model input token size)

In [None]:
t5_tokenizer = T5Tokenizer.from_pretrained('t5-base')

In [11]:
def plot_token_count(df, dataset):
  """Plot token count againts no of data.
  
  Args:
    df (DataFrame): DataFrame of dataset that needs to plot.
    dataset (str): Dataset name.
  """
  source_token_count = []
  target_token_count = []

  for _, row in df.iterrows():
    source_token_count.append(len(t5_tokenizer.encode(row['source_text'])))
    target_token_count.append(len(t5_tokenizer.encode(row['target_text'])))

  fig, (ax1, ax2) = plt.subplots(1,2, figsize=(20,10))
  
  sns.histplot(source_token_count, ax=ax1).set(title=f"{dataset}-'source_text'-tokens")
  sns.histplot(target_token_count, ax=ax2).set(title=f"{dataset}-'target_text'-tokens")

In [12]:
def filter_out_tokens(df, filter_by, max_token_len,):
  """Delete rows which has data with exceeding token length.

  Args:
    df (DataFrame): Dataset that needs to be processed.
    filter_by (str): Name of attribute, to check token length.
    max_token_len (int): Maximum token length (For model input = 512).

  Returns:
    (DataFrame): Filtered dataset by 'filter_by' attribute
  """
  df['not_exceeded'] = df.apply(lambda x: len(t5_tokenizer.encode(x[filter_by])) <= max_token_len, axis=1)
  df = df[df['not_exceeded']]
  return df.drop(columns=['not_exceeded'])

In [13]:
def filter_df_by_token_len(df, dataset,filter_by, max_token_len):
  """Filter dataset against model requirements.

  Args:
    df (DataFrame): Dataset that needs to be processed.
    dataset (str): Dataset name.
    filter_by (str): Name of attribute, to check against.
    max_token_len (int): Maximum token length (For model input = 512).

  Returns:
    (DataFrame): Filtered dataset by 'filter_by' attribute
  """
  print(f"filter by {filter_by}")
  print('Before filtering...')
  plot_token_count(df, dataset)
  df = filter_out_tokens(df, filter_by, max_token_len)
  print('After filtering...')
  plot_token_count(df, dataset)
  return df

***SciQ***

In [None]:
sciq_df = filter_df_by_token_len(sciq_df, 'SciQ', 'source_text', 512)

Since most of 'target_text' token lengths are between 0-6x, filter out the outliers. (Used to set max out token length in Model training)

Let's filtered out by 72.


In [None]:
sciq_df = filter_df_by_token_len(sciq_df, 'SciQ', 'target_text', 72)

***SQuAD***

In [None]:
squad_df = filter_df_by_token_len(squad_df, 'SQuAD', 'source_text', 512)

Since most of 'target_text' token lengths are between 0-4x, filter out the outliers. (Used to set max out token length in Model training)

Let's filtered out by 48.

> We aren't combining both dataset for training. Each one will be trained separately for different purpose. So, differ in max output length doesn't matter.

In [None]:
squad_df = filter_df_by_token_len(squad_df, 'SQuAD', 'target_text', 48)

In [None]:
sciq_df.shape, squad_df.shape

***SQuAD***
- Before: 75,937
- After filtering out data w exceeding token lens: 75,699


***SciQ***
- Before: 12,133
- After filtering out data w exceeding token lens: 11,970

### Export as *.csv and upload to GDrive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
sciq_df.to_csv('SciQ-processed.csv', index=False)
squad_df.to_csv('SQuAD-processed.csv', index=False)

!mv SciQ-processed.csv SQuAD-processed.csv gdrive/MyDrive/mcq-gen