In [1]:
import os

os.sys.path.insert(0, '../script')

from evaluation import evaluate_model

from content_selection import SelectAllContentSelector
from discourse_structuring import DoesntSortDiscourseStructurer
from sentence_aggregation import OneSentenceSentenceAggregator
from lexicalization import PreprocessLexicalizer, preprocess_so
from sentence_generation import JustJoinTripleSentenceGenerator
from text_generation import TextGenerator

In [2]:
model = TextGenerator(
    content_selection_model=SelectAllContentSelector(),
    discourse_structuring_model=DoesntSortDiscourseStructurer(),
    sentence_aggregation_model=OneSentenceSentenceAggregator(),
    lexicalization_model=PreprocessLexicalizer(preprocess=preprocess_so),
    sentence_generation_model=JustJoinTripleSentenceGenerator()
)

In [2]:
from webnlg_corpus import webnlg

corpus = webnlg.load('webnlg_challenge_2017')

In [3]:
s = corpus.sample(seed=100)
s

Triple info: Category=MeanOfTransportation eid=Id1515 idx=test_MeanOfTransportation_3_Id1515

	Modified Triples:

Aston_Martin_V8 | relatedMeanOfTransportation | Aston_Martin_DBS
Aston_Martin_V8 | assembly | "Newport Pagnell, Buckinghamshire, England, United Kingdom"
Aston_Martin_V8 | engine | 5.3 (litres)


	Lexicalizations:

The Aston Martin V8 is similar to the DBS and has a 5.3 litre engine. It is made in Newport Pagnell, Buckinghamshire, England, UK.


The Aston Martin V8 assembly took place in Newport Pagnell, Buckinghamshire, England, United Kingdom. The Aston Martin V8(has a engine volume of 5.3 litres) and the Aston Martin DBS are related means of transport.


The Aston Martin V8 has a 5.3 litre engine and its assembly took place in Newport Pagnell, Buckinghamshire, England, United Kingdom. The Aston Martin V8 and the Aston Martin DBS are related means of transport.


In [7]:
model.predict_entry(s.data)

'Aston Martin V8 related Mean Of Transportation Aston Martin DBS Aston Martin V8 assembly Newport Pagnell, Buckinghamshire, England, United Kingdom Aston Martin V8 engine litres 5.3'

In [5]:
evaluate_model(model, 'just-join')

{'bleu': 24.83, 'meteor': 0.33796256777976597, 'ter': 0.6401647072405756}

# New version

In [10]:
from importlib import reload
import template_based
reload(template_based)
from template_based import *

In [11]:
from collections import ChainMap, defaultdict
from template_based import *

class TemplateBasedModel:
    
    def __init__(self, template_db):
        
        fallback_template_db = defaultdict(lambda: [JustJoinTemplate()])
        
        self.template_db = ChainMap(template_db, fallback_template_db)
        self.ss = StructureData(self.template_db)
        self.st = SelectTemplate()
        self.mt = MakeText(preprocess_so)
    
    def predict(self, X):
        
        for x in X:
            
            structured_data = self.ss.structure(x)
            selected_templates = self.st.select_template(structured_data)
            text = self.mt.make_text(selected_templates)
            
            yield text       

In [12]:
tbm = TemplateBasedModel({})

evaluate_model(tbm, 'template-based-model')

{'bleu': 26.68, 'meteor': 0.32044680722795377, 'ter': 0.6323834446133413}

In [13]:
from reading_thiagos_templates import read_thiagos_templates

template_db = read_thiagos_templates()

In [14]:
tbm = TemplateBasedModel(template_db)

In [16]:
evaluate_model(tbm, 'template-based-model-w-steroids')

{'bleu': 35.31, 'meteor': 0.3785699895492735, 'ter': 0.6655403975293905}