In [None]:
# # from sqlalchemy import Column, String, Integer, DateTime, ForeignKey, JSON
# # from sqlalchemy.sql import func
# # from .base import Base

# class Topic(Base):
#     __tablename__ = "topics"
    
#     id = Column(String, primary_key=True)
#     label = Column(String, nullable=False)
#     icon = Column(String, nullable=False)
#     color = Column(String, nullable=False)

# class Exhibit(Base):
#     __tablename__ = "exhibits"
    
#     id = Column(String, primary_key=True)
#     title = Column(String, nullable=False)
#     description = Column(String, nullable=False)
#     image = Column(String, nullable=False)
#     topic_id = Column(String, ForeignKey("topics.id"), nullable=False)
#     details = Column(JSON, nullable=True)

# class TopicExhibitRelation(Base):
#     __tablename__ = "topic_exhibit_relations"
#    __table_args__ = (
#         PrimaryKeyConstraint("topic_id", "exhibit_id"),
#     )
#    topic_id = Column(String, ForeignKey("topics.id"), nullable=False)
#    exhibit_id = Column(String, ForeignKey("exhibits.id"), nullable=False)
#    strength = Column(Float, nullable=False)

# class User(Base):
#     __tablename__ = "users"
    
#     id = Column(String, primary_key=True)
#     interests = Column(JSON, nullable=True)
#     language = Column(String, nullable=False)
#     reading_level = Column(String, nullable=False)

# class Visit(Base):
#     __tablename__ = "visits"
    
#     id = Column(Integer, primary_key=True)
#     user_id = Column(String, ForeignKey("users.id"), nullable=False)
#     exhibit_id = Column(String, ForeignKey("exhibits.id"), nullable=False)
#     timestamp = Column(DateTime, nullable=False, server_default=func.now())

In [7]:
import json

with open("../data_gen/exhibits.json") as f:
    exhibits_data = json.load(f)

with open("../data_gen/topics.json") as f:
    topics_data = json.load(f)

with open("../data_gen/exhibit_topics.json") as f:
    exhibit_topics_data = json.load(f)


exhibit_topics_data[:5]

[{'exhibit_id': 'exhibit_0', 'topic_id': 'topic_2', 'relevance': 0.5},
 {'exhibit_id': 'exhibit_0', 'topic_id': 'topic_3', 'relevance': 0.5},
 {'exhibit_id': 'exhibit_0', 'topic_id': 'topic_5', 'relevance': 1.0},
 {'exhibit_id': 'exhibit_0', 'topic_id': 'topic_6', 'relevance': 0.5},
 {'exhibit_id': 'exhibit_0', 'topic_id': 'topic_8', 'relevance': 0.5}]

In [None]:
import numpy as np

# Given a user and and their visit history, rank all remaining exhibits by their relevance to the user
#    Generate candidates using a combination of the user's interests, age, popularity.
#    Then, rank using an LLM to make the final recommendation.

def get_all_visits_given_user(user):
    return [
        {
            "id": "visit_0",
            "user_id": "user_0",
            "exhibit_id": "exhibit_0",
            "timestamp": "Now"
        },
        {
            "id": "visit_0",
            "user_id": "user_0",
            "exhibit_id": "exhibit_2",
            "timestamp": "Now"
        }
    ]

def get_all_topics_related_to_an_exhibit(exhibit_id):
    ret = []
    for exhibit_topic in exhibit_topics_data:
        if exhibit_topic["exhibit_id"] == exhibit_id:
            ret.append(exhibit_topic)

    # link all topics to the topic_ids
    for i, topic in enumerate(ret):
        # Get corresponding topic
        for topic_data in topics_data:
            if topic_data["id"] == topic["topic_id"]:
                ret[i]["topic_data"] = topic_data
                break

    return ret

class AbstractRetriever:
    def __init__(self, exhibits):
        self.exhibits = exhibits

    def retrieve(self, user, visit_history):
        raise NotImplementedError
    
class PopularityRetriever(AbstractRetriever):
    def __init__(self, exhibits):
        super().__init__(exhibits)

        # Get popilarity from exhibit data
        self.popularities = {}
        for exhibit in self.exhibits:
            try:
                self.popularities[exhibit["id"]] = float(json.loads(exhibit["details"])["popularity"]) / 5
                if self.popularities[exhibit["id"]] is None:
                    self.popularities[exhibit["id"]] = 0
            except:
                self.popularities[exhibit["id"]] = 0


    def retrieve(self, user, visit_history):

        # Sort exhibits by popularity
        ranked_exhibits = sorted(self.exhibits, key=lambda x: -self.popularities[x["id"]])
        # get IDs
        # Sort
        sorted_exhibits = sorted(ranked_exhibits, key=lambda x: -self.popularities[x["id"]])

        ret = []
        for exhibit in sorted_exhibits:
            ret.append({
                "exhibit_id": exhibit["id"],
                "score": self.popularities[exhibit["id"]]
            })

        return ret

class VisitRetriever(AbstractRetriever):
    def retrieve(self, user, visit_history):
        visited_exhibit_ids = set(map(lambda x: x["exhibit_id"], visit_history))
        print(visited_exhibit_ids)
        # Return 0 if the exhibit has been visited, 1 otherwise
        return [
            {
                "exhibit_id": exhibit["id"],
                "score": -1 if exhibit["id"] in visited_exhibit_ids else 1
            }
            for exhibit in self.exhibits
        ]
    
class InterestBasedRetriever(AbstractRetriever):
    def __init__(self, exhibits):
        super().__init__(exhibits)

        self.exhibit_vetors = {} # exhibit_id -> vector

        # Create vector for every exhibit
        for exhibit in self.exhibits:
            vector = np.zeros(len(topics_data))
            exhibit_topics = get_all_topics_related_to_an_exhibit(exhibit["id"])

            # Set relevance for every topic index
            for exhibit_topic in exhibit_topics:
                topic_index = int(exhibit_topic["topic_data"]["id"].split("_")[1])
                vector[topic_index] = exhibit_topic["relevance"]

            self.exhibit_vetors[exhibit["id"]] = vector




    def retrieve(self, user, visit_history):
        # Get vector from user
        user_vector = np.zeros(len(topics_data))
        user_interests = user["interests"]
        for user_interest in user_interests:
            user_vector[int(user_interest["topic_id"].split("_")[1])] = user_interest["relevance"]

        # normalize 
        user_vector = user_vector / np.linalg.norm(user_vector)

        # Compare user vector to exhibit vectors
        similarities = {}
        for exhibit_id, exhibit_vector in self.exhibit_vetors.items():
            similarities[exhibit_id] = np.dot(user_vector, exhibit_vector)

        ranked_exhibits = sorted(self.exhibits, key=lambda x: -similarities[x["id"]])
        # get IDs
        ret = []
        for exhibit in ranked_exhibits:
            ret.append({
                "exhibit_id": exhibit["id"],
                "score": similarities[exhibit["id"]]
            })

        return ret


class AggregateRetriever(AbstractRetriever):
    def __init__(self):
        self.retrievers = []

    def add_retriever(self, retriever, weight):
        self.retrievers.append({
            "retriever": retriever,
            "weight": weight
        })

    def retrieve(self, user, visit_history):
        # Get all recommendations
        recommendations = []
        for retriever in self.retrievers:
            recommendations.append(retriever["retriever"].retrieve(user, visit_history))

        # Aggregate
        aggregate = {}
        for recommendation in recommendations:
            for rec in recommendation:
                if rec["exhibit_id"] not in aggregate:
                    aggregate[rec["exhibit_id"]] = 0
                aggregate[rec["exhibit_id"]] += rec["score"]

        # convert to list
        ret = []
        for exhibit_id, score in aggregate.items():
            ret.append({
                "exhibit_id": exhibit_id,
                "score": score
            })

        # Sort
        sorted_exhibits = sorted(ret, key=lambda x: -x["score"])
        return list(sorted_exhibits)


popularity_retriever = PopularityRetriever(exhibits_data)
visit_retriever = VisitRetriever(exhibits_data)
interest_retriever = InterestBasedRetriever(exhibits_data)

aggregate_retriever = AggregateRetriever()
aggregate_retriever.add_retriever(popularity_retriever, 1)
aggregate_retriever.add_retriever(visit_retriever, 1)
aggregate_retriever.add_retriever(interest_retriever, 1)

user = {
    "id": "user_0",
    "interests": [
        {
            "topic_id": "topic_0",
            "relevance": 1
        },
        {
            "topic_id": "topic_1",
            "relevance": 1
        }
    ]
}

visit_history = get_all_visits_given_user(user)

aggregate_retriever.retrieve(user, visit_history)


{'exhibit_0', 'exhibit_2'}


[{'exhibit_id': 'exhibit_107', 'score': 3.414213562373095},
 {'exhibit_id': 'exhibit_32', 'score': 3.2142135623730947},
 {'exhibit_id': 'exhibit_109', 'score': 3.060660171779821},
 {'exhibit_id': 'exhibit_124', 'score': 2.8606601717798212},
 {'exhibit_id': 'exhibit_36', 'score': 2.7071067811865475},
 {'exhibit_id': 'exhibit_89', 'score': 2.7071067811865475},
 {'exhibit_id': 'exhibit_110', 'score': 2.7071067811865475},
 {'exhibit_id': 'exhibit_113', 'score': 2.7071067811865475},
 {'exhibit_id': 'exhibit_126', 'score': 2.7071067811865475},
 {'exhibit_id': 'exhibit_128', 'score': 2.7071067811865475},
 {'exhibit_id': 'exhibit_129', 'score': 2.7071067811865475},
 {'exhibit_id': 'exhibit_130', 'score': 2.7071067811865475},
 {'exhibit_id': 'exhibit_131', 'score': 2.7071067811865475},
 {'exhibit_id': 'exhibit_132', 'score': 2.7071067811865475},
 {'exhibit_id': 'exhibit_8', 'score': 2.6606601717798215},
 {'exhibit_id': 'exhibit_146', 'score': 2.5428090415820632},
 {'exhibit_id': 'exhibit_4', 's