Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,17 @@ APP_NAME="ChainReport API"
DEBUG=True
API_V1_PREFIX="/api/v1"
DATABASE_URL="sqlite:///./sql_app.db"

# PostgreSQL Database Configuration
DB_USER="your_db_user"
DB_PASSWORD="your_db_password"
DB_HOST="localhost"
DB_PORT="5432"
DB_NAME="chainreport_db"

# Test PostgreSQL Database Configuration
TEST_DB_USER="postgres"
TEST_DB_PASSWORD="postgres"
TEST_DB_HOST="localhost"
TEST_DB_PORT="5432"
TEST_DB_NAME="test_chainreport_db"
Binary file not shown.
88 changes: 88 additions & 0 deletions backend/app/db/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@

import os
import asyncpg
import asyncio
import logging
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

logger = logging.getLogger(__name__)

class DatabaseConnection:
_pool = None
_pool_lock = asyncio.Lock()

@classmethod
async def connect(cls):
if cls._pool is None:
async with cls._pool_lock:
if cls._pool is None:
try:
db_user = os.getenv("DB_USER")
db_password = os.getenv("DB_PASSWORD")
db_host = os.getenv("DB_HOST")
db_port_str = os.getenv("DB_PORT")
db_name = os.getenv("DB_NAME")

missing_vars = [
name for name, value in {
"DB_USER": db_user,
"DB_PASSWORD": db_password,
"DB_HOST": db_host,
"DB_PORT": db_port_str,
"DB_NAME": db_name,
}.items() if not value
]

if missing_vars:
raise ValueError(
f"Missing or empty database environment variables: {', '.join(missing_vars)}"
)

try:
db_port = int(db_port_str)
except (TypeError, ValueError) as e:
raise ValueError(f"DB_PORT must be an integer: {e}") from e

cls._pool = await asyncpg.create_pool(
user=db_user,
password=db_password,
host=db_host,
port=db_port,
database=db_name,
min_size=1,
max_size=10,
)
logger.info("Database connection pool created successfully.")
except Exception:
logger.error("Error connecting to the database.", exc_info=True)
raise
return cls._pool

@classmethod
async def disconnect(cls):
if cls._pool:
logger.info("Closing database connection pool.")
await cls._pool.close()
cls._pool = None
logger.info("Database connection pool closed.")

@classmethod
async def get_connection(cls):
if cls._pool is None:
await cls.connect()
return await cls._pool.acquire(timeout=30)

@classmethod
async def release_connection(cls, conn):
if cls._pool and conn:
await cls._pool.release(conn)


async def initialize_db_connection():
await DatabaseConnection.connect()

async def close_db_connection():
await DatabaseConnection.disconnect()
Binary file not shown.
58 changes: 58 additions & 0 deletions backend/app/db/tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import os
from dotenv import load_dotenv
from backend.app.db.connection import DatabaseConnection, initialize_db_connection, close_db_connection
import unittest.mock
import asyncpg

# Load environment variables from .env file for testing
load_dotenv()

@pytest.fixture(scope="module", autouse=True)
async def setup_and_teardown_db():
# Mock asyncpg.create_pool and related methods
with unittest.mock.patch('asyncpg.create_pool', new_callable=unittest.mock.AsyncMock) as mock_create_pool:
mock_pool_instance = unittest.mock.AsyncMock()
mock_conn_instance = unittest.mock.AsyncMock()
mock_conn_instance.fetchval.return_value = 1
mock_pool_instance.acquire.return_value = mock_conn_instance
mock_pool_instance.get_size = unittest.mock.MagicMock(return_value=5) # Mock the return value for get_size
mock_create_pool.return_value = mock_pool_instance

# Use mock environment variables for testing database connection
with unittest.mock.patch.dict(os.environ, {
"DB_USER": "test_user",
"DB_PASSWORD": "test_password",
"DB_HOST": "mock_host",
"DB_PORT": "5432",
"DB_NAME": "mock_db",
}):
await initialize_db_connection()
yield
await close_db_connection()

@pytest.mark.asyncio
async def test_database_connection_pool():
pool = await DatabaseConnection.connect()
assert pool is not None
assert pool.get_size() >= 1

conn = await DatabaseConnection.get_connection()
assert conn is not None
await DatabaseConnection.release_connection(conn)

@pytest.mark.asyncio
async def test_get_and_release_connection():
conn = await DatabaseConnection.get_connection()
assert conn is not None
# You can execute a simple query to verify the connection
result = await conn.fetchval("SELECT 1")
assert result == 1
await DatabaseConnection.release_connection(conn)

@pytest.mark.asyncio
async def test_disconnect():
await DatabaseConnection.disconnect()
assert DatabaseConnection._pool is None
# Reconnect for subsequent tests in the same module
await DatabaseConnection.connect()
4 changes: 4 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
asyncio_mode = auto
testpaths = backend/app/db/tests
pythonpath = .
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ pydantic==2.8.0
pydantic-settings==2.11.0
python-dotenv==1.0.0
pytest==8.2.0
pytest-asyncio==0.24.0
httpx==0.25.0
alembic==1.12.0
ruff==0.1.4
ruff==0.1.4
asyncpg==0.30.0