Skip to content

graph.ops.KuzuOps should provide a general method to update_embeddings #2

@DataLabTechTV

Description

@DataLabTechTV

Maybe we need to pass a table_names: list[str] parameter, so that we can generalize the ALTER TABLE statements. We might optionally pass a parameter to customize the embeddings property name as well.

def update_embeddings(self, embeddings: dict[int, list[float]]):
    self.conn.execute("ALTER TABLE User ADD IF NOT EXISTS embedding DOUBLE[]")
    self.conn.execute("ALTER TABLE Genre ADD IF NOT EXISTS embedding DOUBLE[]")
    self.conn.execute("ALTER TABLE Track ADD IF NOT EXISTS embedding DOUBLE[]")

    batch = [dict(nid=nid, e=e) for nid, e in embeddings.items()]

    self.conn.execute(
        """
        UNWIND $batch AS batch
        MATCH (n {node_id: batch.nid})
        SET n.embedding = batch.e
        """,
        parameters=dict(batch=batch),
    )

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions