In [1]:
## Import Packages:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import theano
import theano.tensor as T
import keras
from keras import backend as K
from keras import initializers
from keras.regularizers import l1, l2, l1_l2
from keras.models import Sequential, Model
from keras.layers.core import Dense, Lambda, Activation
from keras.layers import Embedding, Input, Dense, Concatenate, Reshape, Multiply, Flatten, Dropout
from keras.optimizers import Adagrad, Adam, SGD, RMSprop
from evaluate import evaluate_model
from Dataset import Dataset
from time import time
import sys
import GMF, MLP
import argparse
from tqdm import tqdm

Using TensorFlow backend.


## Read in toronto user-item interaction csv file to generate dataset:

In [None]:
toronto_user_item_df = pd.read_csv('../yelp_dataset/toronto_user_item.csv', index_col = 0)

In [96]:
toronto_user_item_df.head(3)

Unnamed: 0,business_id,name,address,city,state,postal_code,latitude,longitude,stars_business_avg,review_count,...,categories,hours,review_id,user_id,stars,useful,funny,cool,text,date
0,g6AFW-zY0wDvBl9U82g4zg,Baretto Caffe,1262 Don Mills Road,toronto,ON,M3B 2W7,43.744703,-79.346468,5.0,317,...,"Restaurants, Italian, Cafes","{'Monday': '7:30-18:0', 'Tuesday': '7:30-18:0'...",SKBNW4QKNiclQ6mB2AQ8MQ,q3JSVBWICgXfO-zuLAp5fg,3.0,0,0,0,The customer service is on point. The food was...,2018-10-04 10:57:11
1,g6AFW-zY0wDvBl9U82g4zg,Baretto Caffe,1262 Don Mills Road,toronto,ON,M3B 2W7,43.744703,-79.346468,5.0,317,...,"Restaurants, Italian, Cafes","{'Monday': '7:30-18:0', 'Tuesday': '7:30-18:0'...",0dsaJN8eljlYRCqPWN1JCQ,0zW0RwIRwyJ6Qdirqvs5gA,5.0,0,0,0,The staff and workers are really friendly and ...,2017-04-30 13:40:40
2,g6AFW-zY0wDvBl9U82g4zg,Baretto Caffe,1262 Don Mills Road,toronto,ON,M3B 2W7,43.744703,-79.346468,5.0,317,...,"Restaurants, Italian, Cafes","{'Monday': '7:30-18:0', 'Tuesday': '7:30-18:0'...",aPUINDQsgifg_hSROs4TTA,eurxcv4blzrEs7-IgLGt5w,5.0,0,0,0,This is one great cafe. A little hard to find ...,2015-03-18 22:16:23


In [93]:
toronto_user_item_df.query("categories == 'Window Washing, Restaurants, Pizza, Home Services, Home Cleaning'")['text']

432502    We had a very bad experience with this company...
432503    I also had a very bad experience with this com...
432504    This cleaner stayed 2.5 hrs, for $180. They do...
Name: text, dtype: object

In [95]:
toronto_user_item_df.loc[432502]['text']

"We had a very bad experience with this company.  We never knew when they place had been cleaned, except for the smell and whether or not they'd left comment cards and chocolates.  By the end, they weren't even leaving those.  Al the while, happy to take our money and do a worse and worse job until the last time when they did almost nothing at all.\n\nSave your money and find a more professional, trustworthy company."

In [89]:
toronto_user_item_df.categories.unique

<bound method Series.unique of 0                       Restaurants, Italian, Cafes
1                       Restaurants, Italian, Cafes
2                       Restaurants, Italian, Cafes
3                       Restaurants, Italian, Cafes
4                       Restaurants, Italian, Cafes
                            ...                    
432509    Restaurants, Fast Food, Delis, Sandwiches
432510    Restaurants, Fast Food, Delis, Sandwiches
432511                  Food, Restaurants, Bakeries
432512                  Food, Restaurants, Bakeries
432513                  Food, Restaurants, Bakeries
Name: categories, Length: 432514, dtype: object>

In [3]:
sum(toronto_user_item_df.groupby('user_id')['name'].count() >= 10)

7905

### Filter the dataset to contain only the users who have reviewed at least 10 restaurants or more

In [99]:
grouped = toronto_user_item_df.groupby('user_id')
toronto_user_item_filtered_df = grouped.filter(lambda x: x['name'].count() >= 10)

In [100]:
dataset_to_use = toronto_user_item_filtered_df.copy()

In [101]:
print('There are %d unique users and %d unique items in the dataset after filtering such that each user has \
reviewed at least 10 restaurants.'%(dataset_to_use.user_id.nunique(), dataset_to_use.business_id.nunique()))

There are 7905 unique users and 8546 unique items in the dataset after filtering such that each user has reviewed at least 10 restaurants.


In [102]:
dataset_to_use.head(3)

Unnamed: 0,business_id,name,address,city,state,postal_code,latitude,longitude,stars_business_avg,review_count,...,categories,hours,review_id,user_id,stars,useful,funny,cool,text,date
3,g6AFW-zY0wDvBl9U82g4zg,Baretto Caffe,1262 Don Mills Road,toronto,ON,M3B 2W7,43.744703,-79.346468,5.0,317,...,"Restaurants, Italian, Cafes","{'Monday': '7:30-18:0', 'Tuesday': '7:30-18:0'...",l8FlUGAgrAAOIi0fWV3Lgg,ZWpLKIbOC5xjuPWc7ZKe9Q,5.0,0,0,0,"Wonderful spaghetti, simple yet clean environm...",2018-09-03 18:13:28
5,g6AFW-zY0wDvBl9U82g4zg,Baretto Caffe,1262 Don Mills Road,toronto,ON,M3B 2W7,43.744703,-79.346468,5.0,317,...,"Restaurants, Italian, Cafes","{'Monday': '7:30-18:0', 'Tuesday': '7:30-18:0'...",N_UO6AguthYg7lK2NoduZA,GGI39_EL1ERSqyWX1tEjMA,5.0,11,3,7,A hidden gem near my home. Found this place wh...,2017-08-16 19:45:54
6,g6AFW-zY0wDvBl9U82g4zg,Baretto Caffe,1262 Don Mills Road,toronto,ON,M3B 2W7,43.744703,-79.346468,5.0,317,...,"Restaurants, Italian, Cafes","{'Monday': '7:30-18:0', 'Tuesday': '7:30-18:0'...",I_nbSUj8mv0BB9Zgx6--UQ,x0cMhVpUcYYHoLdrWSNIMg,5.0,3,0,0,Ambiance/decor- 4\nService- 5+\nFood - 5\nStri...,2015-10-09 00:33:14


There are a substantial number of users who have reviewed at least 10 restaurants

### Add NER to reviews

In [32]:
from deeppavlov import configs, build_model
ner_model = build_model(configs.ner.ner_ontonotes_bert, download=True)
#ner_model = build_model(configs.ner.conll2003_m1, download=True)
# test:
ner_model(['Bob Ross lived in Florida'])

2020-11-23 14:46:49.359 INFO in 'deeppavlov.download'['download'] at line 138: Skipped http://files.deeppavlov.ai/deeppavlov_data/ner_ontonotes_bert_v1.tar.gz download because of matching hashes
INFO:deeppavlov.download:Skipped http://files.deeppavlov.ai/deeppavlov_data/ner_ontonotes_bert_v1.tar.gz download because of matching hashes
2020-11-23 14:46:50.309 INFO in 'deeppavlov.download'['download'] at line 138: Skipped http://files.deeppavlov.ai/deeppavlov_data/bert/cased_L-12_H-768_A-12.zip download because of matching hashes
INFO:deeppavlov.download:Skipped http://files.deeppavlov.ai/deeppavlov_data/bert/cased_L-12_H-768_A-12.zip download because of matching hashes
2020-11-23 14:46:50.376 INFO in 'deeppavlov.core.data.simple_vocab'['simple_vocab'] at line 115: [loading vocabulary from C:\Users\chris\.deeppavlov\models\ner_ontonotes_bert\tag.dict]
INFO:deeppavlov.core.data.simple_vocab:[loading vocabulary from C:\Users\chris\.deeppavlov\models\ner_ontonotes_bert\tag.dict]
2020-11-23 1

INFO:tensorflow:Restoring parameters from C:\Users\chris\.deeppavlov\models\ner_ontonotes_bert\model


INFO:tensorflow:Restoring parameters from C:\Users\chris\.deeppavlov\models\ner_ontonotes_bert\model


[[['Bob', 'Ross', 'lived', 'in', 'Florida']],
 [['B-PERSON', 'I-PERSON', 'O', 'O', 'B-GPE']]]

In [85]:
list(toronto_user_item_df.text[toronto_user_item_df.stars>=3.0])[100:105]

['Visited Baretto for lunch on a weekday. The cafe is located on the main floor of a medical office building. The owners and workers were very friendly and accommodating. \n\nMy girlfriend and I ordered a chicken pizza and lasagna. The food was delicious and the portions were fairly large for the price. We were both very full and packed a couple of slices to go. We finished our meals with a cappuccino and latte. Overall a great first experience.',
 'Tried out this place on my way home from work, every lovely place! The food is very delicious and supper cheap, you get a free drink when you check in here too! The service was amazing, the owner was so nice, we didnt know that the stored was closed and yet he still advised us to stay and enjoy the coffee for another extra hour after operation time. Such a great service! Would highly recommend this restaurant.',
 'Fantastic spot. They have the most amazing ginseng latte I have ever tasted! Amazing fresh squeezed orange juice, home made pizz

In [86]:
result = ner_model(list(toronto_user_item_df.text[0:100]))

pdict = {}
cdict = {}
for i in range(len(result[0])):
    for j in range(len(result[0][i])):
        if 'GPE' in result[1][i][j]:
            if result[0][i][j] in set(pdict.keys()):
                cdict[result[0][i][j]] += 1
            else:
                pdict[result[0][i][j]] = result[1][i][j]
                cdict[result[0][i][j]] = 1

In [87]:
import pprint
for i in pdict.keys():
    pprint.pprint([i, pdict[i], cdict[i]])

['Markham', 'B-GPE', 1]
['Toronto', 'B-GPE', 8]
['Italy', 'B-GPE', 18]
['Gta', 'B-GPE', 1]
['Rome', 'B-GPE', 8]
['North', 'B-GPE', 3]
['York', 'I-GPE', 3]
['U', 'B-GPE', 1]
['S', 'B-GPE', 1]
['Niagara', 'B-GPE', 1]
['Falls', 'I-GPE', 1]
['Don', 'B-GPE', 2]
['Mills', 'I-GPE', 2]
['McEwans', 'B-GPE', 1]
['Roma', 'B-GPE', 1]
['Canada', 'B-GPE', 1]


In [74]:
ner_model(['cafe Caffe caffe coffee Cappuccino cappuccino pizza dumplings Chinese food'])

[[['cafe',
   'Caffe',
   'caffe',
   'coffee',
   'Cappuccino',
   'cappuccino',
   'pizza',
   'dumplings',
   'Chinese',
   'food']],
 [['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-NORP', 'O']]]

#### Not good results

### Selenium, use website 'spoonacular' to detect food

In [327]:
import selenium
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait
from selenium.common.exceptions import TimeoutException
import random


option=webdriver.ChromeOptions()
option.add_argument('headless')
driver = webdriver.Chrome(chrome_options=option)
#driver = webdriver.Chrome()
driver.get("https://spoonacular.com/food-api/docs/detect-foods-in-text")

  if sys.path[0] == '':


In [328]:
def FoodDetect(driver,sample_input):
    if len(sample_input)<=900:
        Dlist,Glist,review = FoodDetect_single(driver,sample_input)
    else:
        n = random.randint(850,900)
        while sample_input[n]!=' ':
            n-=1
        Dlist,Glist,review = FoodDetect_single(driver,sample_input[:n])
        
        n=n-random.randint(30,50)
        while sample_input[n]!=' ':
            n-=1
        Dlist2,Glist2,review2 = FoodDetect_single(driver,sample_input[n+1:])
        Dlist = list(set(Dlist+Dlist2))
        Glist = list(set(Glist+Glist2))
    
    return Dlist,Glist,review
        

def FoodDetect_single(driver,sample_input):
    TextInput = driver.find_element_by_id("detectText")
    TextInput.clear()
    TextInput.send_keys(sample_input)
    #driver.implicitly_wait(0.1)

    DetectButton = driver.find_element_by_xpath("/html/body/div[@class='button blue']")
    DetectButton.click()
    
    WebDriverWait(driver, timeout=3).until(EC.text_to_be_present_in_element((By.ID,"detectionResult"),sample_input[:2]))
    WebDriverWait(driver, timeout=3).until(EC.text_to_be_present_in_element((By.ID,"detectionResult"),sample_input[-3:]))
    #driver.implicitly_wait(0.1)
    
    TextOutput = driver.find_element_by_id("detectionResult")
    review = TextOutput.text
    
    Dlist = []
    Dishes = driver.find_elements_by_xpath('/html/body/div[@id="detectionResult"]/span[@class="typedish"]')
    for i in Dishes:
        Dlist.append(i.text)

    Glist = []
    Gri = driver.find_elements_by_xpath('/html/body/div[@id="detectionResult"]/span[@class="typeingredient"]')
    for i in Gri:
        Glist.append(i.text)
        
    return Dlist,Glist,review

In [329]:
l=len(list(dataset_to_use.text[toronto_user_item_df.stars>=3.0]))

In [330]:
Dish = {}
Ingre = {}

import random
import time
nlist = []
n_count = 20000
while len(nlist)<n_count:
    r = random.randint(0,l)
    if r not in nlist: nlist.append(r)

ticks = time.time()
text_cannot_detect = []
review_prev = []
count = 0
for n in nlist:
    text = list(dataset_to_use.text[toronto_user_item_df.stars>=3.0])[n]
    text = str.lower(text)
    
    rep = 0
    review = []
    while review == [] and rep<5:
        try:
            Dlist,Glist,review = FoodDetect(driver,text)
        except Exception as ex:
            rep+=1
            time.sleep(1)
            print("Exception has been thrown. " + str(ex) + '   Text length:', len(text), ' #:',n)
            driver.close()
            
            option=webdriver.ChromeOptions()
            option.add_argument('headless')
            driver = webdriver.Chrome(chrome_options=option)
            driver.get("https://spoonacular.com/food-api/docs/detect-foods-in-text")
            time.sleep(1)
            
    if rep==5:
        text_cannot_detect.append(text)
    
    if review == review_prev:
        print('Repeating with prev. ',count)
    
    Dlist = list(set(Dlist))
    Glist = list(set(Glist))
    
    for d in Dlist:
        if d in set(Dish.keys()):
            Dish[d]+=1
        else:
            Dish[d]=1
    
    for g in Glist:
        if d in set(Ingre.keys()):
            Ingre[d]+=1
        else:
            Ingre[d]=1
    
    count+=1
    if count % 100 == 0:
        tocks = time.time()
        print(count, (tocks-ticks)/60)
        
    review_prev = review

Exception has been thrown. Message: 
   Text length: 253  #: 121530




Exception has been thrown. Message: 
   Text length: 253  #: 121530
Exception has been thrown. Message: 
   Text length: 253  #: 121530
Exception has been thrown. Message: 
   Text length: 253  #: 121530
Exception has been thrown. Message: 
   Text length: 253  #: 121530
Repeating with prev.  8
Repeating with prev.  23
Repeating with prev.  28
Exception has been thrown. Message: 
   Text length: 1792  #: 119024
Repeating with prev.  76
Exception has been thrown. Message: 
   Text length: 1288  #: 31017
Repeating with prev.  86
100 3.1525816837946574
Repeating with prev.  110
Exception has been thrown. Message: 
   Text length: 627  #: 145100
Exception has been thrown. Message: 
   Text length: 650  #: 77880
Exception has been thrown. Message: 
   Text length: 650  #: 77880
Exception has been thrown. Message: 
   Text length: 650  #: 77880
Exception has been thrown. Message: 
   Text length: 650  #: 77880
Exception has been thrown. Message: 
   Text length: 650  #: 77880
Repeating with 

Repeating with prev.  1655
Repeating with prev.  1670
Repeating with prev.  1687
Exception has been thrown. Message: 
   Text length: 236  #: 107031
Repeating with prev.  1694
1700 43.00846302111943
Repeating with prev.  1725
Exception has been thrown. Message: 
   Text length: 2949  #: 70083
Repeating with prev.  1733
Repeating with prev.  1741
Repeating with prev.  1749
Repeating with prev.  1762
Exception has been thrown. Message: 
   Text length: 644  #: 149491
Repeating with prev.  1775
1800 44.8845480799675
Exception has been thrown. Message: 
   Text length: 381  #: 83801
Repeating with prev.  1818
Repeating with prev.  1823
Repeating with prev.  1835
Exception has been thrown. Message: 
   Text length: 810  #: 183717
Repeating with prev.  1867
Exception has been thrown. Message: 
   Text length: 666  #: 170770
1900 47.14271181821823
Repeating with prev.  1901
Exception has been thrown. Message: 
   Text length: 125  #: 109384
Repeating with prev.  1926
Repeating with prev.  195

Exception has been thrown. Message: 
   Text length: 1158  #: 93203
Exception has been thrown. Message: 
   Text length: 1158  #: 93203
Repeating with prev.  3354
Exception has been thrown. Message: 
   Text length: 1330  #: 169460
Exception has been thrown. Message: 
   Text length: 1330  #: 169460
Repeating with prev.  3390
Exception has been thrown. Message: 
   Text length: 876  #: 40909
3400 87.10853218237558
Exception has been thrown. Message: stale element reference: element is not attached to the page document
  (Session info: headless chrome=87.0.4280.66)
   Text length: 350  #: 105146
Exception has been thrown. Message: 
   Text length: 212  #: 97152
Repeating with prev.  3453
Repeating with prev.  3462
Repeating with prev.  3471
Exception has been thrown. Message: 
   Text length: 1219  #: 60404
Repeating with prev.  3489
3500 89.23109176158906
Repeating with prev.  3500
Repeating with prev.  3502
Exception has been thrown. Message: 
   Text length: 921  #: 140873
Repeating 

Exception has been thrown. Message: 
   Text length: 699  #: 136775
4800 125.55262680451075
Exception has been thrown. Message: 
   Text length: 678  #: 120820
Repeating with prev.  4809
Exception has been thrown. Message: 
   Text length: 85  #: 156439
Repeating with prev.  4847
Exception has been thrown. Message: 
   Text length: 1092  #: 95561
Repeating with prev.  4861
Repeating with prev.  4864
Repeating with prev.  4879
Exception has been thrown. Message: 
   Text length: 1658  #: 81550
4900 128.02094539403916
Repeating with prev.  4927
Exception has been thrown. Message: 
   Text length: 1855  #: 63422
Exception has been thrown. Message: 
   Text length: 800  #: 57870
Repeating with prev.  4945
Repeating with prev.  4950
Repeating with prev.  4953
Repeating with prev.  4963
Exception has been thrown. Message: 
   Text length: 1278  #: 142644
Repeating with prev.  4983
Repeating with prev.  4988
Exception has been thrown. Message: 
   Text length: 2173  #: 126979
Exception has be

Exception has been thrown. Message: 
   Text length: 511  #: 134271
Exception has been thrown. Message: 
   Text length: 511  #: 134271
Exception has been thrown. Message: 
   Text length: 511  #: 134271
Repeating with prev.  5792
Exception has been thrown. Message: 
   Text length: 1630  #: 21513
5800 163.1846166372299
Repeating with prev.  5822
Repeating with prev.  5838
Exception has been thrown. Message: 
   Text length: 1715  #: 153414
Exception has been thrown. Message: 
   Text length: 701  #: 163980
Exception has been thrown. Message: 
   Text length: 701  #: 163980
Exception has been thrown. Message: 
   Text length: 701  #: 163980
Exception has been thrown. Message: 
   Text length: 701  #: 163980
Exception has been thrown. Message: 
   Text length: 701  #: 163980
Exception has been thrown. string index out of range   Text length: 976  #: 167603
Exception has been thrown. string index out of range   Text length: 976  #: 167603
Exception has been thrown. string index out of ra

Exception has been thrown. Message: 
   Text length: 428  #: 34436
Repeating with prev.  7096
7100 200.05615133841832
Repeating with prev.  7114
Exception has been thrown. Message: 
   Text length: 1122  #: 24812
Exception has been thrown. Message: 
   Text length: 1418  #: 135001
Exception has been thrown. Message: 
   Text length: 1418  #: 135001
Exception has been thrown. Message: 
   Text length: 1418  #: 135001
Exception has been thrown. Message: 
   Text length: 1418  #: 135001
Exception has been thrown. Message: 
   Text length: 1418  #: 135001
Repeating with prev.  7123
Exception has been thrown. Message: 
   Text length: 1686  #: 169369
Repeating with prev.  7129
Repeating with prev.  7140
Repeating with prev.  7165
Exception has been thrown. Message: 
   Text length: 924  #: 123505
Repeating with prev.  7174
Repeating with prev.  7192
7200 203.43365632692974
Exception has been thrown. Message: 
   Text length: 416  #: 186222
Repeating with prev.  7224
Exception has been throw

Repeating with prev.  8369
Exception has been thrown. string index out of range   Text length: 1131  #: 114633
Exception has been thrown. string index out of range   Text length: 1131  #: 114633
Exception has been thrown. string index out of range   Text length: 1131  #: 114633
Exception has been thrown. string index out of range   Text length: 1131  #: 114633
Exception has been thrown. string index out of range   Text length: 1131  #: 114633
Repeating with prev.  8390
Repeating with prev.  8396
8400 238.10473519563675
Repeating with prev.  8415
Exception has been thrown. Message: 
   Text length: 1084  #: 158485
Repeating with prev.  8428
Repeating with prev.  8437
Exception has been thrown. Message: 
   Text length: 1365  #: 119551
Exception has been thrown. Message: 
   Text length: 2747  #: 160104
Exception has been thrown. Message: 
   Text length: 2747  #: 160104
Exception has been thrown. Message: 
   Text length: 2747  #: 160104
Exception has been thrown. Message: 
   Text leng

Repeating with prev.  9637
Repeating with prev.  9655
Repeating with prev.  9670
Exception has been thrown. Message: 
   Text length: 697  #: 185407
Exception has been thrown. Message: 
   Text length: 2105  #: 197669
Exception has been thrown. Message: 
   Text length: 2105  #: 197669
Exception has been thrown. Message: 
   Text length: 2105  #: 197669
Exception has been thrown. Message: 
   Text length: 2105  #: 197669
Exception has been thrown. Message: 
   Text length: 2105  #: 197669
Repeating with prev.  9698
9700 277.28828179836273
Exception has been thrown. Message: 
   Text length: 2965  #: 24891
Exception has been thrown. Message: 
   Text length: 1128  #: 125053
Exception has been thrown. Message: 
   Text length: 2115  #: 32183
Exception has been thrown. Message: 
   Text length: 2115  #: 32183
Exception has been thrown. Message: 
   Text length: 2115  #: 32183
Exception has been thrown. Message: 
   Text length: 2115  #: 32183
Exception has been thrown. Message: 
   Text l

Repeating with prev.  10869
Exception has been thrown. Message: 
   Text length: 1817  #: 106282
Repeating with prev.  10898
10900 312.69900952974956
Exception has been thrown. Message: 
   Text length: 732  #: 156033
Repeating with prev.  10918
Repeating with prev.  10920
Exception has been thrown. Message: 
   Text length: 2106  #: 18934
Exception has been thrown. Message: 
   Text length: 2106  #: 18934
Exception has been thrown. Message: 
   Text length: 2106  #: 18934
Exception has been thrown. Message: 
   Text length: 2106  #: 18934
Exception has been thrown. Message: 
   Text length: 2106  #: 18934
Repeating with prev.  10934
Repeating with prev.  10942
Repeating with prev.  10957
Repeating with prev.  10962
Repeating with prev.  10965
Exception has been thrown. Message: 
   Text length: 274  #: 71630
Repeating with prev.  10975
Repeating with prev.  10980
Repeating with prev.  10985
Repeating with prev.  10993
Repeating with prev.  10996
11000 315.5714983145396
Repeating with 

Exception has been thrown. Message: 
   Text length: 138  #: 184687
Exception has been thrown. Message: 
   Text length: 138  #: 184687
Exception has been thrown. Message: 
   Text length: 138  #: 184687
Repeating with prev.  12392
Exception has been thrown. Message: 
   Text length: 536  #: 113394
12400 351.90700441996256
Repeating with prev.  12404
Repeating with prev.  12411
Exception has been thrown. Message: 
   Text length: 2449  #: 173604
Exception has been thrown. Message: 
   Text length: 2449  #: 173604
Exception has been thrown. Message: 
   Text length: 2449  #: 173604
Exception has been thrown. Message: 
   Text length: 2449  #: 173604
Exception has been thrown. Message: 
   Text length: 2449  #: 173604
Repeating with prev.  12429
Repeating with prev.  12445
Repeating with prev.  12450
Repeating with prev.  12454
Exception has been thrown. Message: stale element reference: element is not attached to the page document
  (Session info: headless chrome=87.0.4280.66)
   Text l

Exception has been thrown. Message: 
   Text length: 810  #: 113104
Repeating with prev.  13429
Repeating with prev.  13443
Exception has been thrown. Message: 
   Text length: 95  #: 173126
Repeating with prev.  13464
Exception has been thrown. Message: 
   Text length: 1891  #: 150705
Repeating with prev.  13483
Repeating with prev.  13489
13500 388.00502845446266
Exception has been thrown. Message: 
   Text length: 562  #: 159512
Repeating with prev.  13516
Repeating with prev.  13520
Exception has been thrown. Message: 
   Text length: 393  #: 173340
Repeating with prev.  13564
Repeating with prev.  13573
Exception has been thrown. Message: 
   Text length: 318  #: 80401
Repeating with prev.  13590
13600 390.0278464317322
Repeating with prev.  13604
Repeating with prev.  13607
Repeating with prev.  13614
Exception has been thrown. Message: 
   Text length: 1903  #: 70368
Exception has been thrown. Message: 
   Text length: 684  #: 184301
Exception has been thrown. Message: 
   Text

Repeating with prev.  14824
Repeating with prev.  14843
Exception has been thrown. Message: 
   Text length: 1755  #: 168557
Exception has been thrown. Message: 
   Text length: 301  #: 9185
14900 425.78988438447317
Exception has been thrown. Message: 
   Text length: 2483  #: 150440
Repeating with prev.  14924
Exception has been thrown. Message: 
   Text length: 1674  #: 27527
Exception has been thrown. Message: 
   Text length: 1405  #: 157201
Repeating with prev.  14955
Repeating with prev.  14963
Exception has been thrown. Message: 
   Text length: 1150  #: 159786
Exception has been thrown. Message: 
   Text length: 1150  #: 159786
Repeating with prev.  14978
15000 428.34140375852587
Repeating with prev.  15003
Repeating with prev.  15005
Exception has been thrown. Message: 
   Text length: 1188  #: 115055
Exception has been thrown. Message: 
   Text length: 2602  #: 89555
Repeating with prev.  15013
Repeating with prev.  15014
Repeating with prev.  15017
Repeating with prev.  1502

Repeating with prev.  15885
Exception has been thrown. Message: 
   Text length: 716  #: 184064
15900 460.6294501781464
Repeating with prev.  15921
Exception has been thrown. Message: 
   Text length: 159  #: 190257
Repeating with prev.  15926
Repeating with prev.  15935
Exception has been thrown. Message: 
   Text length: 1743  #: 161128
Repeating with prev.  15945
Repeating with prev.  15952
Repeating with prev.  15958
Repeating with prev.  15965
Repeating with prev.  15978
Exception has been thrown. Message: 
   Text length: 1190  #: 552
Repeating with prev.  15992
Repeating with prev.  15999
16000 462.9941291888555
Exception has been thrown. Message: 
   Text length: 1174  #: 45116
Exception has been thrown. Message: 
   Text length: 1023  #: 182278
Repeating with prev.  16033
Repeating with prev.  16049
Exception has been thrown. Message: 
   Text length: 1318  #: 95849
Repeating with prev.  16060
Repeating with prev.  16082
Exception has been thrown. Message: 
   Text length: 147

Exception has been thrown. Message: 
   Text length: 958  #: 138816
Exception has been thrown. Message: 
   Text length: 958  #: 138816
Repeating with prev.  17537
Repeating with prev.  17541
Repeating with prev.  17565
Exception has been thrown. Message: 
   Text length: 288  #: 196991
Repeating with prev.  17573
Repeating with prev.  17593
17600 502.26424686908723
Exception has been thrown. Message: 
   Text length: 169  #: 54515
Repeating with prev.  17632
Repeating with prev.  17647
Exception has been thrown. Message: 
   Text length: 1074  #: 175208
Repeating with prev.  17653
Repeating with prev.  17667
Exception has been thrown. Message: 
   Text length: 2176  #: 102568
17700 504.6358574350675
Repeating with prev.  17708
Repeating with prev.  17711
Repeating with prev.  17720
Exception has been thrown. Message: 
   Text length: 574  #: 96111
Repeating with prev.  17726
Repeating with prev.  17733
Repeating with prev.  17748
Exception has been thrown. Message: 
   Text length: 99

Exception has been thrown. Message: 
   Text length: 91  #: 171783
Exception has been thrown. Message: 
   Text length: 91  #: 171783
Exception has been thrown. Message: 
   Text length: 91  #: 171783
Repeating with prev.  19054
Repeating with prev.  19075
Exception has been thrown. Message: 
   Text length: 289  #: 52289
Exception has been thrown. Message: 
   Text length: 289  #: 52289
Exception has been thrown. Message: 
   Text length: 289  #: 52289
Exception has been thrown. Message: 
   Text length: 289  #: 52289
Exception has been thrown. Message: 
   Text length: 289  #: 52289
Exception has been thrown. Message: stale element reference: element is not attached to the page document
  (Session info: headless chrome=87.0.4280.66)
   Text length: 1052  #: 107327
19100 543.401229985555
Repeating with prev.  19106
Repeating with prev.  19118
Repeating with prev.  19129
Exception has been thrown. Message: 
   Text length: 285  #: 99665
Repeating with prev.  19141
Repeating with prev. 

In [331]:
nlist = []
n_count = 20000
while len(nlist)<n_count:
    r = random.randint(0,l)
    if r not in nlist: nlist.append(r)
print(len(nlist),len(set(nlist)))

20000 20000


In [332]:
import csv
Dish = dict(sorted(Dish.items(), key=lambda item: item[1], reverse=1))
Ingre = dict(sorted(Ingre.items(), key=lambda item: item[1], reverse=1))
with open('Dish'+str(n_count)+'.csv', 'w') as f:
    for key in Dish.keys():
        f.write("%s,%s\n"%(key,Dish[key]))
with open('Ingre'+str(n_count)+'.csv', 'w') as f:
    for key in Ingre.keys():
        f.write("%s,%s\n"%(key,Ingre[key]))

In [296]:
Dish = dict(sorted(Dish.items(), key=lambda item: item[1], reverse=1))
Ingre = dict(sorted(Ingre.items(), key=lambda item: item[1], reverse=1))

In [333]:
print(len(Dish))
Dish

1621


{'lunch': 1744,
 'dinner': 1550,
 'salad': 1303,
 'fries': 1000,
 'soup': 950,
 'sushi': 837,
 'pizza': 800,
 'coffee': 767,
 'brunch': 726,
 'fish': 694,
 'beer': 674,
 'noodles': 630,
 'burger': 609,
 'fanta': 597,
 'bowl': 594,
 'absolut': 544,
 'sandwich': 512,
 'wine': 510,
 'salmon': 501,
 'shrimp': 494,
 'water': 492,
 'tea': 460,
 'steak': 446,
 'curry': 443,
 'tacos': 411,
 'pasta': 407,
 'ramen': 395,
 'burgers': 366,
 'rolls': 364,
 'sandwiches': 344,
 'ice cream': 343,
 'roll': 334,
 'fried chicken': 328,
 'noodle': 318,
 'pho': 248,
 'cake': 245,
 'pad thai': 244,
 'chips': 239,
 'tuna': 236,
 'lobster': 228,
 'wings': 213,
 'burrito': 204,
 'bun': 203,
 'treat': 194,
 'kimchi': 193,
 'oysters': 174,
 'spring rolls': 174,
 'cocktail': 173,
 'crisp': 173,
 'tempura': 172,
 'milk': 169,
 'pop': 164,
 'ribs': 163,
 'fried rice': 156,
 'gravy': 152,
 'crab': 151,
 'salsa': 150,
 'pancakes': 150,
 'dip': 149,
 'coleslaw': 145,
 'cheesecake': 144,
 'tapas': 144,
 'taco': 142,
 '

In [334]:
print(len(Ingre))
Ingre

876


{'bowl': 1731,
 'beer': 1462,
 'fries': 1183,
 'salad': 1133,
 'lunch': 882,
 'fish': 832,
 'noodles': 735,
 'dinner': 669,
 'curry': 502,
 'tea': 470,
 'pasta': 439,
 'sushi': 438,
 'absolut': 431,
 'cocktail': 388,
 'tacos': 386,
 'buns': 367,
 'fried chicken': 362,
 'pizza': 347,
 'butter chicken': 346,
 'salmon': 329,
 'pho': 322,
 'shrimp': 320,
 'sandwich': 319,
 'burger': 317,
 'matcha': 300,
 'cheesecake': 272,
 'pad thai': 259,
 'burrito': 258,
 'pita': 246,
 'omelette': 246,
 'pudding': 236,
 'wine': 233,
 'shot': 230,
 'tomato sauce': 229,
 'slaw': 226,
 'cake': 225,
 'chili': 223,
 'nutella': 218,
 'caramelized onions': 216,
 'soup': 215,
 'brunch': 201,
 'ramen': 198,
 'gravy': 197,
 'crepes': 189,
 'crisp': 187,
 'peanut butter': 179,
 'chips': 178,
 'green tea': 174,
 'risotto': 169,
 'cod': 166,
 'coffee': 155,
 'waffles': 152,
 'striploin': 147,
 'chicken wings': 145,
 'mash': 137,
 'noodle': 133,
 'taco': 132,
 'wings': 130,
 'lobster': 123,
 'steak': 122,
 'fanta': 1

### Choose top n each to be final features.
#### Smallest n for # final features>=60

In [335]:
Dish, Ingre = {}, {}

with open("Dish20000.csv", "r") as csvFile:
    reader = csv.reader(csvFile)
    for item in reader:
        Dish[item[0]] = int(item[1])
        
with open("Ingre20000.csv", "r") as csvFile:
    reader = csv.reader(csvFile)
    for item in reader:
        Ingre[item[0]] = int(item[1])

Dish = dict(sorted(Dish.items(), key=lambda item: item[1], reverse=1))
Ingre = dict(sorted(Ingre.items(), key=lambda item: item[1], reverse=1))

In [336]:
import nltk
from nltk.stem import WordNetLemmatizer

def constructFeature(n_lim,Dish,Ingre):
    wnl = WordNetLemmatizer()

    n = int(n_lim/2)
    features = set()

    while len(features)<n_lim:
        n+=1
        features = set(list(Dish.keys())[:n]+list(Ingre.keys())[:n])
        features = set([str.lower(wrd) for wrd in features])
        features = set([wnl.lemmatize(wrd) for wrd in features]) # singluarization
        
    # Word "breakfast" not detected. Add to list.
    features = list(features) + ['breakfast']
    
    return features

In [337]:
n_lim = 60
features = constructFeature(n_lim,Dish,Ingre)
pprint.pprint(features)

['treat',
 'salmon',
 'ice cream',
 'dinner',
 'wing',
 'wine',
 'tea',
 'pho',
 'shot',
 'pudding',
 'water',
 'beer',
 'peanut butter',
 'omelette',
 'fry',
 'soup',
 'kimchi',
 'fish',
 'salad',
 'gravy',
 'taco',
 'curry',
 'burrito',
 'matcha',
 'nutella',
 'spring rolls',
 'pizza',
 'sandwich',
 'lunch',
 'burger',
 'crisp',
 'absolut',
 'coffee',
 'fried chicken',
 'tuna',
 'roll',
 'chip',
 'lobster',
 'pasta',
 'fanta',
 'cake',
 'caramelized onions',
 'sushi',
 'pad thai',
 'tomato sauce',
 'bowl',
 'cocktail',
 'slaw',
 'oyster',
 'bun',
 'crepe',
 'brunch',
 'steak',
 'noodle',
 'shrimp',
 'chili',
 'ramen',
 'cheesecake',
 'pita',
 'butter chicken',
 'breakfast']


### Create simpler IDs for users and items

In [347]:
unique_business_id = dataset_to_use.business_id.unique()
mapping_business_id = {}
ctr = 0
for business_id in unique_business_id:
    mapping_business_id[business_id] = ctr
    ctr += 1
    
dataset_to_use['business_id_refined'] = dataset_to_use.business_id.map(mapping_business_id)

In [348]:
unique_user_id = dataset_to_use.user_id.unique()
mapping_user_id = {}
ctr = 0
for user_id in unique_user_id:
    mapping_user_id[user_id] = ctr
    ctr += 1
    
dataset_to_use['user_id_refined'] = dataset_to_use.user_id.map(mapping_user_id)

In [349]:
dataset_to_use[['user_id_refined', 'business_id_refined']].head(5)

Unnamed: 0,user_id_refined,business_id_refined
3,0,0
5,1,0
6,2,0
7,3,0
10,4,0


In [350]:
dataset_to_use.sort_values(by = ['user_id_refined', 'date'], inplace = True)

In [351]:
dataset_to_use.head(3)

Unnamed: 0,business_id,name,address,city,state,postal_code,latitude,longitude,stars_business_avg,review_count,...,review_id,user_id,stars,useful,funny,cool,text,date,business_id_refined,user_id_refined
394705,e49eXgKVuR-lsL0-D4vzDw,Momiji,2111 Sheppard Avenue E,toronto,ON,M2J 1W6,43.775377,-79.333972,3.0,22,...,9kb3ywKCxhCQY0ElsLccNA,ZWpLKIbOC5xjuPWc7ZKe9Q,3.0,3,0,2,I went to Momiji at night wanting to find out ...,2010-11-01 01:50:56,6217,0
270638,ik9VvawL-BeAqlxTI1leew,Gonoe Sushi,1310 Don Mills Road,toronto,ON,M3B 2W6,43.74592,-79.346301,3.5,119,...,ehAgpX1OzHGnkf1fut6Few,ZWpLKIbOC5xjuPWc7ZKe9Q,3.0,2,0,0,I went to this place solely on the recommendat...,2014-12-23 02:53:08,3521,0
119044,Nz44ccUso3nq5S2OlQHNlA,Mexico Lindo,"2600 Birchmount Road, Suite 2586",toronto,ON,M1T 2M5,43.789719,-79.302981,4.0,163,...,nieXZ7BPbe_4X4lJexK--w,ZWpLKIbOC5xjuPWc7ZKe9Q,5.0,0,0,0,"Homemade family style catering, I was welcome ...",2014-12-31 02:27:56,1264,0


### Construct 

In [352]:
num_users = dataset_to_use.user_id.nunique()
num_items = dataset_to_use.business_id.nunique()
print(num_users,num_items)

7905 8546


In [355]:
list(dataset_to_use[dataset_to_use.user_id_refined==0]['text'])

["I went to Momiji at night wanting to find out what this new restaurant is all about. I was greeted instantly by a friendly waiter who took me to my seat. \n\nThey have a variety of Japanese food to choose from and the food quality is fairly good. \n\n At night, their tables are dimly lit with sports games playing on TV providing a bar environment. It's a great place to go if you want to go watch sports games with a group of friends and talk or a place to meet your date.\n\nThe only downside to this restaurant at night is the food tends to be priced slightly higher, but the environment and service provided justifies the price.",
 "I went to this place solely on the recommendations read here on Yelp. \n\nMy first impression of the place is that it's tidy and nicely decorated. That earned a star. \n\nThe waitress were attentive was quick to provide service. 2nd star.\n\n\nI ordered a sushi and sashimi dinner combo and it came in a tiny boat nicely decorated. 3rd star. \n\nWhere it lost 

In [372]:
user_features = np.zeros([num_users,len(features)])

for i in range(num_users):
    user_reviews = list(dataset_to_use[dataset_to_use.user_id_refined==i]['text'])
    for j in range(len(features)):
        for k in range(len(user_reviews)):
            review_temp = user_reviews[k].lower()
            if features[j] in review_temp:
                user_features[i,j]=1

In [374]:
user_features

array([[0., 1., 1., ..., 0., 0., 0.],
       [0., 1., 1., ..., 1., 0., 1.],
       [0., 1., 0., ..., 0., 0., 1.],
       ...,
       [0., 0., 0., ..., 0., 1., 1.],
       [0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.]])

In [375]:
user_features

array([[0., 1., 1., ..., 0., 0., 0.],
       [0., 1., 1., ..., 1., 0., 1.],
       [0., 1., 0., ..., 0., 0., 1.],
       ...,
       [0., 0., 0., ..., 0., 1., 1.],
       [0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.]])

### data preparation

In [13]:
dataset_to_use.head(3)

Unnamed: 0,business_id,name,address,city,state,postal_code,latitude,longitude,stars_business_avg,review_count,...,review_id,user_id,stars,useful,funny,cool,text,date,business_id_refined,user_id_refined
394705,e49eXgKVuR-lsL0-D4vzDw,Momiji,2111 Sheppard Avenue E,toronto,ON,M2J 1W6,43.775377,-79.333972,3.0,22,...,9kb3ywKCxhCQY0ElsLccNA,ZWpLKIbOC5xjuPWc7ZKe9Q,3.0,3,0,2,I went to Momiji at night wanting to find out ...,2010-11-01 01:50:56,6217,0
270638,ik9VvawL-BeAqlxTI1leew,Gonoe Sushi,1310 Don Mills Road,toronto,ON,M3B 2W6,43.74592,-79.346301,3.5,119,...,ehAgpX1OzHGnkf1fut6Few,ZWpLKIbOC5xjuPWc7ZKe9Q,3.0,2,0,0,I went to this place solely on the recommendat...,2014-12-23 02:53:08,3521,0
119044,Nz44ccUso3nq5S2OlQHNlA,Mexico Lindo,"2600 Birchmount Road, Suite 2586",toronto,ON,M1T 2M5,43.789719,-79.302981,4.0,163,...,nieXZ7BPbe_4X4lJexK--w,ZWpLKIbOC5xjuPWc7ZKe9Q,5.0,0,0,0,"Homemade family style catering, I was welcome ...",2014-12-31 02:27:56,1264,0


In [14]:
Review_set = dataset_to_use[['user_id_refined', 'business_id_refined']].values.tolist()
len(Review_set)

237185

In [15]:
Label_set = dataset_to_use['stars'].values.tolist()
len(Label_set)

237185

### Score normalized to [0,1]

In [16]:
from sklearn.model_selection import train_test_split
Review_train, Review_test, Label_train, Label_test = train_test_split(Review_set, Label_set, test_size=0.2, random_state=42)
Label_train = np.array(Label_train)/5
Label_test = np.array(Label_test)/5

In [17]:
num_users = dataset_to_use.user_id.nunique()
num_items = dataset_to_use.business_id.nunique()

## Model modification

In [18]:
'''
Modified for rating prediction
'''
#################### Arguments ####################
def parse_args():
    parser = argparse.ArgumentParser(description="Run NeuMF.")
    parser.add_argument('--path', nargs='?', default='Data/',
                        help='Input data path.')
    parser.add_argument('--dataset', nargs='?', default='ml-1m',
                        help='Choose a dataset.')
    parser.add_argument('--epochs', type=int, default=100,
                        help='Number of epochs.')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='Batch size.')
    parser.add_argument('--num_factors', type=int, default=8,
                        help='Embedding size of MF model.')
    parser.add_argument('--layers', nargs='?', default='[64,32,16,8]',
                        help="MLP layers. Note that the first layer is the concatenation of user and item embeddings. So layers[0]/2 is the embedding size.")
    parser.add_argument('--reg_mf', type=float, default=0,
                        help='Regularization for MF embeddings.')                    
    parser.add_argument('--reg_layers', nargs='?', default='[0,0,0,0]',
                        help="Regularization for each MLP layer. reg_layers[0] is the regularization for embeddings.")
    parser.add_argument('--num_neg', type=int, default=4,
                        help='Number of negative instances to pair with a positive instance.')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate.')
    parser.add_argument('--learner', nargs='?', default='adam',
                        help='Specify an optimizer: adagrad, adam, rmsprop, sgd')
    parser.add_argument('--verbose', type=int, default=1,
                        help='Show performance per X iterations')
    parser.add_argument('--out', type=int, default=1,
                        help='Whether to save the trained model.')
    parser.add_argument('--mf_pretrain', nargs='?', default='',
                        help='Specify the pretrain model file for MF part. If empty, no pretrain will be used')
    parser.add_argument('--mlp_pretrain', nargs='?', default='',
                        help='Specify the pretrain model file for MLP part. If empty, no pretrain will be used')
    return parser.parse_args()

def get_model(num_users, num_items, mf_dim=10, layers=[10], reg_layers=[0], reg_mf=0):
    assert len(layers) == len(reg_layers)
    num_layer = len(layers) #Number of layers in the MLP
    
    # Input variables
    user_input = Input(shape=(1,), dtype='int32', name = 'user_input')
    item_input = Input(shape=(1,), dtype='int32', name = 'item_input')
    
    print('Input layer: ', user_input.shape, user_input.dtype)
    
    # Embedding layer
    MF_Embedding_User = Embedding(input_dim = num_users, output_dim = mf_dim, name = 'mf_embedding_user',
                                  embeddings_initializer = 'random_normal', embeddings_regularizer = l2(reg_mf), input_length=1)
    MF_Embedding_Item = Embedding(input_dim = num_items, output_dim = mf_dim, name = 'mf_embedding_item',
                                  embeddings_initializer = 'random_normal', embeddings_regularizer = l2(reg_mf), input_length=1)  

    MLP_Embedding_User = Embedding(input_dim = num_users, output_dim = int(layers[0]/2), name = "mlp_embedding_user",
                                   embeddings_initializer = 'random_normal', embeddings_regularizer = l2(reg_layers[0]), input_length=1)
    MLP_Embedding_Item = Embedding(input_dim = num_items, output_dim = int(layers[0]/2), name = 'mlp_embedding_item',
                                   embeddings_initializer = 'random_normal', embeddings_regularizer = l2(reg_layers[0]), input_length=1)

    output = MLP_Embedding_User(user_input)
    #print(MLP_Embedding_User.weights)
    
    # MF part
    mf_user_latent = Flatten()(MF_Embedding_User(user_input))
    mf_item_latent = Flatten()(MF_Embedding_Item(item_input))
    
    mf_vector = Multiply()([mf_user_latent, mf_item_latent]) # element-wise multiply

    # MLP part 
    mlp_user_latent = Flatten()(MLP_Embedding_User(user_input))
    mlp_item_latent = Flatten()(MLP_Embedding_Item(item_input))
    
    mlp_vector = Concatenate()([mlp_user_latent, mlp_item_latent])
    for idx in range(1, num_layer):
        layer = Dense(layers[idx], kernel_regularizer= l2(reg_layers[idx]), activation='relu', name="layer%d" %idx)
        mlp_vector = layer(mlp_vector)

    # Concatenate MF and MLP parts
    #mf_vector = Lambda(lambda x: x * alpha)(mf_vector)
    #mlp_vector = Lambda(lambda x : x * (1-alpha))(mlp_vector)
    predict_vector = Concatenate()([mf_vector, mlp_vector])
    
    # Final prediction layer
    prediction = Dense(1, activation='sigmoid', kernel_initializer='lecun_uniform', name = "prediction")(predict_vector)
    
    model = Model(inputs =[user_input, item_input], 
                  outputs =prediction)
    
    return model

def load_pretrain_model(model, gmf_model, mlp_model, num_layers):
    # MF embeddings
    gmf_user_embeddings = gmf_model.get_layer('user_embedding').get_weights()
    gmf_item_embeddings = gmf_model.get_layer('item_embedding').get_weights()
    model.get_layer('mf_embedding_user').set_weights(gmf_user_embeddings)
    model.get_layer('mf_embedding_item').set_weights(gmf_item_embeddings)
    
    # MLP embeddings
    mlp_user_embeddings = mlp_model.get_layer('user_embedding').get_weights()
    mlp_item_embeddings = mlp_model.get_layer('item_embedding').get_weights()
    model.get_layer('mlp_embedding_user').set_weights(mlp_user_embeddings)
    model.get_layer('mlp_embedding_item').set_weights(mlp_item_embeddings)
    
    # MLP layers
    for i in range(1, num_layers):
        mlp_layer_weights = mlp_model.get_layer('layer%d' %i).get_weights()
        model.get_layer('layer%d' %i).set_weights(mlp_layer_weights)
        
    # Prediction weights
    gmf_prediction = gmf_model.get_layer('prediction').get_weights()
    mlp_prediction = mlp_model.get_layer('prediction').get_weights()
    new_weights = np.concatenate((gmf_prediction[0], mlp_prediction[0]), axis=0)
    new_b = gmf_prediction[1] + mlp_prediction[1]
    model.get_layer('prediction').set_weights([0.5*new_weights, 0.5*new_b])    
    return model

def get_train_instances(train, labelRatings):
    user_input, item_input, labels = [],[],[]
    
    for i in range(len(train)):
        
        # positive instance
        user_input.append(train[i][0])
        item_input.append(train[i][1])
        labels.append(labelRatings[i])

    return user_input, item_input, labels

In [22]:
sys.argv = ['NeuMF.py','--epochs','30 ','--batch_size','256','--num_factors','8',
'--layers','[64,32,16,8]','--reg_mf','0','--reg_layers','[0,0,0,0]','--num_neg','5',
'--lr','0.01','--learner','adagrad','--verbose','1','--out','0','--dataset','Toronto']

args = parse_args()

num_epochs = args.epochs
batch_size = args.batch_size
mf_dim = args.num_factors
layers = eval(args.layers)
reg_mf = args.reg_mf
reg_layers = eval(args.reg_layers)
num_negatives = args.num_neg
learning_rate = args.lr
learner = args.learner
verbose = args.verbose
mf_pretrain = args.mf_pretrain
mlp_pretrain = args.mlp_pretrain

topK = 10
evaluation_threads = -1 # mp.cpu_count()
print("NeuMF arguments: %s " %(args))
#model_out_file = 'Pretrain/%s_NeuMF_%d_%s_%d.h5' %(args.dataset, mf_dim, args.layers, time())

NeuMF arguments: Namespace(batch_size=256, dataset='Toronto', epochs=30, layers='[64,32,16,8]', learner='adagrad', lr=0.01, mf_pretrain='', mlp_pretrain='', num_factors=8, num_neg=5, out=0, path='Data/', reg_layers='[0,0,0,0]', reg_mf=0.0, verbose=1) 


In [23]:
from evaluate_rate import evaluate_rate_model

# Build and compile, and check initial performance
model = get_model(num_users, num_items, mf_dim, layers, reg_layers, reg_mf)

if learner.lower() == "adagrad": 
    model.compile(optimizer=Adagrad(lr=learning_rate), loss='binary_crossentropy')
elif learner.lower() == "rmsprop":
    model.compile(optimizer=RMSprop(lr=learning_rate), loss='binary_crossentropy')
elif learner.lower() == "adam":
    model.compile(optimizer=Adam(lr=learning_rate), loss='binary_crossentropy')
else:
    model.compile(optimizer=SGD(lr=learning_rate), loss='binary_crossentropy')
    
# Load pretrain model
if mf_pretrain != '' and mlp_pretrain != '':
    gmf_model = GMF.get_model(num_users,num_items,mf_dim)
    gmf_model.load_weights(mf_pretrain)
    mlp_model = MLP.get_model(num_users,num_items, layers, reg_layers)
    mlp_model.load_weights(mlp_pretrain)
    model = load_pretrain_model(model, gmf_model, mlp_model, len(layers))
    print("Load pretrained GMF (%s) and MLP (%s) models done. " %(mf_pretrain, mlp_pretrain))
    
# Initial performance
(mse, r2) = evaluate_rate_model(model, Review_test, Label_test, evaluation_threads)
print('Init: MSE = %.4f, R2 = %.4f' % (mse, r2))
best_mse, best_r2, best_iter = mse, r2, -1
# if args.out > 0:
#     model.save_weights(model_out_file, overwrite=True) 

Input layer:  (None, 1) <dtype: 'int32'>
Init: MSE = 2.5316, R2 = -0.9948


In [24]:
import random
# Training model
for epoch in range(num_epochs):
    t1 = time()
    # Generate training instances
    #Review_train, Review_test, Label_train, Label_test = train_test_split(Review_set, Label_set,
    #                                                                      test_size=0.2, random_state=random.randint(1,100))
    #Label_train = np.array(Label_train)/5
    #Label_test = np.array(Label_test)/5
    user_input, item_input, labels = get_train_instances(Review_train, Label_train)
    #print('Finished generating')
    
    # Training
    hist = model.fit([np.array(user_input), np.array(item_input)], #input
                     np.array(labels), # labels 
                     batch_size=batch_size, epochs=1, verbose=0, shuffle=True)
    
    t2 = time()
    
    #print('Finished training')

    # Evaluation
    if epoch % verbose == 0:
        (mse, r2) = evaluate_rate_model(model, Review_test, Label_test, evaluation_threads)
        loss = hist.history['loss'][0]
        
        print('Iteration %d [%.1f s]: RMSE = %.4f, R2 = %.4f, loss = %.4f [%.1f s]' 
              % (epoch,  t2-t1, np.sqrt(mse), r2, loss, time()-t2))
        if mse < best_mse:
            best_mse, best_r2, best_iter = mse, r2, epoch
#             if args.out > 0:
#                 model.save_weights(model_out_file, overwrite=True)

print("End. Best Iteration %d:  RMSE = %.4f, R2 = %.4f. " %(best_iter, np.sqrt(best_mse), best_r2))
# if args.out > 0:
#     print("The best NeuMF model is saved to %s" %(model_out_file))

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Iteration 0 [4.3 s]: RMSE = 1.0151, R2 = 0.1881, loss = 0.5720 [0.1 s]
Iteration 1 [3.4 s]: RMSE = 1.0078, R2 = 0.1996, loss = 0.5546 [0.1 s]
Iteration 2 [3.3 s]: RMSE = 1.0076, R2 = 0.2000, loss = 0.5501 [0.2 s]
Iteration 3 [3.4 s]: RMSE = 1.0082, R2 = 0.1992, loss = 0.5473 [0.2 s]
Iteration 4 [3.6 s]: RMSE = 1.0092, R2 = 0.1975, loss = 0.5446 [0.2 s]
Iteration 5 [3.5 s]: RMSE = 1.0111, R2 = 0.1945, loss = 0.5419 [0.1 s]
Iteration 6 [3.4 s]: RMSE = 1.0128, R2 = 0.1918, loss = 0.5390 [0.2 s]
Iteration 7 [3.4 s]: RMSE = 1.0167, R2 = 0.1856, loss = 0.5359 [0.1 s]
Iteration 8 [3.4 s]: RMSE = 1.0202, R2 = 0.1798, loss = 0.5326 [0.1 s]
Iteration 9 [3.5 s]: RMSE = 1.0244, R2 = 0.1732, loss = 0.5292 [0.1 s]
Iteration 10 [3.3 s]: RMSE = 1.0301, R2 = 0.1639, loss = 0.5257 [0.1 s]
Iteration 11 [3.6 s]: RMSE = 1.0372, R2 = 0.1523, loss = 0.5222 [0.1 s]
Iteration 12 [3.4 s]: RMSE = 1.0438, R2 = 0.1415, loss = 0.5188 [0.2 s]
Iteration 13 [3.4 s]: RMSE = 1.0526, R2 = 0.1269, loss = 0.5154 [0.2 s]
It

In [25]:
model_out_file = 'Pretrain/%s_NeuMF_%d_%s_%d.h5' %(args.dataset, mf_dim, args.layers, time())
model_out_file

'Pretrain/Toronto_NeuMF_8_[64,32,16,8]_1605579446.h5'

In [26]:
model.save_weights(model_out_file, overwrite=True) 