In [1]:
%load_ext autoreload
%autoreload 2
import os

os.sys.path.insert(0, '../script')

from evaluation import evaluate_model
from collections import ChainMap, defaultdict, Counter
from template_based import *
import re

In [2]:
PARENTHESIS_RE = re.compile(r'(.*?)\((.*?)\)')
CAMELCASE_RE = re.compile(r'([a-z])([A-Z])')

def preprocess_so(so):

    parenthesis_preprocessed = PARENTHESIS_RE.sub('\g<2> \g<1>', so)
    underline_removed = parenthesis_preprocessed.replace('_', ' ')
    camelcase_preprocessed = CAMELCASE_RE.sub('\g<1> \g<2>', underline_removed)

    return camelcase_preprocessed.strip('" ')

In [3]:
class TemplateBasedModel:
    
    def __init__(self, template_db, lexicalization_f):
        
        fallback_template_db = defaultdict(lambda: Counter([JustJoinTemplate()]))
        
        self.template_db = ChainMap(template_db, fallback_template_db)
        self.ss = StructureData(self.template_db)
        self.st = SelectTemplate()
        self.mt = MakeText(lexicalization_f=lexicalization_f)
    
    def predict(self, X):
        
        result = []
        
        for x in X:
            try:

                structured_data = self.ss.structure(x.data)
                selected_templates = self.st.select_template(structured_data)
                text = self.mt.make_text(selected_templates)

                result.append(text)

            except Exception as ex:
                print(x)
                raise ex
        
        return result

In [4]:
tbm = TemplateBasedModel(template_db={}, lexicalization_f=preprocess_so)

evaluate_model(tbm, 'template-based-model-just-join')

{'bleu': 28.25, 'meteor': 0.3514776430040795, 'ter': 0.6457524388601327}