In [8]:
import os
import re
import itertools
import numpy as np
import pandas as pd

from bert_score import BERTScorer
from scipy.stats import zscore
from collections import defaultdict

from src.config import substitution_rating_file, substitution_rating_scores, substitution_score_file 
from src.config import label_file, substitution_rating_scores, weighing_scheme
from src.config import BERT_F1_dict_file, lst_ref
from src.io import load_pickle, save_pickle

import warnings
warnings.filterwarnings('ignore')

In [9]:
# BERT scores
if os.path.exists(BERT_F1_dict_file):
    sim_dict = load_pickle(BERT_F1_dict_file)
else:
    sim_dict = {}

print(len(sim_dict))
lst_ref = {v: k.split('__')[-1].replace('_', ' ') for k, v in lst_ref.items()}

43678


In [10]:
# some name is written in the users' format for easy user reference
def process_food_name(s1):
    # separators: ", " + any of (integer, decimal & fraction) +" "
    exp = r", \d+\.\d+ |, \d+\,\d+ |, \d+ |, \d+\/\d+ "
    # remove content in parenthesis for finding the separator
    if s1.count('(') == s1.count(')'):
        s2 = re.sub(r'[(].*?[\)]', ' ', s1)
    else:
        s2 = s1
    try:
        split_by = re.findall(exp, s2)[0]
        return clean_name(s1.split(split_by)[0])
    except:
        return clean_name(s2)

def clean_name(name):
    name = name.replace("\t", " ").replace("\n", " ").replace("w/o", " no ").replace("w/", " ")
    return re.sub(' +', ' ', name.strip()).lower()

def token_transform(t):
    tokens = [t, t+'s', t+'es']
    if t[-1] == 'y':
        tokens.append(t[:-1]+'ies')
    return tokens  

In [11]:
df = pd.read_csv(substitution_rating_file)
df['item_10'] = df['item_1'].apply(process_food_name)
df['item_20'] = df['item_2'].apply(process_food_name)

In [12]:
# index -> label
labels = load_pickle(label_file)
label_with_food = defaultdict(list)

def group_food_name(line):
    s = line['food_name']
    for i in line['label_summary']:
        label_with_food[i].append(s)

labels.apply(group_food_name, axis=1)

l0 = set(i for j in labels['label_summary'].tolist() for i in j)
# labels['max'] = labels['label_summary'].apply(lambda s: max(s) if len(s)>0 else 0)

# '_' connected tokens for l0 tags
concat_list = list(zip(labels['cat_info'], labels['label_summary']))
label_index = dict()

for i,j in concat_list:
    label_index.update(zip(i,j))  

index_label = {k:v for v,k in label_index.items()}
label_index_l0 = {i.split('__')[-1]:v for i,v in label_index.items() if len(i.split('__'))==3}
l0_tags = sorted(label_index_l0.keys())

label_name_l0 = {}

for l,i in label_index_l0.items():
    label_name_l0[l] = label_with_food[i]

name_labels = labels.set_index('food_name').to_dict()['cat_info']
name_label_0 = {k:[s.split('__')[-1] for s in v if len(s.split('__'))==3] for k,v in name_labels.items()}

matched_label_l0, unmatched_label_l0 = [], []

for l, s_lst in label_name_l0.items():    
    l_primes = [token_transform(t) for t in l.split('_')]
    l2 = list(itertools.product(*l_primes))
    l_primes = [' '.join(i) for i in l2]  
 
    matched = False
    for s in s_lst:
        s = ' '.join(s.replace("'", '').replace('&', ' ').split())
        for l_prime in l_primes:
            if l_prime in s:            
                matched = True
            
    if matched:
        matched_label_l0.append(l)
    else:
        unmatched_label_l0.append(l)        

# get all labels associated with item
def match_labels(s):
    if s in name_label_0.keys():
        found = name_label_0[s] 
    else:
        found = []
        s = ' '.join(s.replace("'", '').replace('&', ' ').split())
        for l in matched_label_l0:
            matched = False
            l_primes = [token_transform(t) for t in l.split('_')]
            for l_prime in [' '.join(i) for i in list(itertools.product(*l_primes))]:
                if l_prime in s:  
                    matched = True
            if matched:
                found.append(l) 
                        
    # full label names
    all_labels = []
    if len(found) > 0:
        full_label_l0 = [index_label[label_index_l0[l0_label]] for l0_label in found]
        for l0 in full_label_l0:
            l2 = l0.split('__')[0]
            l1 = '__'.join(l0.split('__')[:2])
            all_labels.extend([l2, l1, l0])
    return sorted(set(all_labels))

# we will use the inverse dict later in label_score() instead of lst_ref
# Two variants: 1) using all tags along the branch; 2) using only the last tag
inv_label_index = {v: " ".join(k.replace("__", " ").split("_")) for k, v in label_index.items()} # use all tags along the branch
# inv_label_index = {v: " ".join(k.split("__")[-1].split("_")) for k, v in label_index.items()} # use only the last tag

In [13]:
import time
start = time.time()
df['item_l1'] = df['item_10'].apply(match_labels)
df['item_l2'] = df['item_20'].apply(match_labels)
df['l1'] = df['item_l1'].apply(lambda s: [label_index[i] for i in s] if len(s)>0 else [])
df['l2'] = df['item_l2'].apply(lambda s: [label_index[i] for i in s] if len(s)>0 else [])
end = time.time()
print(end - start)

21.09162139892578


# BERT

In [14]:
import time
start = time.time()
from bert_score import BERTScorer
scorer = BERTScorer(lang="en", rescale_with_baseline=True)
end = time.time()
print(end - start)

Some weights of the model checkpoint at roberta-large were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


13.750503540039062


In [15]:
def get_scores(single_cands, multi_refs):
    P_mul, R_mul, F_mul = scorer.score([single_cands], [multi_refs])
    return F_mul[0]

def sim_score(l0, l1, sim_dict=sim_dict):
    l0, l1 = tuple(sorted([l0, l1]))
    
    if (l0, l1) in sim_dict.keys():
        return sim_dict[(l0, l1)]
    else:
        if l0 == l1: 
            val = 1
        else:
            val = get_scores(l0.lower(), l1.lower())
        sim_dict[(l0, l1)] = val
        return float(val)
    
def label_score(c_cap_j, c_cap_j_prime, weight, sim_dict=sim_dict):
    if len(c_cap_j)==0 or len(c_cap_j_prime)==0:
        return 0
        #return np.nan
    
    numerator = 0.0
    denominator = 0.0

    for c_t in c_cap_j:
        lambda_t = weight[c_t]
        denominator += lambda_t
        max_val = -1
       
        for c_s in c_cap_j_prime:
            #val = sim_score(lst_ref[c_t], lst_ref[c_s], sim_dict=sim_dict)
            val = sim_score(inv_label_index[c_t], inv_label_index[c_s], sim_dict=sim_dict)
            #debug
            print("'%s', '%s', %.4f" % (inv_label_index[c_t], inv_label_index[c_s],val.item()))
            if val > max_val:
                max_val = val
        
        numerator += lambda_t * max_val
    
    return numerator / denominator

In [45]:
import spacy
nlp = spacy.load('en_core_web_lg')

In [53]:
s1 = "a quick brown fox jumps over the lazy dog"
s2 = "dog the bounty hunter gave an interview to fox news"
s1 = "I bought cheese burger from my mom's favorite restaurant"
s2 = "beef sandwich is made by my mother"
print(nlp(s1).similarity(nlp(s2)))

print(get_scores(s1, s2).item())

0.8579623458539996
0.42279067635536194


In [69]:
# Debugging
weighing='equal'

def debug_pair_sim_scores(line, weight=weighing_scheme[weighing], sim_dict=sim_dict):    
    if(len(line['l1'])==0) or (len(line['l2'])==0):
        return 0
    
    j = line['l1']
    j_prime = line['l2']
    
    if j_prime == j:
        return 1
    else:
        curr_score = label_score(j_prime, j, weight, sim_dict=sim_dict)
    return float(curr_score)

temp_df = df[39:40]
print(temp_df.apply(debug_pair_sim_scores, axis=1))
temp_df

tensor(-0.0785)
'fruit', 'herb spice', 0.1982
'fruit', 'herb spice peppers', 0.1223
'fruit', 'herb spice peppers pepper', 0.0354
'fruit', 'meat', 0.6130
'fruit', 'meat sausage', 0.3689
'fruit', 'meat sausage pepperoni', 0.1651
'fruit', 'staple', 0.2437
'fruit', 'staple wheat', 0.0813
'fruit', 'staple wheat pizza', 0.0007
'fruit temperate', 'herb spice', 0.2265
'fruit temperate', 'herb spice peppers', 0.1227
'fruit temperate', 'herb spice peppers pepper', 0.0419
'fruit temperate', 'meat', 0.3063
'fruit temperate', 'meat sausage', 0.2764
'fruit temperate', 'meat sausage pepperoni', 0.2212
'fruit temperate', 'staple', 0.2162
'fruit temperate', 'staple wheat', 0.0788
'fruit temperate', 'staple wheat pizza', 0.0517
'fruit temperate apple', 'herb spice', 0.2460
'fruit temperate apple', 'herb spice peppers', 0.2058
'fruit temperate apple', 'herb spice peppers pepper', 0.1837
'fruit temperate apple', 'meat', 0.2252
'fruit temperate apple', 'meat sausage', 0.3011
'fruit temperate apple', 'meat 

Unnamed: 0,user,item_1,item_2,rating,item_10,item_20,item_l1,item_l2,l1,l2
39,1,Costco - Pepperoni Pizza,apple,1,costco - pepperoni pizza,apple,"[herb_spice, herb_spice__peppers, herb_spice__...","[fruit, fruit__temperate, fruit__temperate__ap...","[726, 744, 746, 762, 841, 847, 1103, 1147, 1199]","[661, 680, 681]"


In [14]:
weighing='equal'
col = 'hSim-1'

def pair_sim_scores(line, weight=weighing_scheme[weighing], sim_dict=sim_dict):
    if(len(line['l1'])==0) or (len(line['l2'])==0):
        return 0
    
    j = line['l1']
    j_prime = line['l2']

    if j_prime == j:
        return 1
    else:
        curr_score = label_score(j_prime, j, weight, sim_dict=sim_dict)
    return float(curr_score)

import time
start = time.time()
df[col] = df.apply(pair_sim_scores, axis=1)
end = time.time()
print(end - start)

2772.4561104774475


In [15]:
weighing='124'
col = 'hSim-2'

def pair_sim_scores(line, weight=weighing_scheme[weighing], sim_dict=sim_dict):
    if(len(line['l1'])==0) or (len(line['l2'])==0):
        return 0
    
    j = line['l1']
    j_prime = line['l2']

    if j_prime == j:
        return 1
    else:
        curr_score = label_score(j_prime, j, weight, sim_dict=sim_dict)
    return float(curr_score)

import time
start = time.time()
df[col] = df.apply(pair_sim_scores, axis=1)
end = time.time()
print(end - start)

1.191934585571289


In [16]:
weighing='freq'
col = 'hSim-freq'

def pair_sim_scores(line, weight=weighing_scheme[weighing], sim_dict=sim_dict):
    if(len(line['l1'])==0) or (len(line['l2'])==0):
        return 0

    j = line['l1']
    j_prime = line['l2']

    if j_prime == j:
        return 1
    else:
        curr_score = label_score(j_prime, j, weight, sim_dict=sim_dict)
    return float(curr_score)

import time
start = time.time()
df[col] = df.apply(pair_sim_scores, axis=1)
end = time.time()
print(end - start)

1.0926084518432617


In [17]:
for col in ['hSim-1', 'hSim-2', 'hSim-freq']:
#     df.loc[df[col]==-1.0, col] = np.nan
    df.loc[df[col]==-1.0, col] = 0.0
cols = ['user', 'item_1', 'item_2', 'rating', 'hSim-1', 'hSim-2', 'hSim-freq']
df = df[cols]
filename = substitution_rating_scores['hSim']
df.to_csv(filename, index=False)

In [None]:
save_pickle(BERT_F1_dict_file, sim_dict)

# Normalization

In [18]:
cols = ['rating_z','hSim-1', 'hSim-2', 'hSim-freq']
dfs = []
for u, df_temp in df.groupby('user'):
    df_temp['rating_z'] = zscore(df_temp['rating'])
    dfs.append(df_temp)
d1_z = pd.concat(dfs)
r2 = d1_z.dropna()[cols].corr().head(1).round(3)
print(' & '.join([str(s) for s in r2.values[0][1:]]), '\n')
r2

0.28 & 0.28 & 0.269 



Unnamed: 0,rating_z,hSim-1,hSim-2,hSim-freq
rating_z,1.0,0.28,0.28,0.269


# Collate all ratings

In [19]:
dfs = []
for metric, filename in substitution_rating_scores.items():
    df_temp = pd.read_csv(filename)
    cols = ['user', 'item_1', 'item_2', 'rating']
    dfs.append(df_temp.set_index(cols))

In [20]:
# df_all = pd.concat(dfs, axis=1).reset_index().drop(columns=['Unnamed: 0'])
df_all = pd.concat(dfs, axis=1).reset_index().drop(columns=['Unnamed: 0'])
df_all.to_csv(substitution_score_file, index=False)