In [31]:
from sklearn.base import BaseEstimator, RegressorMixin, clone

import numpy as np

import codecs

import subprocess

from itertools import islice, product

import re

import pandas as pd

import os

os.sys.path.insert(0, '../script')
from webnlg import WebNLGCorpus

In [3]:
model_name = 'just-join-baseline'

# Algorithm

In [4]:
# The algorithm using the scikit-learn framework

def simple_spo_sort(spo):
    
    return [spo['subject'], spo['predicate'], spo['object']]

def simple_ops_sort(spo):
    
    return [spo['object'], spo['predicate'], spo['subject']]


class NLGBaseline(BaseEstimator, RegressorMixin):
    
    def __init__(self, spo_sep=' ', sen_sep=',', spo_sort=simple_spo_sort, sen_sort=lambda x:x):
        
        self.spo_sep = spo_sep
        self.sen_sep = sen_sep
        self.spo_sort = spo_sort
        self.sen_sort = sen_sort
        
    
    # there isn't any training step, as it's all rule-based        
    def fit(self, X, y=None):
        pass
    
    # generating text for an entry
    def predict_entry(self, x):
        
        sens = []
        
        for t in self.sen_sort(x):
            
            t_sorted = self.spo_sort(t)
            
            sen = self.spo_sep.join(t_sorted)
            
            sens.append(sen)
        
        text = self.sen_sep.join(sens)
            
        return text
    
    
    def predict(self, X, y=None):
        
        # for each entry, generate a text
        return [self.predict_entry(x) for x in X]

# Loading dataset

In [5]:
test = WebNLGCorpus.load("test_with_lex")

X_test = np.array([t.get_data() for t in test])
y_test = np.array([t.lexes() for t in test])

# Generating texts

In [22]:
%%time

model = NLGBaseline()

texts = model.predict(X_test)

with codecs.open('../data/models/just-join-baseline', 'w', 'utf-8') as f:
    
    for text in texts:
        
        f.write("{}\n".format(text))

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 20.3 ms


## Generated texts sample

In [23]:
!head -100 ../data/models/just-join-baseline | tail -10

Castle_(novel) language English_language
Eric_Flint birthPlace Burbank,_California
Farrar,_Straus_and_Giroux parentCompany Macmillan_Publishers
John_Cowper_Powys notableWork A_Glastonbury_Romance
Soho_Press country United_States
The_Secret_Scripture publisher Faber_and_Faber
United_States ethnicGroup Asian_Americans
United_States language English_language
Weymouth_Sands precededBy A_Glastonbury_Romance
A.C._Chievo_Verona manager Rolando_Maran


## Individual samples

### Sample with ntriples = 1

In [11]:
sample = test.sample(eid='Id10')

sample

Triple info: category=Airport eid=Id10

	Modified triples:

Amsterdam_Airport_Schiphol | 1st_runway_Number | 18


	Lexicalizations:

The first runway at Amsterdam's Schiphol Airport is known as Number 18.
The Amsterdam Airport Schiphol's 1st runway number is 18.
The number of the 1st runway at Amsterdam Airport Schiphol is 18.

In [16]:
model.predict([sample.get_data()])[0]

',Amsterdam_Airport_Schiphol 1st_runway_Number 18'

### Sample with ntriples = 3

In [18]:
sample = test.sample(eid='Id397')

sample

Triple info: category=Astronaut eid=Id397

	Modified triples:

Alan_Bean | nationality | United_States
Alan_Bean | birthPlace | Wheeler,_Texas
Alan_Bean | status | "Retired"


	Lexicalizations:

American retiree Alan Bean was born in Wheeler, Texas.
Retired US National Alan Bean was born in Wheeler, Texas.
Alan Bean was born in the United States in Wheeler, Texas and has retired.

In [89]:
model.predict([sample.get_data()])[0]

'Alan_Bean nationality United_States, Alan_Bean birthPlace Wheeler,_Texas, Alan_Bean status "Retired"'

# Evaluation

### Function

In [52]:
BLEU_RE = re.compile(r'BLEU\ =\ ([\d\.]*),')
METEOR_RE = re.compile(r'Final score:\s+([\d\.]+)\n')

def evaluate_model(base_model, param_grid):
    
    results = []
    
    for i, params in enumerate(product(*param_grid.values())):
        
        params = dict(zip(param_grid.keys(), params))
        
        model = clone(base_model)
        
        model.set_params(**params)
        
        # generate the texts
        
        with codecs.open(f'../data/models/{i}', 'w', 'utf-8') as f:

            for text in model.predict(X_test):

                f.write("{}\n".format(text))
                
        # generate the files needed to calculate BLEU and METEOR
        
        subprocess.run(['mkdir', '-p', f'../tmp/{i}'])
        
        subprocess.run(['python', '../evaluation/webnlg2017/webnlg-automatic-evaluation-v2/evaluation_v2.py',
                        '--team_name', f'{i}',
                        '--team_filepath', f'../data/models/{i}',
                        '--outdir', f'../tmp/{i}'])
        
        all_cat_filepath = f'../tmp/{i}/{i}_all-cat.txt'
        
        result = dict(params)
        results.append(result)
                        
        with open(all_cat_filepath, 'rb') as f:
            # bleu
            bleu_result = subprocess.run(['../evaluation/webnlg2017/webnlg-baseline-master/multi-bleu.perl', 
                                     '-lc',
                                     '../evaluation/webnlg2017/webnlg-automatic-evaluation/references/gold-all-cat-reference0.lex',
                                     '../evaluation/webnlg2017/webnlg-automatic-evaluation/references/gold-all-cat-reference1.lex',
                                     '../evaluation/webnlg2017/webnlg-automatic-evaluation/references/gold-all-cat-reference2.lex'],
                                     stdout=subprocess.PIPE,
                                     input=f.read())
                  
            result['bleu'] = float(BLEU_RE.findall(bleu_result.stdout.decode('utf-8'))[0])
                  
        # meteor
        meteor_result = subprocess.run(['java', '-Xmx2G', '-jar',
                                        '../evaluation/webnlg2017/meteor-1.5/meteor-1.5.jar',
                                        all_cat_filepath,
                                        '../evaluation/webnlg2017/webnlg-automatic-evaluation/references/gold-all-cat-reference-3ref.meteor',
                                        '-l', 'en', '-norm', '-r', '3', '-a', 
                                        '../evaluation/webnlg2017/meteor-1.5/data/paraphrase-en.gz'],
                                        stdout=subprocess.PIPE)
                        
        result['meteor'] = float(METEOR_RE.findall(meteor_result.stdout.decode('utf-8'))[0])
                  
    return pd.DataFrame.from_records(results)

### Models

In [53]:
%%time

def as_is(x):
    
    return x


param_grid = {'spo_sep': [' '],
              'sen_sep': [',', ' '],
              'spo_sort': [simple_spo_sort, simple_ops_sort],
              'sen_sort': [as_is, reversed]
             }

results = evaluate_model(NLGBaseline(), param_grid)

CPU times: user 93.8 ms, sys: 1.36 s, total: 1.45 s
Wall time: 1min 44s


In [54]:
results

Unnamed: 0,bleu,meteor,sen_sep,sen_sort,spo_sep,spo_sort
0,11.62,0.263191,",",<function as_is at 0x7fbcd2df68c8>,,<function simple_spo_sort at 0x7fbcdb95ef28>
1,11.61,0.263003,",",<class 'reversed'>,,<function simple_spo_sort at 0x7fbcdb95ef28>
2,11.51,0.262994,",",<function as_is at 0x7fbcd2df68c8>,,<function simple_ops_sort at 0x7fbcdb95ec80>
3,11.52,0.26306,",",<class 'reversed'>,,<function simple_ops_sort at 0x7fbcdb95ec80>
4,11.29,0.257693,,<function as_is at 0x7fbcd2df68c8>,,<function simple_spo_sort at 0x7fbcdb95ef28>
5,11.29,0.257685,,<class 'reversed'>,,<function simple_spo_sort at 0x7fbcdb95ef28>
6,11.27,0.257691,,<function as_is at 0x7fbcd2df68c8>,,<function simple_ops_sort at 0x7fbcdb95ec80>
7,11.27,0.257636,,<class 'reversed'>,,<function simple_ops_sort at 0x7fbcdb95ec80>
