Skip to content

Commit

Permalink
Add sqlite3 interface (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
aerorahul committed Mar 7, 2024
1 parent a99775c commit 97159f7
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/wxflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .fsutils import chdir, cp, mkdir, mkdir_p, rm_p, rmdir
from .jinja import Jinja
from .logger import Logger, logit
from .sqlitedb import SQLiteDB
from .task import Task
from .template import Template, TemplateConstants
from .timetools import *
Expand Down
188 changes: 188 additions & 0 deletions src/wxflow/sqlitedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import sqlite3
from typing import Any, List, Optional, Tuple

__all__ = ["SQLiteDB"]


class SQLiteDB:
"""
A class for interacting with an SQLite3 database.
Parameters:
db_name (str): The name of the SQLite database file.
Attributes:
db_name (str): The name of the SQLite database file.
connection (sqlite3.Connection): The connection object for the database.
"""

def __init__(self, db_name: str) -> None:
self.db_name = db_name
self.connection: Optional[sqlite3.Connection] = None

def connect(self) -> None:
"""
Connects to the SQLite database.
"""

try:
self.connection = sqlite3.connect(self.db_name)
except sqlite3.OperationalError as exc:
raise sqlite3.OperationalError(exc)

def disconnect(self) -> None:
"""
Disconnects from the SQLite database.
"""

if self.connection:
self.connection.close()

def execute_query(self, query: str, params: Optional[Tuple[Any, ...]] = None) -> sqlite3.Cursor:
"""
Executes an SQL query.
Parameters:
query (str): The SQL query to execute.
params (tuple, optional): The parameters to be passed to the query.
Returns:
cursor (sqlite3.Cursor): The cursor object.
"""

cursor = self.connection.cursor()
if params:
cursor.execute(query, params)
else:
cursor.execute(query)
self.connection.commit()
return cursor

def create_table(self, table_name: str, columns: List[str]) -> None:
"""
Creates a table in the database.
Parameters:
table_name (str): The name of the table to create.
columns (list): The list of column definitions.
"""

query = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)})"
self.execute_query(query)

def add_column(self, table_name: str, column_name: str, column_type: str) -> None:
"""
Adds a column to an existing table.
Parameters:
table_name (str): The name of the table.
column_name (str): The name of the column to add.
column_type (str): The data type of the column.
"""

query = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
self.execute_query(query)

def remove_column(self, table_name: str, column_name: str) -> None:
"""
Removes a column from an existing table.
Parameters:
table_name (str): The name of the table.
column_name (str): The name of the column to remove.
"""

try:
query = f"ALTER TABLE {table_name} DROP COLUMN {column_name}"
self.execute_query(query)
except sqlite3.OperationalError as exc:
query = f"PRAGMA table_info({table_name})"
cursor = self.execute_query(query)
columns = [column[1] for column in cursor.fetchall()]
if column_name not in columns:
raise ValueError(f"Column '{column_name}' does not exist in table '{table_name}'")
raise sqlite3.OperationalError(exc)

def update_data(
self,
table_name: str,
column_name: str,
new_value: Any,
condition_column: str,
condition_value: Any
) -> None:
"""
Updates data in a table.
Parameters:
table_name (str): The name of the table.
column_name (str): The name of the column to update.
new_value (any): The new value for the column.
condition_column (str): The column to use for the condition.
condition_value (any): The value to use in the condition.
"""

query = f"UPDATE {table_name} SET {column_name} = ? WHERE {condition_column} = ?"
self.execute_query(query, (new_value, condition_value))

def insert_data(self, table_name: str, values: List[Any]) -> None:
"""
Inserts data into a table.
Parameters:
table_name (str): The name of the table.
values (list): The values to insert.
"""

placeholders = ", ".join(["?"] * len(values))
query = f"INSERT INTO {table_name} VALUES ({placeholders})"
self.execute_query(query, values)

def fetch_data(
self,
table_name: str,
columns: Optional[List[str]] = None,
condition: Optional[str] = None
) -> List[Tuple]:
"""
Fetches data from a table.
Parameters:
table_name (str): The name of the table.
columns (list, optional): The list of columns to fetch.
condition (str, optional): The condition to use in the query.
Returns:
result (list): The fetched data.
"""

column_names = "*" if not columns else ", ".join(columns)
query = f"SELECT {column_names} FROM {table_name}"
if condition:
query += f" WHERE {condition}"
cursor = self.execute_query(query)
return cursor.fetchall()

def remove_data(self, table_name: str, condition_column: str, condition_value: Any) -> None:
"""
Removes data from a table.
Parameters:
table_name (str): The name of the table.
condition_column (str): The column to use for the condition.
condition_value (any): The value to use in the condition.
"""

query = f"DELETE FROM {table_name} WHERE {condition_column} = ?"
self.execute_query(query, (condition_value,))
122 changes: 122 additions & 0 deletions tests/test_sqlitedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest

from wxflow import SQLiteDB


@pytest.fixture(scope="module")
def db():
# Create an in-memory SQLite database for testing
db = SQLiteDB(":memory:")
db.connect()

# Create a test table
table_name = "test_table"
columns = ["id INTEGER PRIMARY KEY", "name TEXT", "age INTEGER"]
db.create_table(table_name, columns)

yield db

# Disconnect from the database
db.disconnect()


def test_create_table(db):
# Verify that the test table exists
assert table_exists(db, "test_table")


def test_add_column(db):
# Add a new column to the test table
column_name = "address"
column_type = "TEXT"
db.add_column("test_table", column_name, column_type)

# Verify that the column exists in the test table
assert column_exists(db, "test_table", column_name)


def test_update_data(db):
# Insert test data into the table
values = [1, "Alice", 25, 'Apt 101']
db.insert_data("test_table", values)

# Update the age of the record
new_age = 30
db.update_data("test_table", "age", new_age, "name", "Alice")

# Fetch the updated data
result = db.fetch_data("test_table", condition="name='Alice'")

# Verify that the age is updated correctly
assert result[0][2] == new_age


def test_remove_column(db):
# Removes a column from the test table
column_name = "address"
db.remove_column("test_table", column_name)

# Verify that the column exists in the test table
assert not column_exists(db, "test_table", column_name)


def test_remove_column_raises_error_when_column_not_exists(db):
table_name = "test_table"
column_name = "vacation address"

with pytest.raises(ValueError, match=f"Column '{column_name}' does not exist in table '{table_name}'"):
db.remove_column("test_table", column_name)


def test_insert_data(db):
# Insert test data into the table
values = [2, "Bob", 35]
db.insert_data("test_table", values)

# Fetch all data from the table
result = db.fetch_data("test_table")

# Verify that the inserted data is present in the table
assert len(result) == 2


def test_fetch_data(db):
# Insert test data into the table
values = [3, "Charlie", 40]
db.insert_data("test_table", values)

# Fetch data from the table
result = db.fetch_data("test_table", condition="age > 30")

# Verify that the fetched data meets the condition
assert len(result) == 2


def test_remove_data(db):
# Insert test data into the table
values = [4, "David", 45]
db.insert_data("test_table", values)

# Remove a record from the table
db.remove_data("test_table", "name", "David")

# Fetch all data from the table
result = db.fetch_data("test_table")

# Verify that the removed data is not present in the table
assert len(result) == 3


# Helper functions

def table_exists(db, table_name):
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'"
cursor = db.execute_query(query)
return cursor.fetchone() is not None


def column_exists(db, table_name, column_name):
query = f"PRAGMA table_info({table_name})"
cursor = db.execute_query(query)
columns = [column[1] for column in cursor.fetchall()]
return column_name in columns

0 comments on commit 97159f7

Please sign in to comment.