In [1]:
import numpy as np
import requests
import csv

In [82]:
def get_embedding(s: str):
    res = requests.post("http://localhost:1234/v1/embeddings", json={
        "input": s,
        "encoding_format": "float"
    })
    a = np.array(res.json()["data"][0]["embedding"])
    return a / np.linalg.norm(a)

In [3]:
def cosine_similarity(a, b):
    """
    Calculates the cosine similarity between two NumPy vectors.
    """
    dot_product = np.dot(a, b)
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    
    if norm_a == 0 or norm_b == 0:
        return 0 # Return 0 if either vector is a zero vector
    
    similarity = dot_product / (norm_a * norm_b)
    return similarity

In [83]:
restaurants = []
with open('test_data.csv', newline='') as csvfile:
    rows = csv.reader(csvfile, delimiter=',', quotechar='"')
    restaurants = [
        (title, description, get_embedding(description))
        for title, description in rows
    ][1:]

In [5]:
def execute_query(query):
    query = get_embedding(query)
    x = sorted([ (cosine_similarity(restaurant[2], query), i) for i, restaurant in enumerate(restaurants)])
    return [ (restaurants[i][0], similarity) for (similarity, i) in x[-10:]]

In [17]:
def execute_query_embedding(embedding):
    # query = get_embedding(query)
    x = sorted([ (cosine_similarity(restaurant[2], embedding), i) for i, restaurant in enumerate(restaurants)])
    return [ (restaurants[i][0], similarity) for (similarity, i) in x[-10:]]

In [None]:
user_votes = [
    (0, False),
    (1, False),
    (2, False),
    (3, False),
    (4, False),
    (5, False),
    (6, True),
    (7, True),
    (8, False),
    (9, False),

    # (10, False)
]
s = ", ".join([ restaurants[votes[0]][0] for votes in user_votes if votes[1]])
print(f"Likes {s}")
s = ", ".join([ restaurants[votes[0]][0] for votes in user_votes if not votes[1]])
print(f"Doesn't like {s}")

Likes Campus Burgers, The Halal Guys
Doesn't like La Victoria Taqueria, Iguanas Burritozilla, Angelou’s Mexican Grill, Taqueria Tlaquepaque, Tacos El Compa, Spartan Taco Truck, Falafel Drive-In, Nick the Greek, Ike’s Love & Sandwiches


In [None]:
embedding_average = np.average([r[2] for r in restaurants]) 
vote_embedding = restaurants[user_votes[0][0]][2].copy()

for i in range(1000):
    for idx, likes in user_votes:
        sub = restaurants[idx][2] - vote_embedding
        if(np.linalg.norm(sub) == 0):
            continue
        sub = sub / np.linalg.norm(sub)
        if(likes):
            vote_embedding = 0.7 * vote_embedding + 0.3 * restaurants[idx][2]
        else:
            vote_embedding = 0.9 * vote_embedding + 0.1 * (2 * embedding_average - restaurants[idx][2])
        vote_embedding = vote_embedding / np.linalg.norm(vote_embedding)

# v = a- b
# c = b - v
# c = 2b - a
# vote_embedding

In [90]:
execute_query_embedding(vote_embedding)

[('Lee’s Sandwiches', np.float64(-0.533795320425249)),
 ('Krispy Kreme', np.float64(-0.5326121538057319)),
 ('Pho 24', np.float64(-0.5323856534140269)),
 ('Nirvana Soul', np.float64(-0.5275866908797465)),
 ('Sharetea', np.float64(-0.5232642000847177)),
 ('Panda Express', np.float64(-0.5204957263633067)),
 ('Dunkin’', np.float64(-0.5079942371138045)),
 ('Gong Cha', np.float64(-0.49978820055298057)),
 ('Marugame Udon', np.float64(-0.4864960429920121)),
 ('Campus Burgers', np.float64(-0.47909979722333224))]

In [58]:
from scipy.cluster import  hierarchy

In [66]:
threshold = 0.3
Z = hierarchy.linkage(np.asarray([ r[2] for r in restaurants]), "average", metric="cosine")
C = hierarchy.fcluster(Z, threshold, criterion="distance")

In [70]:
C

array([26, 22, 22, 23, 22, 25, 35, 36, 40, 36, 32, 32, 30,  9, 32, 20, 13,
       37, 41, 19, 19, 19, 14, 13, 15, 15,  2,  3, 22, 24, 22, 34, 34, 34,
       34, 34, 31, 34, 34, 33, 21, 27, 27, 11, 11, 11, 11, 11, 11, 11, 18,
       12, 18, 18, 18, 32,  8,  8, 10,  8,  9,  8,  7,  7,  8, 16,  6,  5,
        5, 17,  1,  1,  2, 32,  2, 22, 32,  8,  8,  8, 29, 34, 39, 20, 28,
       28, 38, 23, 23, 22, 32, 15,  4, 33,  6,  6], dtype=int32)

In [None]:
", ".join([ restaurants[i][0] for i, c in enumerate(C) if c % ]) 10 == 3

'Taqueria Tlaquepaque, Pho 24, Orenchi Ramen, Spoonfish Poke, Chick-fil-A, Dia de Pesca, Aqui Cal-Mex, Popeye’s Louisiana Kitchen'