Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 15 additions & 33 deletions pynotiondb/mysql_query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Insert,
Null,
Or,
Schema,
Select,
Star,
Update,
Expand All @@ -26,9 +27,10 @@ def _process_string(input_string: list[Expression]) -> list[str]:
def extract_insert_statement_info(self) -> dict:
assert isinstance(self.statement, Insert)
match: Insert = self.statement
schema: Schema = match.this

table_name = match.this.this.this.this
prop_string = match.this.expressions
table_name = schema.this.text("this")
prop_string = schema.expressions
values_string = match.expression.expressions[0].expressions

properties = self._process_string(prop_string)
Expand Down Expand Up @@ -77,13 +79,14 @@ def extract_select_statement_info(self) -> dict:

return {"table_name": table_name, "columns": columns, "conditions": conditions}

def unwrap_where(self, conditions_str) -> dict:
def unwrap_where(self, conditions_str: Binary) -> dict:
assert not isinstance(conditions_str, Where)
assert isinstance(conditions_str, Binary)
if isinstance(conditions_str, And):
return {
"and": [
self.unwrap_where(conditions_str.this),
self.unwrap_where(conditions_str.expression),
self.unwrap_where(conditions_str.left),
self.unwrap_where(conditions_str.right),
]
}
elif isinstance(conditions_str, Or):
Expand All @@ -99,27 +102,13 @@ def unwrap_where(self, conditions_str) -> dict:
def parse_condition(self, op: Binary) -> dict:
operator = type(op).__name__
key = op.left.text("this")
value = op.right

if value.is_int and value.is_number:
value = value.to_py()
elif value.is_string:
value = value.this
elif isinstance(value, Null):
value = None
else:
raise ValueError("Unsupported value type")

return {
"parameter": key,
"operator": operator,
"value": value,
}
value = op.right.to_py()
return {"parameter": key, "operator": operator, "value": value}

def extract_update_statement_info(self) -> dict:
match: Update = self.statement

table_name = match.this.this.this
table_name = match.this.text("this")
set_values_str = match.expressions
where_clause = match.args.get("where")

Expand All @@ -134,26 +123,19 @@ def extract_update_statement_info(self) -> dict:
def extract_delete_statement_info(self) -> dict:
match: Delete = self.statement

table_name = match.this.this.this
table_name = match.this.text("this")
where_clause = match.args.get("where")

return {"table_name": table_name, "where_clause": where_clause}

def extract_set_values(self, set_values_str: list[EQ]) -> list[dict]:
set_values = []
# Split by 'AND', but not within quotes
pairs = set_values_str
for pair in pairs:
for pair in set_values_str:
# Find the position of the first '=' outside quotes

key = pair.this.this.this
value = pair.expression.this

# Handle numeric values
if value.isdigit():
value = int(value)
elif value.replace(".", "", 1).isdigit():
value = float(value)
key = pair.left.text("this")
value = pair.right.to_py()

set_values.append({"key": key, "value": value})
return set_values
Expand Down