In [None]:
!pip install simpletransformers
!pip install rouge

In [2]:
import pandas as pd
from simpletransformers.t5 import T5Model, T5Args
import json
import re
from rouge import Rouge 
from nltk.translate.bleu_score import sentence_bleu
from sklearn.model_selection import train_test_split
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [None]:
# Downloading LC-QuAD v2 from Dropbox link
!wget https://www.dropbox.com/s/pygr3g3ansoj043/lcquad_v2.json?dl=0

In [4]:
def read_json(json_file):
    file = json.load(open(json_file))

    prefix = []
    questions = []
    sparql = []

    for instance in file:
      for k, v in instance.items():
        if k == 'paraphrased_question':
          prefix.append('generate sparql')
          questions.append(v)
        elif k == 'sparql_wikidata':
          sparql.append(v)
    return prefix, questions, sparql

In [5]:
prefix, text, label = read_json('lcquad_v2.json?dl=0')

In [6]:
input_text = []
target_text = []
for input, target in zip(text, label):
  target_length = target.split(' ')
  if len(target_length) < 10:
      target = target.replace('{', 'open_bracket').replace('}', ' close_bracket').replace('.', 'sep_dot').replace('?answer', 'var_uri')
      target = re.sub(' +', ' ', target)
      input_text.append(input)
      target_text.append(target)

In [7]:
df = pd.DataFrame()
df['prefix'] = prefix[:len(input_text)]
df['input_text'] = input_text
df['target_text'] = target_text

In [8]:
# Splitting dataframe in train and test set, ratio 90% train, 10% test
df_train, df_test = train_test_split(df, test_size=0.1)

In [9]:
# Converting dataframes to string values and lowercasing 
df_train = df_train.astype(str)
df_test = df_test.astype(str)
df_train['input_text'] = df_train['input_text'].str.lower()
df_test['input_text'] = df_test['input_text'].str.lower()
df_train['target_text'] = df_train['target_text'].str.lower()
df_test['target_text'] = df_test['target_text'].str.lower()

In [10]:
for index, row in df_train[:5].iterrows():
  print("question: {}\nsparql: {}\n".format(row['input_text'], row['target_text']))
print("Total train size: {}".format(len(df_train)))

question: what is the jufo id of the modern york survey of books?
sparql: select distinct var_uri where open_bracket wd:q1426223 wdt:p1277 var_uri close_bracket

question: reunion has been given what iso 3166-1 alpha-3 code?
sparql: select distinct var_uri where open_bracket wd:q17070 wdt:p298 var_uri close_bracket

question: which are the primary help measures of petroleum jelly?
sparql: select distinct var_uri where open_bracket var_uri wdt:p2239 wd:q457239 close_bracket

question: how many works of art are in l'orfeo?
sparql: select distinct var_uri where open_bracket wd:q724008 wdt:p2635 var_uri close_bracket

question: what is the central tibetan administrations executive body?
sparql: select distinct var_uri where open_bracket var_uri wdt:p208 wd:q326343 close_bracket

Total train size: 4089


In [11]:
for index, row in df_test[:5].iterrows():
  print("question: {}\nsparql: {}\n".format(row['input_text'], row['target_text']))
print("Total test size: {}".format(len(df_test)))

question: which is the encyclopædia universalis id of john artist sargent?
sparql: select distinct var_uri where open_bracket wd:q155626 wdt:p3219 var_uri close_bracket

question: what is the dark knight's kmrb rating?
sparql: select distinct var_uri where open_bracket wd:q163872 wdt:p3818 var_uri close_bracket

question: which is the d-u-n-s for college of california, berkeley?
sparql: select distinct var_uri where open_bracket wd:q168756 wdt:p2771 var_uri close_bracket

question: which is daniel schneidermann's canal - u person id?
sparql: select distinct var_uri where open_bracket wd:q146749 wdt:p5243 var_uri close_bracket

question: the three georges dam is which dam?
sparql: select distinct var_uri where open_bracket var_uri wdt:p4792 wd:q12514 close_bracket

Total test size: 455


#T5-small

In [30]:
# Specifying model arguments, e.g. learn rate, batch size, epochs
model_args = T5Args()
model_args.num_train_epochs = 25
model_args.overwrite_output_dir = True
model_args.learning_rate = 0.001
model_args.batch_size = 32

model = T5Model(
    't5',
    't5-small',
    args=model_args,
    early_stopping=True,
    use_cuda=True
)

In [None]:
# Training the model with the specified model arguments
model.train_model(df_train)

In [38]:
# Converting test dataframe to separate X_test and Y_test
X_test = df_test['input_text'].tolist()
Y_test = df_test['target_text'].tolist()

In [39]:
# Let the trained model predict on test questions
predict = model.predict(X_test)

Generating outputs:   0%|          | 0/57 [00:00<?, ?it/s]

Decoding outputs:   0%|          | 0/455 [00:00<?, ?it/s]

In [43]:
rouge = Rouge()
rouge.get_scores(predict, Y_test, avg=True)

{'rouge-1': {'r': 0.595290423861852,
  'p': 0.7945368916797466,
  'f': 0.6805322833391653},
 'rouge-2': {'r': 0.4800366300366302,
  'p': 0.7461538461538509,
  'f': 0.5835701979392663},
 'rouge-l': {'r': 0.595290423861852,
  'p': 0.7945368916797466,
  'f': 0.6805322833391653}}

## ROUGE score if prediction contains empty value

In [41]:
count  = 0
ROUGEL = []
for pred, gold in zip(predict, Y_test):
  count += 1
  try:
    ROUGEL.append(rouge.get_scores(gold, pred))
  except ValueError:
    pass

In [42]:
rougel_f1 = []
for scores in ROUGEL:
  for score in scores:
    for key, value in score.items():
      if key == 'rouge-l':
        for k, v in value.items():
          if k == 'f':
            rougel_f1.append(v)

# Printing ROUGEL f1
rougel_f1 = sum(rougel_f1) / len(rougel_f1)
print("ROUGEL f1-score: {}".format(round(rougel_f1, 4)))

ROUGEL f1-score: 0.6805


## Model prediction on simple questions

In [None]:
questions = ["what is the population of Groningen?", "what is the color of an elephant?", "what color is the flag of Groningen?"]
sparql_pred = model.predict(questions)

In [45]:
for sparql, question in zip(sparql_pred, questions):
  print("question: {}\tSparql: {}".format(question, sparql))

question: what is the population of Groningen?	Sparql: select distinct var_uri where open_bracket wd:q7991
question: what is the color of an elephant?	Sparql: select distinct var_uri where open_bracket wd:q2831
question: what color is the flag of Groningen?	Sparql: select distinct var_uri where open_bracket wd:q25662
