In [1]:
import tkinter as tk 
from tkinter.scrolledtext import *
import datetime
 from tkinter import Frame, Text, Scrollbar, Pack, Grid, Place
from tkinter.constants import RIGHT, LEFT, Y, BOTH, CENTER
import os
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import AgglomerativeClustering
from sentence_transformers import SentenceTransformer, util
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup,AutoTokenizer
import torch.nn as nn
import torch

from model import BERTREC



In [2]:
# Load files
data = pd.read_csv("reviews.csv")
restaurants_names = data["business_name"]
label = restaurants_names.values.tolist()
classes = restaurants_names.drop_duplicates().values.tolist()
reviews = data["text"]
reviews = reviews.values.tolist()
labels, dic = restaurants_names.factorize()

# Load BERT models 
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
restaurant_rec = BERTREC.BERTRECC(bert_model_name = "bert-base-uncased", num_classes = 100)
restaurant_rec.load_state_dict(torch.load("./model/model_weights.pth", map_location=torch.device('cpu')))
model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')

# Use BERT to predict restaurant
def predict_restaurant(text, model, tokenizer, device, max_length=128):
    model.eval()
    encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs, dim=1)
    return dic[preds.item()]

def findAverageRating(restaurant, data):
    reviews = data
    averageRating = reviews[reviews['business_name'] == restaurant]['rating'].mean()
    return averageRating

def getReviews(queryResturant):
    reviews = pd.read_csv('reviews.csv')
    reviews = reviews.loc[reviews['business_name'] == queryResturant]
    reviews = reviews['text']
    reviews = reviews.values
    return reviews


# Gets context from all user reviews 
def context_analysis(queryRestaurant):
    reviews = getReviews(queryRestaurant)
    embeddings = model.encode(reviews, convert_to_tensor=True)
    top_k = 3
    used_indices = set()
    usable_corpus = reviews.tolist()
    original_cluster_labels = reviews.tolist()
    clusters = []
    cls_id = 0

    while(len(used_indices) < len(usable_corpus)):
        cls_id += 1
        query = usable_corpus[0]
        usable_corpus = [label for label in reviews if label not in used_indices]
        if query in usable_corpus: 
            usable_corpus.remove(query)
            
        # Semenatic search
        query_embedding = model.encode(query, convert_to_tensor=True)
        embeddings = model.encode(usable_corpus, convert_to_tensor=True)
        cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
        for score in cos_scores:
            if score > 0.5: 
              top_results = np.argpartition(-cos_scores, range(top_k))[0:top_k]
              top_matches = [usable_corpus[idx] for idx in top_results[0:top_k]]  
            else: 
                continue
        for idx in top_matches:
            used_indices.add(idx)

        top_matches.append(query)
        used_indices.add(query)
        clusters.append(top_matches)
        
    cluster_labels = []
    for cluster in clusters:
        cluster_embeddings = model.encode(cluster, convert_to_tensor=True)
        cos_scores = util.pytorch_cos_sim(cluster_embeddings, cluster_embeddings)
        cos_scores = cos_scores.cpu().detach().numpy()
        cos_scores = np.max(cos_scores, axis=0)
        top_cluster = np.argmax(cos_scores)
        cluster_labels.append(cluster[top_cluster])
    return cluster_labels

# Compile useful information into final print statement
def print_info(queryRestaurant, averageRating):
    cluster_labels = context_analysis(queryRestaurant)
    
    
    print("I recommend that you try out this restaurant: ", queryRestaurant)
    print("The average rating of the restaurant is: ", averageRating)

    print("From the reviews, the most common experiences are: ")
    for label in cluster_labels:
        print("- ", label)

    "The overall experience of this restaurant is normally good."
    print("This will more than likely be your experience when you go to", queryRestaurant, ".")
    
    # Save overall experience sentiment
    overallExp = "The overall experience of this restaurant is normally awful."
    if averageRating >= 4.0:
        overallExp = "The overall experience of this restaurant is normally good."
    elif averageRating >= 3.0:
        overallExp = "The overall experience of this restaurant is normally average."
    elif averageRating >= 2.0:
        overallExp = "The overall experience of this restaurant is normally bad."
    return queryRestaurant, averageRating,  cluster_labels, overallExp

# print_info(predict_restaurant("doing your mom",restaurant_rec, tokenizer, device), findAverageRating(predict_restaurant("doing your mom",restaurant_rec, tokenizer, device), data=data))


In [3]:
__all__ = ['ScrolledText']

# Make text auto scroll when full (without a scrollbar)
class ScrolledText(Text):
    def __init__(self, master=None, **kw):
        self.frame = Frame(master)
        Text.__init__(self, self.frame, **kw)
        self.pack(side=LEFT, fill=BOTH, expand=True)
        
        # Copy geometry methods of self.frame without overriding Text
        text_meths = vars(Text).keys()
        methods = vars(Pack).keys() | vars(Grid).keys() | vars(Place).keys()
        methods = methods.difference(text_meths)

        for m in methods:
            if m[0] != '_' and m != 'config' and m != 'configure':
                setattr(self, m, getattr(self.frame, m))

    def __str__(self):
        return str(self.frame)

    
#  Clear entry box when clicked 
def click(*args): 
    input_field.delete(0, 'end') 
    
    
# When user submits an input 
def on_submit(event=None):
    user_input = input_field.get().lower()
    
   # User message 
    output_field.config(state='normal')
    output_field.tag_config('user', foreground='white', background='#535353')
    output_field.insert('end', '\n' + user_input + '\n\n', 'user')
    
    # Chatbot response 
    output_field.tag_config('response', foreground='white', background='#333333')
    output_field.insert('end', '\nBert: ', 'response')
    
    # Run twin models 
    restaurantRecommendation, averageRating, cluster_labels, overallExp = print_info(predict_restaurant(user_input, restaurant_rec, tokenizer, device), findAverageRating(predict_restaurant(user_input,restaurant_rec, tokenizer, device), data=data))    
    
    # Print out model recommendations 
    output_field.insert('end', "I recommend that you try out this restaurant: ", 'response') 
    output_field.tag_config('response', foreground='yellow', background='#333333', font=("Futura bold", 15))
    output_field.insert('end', restaurantRecommendation, 'response')
    
    output_field.tag_config('response', foreground='white', background='#333333', font=("Futura", 15))
    output_field.insert('end', ". The average rating of the restaurant is: ", 'response')
    output_field.tag_config('response', foreground='white', background='#333333', font=("Futura bold", 15))
    output_field.insert('end', averageRating, 'response')
    output_field.tag_config('response', foreground='white', background='#333333', font=("Futura", 15))
    output_field.insert('end', ". " + overallExp, 'response')

    output_field.tag_config('response', foreground='white', background='#333333', font=("Futura", 15))
    output_field.insert('end', "From the reviews, the most common experiences are: ", 'response')
    for label in cluster_labels:
        # print("- ", label)
        output_field.insert('end', "\n- " + label, 'response')
    output_field.insert('end', "\n\n", 'response')
                                    
    input_field.delete(0, 'end')
    output_field.yview(tk.END)
    output_field.config(state='disabled')

In [8]:
# Make tkinter root window 
root = tk.Tk()
root.title("Restaurant Recommender")

# previous = tk.StringVar(root, "one")
# print(previous.get())
# previous.set("one") #default value

# Output Frame 
output_frame = tk.Frame(root, width=150, height=500, relief=tk.FLAT, background="#65350F")
output_frame.pack_propagate(False)
output_frame.pack(side='top', fill='both', expand=True, padx=20)
# Autoscrolling output text  
output_field = ScrolledText(output_frame)

# Output Text details 
output_field.pack(side='left', fill='both', expand=True, padx=2, pady=3)
output_field.config(state='disabled')
output_field.tag_config('user', background='#1a3a46', justify='center')  # I don't think this does anything
output_field.tag_config('response', background='lightgreen', justify='center')  # I don't think this does anything 
output_field.config(font=("Futura", 18))
output_field.yview(tk.END)

# Input Frame
input_frame = tk.Frame(root, width=125, height=20, relief=tk.FLAT, highlightbackground = "#65350F", highlightthickness=2, background = "#EFEADD")
input_frame.pack(padx=20)
input_frame.pack(side='bottom', fill='x')

# Input text box with default prompt and disappears when clicked
input_field = tk.Entry(input_frame, width=125, bd=2.5, justify='center', relief = tk.FLAT, highlightbackground ="#EFEADD", highlightthickness=1, background = "#EFEADD")
input_field.insert(0, 'Tell me anything you want from a restaurant?') 
input_field.bind("<Button-1>", click) 
input_field.pack(side='left', padx=5, pady=(2,2))
input_field.bind("<Return>", on_submit)


# Run code 
root.mainloop()

I recommend that you try out this restaurant:  Sema Gozleme
The average rating of the restaurant is:  3.909090909090909
From the reviews, the most common experiences are: 
-  The place is very cozy; the food was fresh and good. They served pickles and tea; both were delicious.
-  Turkish sarma was really tasty; liked it.
This will more than likely be your experience when you go to Sema Gozleme .
