diff --git a/pynotiondb/mysql_query_parser.py b/pynotiondb/mysql_query_parser.py index b3a0d0f..8ac51f9 100644 --- a/pynotiondb/mysql_query_parser.py +++ b/pynotiondb/mysql_query_parser.py @@ -8,6 +8,7 @@ Insert, Null, Or, + Schema, Select, Star, Update, @@ -26,9 +27,10 @@ def _process_string(input_string: list[Expression]) -> list[str]: def extract_insert_statement_info(self) -> dict: assert isinstance(self.statement, Insert) match: Insert = self.statement + schema: Schema = match.this - table_name = match.this.this.this.this - prop_string = match.this.expressions + table_name = schema.this.text("this") + prop_string = schema.expressions values_string = match.expression.expressions[0].expressions properties = self._process_string(prop_string) @@ -77,13 +79,14 @@ def extract_select_statement_info(self) -> dict: return {"table_name": table_name, "columns": columns, "conditions": conditions} - def unwrap_where(self, conditions_str) -> dict: + def unwrap_where(self, conditions_str: Binary) -> dict: assert not isinstance(conditions_str, Where) + assert isinstance(conditions_str, Binary) if isinstance(conditions_str, And): return { "and": [ - self.unwrap_where(conditions_str.this), - self.unwrap_where(conditions_str.expression), + self.unwrap_where(conditions_str.left), + self.unwrap_where(conditions_str.right), ] } elif isinstance(conditions_str, Or): @@ -99,27 +102,13 @@ def unwrap_where(self, conditions_str) -> dict: def parse_condition(self, op: Binary) -> dict: operator = type(op).__name__ key = op.left.text("this") - value = op.right - - if value.is_int and value.is_number: - value = value.to_py() - elif value.is_string: - value = value.this - elif isinstance(value, Null): - value = None - else: - raise ValueError("Unsupported value type") - - return { - "parameter": key, - "operator": operator, - "value": value, - } + value = op.right.to_py() + return {"parameter": key, "operator": operator, "value": value} def extract_update_statement_info(self) -> dict: match: Update = self.statement - table_name = match.this.this.this + table_name = match.this.text("this") set_values_str = match.expressions where_clause = match.args.get("where") @@ -134,7 +123,7 @@ def extract_update_statement_info(self) -> dict: def extract_delete_statement_info(self) -> dict: match: Delete = self.statement - table_name = match.this.this.this + table_name = match.this.text("this") where_clause = match.args.get("where") return {"table_name": table_name, "where_clause": where_clause} @@ -142,18 +131,11 @@ def extract_delete_statement_info(self) -> dict: def extract_set_values(self, set_values_str: list[EQ]) -> list[dict]: set_values = [] # Split by 'AND', but not within quotes - pairs = set_values_str - for pair in pairs: + for pair in set_values_str: # Find the position of the first '=' outside quotes - key = pair.this.this.this - value = pair.expression.this - - # Handle numeric values - if value.isdigit(): - value = int(value) - elif value.replace(".", "", 1).isdigit(): - value = float(value) + key = pair.left.text("this") + value = pair.right.to_py() set_values.append({"key": key, "value": value}) return set_values