Skip to content

Commit

Permalink
test: Adding integration tests for spanner retriever service (#364)
Browse files Browse the repository at this point in the history
The tests takes following inputs from environments: 
1. DB_PROJECT
2. DB_INSTANCE
3. DB_NAME
4. SERVICE_ACCOUNT_KEY_FILE_PATH

---------

Co-authored-by: Yuan <45984206+Yuan325@users.noreply.github.com>
  • Loading branch information
gauravpurohit06 and Yuan325 committed Jun 3, 2024
1 parent 67470e9 commit 251f16d
Show file tree
Hide file tree
Showing 3 changed files with 654 additions and 22 deletions.
48 changes: 27 additions & 21 deletions retrieval_service/datastore/providers/spanner_gsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,14 @@ async def initialize_data(

# Insert data into 'flights' table using batch operation
values = [
tuple(getattr(flight, field) for field in self.FLIGHTS_COLUMNS)
tuple(
(
str(getattr(flight, field))
if isinstance(getattr(flight, field), datetime.datetime)
else getattr(flight, field)
)
for field in self.FLIGHTS_COLUMNS
)
for flight in flights
]

Expand Down Expand Up @@ -520,7 +527,7 @@ async def get_airport_by_iata(self, iata: str) -> Optional[models.Airport]:
with self.__database.snapshot() as snapshot:
# Execute SQL query to fetch airport by ID
result = snapshot.execute_sql(
sql="SELECT * FROM airports WHERE iata LIKE @iata",
sql="SELECT * FROM airports WHERE LOWER(iata) LIKE LOWER(@iata)",
params={"iata": iata},
param_types={"iata": param_types.STRING},
)
Expand Down Expand Up @@ -560,10 +567,9 @@ async def search_airports(
# Construct SQL query based on provided parameters
query = """
SELECT * FROM airports
WHERE (@country IS NULL OR country LIKE @country)
AND (@city IS NULL OR city LIKE @city)
AND (@name IS NULL OR name LIKE '%' || @name || '%')
LIMIT 10
WHERE (@country IS NULL OR LOWER(country) LIKE LOWER(@country))
AND (@city IS NULL OR LOWER(city) LIKE LOWER(@city))
AND (@name IS NULL OR LOWER(name) LIKE '%' || LOWER(@name) || '%')
"""

# Execute SQL query with parameters
Expand Down Expand Up @@ -649,7 +655,7 @@ async def amenities_search(
COSINE_DISTANCE(embedding, @query_embedding) AS similarity
FROM amenities
) AS sorted_amenities
WHERE (2 - similarity) > @similarity_threshold
WHERE (1 - similarity) > @similarity_threshold
ORDER BY similarity
LIMIT @top_k
"""
Expand All @@ -671,7 +677,7 @@ async def amenities_search(

# Convert query result to model instance using model_validate method
amenities = [
{key: value for key, value in zip(self.AMENITIES_COLUMNS, a)}
{key: value for key, value in zip(self.AMENITIES_COLUMNS[1:], a)}
for a in results
]

Expand Down Expand Up @@ -773,8 +779,8 @@ async def search_flights_by_airports(
# Spread SQL query for readability
query = """
SELECT * FROM flights
WHERE (@departure_airport IS NULL OR departure_airport LIKE @departure_airport)
AND (@arrival_airport IS NULL OR arrival_airport LIKE @arrival_airport)
WHERE (@departure_airport IS NULL OR LOWER(departure_airport) LIKE LOWER(@departure_airport))
AND (@arrival_airport IS NULL OR LOWER(arrival_airport) LIKE LOWER(@arrival_airport))
AND cast(departure_time as TIMESTAMP) >= CAST(@datetime AS TIMESTAMP)
AND cast(departure_time as TIMESTAMP) < TIMESTAMP_ADD(CAST(@datetime AS TIMESTAMP), INTERVAL 1 DAY)
LIMIT 10
Expand All @@ -786,7 +792,7 @@ async def search_flights_by_airports(
params={
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"datetime": datetime.datetime.strptime(date, "%Y-%m-%d"),
"datetime": date,
},
param_types={
"departure_airport": param_types.STRING,
Expand Down Expand Up @@ -819,10 +825,10 @@ async def validate_ticket(
results = snapshot.execute_sql(
sql="""
SELECT * FROM flights
WHERE airline LIKE @airline
AND flight_number LIKE @flight_number
AND departure_airport LIKE @departure_airport
AND arrival_airport LIKE @arrival_airport
WHERE LOWER(airline) LIKE LOWER(@airline)
AND LOWER(flight_number) LIKE LOWER(@flight_number)
AND LOWER(departure_airport) LIKE LOWER(@departure_airport)
AND LOWER(arrival_airport) LIKE LOWER(@arrival_airport)
AND departure_time = @departure_time
AND arrival_time = @arrival_time
""",
Expand All @@ -831,8 +837,8 @@ async def validate_ticket(
"flight_number": flight_number,
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"departure_time": departure_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
"arrival_time": arrival_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
"departure_time": departure_time.strftime("%Y-%m-%d %H:%M:%S"),
"arrival_time": arrival_time.strftime("%Y-%m-%d %H:%M:%S"),
},
param_types={
"airline": param_types.STRING,
Expand Down Expand Up @@ -878,10 +884,10 @@ async def insert_ticket(
arrival_time (str): The arrival time of the flight.
"""
departure_time_datetime = datetime.datetime.strptime(
departure_time, "%Y-%m-%dT%H:%M:%S.%fZ"
departure_time, "%Y-%m-%d %H:%M:%S"
)
arrival_time_datetime = datetime.datetime.strptime(
arrival_time, "%Y-%m-%dT%H:%M:%S.%fZ"
arrival_time, "%Y-%m-%d %H:%M:%S"
)

if not await self.validate_ticket(
Expand Down Expand Up @@ -991,8 +997,8 @@ async def policies_search(
SELECT content, COSINE_DISTANCE(embedding, @query_embedding) AS similarity
FROM policies
) AS sorted_policies
WHERE (2 - similarity) > @similarity_threshold
ORDER BY similarity DESC
WHERE (1 - similarity) > @similarity_threshold
ORDER BY similarity
LIMIT @top_k
"""

Expand Down
Loading

0 comments on commit 251f16d

Please sign in to comment.