<a href="https://colab.research.google.com/github/giuliofortini/NLP_SQuAD_Project/blob/gpt/SQUAD_question_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SQUAD Question Generation

##Setup

In [None]:
import json
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

from google.colab import drive
drive.mount('/content/drive')

RANDOM_STATE = 42

!pip install transformers
!nvidia-smi

import tensorflow as tf
import transformers
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead, AutoTokenizer
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import pipeline

import time


PRETRAINED = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
model = AutoModelWithLMHead.from_pretrained(PRETRAINED)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Fri Feb 12 15:48:14 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.39       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                       



In [None]:
MODEL_NAME = "GTP2_SQUAD_QGEN_3"
EPOCHS = 3
DATASET_LIMIT = None
LOAD_FROM_DRIVE = True
N_CONTEXT_GEN = None           # number of contexts to use. Choose None to take al the context in the test set
QUESTIONS_PER_CONTEXT = 3     # how many question generate for each context. 

In [None]:
try:
  with open('training_set.json') as f:
    json_data = json.load(f)
except:
  with open('/content/drive/My Drive/SQUAD/training_set.json') as f:
    json_data = json.load(f)

## Dataset

In [None]:
# Creates DataFrames with useful columns by unpacking 'paragraphs' column
def preprocess_df(df):
  temp = []
  title_dict = {}
  contexts = []

  for i, row in df.iterrows():
    for context in row['paragraphs']:
      contexts.append(context['context'])
      for qa in context['qas']:
        question_id = qa['id']
        question = qa['question']
        for answer in qa['answers']:
          answer_text = answer['text']
          answer_start = answer['answer_start']
          answer_end = answer_start+len(answer_text)
          temp.append([question_id, question, answer_text, answer_start, answer_end, i, len(contexts)-1])

  context_dict = dict(enumerate(contexts))
  df = pd.DataFrame(temp, columns=['question_id', 'question_text', 'answer_text', 'answer_start', 'answer_end', 'title_id', 'context_id'])
  
  return df, context_dict

# Read data from json
data = pd.json_normalize(json_data['data'])
data = data

# Split train and test
train, val_test  = train_test_split(data, test_size=0.15, random_state=RANDOM_STATE)
val, test         = train_test_split(val_test, test_size=0.05, random_state=RANDOM_STATE)

# Create DataFrames with useful columns
train_df, train_context_dict = preprocess_df(train)
val_df, val_context_dict = preprocess_df(val)
test_df, test_context_dict = preprocess_df(test)

train_df = train_df[["context_id", "question_text", "answer_text"]]
val_df = val_df[["context_id", "question_text", "answer_text"]]
test_df = test_df[["context_id", "question_text", "answer_text"]]

print(f"Train samples:\t{len(train_df)}\nVal samples:\t{len(val_df)}\nTest samples:\t{len(test_df)}")
train_df.head()

Train samples:	74520
Val samples:	12342
Test samples:	737


Unnamed: 0,context_id,question_text,answer_text
0,0,What type of stimuli causes pain?,intense or damaging
1,0,What type of feeling is pain?,distressing
2,0,Why has defining pain been a challenge?,"complex, subjective phenomenon"
3,0,What organization's definition is widely used?,The International Association for the Study of...
4,0,"In medical diagnosis, what is pain considered?",a symptom


In [None]:
def create_samples(df, context_dict, name):
  samples = []
  print(f"Creating {name}.txt...", end="")
  with open(f"{name}_samples.txt", "w") as out_file:
    for i, row in df.iterrows():
      context = context_dict[row["context_id"]].replace("\n", " ")
      line = f"[CTX] {context} [QS] {row['question_text']} [QE]\n"
      out_file.write(line)
      samples.append(line)
    print("done")
  return samples

train_samples = create_samples(train_df[:DATASET_LIMIT], train_context_dict, "train")
val_samples = create_samples(val_df[:DATASET_LIMIT], val_context_dict, "val")
test_samples = create_samples(test_df[:DATASET_LIMIT], test_context_dict, "test")

Creating train.txt...done
Creating val.txt...done
Creating test.txt...done


In [None]:
for t in train_samples:
  assert "[CTX]" in t[:7], t
print("Passed")

Passed


In [None]:
print("\nTRAIN samples: ", end="")
with open("train_samples.txt") as f:
  train_samples = f.readlines()
  print(len(train_samples))
  for sample in train_samples[:3]:
    print(sample.replace("\n", ""))

print("\nVAL samples: ", end="")
with open("val_samples.txt") as f:
  val_samples = f.readlines()
  print(len(val_samples))
  for sample in val_samples[:3]:
    print(sample.replace("\n", ""))

print("\nTEST samples: ", end="")
with open("test_samples.txt") as f:
  test_samples = f.readlines()
  print(len(test_samples))
  for sample in test_samples[:3]:
    print(sample.replace("\n", ""))



TRAIN samples: 74520
[CTX] Pain is a distressing feeling often caused by intense or damaging stimuli, such as stubbing a toe, burning a finger, putting alcohol on a cut, and bumping the "funny bone". Because it is a complex, subjective phenomenon, defining pain has been a challenge. The International Association for the Study of Pain's widely used definition states: "Pain is an unpleasant sensory and emotional experience associated with actual or potential tissue damage, or described in terms of such damage." In medical diagnosis, pain is a symptom. [QS] What type of stimuli causes pain? [QE]
[CTX] Pain is a distressing feeling often caused by intense or damaging stimuli, such as stubbing a toe, burning a finger, putting alcohol on a cut, and bumping the "funny bone". Because it is a complex, subjective phenomenon, defining pain has been a challenge. The International Association for the Study of Pain's widely used definition states: "Pain is an unpleasant sensory and emotional experi

In [None]:

def load_dataset(train_path, val_path, test_path, tokenizer):
    print("Creating textdataset for Train...", end="")
    train_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=train_path,
          block_size=128)
    print("done.\nCreating textdataset for Validation...", end="")
    val_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=val_path,
          block_size=128)
    print("done.\nCreating textdataset for Test...", end="")
    test_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=test_path,
          block_size=128)   
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False,
    )
    return train_dataset, val_dataset, test_dataset, data_collator

train_text,  val_text, test_text, data_collator = load_dataset("train_samples.txt", "val_samples.txt", "test_samples.txt", tokenizer)


Creating textdataset for Train...

Token indices sequence length is longer than the specified maximum sequence length for this model (13660468 > 1024). Running this sequence through the model will result in indexing errors


done.
Creating textdataset for Validation...done.
Creating textdataset for Test...

## Loading Model

In [None]:
if LOAD_FROM_DRIVE:
  MODEL_PATH = f"/content/drive/MyDrive/SQUAD/{MODEL_NAME}"

print("Model selected:", MODEL_PATH)

Model selected: /content/drive/MyDrive/SQUAD/GTP2_SQUAD_QGEN_3


## Training

In [None]:
if not LOAD_FROM_DRIVE: 
  training_args = TrainingArguments(
    output_dir = MODEL_NAME, # The output directory
    overwrite_output_dir=True, # overwrite the content of the output directory
    num_train_epochs = EPOCHS, # number of training epochs
    per_device_train_batch_size = 32, # batch size for training
    per_device_eval_batch_size = 64,  # batch size for evaluation
    eval_steps = 5000, # Number of update steps between two evaluations.
    save_steps = 5000, # after # steps model is saved 
    )

  #train_text, eval_text = train_test_split(train_text, test_size=0.2, random_state=RANDOM_STATE)

  trainer = Trainer(
      model=model,
      args=training_args,
      data_collator=data_collator,
      train_dataset=train_text,
      eval_dataset=val_text,
      #prediction_loss_only=True,
  )

  trainer.train()
  trainer.save_model()

 Save model on drive

In [None]:
if not LOAD_FROM_DRIVE: 
  import shutil
  print("Copying on Drive...")
  drive_dest = f"/content/drive/MyDrive/SQUAD/{MODEL_NAME}"
  shutil.copytree(MODEL_NAME, drive_dest)
  print(f"Model saved on drive at \t{drive_dest}")

## Question generation

In [None]:
text_generator = pipeline('text-generation', model=MODEL_PATH, tokenizer=tokenizer, config={'max_length':100})

In [None]:
def paragraph(text, max_width=80):
  if len(text) > max_width:
    cut = max_width
    while text[cut] != " ": cut -= 1
    return text[:cut].strip() + "\n" + paragraph(text[cut:], max_width)
  else:
    return text.strip()

In [None]:
test = np.array(test_samples)[:]
true_questions = {}

for t in test:
  # separate context and questions
  question_start = t.index("[QS]") + 4
  context, question = t[:question_start], t[question_start:]
  # initualize dict key
  if context not in true_questions: true_questions[context] = []
  # add true question for the context
  question = question.replace("[QE]", "").replace("\n", "")
  true_questions[context].append(question)


contexts = list(true_questions.keys())

In [None]:
contexts[0]

'[CTX] The city of Bern or Berne (German: Bern, pronounced [bɛrn] ( listen); French: Berne [bɛʁn]; Italian: Berna [ˈbɛrna]; Romansh: Berna  [ˈbɛrnɐ] (help·info); Bernese German: Bärn [b̥æːrn]) is the de facto capital of Switzerland, referred to by the Swiss as their (e.g. in German) Bundesstadt, or "federal city".[note 1] With a population of 140,634 (November 2015), Bern is the fifth most populous city in Switzerland. The Bern agglomeration, which includes 36 municipalities, had a population of 406,900 in 2014. The metropolitan area had a population of 660,000 in 2000. Bern is also the capital of the Canton of Bern, the second most populous of Switzerland\'s cantons. [QS]'

In [None]:
import tensorflow as tf

tf.get_logger().setLevel("ERROR")
#transformers.logging.set_verbosity_error()

samples_outputs = []
for i in range(QUESTIONS_PER_CONTEXT):
  print(f"\rGenerated questions for each context: ({len(contexts[:N_CONTEXT_GEN])})", end="")
  print(f"{i+1}/{QUESTIONS_PER_CONTEXT}")
  generation = text_generator(
      contexts[:N_CONTEXT_GEN],
      do_sample=True,
      max_length=250,
      top_k=50,
      top_p=0.95,
      num_return_sequences=1,
      verbose=True
  )
  samples_outputs.append(generation)

Generated questions for each context: 1/3
Generated questions for each context: 2/3
Generated questions for each context: 3/3


In [None]:
def clean(text):
  return (text.
          replace("[CTX]", "").
          replace("[QS]", "").
          replace("[QE]", "").
          strip())

In [None]:
pred_questions = {}
count = 0
bad_generations = 0
for output in samples_outputs:
  for batch in output:
    for gen_text in batch:
      text = gen_text["generated_text"]
      try:
        q_start = text.index("[QS]") + 4
        q_end = text.index("[QE]")
        context = text[:q_start].replace("[QE]", "")
        pred_question = text[q_start : q_end]
        if context not in pred_questions: pred_questions[context] = []
        pred_questions[context].append(pred_question)
        count += 1
      except:
       bad_generations += 1
      

print(f"Expected questions: \t{QUESTIONS_PER_CONTEXT*len(contexts[:N_CONTEXT_GEN])} \t({QUESTIONS_PER_CONTEXT} questions for each {len(contexts[:N_CONTEXT_GEN])} context)")
print(f"Well formed ones: \t{count}")
print(f"Bad formed ones: \t{bad_generations}")

Expected questions: 	528 	(3 questions for each 176 context)
Well formed ones: 	492
Bad formed ones: 	36


In [None]:
for context in pred_questions:
  print("\nContext: \n", paragraph(clean(context)), "\n")

  print("True questions: ")
  for true in true_questions[context]:
    print("-", clean(true))

  print("\nPred questions: ")
  for pred in pred_questions[context]:
    print("-", clean(pred))

  print("="*100)


Context: 
 The city of Bern or Berne (German: Bern, pronounced [bɛrn] ( listen); French:
Berne [bɛʁn]; Italian: Berna [ˈbɛrna]; Romansh: Berna  [ˈbɛrnɐ] (help·info);
Bernese German: Bärn [b̥æːrn]) is the de facto capital of Switzerland, referred
to by the Swiss as their (e.g. in German) Bundesstadt, or "federal city".[note
1] With a population of 140,634 (November 2015), Bern is the fifth most
populous city in Switzerland. The Bern agglomeration, which includes 36
municipalities, had a population of 406,900 in 2014. The metropolitan area had
a population of 660,000 in 2000. Bern is also the capital of the Canton of
Bern, the second most populous of Switzerland's cantons. 

True questions: 
- What city is the de facto capital of Switserland?
- What is the second most populous of Switzerland's cantons?
- Which canton is Berne the capital?
- How many municipalities are in the Berne agglomeration?
- What is the population of Berne?
- Where is Bern located?
- How many municiplaities are in

In [None]:
import json

with open(f"{count}_generated_questions.json", "w") as f:
  json.dump(pred_questions, f)
  print("Resulst saved.")

Resulst saved.
