In [None]:
!pip install simpletransformers
!pip install rouge

In [2]:
import pandas as pd
from simpletransformers.t5 import T5Model, T5Args
import json
import re
from rouge import Rouge 
from nltk.translate.bleu_score import sentence_bleu
from sklearn.model_selection import train_test_split
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [80]:
def read_json(json_file):
    file = json.load(open(json_file))

    prefix = []
    questions = []
    sparql = []

    for instance in file:
      for k, v in instance.items():
        if k == 'corrected_question':
          prefix.append('generate_sparql')
          questions.append(v)
        elif k == 'sparql_query':
          sparql.append(v)
    return prefix, questions, sparql

In [81]:
def convert_sparql(sparql):
  sent = ""
  labels = []
  for n in sparql:
    n = n.lower()
    n = n.replace('?uri.', '?uri .').replace('?uri}', '?uri }').replace('{?uri', '{ ?uri').replace('count(?uri)', 'count ( ?uri )')
    n = n.split(' ')
    for word in n:
        if word.startswith("<"):
            if "resource" in word:
                word = word.split('/')
                dbr = '<dbr_' + word[4]
                sent += dbr + ' '

            elif "property" in word:
                word = word.split('/')
                dbp = '<dbp_' + word[4]
                sent += dbp + ' '

            elif "ontology" in word:
                word = word.split('/')
                dbo = '<dbo_' + word[4]
                sent += dbo + ' '
        elif word == '{':
            sent += 'bracket_open '
        elif word == '}':
            sent += 'bracket_close'
        elif word == '?uri':
            sent += 'var_uri '
        elif word == '?x':
            sent += 'var_x '
        elif word == '.':
            sent += 'sep_dot '
        elif word == '(':
            sent += 'attr_open '
        elif word == ')':
            sent += 'attr_close '
        else:
            sent += word + ' '
    sent = re.sub(' +', ' ', sent)
    labels.append(sent)
    sent = ""
  return labels

In [82]:
def create_df(json_file):
    prefix, questions, sparql = read_json(json_file)
    labels = convert_sparql(sparql)

    input_text = []
    target_text = []
    for text, label in zip(questions, labels):
      label = label.rstrip()
      if label[-1] != '}':
        target_text.append(label)
        input_text.append(text)

    df = pd.DataFrame()
    df['prefix'] = prefix[:len(input_text)]
    df['input_text'] = input_text
    df['input_text'] = df['input_text'].str.lower()
    df['target_text'] = target_text

    return df, sparql

In [None]:
!wget https://www.dropbox.com/s/j5di3g5jm3e72p8/train-data.json?dl=0
!wget https://www.dropbox.com/s/8kil1x0pkf6c40p/test-data.json?dl=0

In [83]:
df1, sparql_train = create_df('train-data.json?dl=0')
df2, sparql_test = create_df('test-data.json?dl=0')
df = df1.append(df2, ignore_index=True)

In [84]:
df_train, df_test = train_test_split(df, test_size=0.1)

In [85]:
for index, row in df_train[:5].iterrows():
  print("question: {}\nsparql: {}\n".format(row['input_text'], row['target_text']))

question: what prizes have been awarded to the relatives of linn ullmann?
sparql: select distinct var_uri where bracket_open <dbr_linn_ullmann> <dbp_relatives> var_x sep_dot var_x <dbp_awards> var_uri sep_dot bracket_close

question: where was mackenzie miller born?
sparql:  select distinct var_uri where bracket_open <dbr_mackenzie_miller> <dbp_birthplace> var_uri bracket_close

question: where is the thorington train station located?
sparql:  select distinct var_uri where bracket_open <dbr_thorington_railway_station> <dbo_district> var_uri bracket_close

question: how many people currently play for the nyc fc?
sparql: select distinct count attr_open var_uri attr_close where bracket_open var_uri <dbp_currentclub> <dbr_new_york_city_fc> sep_dot bracket_close

question: in which races does coneygree compete?
sparql:  select distinct var_uri where bracket_open <dbr_coneygree> <dbp_race> var_uri bracket_close



In [86]:
for index, row in df_test[:5].iterrows():
  print("question: {}\nsparql: {}\n".format(row['input_text'], row['target_text']))

question: did annie leibovitz do the cover of the road ahead?
sparql: ask where bracket_open <dbr_the_road_ahead_(bill_gates_book)> <dbo_coverartist> <dbr_annie_leibovitz> bracket_close

question: does nintendo have a division called nintendo entertainment planning & development?
sparql: ask where bracket_open <dbr_nintendo> <dbp_divisions> <dbr_nintendo_entertainment_planning_&_development> bracket_close

question: what type of government is elected in kumta?
sparql:  select distinct var_uri where bracket_open <dbr_kumta> <dbo_governmenttype> var_uri bracket_close

question: which writer is famous for works written by neil gaiman?
sparql: select distinct var_uri where bracket_open var_x <dbp_writers> <dbr_neil_gaiman> sep_dot var_uri <dbo_notablework> var_x sep_dot bracket_close

question: where can one find the dzogchen ponolop rinpoche?
sparql:  select distinct var_uri where bracket_open <dbr_dzogchen_ponlop_rinpoche> <dbp_location> var_uri bracket_close



In [57]:
# Specifying model arguments, e.g. learn rate, batch size, epochs
model_args = T5Args()
model_args.num_train_epochs = 20
model_args.overwrite_output_dir = True
model_args.learning_rate = 0.001
model_args.batch_size = 32

model = T5Model(
    't5',
    't5-small',
    args=model_args,
    early_stopping=True,
    use_cuda=True
)

In [None]:
# Training the model with the specified model arguments
model.train_model(df_train)

In [87]:
# Converting test dataframe to separate X_test and Y_test
X_test = df_test['input_text'].tolist()
Y_test = df_test['target_text'].tolist()

In [88]:
# Let the trained model predict on test questions
predict = model.predict(X_test)

Generating outputs:   0%|          | 0/39 [00:00<?, ?it/s]

Decoding outputs:   0%|          | 0/308 [00:00<?, ?it/s]

In [89]:
rouge = Rouge()
rouge.get_scores(predict, Y_test, avg=True)

{'rouge-1': {'r': 0.48459587850821634,
  'p': 0.7785482374768109,
  'f': 0.5917736584881902},
 'rouge-2': {'r': 0.3665513373711424,
  'p': 0.7249149659863968,
  'f': 0.4803122392534425},
 'rouge-l': {'r': 0.48459587850821634,
  'p': 0.7785482374768109,
  'f': 0.5917736584881902}}