In [1]:
import subprocess
from rouge import Rouge, FilesRouge
from torchnlp._third_party.lazy_loader import LazyLoader
from torchnlp.metrics import get_moses_multi_bleu

def full_moses_multi_bleu(original, generation):
    '''

    Note: need to add space before punctuations
    :original: string
    :generation: string
    '''
    filereference = '../../to_gpt2/generation_ori.txt'
    filehypothesis = '../../to_gpt2/generation_gen.txt'
    save(filereference, original ,overwrite = True, print_ = False)
    save(filehypothesis, generation ,overwrite = True, print_ = False)

    bleu_cmd = ['perl', 'multi-bleu.perl', filehypothesis]
    with open(filereference, "r") as read_pred:
        bleu_out = subprocess.check_output(bleu_cmd, stdin=read_pred, stderr=subprocess.STDOUT)
        bleu_out = bleu_out.decode("utf-8")
        
        
        BLEU = re.search(r"BLEU = (.+?),", bleu_out).group(1)
        bleu_out = bleu_out.split(', ')
        B1, B2, B3, B4, BP = bleu_out[-4].replace(' (BP=', '/').split('/')
        ratio = bleu_out[-3].replace('ratio=', '')
        hyp_len=bleu_out[-2].replace('hyp_len=', '')
        ref_len=bleu_out[-1].replace('ref_len=', '').replace(')\n', '')

        ans = {'BLEU': BLEU, 
               'B1': B1, 'B2':B2, 'B3':B3, 'B4':B4, 'BP': BP, 
               'ratio':ratio, 'hyp_len': hyp_len, 'ref_len':ref_len}
        ans = {i: float(v) for i, v in ans.items()}
    return ans

In [2]:
from dependency import parent_dir
from common.basics import *
from common.save import save_pickle, load_pickle, save
import spacy

In [3]:
from utils.evaluation import evaluation
from utils.metrics import metrics

In [4]:
database = load_pickle('../big_data/database.pickle')

In [5]:
class ev(evaluation):
    """ load the generation results and ground truth, then calculate the specified metrics """
    def __init__(self, filename, tag):
        '''
        Args:
          filename: A directory to the files
          tag: A name we assign to the directory, can be any string
        
        '''
        self.dic = self.load_dic({}, filename, tag)
        self.ori = tag
        self.gens = []
        self.sp = spacy.load('en_core_web_lg')
        self.rouge = Rouge()
        self.filesrouge = FilesRouge()

    def ingr_f1(self, root = True):
        return {'ingr_F1':5}
        value, number = [], []
        for i, v in tqdm.tqdm(self.dic.items()):
            true, pred = v['%s_ingr'%(self.ori)], v['%s_ingr'%(self.gen)]
            if root:
                true, pred = self.ingr(true), self.ingr(pred)
            scores = metrics(true, pred)
            value.append(scores.f1())
            number.append(len(set(pred)))
        avg = sum(value)/len(value)
        return {'ingredient_f1':avg, 'average ingr number':sum(number)/len(number) } 
        
    def jaccard(self, generate = 'ingr'):
        assert generate in ['ingr','instr','human']
        jaccard = []
        for i, v in tqdm.tqdm(self.dic.items()):
            if generate == 'ingr':
                true, pred = v['%s_ingr'%(self.ori)], v['%s_ingr'%(self.gen)]
                true, pred = self.ingr(true), self.ingr(pred)
            elif generate == 'instr':
                true, pred = v['%s_ingr'%(self.ori)], v['%s_instr'%(self.gen)]
                true, pred = self.ingr(true), self.instr(pred)
            elif generate == 'human':   
                true, pred = v['%s_ingr'%(self.ori)], v['%s_instr'%(self.ori)]
                true, pred = self.ingr(true), self.instr(pred)
            true, pred = set(true), set(pred)
            
            intersect = len(true & pred)
            similarity = intersect /(len(true)+len(pred) - intersect)
            jaccard.append(similarity)
    
        return {'jaccard similarity': sum(jaccard)/len(jaccard)}
    
    def instr(self, directions):
        instr = sp.sp(directions)
        root_instr = []
        for chunk in instr.noun_chunks:
            idx_rootnoun = chunk.end - 1
            str_rootnoun = instr[idx_rootnoun].lemma_
            if str_rootnoun in database:
                root_instr.append(str_rootnoun)
        return root_instr
    
    def ingr(self, lst):
        '''
        Args: lst: a list of ingredient names
        used when len(lst) must equal to root_match
        '''
        hl = [[{'text':x, 'highlight': None} for x in i.split(' ')] for i in lst]
        root_match = []
        for i, ingr in enumerate(lst):
            if ' ' not in ingr:
                hl[i][0]['highlight'] = 'wrong'
                doc = self.sp(ingr)
                root_match.append(doc[0].lemma_)
            else:
                phrase = 'Mix the %s and water.'%ingr
                doc = self.sp(phrase)
                
                last_chunk = None
                for chunk in doc.noun_chunks:
                    if chunk.text != 'water':
                        last_chunk = chunk
                if not last_chunk:
                    root_match.append('CANNOT_DETECT')
                else:
                    found = False
                    for j, word in enumerate(hl[i]):
                        if doc[last_chunk.end - 1].text in word['text']:
                            hl[i][j]['highlight'] = 'wrong' 
                            root_match.append(doc[last_chunk.end - 1].lemma_)
                            found = True
                            break
                    if not found:
                        root_match.append('CANNOT_DETECT')
                        
        assert len(root_match) == len(lst)
        return root_match

    
    def to_bleu(self):
        to_write = {'%s_i'%(self.ori):'',
                    '%s_i'%(self.gen):'',
                    '%s_d'%(self.ori):'',
                    '%s_d'%(self.gen):''}
        
        for i, v in self.dic.items():
            to_write['%s_i'%(self.ori)] += self.add_space(' $ '.join(v['%s_ingr'%(self.ori)]))+ ' $ \n'
            to_write['%s_i'%(self.gen)] += self.add_space(' $ '.join(v['%s_ingr'%(self.gen)])) + ' $ \n'
            
            to_write['%s_d'%(self.ori)] += self.add_space(v['%s_instr'%(self.ori)])+ '\n'
            to_write['%s_d'%(self.gen)] += self.add_space(v['%s_instr'%(self.gen)])+ '\n'
        
        for k, v in to_write.items():
            save('../../to_gpt2/generation_%s.txt'%(k), v ,overwrite = True)
        !eval {"perl multi-bleu.perl ../../to_gpt2/generation_%s_i.txt < ../../to_gpt2/generation_%s_i.txt" %(self.ori, self.gen)}
        !eval {"perl multi-bleu.perl ../../to_gpt2/generation_%s_d.txt < ../../to_gpt2/generation_%s_d.txt" %(self.ori, self.gen)}
    
        !eval {"rouge -f ../../to_gpt2/generation_%s_i.txt ../../to_gpt2/generation_%s_i.txt --avg"%(self.ori, self.gen)}
        !eval {"rouge -f ../../to_gpt2/generation_%s_d.txt ../../to_gpt2/generation_%s_d.txt --avg"%(self.ori, self.gen)}
        
        print()
        
    def full_bleu(self):
        ori, gen = '',''
        for i, v in self.dic.items():            
            ori += self.add_space(v['%s_instr'%(self.ori)])+ '\n'
            gen += self.add_space(v['%s_instr'%(self.gen)])+ '\n'
        ans = full_moses_multi_bleu(gen, ori)
        
        filehypothesis = '../../to_gpt2/generation_gen.txt'        
        filereference = '../../to_gpt2/generation_ori.txt'
        scores = self.filesrouge.get_scores(filehypothesis, filereference, avg = True)
        ans.update({'R-L': scores['rouge-l']['f']})
        return ans

## Table 1a

In [None]:
data = ev('../../to_gpt2/recipe1M_1218/val/y/', 'ori')
data.append_dic('../../to_gpt2/val/generation_1220_k1_val/', 'k1')
#data.append_dic('../../to_gpt2/val/generation_1220_k3_val/', 'k3')
#data.append_dic('../../to_gpt2/val/generation_1220_k5_val/', 'k5')
#data.append_dic('../../to_gpt2/val/generation_1220_k10_val/', 'k10')
#data.append_dic('../../to_gpt2/val/generation_1220_k30_val/', 'k30')
results = {}
for tag in data.gens:
    print(tag)
    data.gen = tag
    ans = data.full_bleu()
    ans.update(data.ingr_f1(root=True))
    ans.update(data.instr_tree(stem_only = False))
    results[tag] = ans
display(pd.DataFrame(results).T)

load ../../to_gpt2/recipe1M_1218/val/y/
load ../../to_gpt2/val/generation_1220_k1_val/
k1


  5%|▍         | 197/4000 [01:08<19:32,  3.24it/s]

## Table 1b

In [18]:
data = ev('../../to_gpt2/recipe1M_1218/test/y/', 'ori')
data.append_dic('../../to_gpt2/generation_1220_k3_test/', '117M')
data.append_dic('../../to_gpt2/generation_scratch_k3_test/', 'scratch')
data.append_dic('../../to_gpt2/generation_medium_k3_test/', '345M')
for tag in data.gens:
    print(tag)
    data.gen = tag
    data.ingr_f1(root=True)

load ../../to_gpt2/recipe1M_1218/test/y/


KeyboardInterrupt: 

In [None]:
for tag in data.gens:
    print(tag)
    data.gen = tag
    data.to_bleu()

In [None]:
for tag in data.gens:
    print(tag)
    data.gen = tag
    data.instr_tree(stem_only = False)

## Jaccard similarity

In [None]:
data = ev('../../to_gpt2/recipe1M_1218/test/y/', 'ori')
data.append_dic('../../to_gpt2/generation_1220_k3_test/', '117M')
data.jaccard(generate = 'ingr')
data.jaccard(generate = 'instr')
data.jaccard(generate = 'human')