In [56]:
import os
import re
import itertools
import pandas as pd

from bert_score import BERTScorer
from collections import defaultdict

from src.config import substitution_rating_file, preference_rating_file
from src.config import preference_rating_scores
from src.config import label_file, label_file_full, weighing_scheme
from src.config import BERT_F1_dict_file, lst_ref
from src.io import load_pickle, save_pickle

import warnings
warnings.filterwarnings('ignore')

In [57]:
import time
start = time.time()

# BERT scores
if os.path.exists(BERT_F1_dict_file):
    sim_dict = load_pickle(BERT_F1_dict_file)
else:
    sim_dict = {}

print(len(sim_dict))
lst_ref = {v: k.split('__')[-1].replace('_', ' ') for k, v in lst_ref.items()}

end = time.time()
print(end - start)

43678
3.457052707672119


In [58]:
def process_food_name(s1):
    # separators: ", " + any of (integer, decimal & fraction) +" "
    exp = r", \d+\.\d+ |, \d+\,\d+ |, \d+ |, \d+\/\d+ "
    # remove content in parenthesis for finding the separator
    if s1.count('(') == s1.count(')'):
        s2 = re.sub(r'[(].*?[\)]', ' ', s1)
    else:
        s2 = s1
    try:
        split_by = re.findall(exp, s2)[0]
        return clean_name(s1.split(split_by)[0])
    except:
        return clean_name(s2)

def clean_name(name):
    name = name.replace("\t", " ").replace("\n", " ").replace("w/o", " no ").replace("w/", " ")
    return re.sub(' +', ' ', name.strip()).lower()

In [59]:
import time
start = time.time()

# ground truth
df2 = pd.read_csv(substitution_rating_file)[['user', 'item_1']].drop_duplicates().reset_index(drop=True)
df2 = df2[['user', 'item_1']]
df2['item_gt'] = df2['item_1'].apply(process_food_name)

# predictions
df = pd.read_csv(preference_rating_file)
df['item'] = df['choice'].apply(process_food_name)
gt = df2.groupby('user')['item_gt'].apply(list).to_dict()
df['gt'] = df['user'].map(gt)
end = time.time()
print(end - start)

0.03760385513305664


In [60]:
labels = load_pickle(label_file)
concat_list = list(zip(labels['cat_info'], labels['label_summary']))
label_index = dict()
for i,j in concat_list:
    label_index.update(zip(i,j))  

index_label = {k:v for v,k in label_index.items()}

# '_' connected tokens for l0 tags
label_index_l0 = {i.split('__')[-1]:v for i,v in label_index.items() if len(i.split('__'))==3}
l0_tags = sorted(label_index_l0.keys())

label_with_food = defaultdict(list)
def group_food_name(line):
    s = line['food_name']
    for i in line['label_summary']:
        label_with_food[i].append(s)
labels.apply(group_food_name, axis=1) 

label_name_l0 = {}
for l,i in label_index_l0.items():
    label_name_l0[l] = label_with_food[i]

name_labels = labels.set_index('food_name').to_dict()['cat_info']
name_label_0 = {k:[s.split('__')[-1] for s in v if len(s.split('__'))==3] for k,v in name_labels.items()}

matched_label_l0 = []
unmatched_label_l0 = []

def token_transform(t):
    tokens = [t, t+'s', t+'es']
    if t[-1] == 'y':
        tokens.append(t[:-1]+'ies')
    return tokens  

for l, s_lst in label_name_l0.items():    
    l_primes = [token_transform(t) for t in l.split('_')]
    l2 = list(itertools.product(*l_primes))
    l_primes = [' '.join(i) for i in l2]  
 
    matched = False
    for s in s_lst:
        s = ' '.join(s.replace("'", '').replace('&', ' ').split())
        for l_prime in l_primes:
            if l_prime in s:            
                matched = True
            
    if matched:
        matched_label_l0.append(l)
    else:
        unmatched_label_l0.append(l)  

perc = len(unmatched_label_l0)/(len(matched_label_l0 )+ len(unmatched_label_l0)) 
if perc < 0.05:
    print('Less than 5% labels are not matched: {:.2%}'.format(perc))

# get all labels associated with item
def match_labels(s):
    if s in name_label_0.keys():
        found = name_label_0[s] 
    else:
        found = []
        s = ' '.join(s.replace("'", '').replace('&', ' ').split())
        for l in matched_label_l0:
            matched = False
            l_primes = [token_transform(t) for t in l.split('_')]
            for l_prime in [' '.join(i) for i in list(itertools.product(*l_primes))]:
                if l_prime in s:  
                    matched = True
            if matched:
                found.append(l) 
                        
    # full label names
    all_labels = []
    if len(found) > 0:
        full_label_l0 = [index_label[label_index_l0[l0_label]] for l0_label in found]
        for l0 in full_label_l0:
            l2 = l0.split('__')[0]
            l1 = '__'.join(l0.split('__')[:2])
            all_labels.extend([l2, l1, l0])
    return sorted(set(all_labels))

# we will use the inverse dict later in label_score() instead of lst_ref
# Two variants: 1) using all tags along the branch; 2) using only the last tag
inv_label_index = {v: " ".join(k.replace("__", " ").split("_")) for k, v in label_index.items()} # use all tags along the branch
# inv_label_index = {v: " ".join(k.split("__")[-1].split("_")) for k, v in label_index.items()} # use only the last tag

Less than 5% labels are not matched: 1.33%


In [61]:
df.loc[:, 'label']  = df['item'].apply(match_labels)
df['label_summary'] = df['label'].apply(lambda s: [label_index[i] for i in s] if len(s)>0 else [])

df2.loc[:, 'label']  = df2['item_gt'].apply(match_labels)
df2['label_summary'] = df2['label'].apply(lambda s: [label_index[i] for i in s] if len(s)>0 else [])
df2 = df2[['user', 'item_1', 'label', 'label_summary']]

gt_lst = df2.groupby('user')['label_summary'].apply(list).to_dict()
df['l2'] = df['user'].map(gt_lst)

df1 = df.groupby(['user', 'qn'])['label_summary'].apply(list).reset_index().rename(columns={'label_summary': 'l2'})
df1 = df2.merge(df1, on='user', how='outer')

In [62]:
labels = load_pickle(label_file_full)
lst = sorted(set([item for sublist in labels['labels'].tolist() for item in sublist]))
lst_ref = {v: k.split('__')[-1].replace('_', ' ') for v, k in enumerate(lst)}

# BERT

In [63]:
import time
start = time.time()
scorer = BERTScorer(lang="en", rescale_with_baseline=True)
end = time.time()
print(end - start)

Some weights of the model checkpoint at roberta-large were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


14.808554649353027


In [64]:
# def get_scores(single_cands, multi_refs):
#     P_mul, R_mul, F_mul = scorer.score([single_cands], [multi_refs])
#     return F_mul[0]

# def sim_score(l0, l1, sim_dict=sim_dict):
#     l0, l1 = tuple(sorted([l0, l1]))
#     if (l0, l1) in sim_dict.keys():
#         return sim_dict[(l0, l1)]
#     else:
#         if l0 == l1: 
#             val = 1
#         else:
#             val = get_scores(l0, l1)
#         sim_dict[(l0, l1)] = val
#         return float(val)
    
# def label_score(c_cap_j, c_cap_j_prime, weight, sim_dict=sim_dict):
#     numerator = 0.0
#     denominator = 0.0

#     for c_t in c_cap_j:
    
#         lambda_t = weight[c_t]
#         denominator += lambda_t

#         max_val = -1
#         for c_s in c_cap_j_prime:
#             val = sim_score(lst_ref[c_t], lst_ref[c_s], sim_dict=sim_dict)
#             if val > max_val:
#                 max_val = val
#         numerator += lambda_t * max_val
        
#     if numerator == 0.0:
#         return -1
        
#     return numerator / denominator  

def get_scores(single_cands, multi_refs):
    P_mul, R_mul, F_mul = scorer.score([single_cands], [multi_refs])
    return F_mul[0]

def sim_score(l0, l1, sim_dict=sim_dict):
    l0, l1 = tuple(sorted([l0, l1]))
    
    if (l0, l1) in sim_dict.keys():
        return sim_dict[(l0, l1)]
    else:
        if l0 == l1: 
            val = 1
        else:
            val = get_scores(l0.lower(), l1.lower())
        sim_dict[(l0, l1)] = val
        return float(val)
    
def label_score(c_cap_j, c_cap_j_prime, weight, sim_dict=sim_dict):
    if len(c_cap_j)==0 or len(c_cap_j_prime)==0:
        return 0
        #return np.nan
    
    numerator = 0.0
    denominator = 0.0

    for c_t in c_cap_j:
        lambda_t = weight[c_t]
        denominator += lambda_t
        max_val = -1
       
        for c_s in c_cap_j_prime:
            #val = sim_score(lst_ref[c_t], lst_ref[c_s], sim_dict=sim_dict)
            val = sim_score(inv_label_index[c_t], inv_label_index[c_s], sim_dict=sim_dict)
            if val > max_val:
                max_val = val
        
        numerator += lambda_t * max_val
    
    return numerator / denominator

def pair_sim_scores(j, j_prime, weight, sim_dict):
    if(len(j)==0) or (len(j_prime)==0):
        return 0
    
    if j_prime == j:
        return 1
    else:
        curr_score = label_score(j, j_prime, weight, sim_dict=sim_dict)
    return float(curr_score)

# Debugging

In [65]:
# Debugging

weighing='equal'
col = 'hSim-1'

def lst_sim_scores(line, weight=weighing_scheme[weighing], sim_dict=sim_dict):
    j = line['label_summary']
    j_primes = line['l2']
    pair_sim_score = 0
    for j_prime in j_primes:
        s = pair_sim_scores(j, j_prime, weight, sim_dict)
        if s > pair_sim_score:
            pair_sim_score = s
    return pair_sim_score

# df1[col] = df1.apply(lst_sim_scores, axis=1)
# df1.sort_values(by=["user", "qn", "item_1"])[0:15]

In [66]:
weighing='equal'
col = 'hSim-1'

def lst_sim_scores(line, weight=weighing_scheme[weighing], sim_dict=sim_dict):
    j = line['label_summary']
    j_primes = line['l2']
    pair_sim_score = 0
    for j_prime in j_primes:
        s = pair_sim_scores(j, j_prime, weight, sim_dict)
        if s > pair_sim_score:
            pair_sim_score = s
    return pair_sim_score

import time
start = time.time()
df[col] = df.apply(lst_sim_scores, axis=1)
df1[col] = df1.apply(lst_sim_scores, axis=1)
end = time.time()
print(end - start)

5.8023903369903564


In [67]:
weighing='124'
col = 'hSim-2'

def lst_sim_scores(line, weight=weighing_scheme[weighing], sim_dict=sim_dict):
    j = line['label_summary']
    j_primes = line['l2']
    pair_sim_score = 0
    for j_prime in j_primes:
        s = pair_sim_scores(j, j_prime, weight, sim_dict)
        if s > pair_sim_score:
            pair_sim_score = s
    return pair_sim_score

import time
start = time.time()
df[col] = df.apply(lst_sim_scores, axis=1)
df1[col] = df1.apply(lst_sim_scores, axis=1)
end = time.time()
print(end - start)

6.216416835784912


In [68]:
weighing='freq'
col = 'hSim-freq'
def lst_sim_scores(line, weight=weighing_scheme[weighing], sim_dict=sim_dict):
    j = line['label_summary']
    j_primes = line['l2']
    pair_sim_score = 0
    for j_prime in j_primes:
        s = pair_sim_scores(j, j_prime, weight, sim_dict)
        if s > pair_sim_score:
            pair_sim_score = s
    return pair_sim_score

import time
start = time.time()
df[col] = df.apply(lst_sim_scores, axis=1)
df1[col] = df1.apply(lst_sim_scores, axis=1)
end = time.time()
print(end - start)

6.044835805892944


In [69]:
# precision
cols = ['user', 'qn', 'choice', 'rating', 'hSim-1', 'hSim-2', 'hSim-freq']
d = df[cols]
d.columns = ['user', 'qn', 'rec_item', 'rating', 'hP-Sim-1', 'hP-Sim-2', 'hP-Sim-freq']
d.to_csv(preference_rating_scores['hP-Sim'], index=False)

In [70]:
# recall
cols = ['user', 'qn', 'item_1', 'hSim-1', 'hSim-2', 'hSim-freq']
d = df1[cols]
d.columns = ['user', 'qn', 'gt_item', 'hR-Sim-1', 'hR-Sim-2', 'hR-Sim-freq']
d.to_csv(preference_rating_scores['hR-Sim'], index=False)

In [71]:
if len(sim_dict)>0:
    save_pickle(BERT_F1_dict_file, sim_dict)