Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions services/vector_db_service.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,38 @@

"""
Service for storing and retrieving vectors using Pinecone (scaffold).
"""

# To use: pip install pinecone-client
import os
from typing import List
from config.settings import settings

class VectorDBClient:
def __init__(self, api_key: str, dimension: int = None, cloud: str = "aws", region: str = "us-east-1", environment: str = None):
def __init__(self, dimension: int = None, index_name: str = None):
"""
Initialize Pinecone client for serverless (cloud/region) or legacy (environment).
:param api_key: Pinecone API key
:param dimension: Vector dimension (required for index ops)
:param cloud: Serverless cloud provider (default 'aws')
:param region: Serverless region (default 'us-east-1')
:param environment: Legacy environment (optional, unused for serverless)
Initialize Pinecone client for serverless (cloud/region).
Loads API key and config from settings.
"""
from pinecone import Pinecone
self.pc = Pinecone(api_key=api_key) # For serverless, environment is ignored
self.cloud = cloud
self.region = region
self.environment = environment
self.pc = Pinecone(api_key=settings.pinecone_api_key.get_secret_value())
self.cloud = settings.pinecone_cloud
self.region = settings.pinecone_region
self.index_name = index_name or settings.pinecone_index_name
self.index = None
self.dimension = dimension # Must be set for upsert/query; can be set at index creation
self.dimension = dimension or settings.embedding_dimension

def create_index(self, index_name: str, dimension: int):
def create_index(self, index_name: str = None, dimension: int = None):
from pinecone import ServerlessSpec
idx_name = index_name or self.index_name
dim = dimension or self.dimension
existing = [idx.name for idx in self.pc.list_indexes()]
if index_name not in existing:
if idx_name not in existing:
self.pc.create_index(
name=index_name,
dimension=dimension,
name=idx_name,
dimension=dim,
spec=ServerlessSpec(cloud=self.cloud, region=self.region),
)
self.index = self.pc.Index(index_name)
self.index = self.pc.Index(idx_name)

def upsert_vectors(self, vectors: List[List[float]], ids: List[str]):
# Input validation
Expand All @@ -47,7 +46,6 @@ def upsert_vectors(self, vectors: List[List[float]], ids: List[str]):
raise ValueError("vector dimensionality mismatch with index.")
if not self.index:
raise RuntimeError("Index is not initialized. Call create_index first.")
# Upsert format may vary by pinecone-client version.
self.index.upsert(vectors=[(id, vec) for id, vec in zip(ids, vectors)])

def query(self, vector: List[float], top_k: int = 5):
Expand All @@ -56,3 +54,16 @@ def query(self, vector: List[float], top_k: int = 5):
if self.dimension is not None and len(vector) != self.dimension:
raise ValueError("query vector dimensionality mismatch with index.")
return self.index.query(vector=vector, top_k=top_k)

# Interactive test block
if __name__ == "__main__":
vdb = VectorDBClient()
vdb.create_index()
print("Index created successfully!")
# Example upsert and query (uncomment to use):
# ids = ["id1", "id2"]
# vectors = [[0.0]*vdb.dimension, [1.0]*vdb.dimension]
# vdb.upsert_vectors(vectors, ids)
# print("Upserted vectors.")
# result = vdb.query([0.0]*vdb.dimension)
# print("Query result:", result)
119 changes: 34 additions & 85 deletions tests/test_vector_db_service.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are getting rid of this file completely. its testing pinecone client which we dont need

Original file line number Diff line number Diff line change
@@ -1,97 +1,46 @@

# Patch config.settings and env vars before any other imports
import sys
import types
import os
from pydantic import SecretStr

# Set required env vars and patch config.settings before any other imports
os.environ["INTERNAL_API_KEY"] = "test-key"

class DummySettings:
pinecone_api_key = SecretStr("test-key")
pinecone_cloud = "aws"
pinecone_region = "us-east-1"
pinecone_index_name = "authormaton-core"
embedding_model = "test-model"
embedding_dimension = 16
embed_batch_size = 64
max_upload_mb = 25

dummy_config = types.ModuleType("config.settings")
dummy_config.settings = DummySettings()
sys.modules["config.settings"] = dummy_config


# Patch Pinecone client for tests to avoid real API calls
from unittest.mock import MagicMock
class MockIndex:
def __init__(self, dimension):
self.dimension = dimension
self.upserted = []
def upsert(self, items, namespace=None):
self.upserted.extend(items)
def query(self, vector, top_k=8, namespace=None, filter=None):
return {'matches': [{'id': 'test', 'score': 0.99}]}

mock_pc = MagicMock()
mock_pc.list_indexes.return_value = []
mock_pc.create_index.return_value = None
mock_pc.Index.side_effect = lambda name: MockIndex(16)
mock_pc.describe_index.return_value = {"dimension": 16}

mock_pinecone = MagicMock()
mock_pinecone.Pinecone = MagicMock(return_value=mock_pc)
mock_pinecone.ServerlessSpec = MagicMock()
sys.modules["pinecone"] = mock_pinecone

# Now import everything else
import pytest
from services.vector_db_service import VectorDBService
from services.vector_db_service import VectorDBClient

class MockIndex:
class DummyIndex:
def __init__(self, dimension):
self.dimension = dimension
self.upserted = []
def upsert(self, items, namespace=None):
self.upserted.extend(items)
def query(self, vector, top_k=8, namespace=None, filter=None):
return {'matches': [{'id': 'test', 'score': 0.99}]}

def test_ensure_index_idempotent(monkeypatch):
svc = VectorDBService()
# Simulate: first call creates the index; second call sees it and skips creation.
mock_pc.create_index.reset_mock()
import types
mock_pc.list_indexes.side_effect = [
[],
[types.SimpleNamespace(name=dummy_config.settings.pinecone_index_name)],
]

svc.ensure_index(svc.embedding_dimension)
svc.ensure_index(svc.embedding_dimension)
# Should only ever create the index once
assert mock_pc.create_index.call_count == 1
assert svc.index.dimension == svc.embedding_dimension

def test_upsert_dimension_guard(monkeypatch):
svc = VectorDBService()
monkeypatch.setattr(svc, 'index', MockIndex(svc.embedding_dimension))
ids = ['a', 'b']
vectors = [[0.0]*svc.embedding_dimension, [0.0]*svc.embedding_dimension]
metadata = [{}, {}]
count = svc.upsert(namespace='proj', ids=ids, vectors=vectors, metadata=metadata)
assert count == 2
def upsert(self, vectors=None):
if vectors:
self.upserted.extend(vectors)
def query(self, vector, top_k=5):
return {'matches': [{'id': 'id1', 'score': 0.99}]}

@pytest.fixture
def vdb(monkeypatch):
svc = VectorDBClient(dimension=8, index_name="test-index")
monkeypatch.setattr(svc, 'index', DummyIndex(svc.dimension))
return svc

def test_create_index(monkeypatch):
svc = VectorDBClient(dimension=8, index_name="test-index")
monkeypatch.setattr(svc.pc, 'list_indexes', lambda: [])
monkeypatch.setattr(svc.pc, 'create_index', lambda **kwargs: None)
monkeypatch.setattr(svc.pc, 'Index', lambda name: DummyIndex(8))
svc.create_index()
assert svc.index.dimension == 8

def test_upsert_vectors(vdb):
ids = ["id1", "id2"]
vectors = [[0.0]*8, [1.0]*8]
vdb.upsert_vectors(vectors, ids)
assert len(vdb.index.upserted) == 2
# Wrong dimension
with pytest.raises(ValueError):
svc.upsert(namespace='proj', ids=ids, vectors=[[0.0]*10, [0.0]*10], metadata=metadata)
vdb.upsert_vectors([[0.0]*5, [1.0]*5], ids)

def test_query_dimension_guard(monkeypatch):
svc = VectorDBService()
monkeypatch.setattr(svc, 'index', MockIndex(svc.embedding_dimension))
vector = [0.0]*svc.embedding_dimension
matches = svc.query(namespace='proj', vector=vector)
assert matches[0]['id'] == 'test'
def test_query(vdb):
vector = [0.0]*8
result = vdb.query(vector)
assert result['matches'][0]['id'] == 'id1'
# Wrong dimension
with pytest.raises(ValueError):
svc.query(namespace='proj', vector=[0.0]*10)
vdb.query([0.0]*5)