In [40]:
from turtle import title
from typing import Any

import torch
from sentence_transformers import SentenceTransformer


class CatalogTable:
    def __init__(self, title: Any):
        self.llm = SentenceTransformer("all-MiniLM-L6-v2")
        self.title = title
        self.title_embedding = torch.from_numpy(self.llm.encode(title))
        self.uuids = []
        self.embeddings = torch.Tensor()
    
    def extract_possible_list_of_values_from_string(self, possible_list_of_values_as_string: str) -> list[str]:
        possible_list_elements =[
            possible_list_element.strip() for
            possible_list_element in ";".join(possible_list_of_values_as_string.split(",")).split(";")
        ]

        number_of_words = len(possible_list_of_values_as_string.split(" "))
        
        if len(possible_list_elements) > number_of_words / 3:
            return possible_list_elements
        
        return [possible_list_of_values_as_string]

    def extract_value_from_catalog_element(self, catalog_element: Any) -> list[str]:
        if type(catalog_element) is str:
            return self.extract_possible_list_of_values_from_string(catalog_element)
        elif type(catalog_element) is list:
            list_of_values = []
            for _catalog_element in catalog_element:
                if (
                    type(_catalog_element) is int or type(_catalog_element) is float
                ):
                    list_of_values.append(str(_catalog_element))
                else:
                    list_of_values.extend(
                                self.extract_possible_list_of_values_from_string(
                            possible_list_of_values_as_string=_catalog_element
                        )
                    )
            return list_of_values
        elif type(catalog_element) is int or type(catalog_element) is float:
            return list(str(catalog_element))
        else:
            raise RuntimeError(f"Unexpected value for catalog in input: {catalog_element}")

    def add_element(self, identifier: int, element: Any):
        list_of_elements = self.extract_value_from_catalog_element(element)
        for content_element in list_of_elements:
            self.uuids.append(identifier)

            self.embeddings = torch.cat(
                [
                    self.embeddings,
                    torch.from_numpy(self.llm.encode(content_element.lower())).unsqueeze(0)
                ]
            )
    
    def find_elements(self, value: Any):
        value_embedding = torch.from_numpy(self.llm.encode(str(value))).unsqueeze(0)
        pairwise_l2_distances = torch.cdist(value_embedding, self.embeddings)
        index_of_closest_object = torch.argmin(pairwise_l2_distances, dim=1)
        return self.uuids[index_of_closest_object]

class CatalogRetrievalDatabase():
    def __init__(self, data: list[dict[str, Any]]) -> None:
        self.llm = SentenceTransformer("all-MiniLM-L6-v2")
        self.id_to_data_entry = {
            i: data_entry for i, data_entry in enumerate(data)
        }
        self.table_title_to_table: dict[str, CatalogTable] = {}
        self.populate_list_of_tables()

        self.titles = [title for title, _ in self.table_title_to_table.items()]
        self.title_embeddings = torch.stack(
            [
                torch.from_numpy(
                    self.llm.encode(title)
                ) for title, _ in self.table_title_to_table.items()
            ], dim=0
        )

    def populate_list_of_tables(self) -> None:
        for id, catalog_entry in self.id_to_data_entry.items():
            for title, value in catalog_entry.items():
                if title in self.table_title_to_table:
                    table = self.table_title_to_table[title]
                    table.add_element(id, value)
                else:
                    self.table_title_to_table[title] = CatalogTable(title)
                    self.table_title_to_table[title].add_element(id, value)

    def find_by_category_and_value(self, category: str, value: Any) -> dict[str, Any]:
        category_embedding = torch.from_numpy(self.llm.encode(category)).unsqueeze(0)
        pairwise_l2_distances = torch.cdist(category_embedding, self.title_embeddings)
        index_of_closest_category = torch.argmin(pairwise_l2_distances, dim=1)
        corresponding_table = self.table_title_to_table[self.titles[index_of_closest_category]]
        uuid = corresponding_table.find_elements(value)
        return self.id_to_data_entry[uuid]
    

In [41]:
input: list[dict[str, Any]] = [
    {
        "Color": "Black",
        "Price": 4000,
    },
    {"Size": [30, 50]},
    {
        "Size": [10, 20],
        "Color": "Blue, Red; Green",
        "Price": 3000,
    },
]

database = CatalogRetrievalDatabase(input)
database.find_by_category_and_value("Green", 20)

for category, value in (
    ("Price", 3000),
    ("Colour", "Green"),
    ("Colors", "Blue"),
    ("Colrs", "Blck"),
    ("Size", 50),
):
    print(
        database.find_by_category_and_value(
            category=category,
            value=value
        )
    )

{'Color': 'Black', 'Price': 4000}
{'Size': [10, 20], 'Color': 'Blue, Red; Green', 'Price': 3000}
{'Size': [10, 20], 'Color': 'Blue, Red; Green', 'Price': 3000}
{'Color': 'Black', 'Price': 4000}
{'Size': [30, 50]}
