diff --git a/tests/unit/test_sqlserver.py b/tests/unit/test_sqlserver.py index 244408652..0abcb1475 100644 --- a/tests/unit/test_sqlserver.py +++ b/tests/unit/test_sqlserver.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import logging from decimal import Decimal +from typing import Any import boto3 import pyarrow as pa @@ -8,6 +11,7 @@ import awswrangler as wr import awswrangler.pandas as pd +from awswrangler import _databases as _db_utils from .._utils import ensure_data_types, get_df, pandas_equals @@ -17,26 +21,27 @@ @pytest.fixture(scope="module", autouse=True) -def create_sql_server_database(databases_parameters): +def create_sql_server_database(databases_parameters: dict[str, Any]) -> None: + attrs = _db_utils.get_connection_attributes(connection="aws-sdk-pandas-sqlserver") connection_str = ( f"DRIVER={{ODBC Driver 17 for SQL Server}};" - f"SERVER={databases_parameters['sqlserver']['host']},{databases_parameters['sqlserver']['port']};" - f"UID={databases_parameters['user']};" - f"PWD={databases_parameters['password']}" + f"SERVER={attrs.host},{attrs.port};" + f"UID={attrs.user};" + f"PWD={attrs.password}" ) - con = pyodbc.connect(connection_str, autocommit=True) + + database_name = databases_parameters["sqlserver"]["database"] sql_create_db = ( - f"IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '{databases_parameters['sqlserver']['database']}') " + f"IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '{database_name}') " "BEGIN " - f"CREATE DATABASE {databases_parameters['sqlserver']['database']} " + f"CREATE DATABASE {database_name} " "END" ) - with con.cursor() as cursor: - cursor.execute(sql_create_db) - con.commit() - con.close() - yield + with pyodbc.connect(connection_str, autocommit=True) as con: + with con.cursor() as cursor: + cursor.execute(sql_create_db) + con.commit() @pytest.fixture(scope="function")