In [1]:
%time
%load_ext autotime
%load_ext autoreload
%autoreload 2

# if cannot import the modules, add the parent directory to system path might help
import os, tqdm, sys
parent_dir = os.path.abspath(os.getcwd()+'/..')+'/'
sys.path.append(parent_dir)

from utils.path import dir_HugeFiles
from utils.save import make_dir, save_pickle, load_pickle, save
from utils.tree import instr2tree, tree_distance, build_tree, stem
from utils.evaluation import metrics, spacy_extension

import pandas as pd
import numpy as np
import re

treemaker = instr2tree()
sp = spacy_extension()

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 8.11 µs


In [2]:
class evaluation:
    def __init__(self, filename, tag):
        self.dic = self.load_dic({}, filename, tag)
        self.ori = tag
        self.gens = []
        #from utils.evaluation import metrics
    '''
    loading data
    '''
    def append_dic(self, filename, tag):
        if tag in self.gens:
            print('already exist, will not load again')
            self.gen = tag
        else:
            self.dic = self.load_dic(self.dic, filename, tag)
            self.gen = tag
            self.gens += [tag]
        
    def load_dic(self, dic, filename, tag):
        if os.path.isdir(filename):
            print('load', filename)
            for (dirpath, _, fnames) in os.walk(filename):
                for fname in fnames:
                    path = os.path.join(dirpath, fname)
                    with open(path, 'r') as fp:
                        raw_text = fp.read()
                        raw_text = self.remove_end(raw_text)

                    name, field = int(fname[:-5]), fname[-5]

                    if name not in dic.keys() and field in ['d','i']:
                        dic.update({name: {}})

                    if field == 'd':
                        dic[name].update({'%s_instr'%(tag): raw_text})

                    if field == 'i':
                        raw_text = self.reverse_list(raw_text.split('$'))
                        dic[name].update({'%s_ingr'%(tag): raw_text})
        return dic
    
    '''
    exporting data
    '''
    def to_bleu(self):
        to_write = {'%s_i'%(self.ori):'',
                    '%s_i'%(self.gen):'',
                    '%s_d'%(self.ori):'',
                    '%s_d'%(self.gen):''}
        
        for i, v in self.dic.items():
            to_write['%s_i'%(self.ori)] += self.add_space(' $ '.join(v['%s_ingr'%(self.ori)]))+ ' $ \n'
            to_write['%s_i'%(self.gen)] += self.add_space(' $ '.join(v['%s_ingr'%(self.gen)])) + ' $ \n'
            
            to_write['%s_d'%(self.ori)] += self.add_space(v['%s_instr'%(self.ori)])+ '\n'
            to_write['%s_d'%(self.gen)] += self.add_space(v['%s_instr'%(self.gen)])+ '\n'
        
        for k, v in to_write.items():
            save('../../to_gpt2/generation_%s.txt'%(k), v ,overwrite = True)
        !eval {"perl multi-bleu.perl ../../to_gpt2/generation_%s_i.txt < ../../to_gpt2/generation_%s_i.txt" %(self.ori, self.gen)}
        !eval {"perl multi-bleu.perl ../../to_gpt2/generation_%s_d.txt < ../../to_gpt2/generation_%s_d.txt" %(self.ori, self.gen)}
    
        !eval {"rouge -f ../../to_gpt2/generation_%s_i.txt ../../to_gpt2/generation_%s_i.txt --avg"%(self.ori, self.gen)}
        !eval {"rouge -f ../../to_gpt2/generation_%s_d.txt ../../to_gpt2/generation_%s_d.txt --avg"%(self.ori, self.gen)}
        
        print()

    def ingr_f1_freq(self, root = False):
        value = []
        for i, v in tqdm.tqdm(self.dic.items()):
            true, pred = v['%s_ingr'%(self.ori)], v['%s_ingr'%(self.gen)]
            if root:
                true, pred = sp.root(true), sp.root(pred)
            scores = metrics(true, pred)
            value.append(scores.f1_freq())
        avg = sum(value)/len(value)
        print(avg)
        return avg
    '''
    instruction evaluation
    '''
    def instr_tree(self, stem_only = False):
        value = []
        for i, v in tqdm.tqdm(self.dic.items()):
            ori_instr, gen_instr = v['%s_instr'%(self.ori)], v['%s_instr'%(self.gen)]
            score = self.norm_dist(ori_instr, gen_instr, stem_only = stem_only)
            value.append(score)
        avg = sum(value)/len(value)
        print(avg)
        return avg

    def state_f1_freq(self):
        value = []
        for i, v in tqdm.tqdm(self.dic.items()):
            true, pred = v['%s_instr'%(self.ori)], v['%s_instr'%(self.gen)]
            true, pred = sp.match_state(true), sp.match_state(pred)
            scores = metrics(true, pred)
            value.append(scores.f1_freq())
        avg = sum(value)/len(value)
        print(avg)
        return avg
    
    def verb_f1_freq(self):
        value = []
        for i, v in tqdm.tqdm(self.dic.items()):
            true, pred = v['%s_instr'%(self.ori)], v['%s_instr'%(self.gen)]
            true, pred = sp.instructions(true)[1], sp.instructions(pred)[1]
            scores = metrics(true, pred)
            value.append(scores.f1_freq())
        avg = sum(value)/len(value)
        print(avg)
        return avg
    
    '''
    cleaning data
    '''
    def remove_end(self, text):
        return text.replace('\n','').split('<')[0]
    
    def reverse(self, text):
        '''
        Important data cleaning before NY times parser
        '''
        # replace things in brace
        text = re.sub(r'\([^)]*\)', '', text)

        # remove space before punct
        text = re.sub(r'\s([?.!,"](?:\s|$))', r'\1', text)

        # remove consecutive spaces
        text = re.sub(' +',' ',text).strip()
        return text
    
    def reverse_list(self, listoftext):
        output = []
        for text in listoftext:
            rev = self.reverse(text)
            if rev:
                output.append(rev)
        return output
    
    def add_space(self, line):
        # add space before punct
        line = re.sub('([.,!?()])', r' \1 ', line)
        line = re.sub('\s{2,}', ' ', line)
        return line
    
    '''
    tree edit distance
    '''

    def str2tree(self, instr, stem_only):
        instr = [x for x in instr.split('. ') if x]
        instr = treemaker.sents2tree(instr)
        if stem_only:
            instr = stem(instr)
        n_nodes = sum([len(line['ingredient']) +1 for line in instr])
        return build_tree(instr), n_nodes

    def norm_dist(self, ori_instr, gen_instr, stem_only):
        '''
        Args: ori_instr: str
        Args: gen_instr: str
        '''
        ori_tree, ori_nodes = self.str2tree(ori_instr, stem_only = stem_only)
        gen_tree, gen_nodes = self.str2tree(gen_instr, stem_only = stem_only)
        tree_dist = tree_distance(ori_tree, gen_tree)
        normed = tree_dist/(ori_nodes+gen_nodes)
        return normed

time: 86.4 ms


In [3]:
data = evaluation('../../to_gpt2/recipe1M_1118/val/y/', 'ori')
data.append_dic('../../to_gpt2/generation_201911118_k1_val/', 'k1')
data.append_dic('../../to_gpt2/generation_201911118_k3_val/', 'k3')
data.append_dic('../../to_gpt2/generation_201911118_k5_val/', 'k5')
data.append_dic('../../to_gpt2/generation_201911118_k10_val/', 'k10')
data.append_dic('../../to_gpt2/generation_201911118_k30_val/', 'k30')
data.append_dic('../../to_gpt2/generation_201911118_p99_val/', 'p99')

load ../../to_gpt2/recipe1M_1118/val/y/
load ../../to_gpt2/generation_201911118_k1_val/
load ../../to_gpt2/generation_201911118_k3_val/
load ../../to_gpt2/generation_201911118_k5_val/
load ../../to_gpt2/generation_201911118_k10_val/
load ../../to_gpt2/generation_201911118_k30_val/
load ../../to_gpt2/generation_201911118_p99_val/
time: 9.63 s


In [4]:
for tag in data.gens:
    print(tag)
    data.gen = tag
    data.verb_f1_freq()

  0%|          | 0/4000 [00:00<?, ?it/s]

k1


  self.warn()
100%|██████████| 4000/4000 [05:56<00:00, 17.09it/s]
  0%|          | 2/4000 [00:00<04:15, 15.67it/s]

0.3174163362428839
k3


100%|██████████| 4000/4000 [06:00<00:00, 10.85it/s]
  0%|          | 1/4000 [00:00<09:11,  7.25it/s]

0.3104074861916508
k5


100%|██████████| 4000/4000 [06:08<00:00, 10.17it/s]
  0%|          | 1/4000 [00:00<11:19,  5.88it/s]

0.2965441320713552
k10


100%|██████████| 4000/4000 [06:20<00:00,  9.32it/s]
  0%|          | 1/4000 [00:00<08:36,  7.75it/s]

0.28372376728385196
k30


100%|██████████| 4000/4000 [06:29<00:00, 13.45it/s]
  0%|          | 1/4000 [00:00<06:59,  9.53it/s]

0.2745738272235941
p99


100%|██████████| 4000/4000 [06:47<00:00,  9.82it/s]

0.25877475399872046
time: 37min 43s





In [5]:
for tag in data.gens:
    print(tag)
    data.gen = tag
    data.state_f1_freq()

  0%|          | 0/4000 [00:00<?, ?it/s]

k1


100%|██████████| 4000/4000 [05:56<00:00, 11.17it/s]
  0%|          | 1/4000 [00:00<08:58,  7.43it/s]

0.6121037329611493
k3


100%|██████████| 4000/4000 [06:02<00:00, 10.53it/s]
  0%|          | 1/4000 [00:00<09:19,  7.15it/s]

0.6055824848023792
k5


  self.warn()
100%|██████████| 4000/4000 [06:09<00:00, 10.52it/s]
  0%|          | 1/4000 [00:00<11:17,  5.90it/s]

0.5984877121427107
k10


100%|██████████| 4000/4000 [06:15<00:00,  9.52it/s]
  0%|          | 1/4000 [00:00<08:28,  7.87it/s]

0.59289367609037
k30


100%|██████████| 4000/4000 [06:23<00:00, 10.42it/s]
  0%|          | 1/4000 [00:00<10:36,  6.28it/s]

0.5864356461026095
p99


100%|██████████| 4000/4000 [06:45<00:00,  9.87it/s]

0.5716436389072083
time: 37min 33s





In [6]:
data.gen = 'k5'
data.instr_tree(stem_only=False)
data.instr_tree(stem_only=True)

100%|██████████| 4000/4000 [41:45<00:00,  1.29it/s]
  0%|          | 0/4000 [00:00<?, ?it/s]

0.5333438888732768


100%|██████████| 4000/4000 [24:00<00:00,  2.28it/s]

0.4884516793939745





0.4884516793939745

time: 1h 5min 45s
