Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions __snapshots__/test_parser.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,37 @@
# serializer version: 1
# name: test_notion
dict({
'parent': dict({
'database_id': 'PARENT_PAGE',
}),
'properties': dict({
'id': dict({
'number': True,
}),
'title': dict({
'title': dict({
}),
}),
}),
'title': list([
dict({
'text': dict({
'content': 'table1',
}),
}),
]),
})
# ---
# name: test_sql_parser[CREATE TABLE IF NOT EXISTS table1 (title title, id int);]
dict({
'columns': dict({
'id': 'INT',
'title': 'title',
}),
'exists': True,
'table_name': 'table1',
})
# ---
# name: test_sql_parser[DELETE FROM table1 WHERE column1 = 'value1';]
dict({
'table_name': 'table1',
Expand Down
23 changes: 23 additions & 0 deletions pynotiondb/mysql_query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
EQ,
And,
Binary,
Create,
Delete,
Expression,
Insert,
Expand Down Expand Up @@ -128,6 +129,23 @@ def extract_delete_statement_info(self) -> dict:

return {"table_name": table_name, "where_clause": where_clause}

def extract_create_statement_info(self) -> dict:
match: Create = self.statement
assert match.kind == "TABLE"
schema = match.this

table_name = schema.this.text("this")
columns = {
col.text("this"): col.kind.args.get("kind", col.kind.this.value)
for col in schema.expressions
}

return {
"table_name": table_name,
"columns": columns,
"exists": match.args.get("exists", False),
}

def extract_set_values(self, set_values_str: list[EQ]) -> list[dict]:
set_values = []
# Split by 'AND', but not within quotes
Expand All @@ -153,6 +171,9 @@ def parse(self) -> dict:
if isinstance(self.statement, Delete):
return self.extract_delete_statement_info()

if isinstance(self.statement, Create):
return self.extract_create_statement_info()

raise ValueError("Invalid SQL statement")

def check_statement(self) -> tuple[bool, str]:
Expand All @@ -164,5 +185,7 @@ def check_statement(self) -> tuple[bool, str]:
return True, "update"
if isinstance(self.statement, Delete):
return True, "delete"
elif isinstance(self.statement, Create):
return True, "create"

return False, "unknown"
46 changes: 43 additions & 3 deletions pynotiondb/notion_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
import logging
from typing import Optional
from functools import lru_cache

import requests
from notion_client import Client

from .exceptions import NotionAPIError
from .mysql_query_parser import MySQLQueryParser

logger = logging.getLogger(__name__)


def format_type(s: str) -> dict:
if s == "INT":
return {"number": True}
elif s == "VARCHAR":
return {"rich_text": {}}
elif s == "title":
return {"title": {}}
else:
raise Exception(s)


class NotionAPI:
SEARCH = "https://api.notion.com/v1/search"
PAGES = "https://api.notion.com/v1/pages"
Expand All @@ -17,7 +31,7 @@ class NotionAPI:
QUERY_DATABASE = "https://api.notion.com/v1/databases/{}/query"
DEFAULT_PAGE_SIZE_FOR_SELECT_STATEMENTS = 20
token: str
databases: dict[str, str]
table_parent_page: Optional[str]

CONDITION_MAPPING = {
"EQ": "equals",
Expand All @@ -27,9 +41,14 @@ class NotionAPI:
">=": "greater_than_or_equal_to",
}

def __init__(self, token: str, databases: dict[str, str]) -> None:
def __init__(
self,
token: str,
*,
table_parent_page: Optional[str] = None,
) -> None:
self.token = token
self.databases = databases
self.table_parent_page = table_parent_page
self.DEFAULT_NOTION_VERSION = "2022-06-28"
self.AUTHORIZATION = "Bearer " + self.token
self.headers = {
Expand All @@ -39,6 +58,7 @@ def __init__(self, token: str, databases: dict[str, str]) -> None:
}
self.session = requests.Session()
self.session.headers.update(self.headers)
self.client = Client(auth=token)

def request_helper(self, url: str, method: str = "GET", payload=None):
response = self.session.request(method, url, json=payload)
Expand Down Expand Up @@ -162,6 +182,11 @@ def get_all_database_info(self, cursor=None, page_size=20):

return data

@property
@lru_cache()
def databases(self):
return {db["title"]: db["id"] for db in self.get_all_database_info()["results"]}

def get_all_database(self):
dbs = self.get_all_database_info()
databases = [db.get("title") for db in dbs.get("results")]
Expand Down Expand Up @@ -419,6 +444,17 @@ def update(self, query) -> None:
payload=payload,
)

def create(self, query: str) -> None:
if not self.table_parent_page:
raise Exception("Parent for new tables must be specified")
parsed_data = MySQLQueryParser(query).parse()
props = {col: format_type(typ) for col, typ in parsed_data["columns"].items()}
return self.client.databases.create(
title=[{"text": {"content": parsed_data["table_name"]}}],
parent={"database_id": self.table_parent_page},
properties=props,
)

def delete(self, query) -> None:
parsed_data = MySQLQueryParser(query).parse()

Expand Down Expand Up @@ -468,10 +504,14 @@ def execute(self, sql, val=None):
elif to_do == "delete":
self.delete(query)

elif to_do == "create":
return self.create(query)

else:
raise ValueError("Unsupported operation")

else:
raise ValueError(
"Invalid SQL statement or type of statement not implemented"
)

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = [
"notion-client>=2.5.0",
"requests>=2.0.0",
"respx>=0.22.0",
"sqlglot>=18.2.0",
]

Expand Down
17 changes: 17 additions & 0 deletions test_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import json

from pytest import mark
from respx import MockRouter

from pynotiondb import NotionAPI
from pynotiondb.mysql_query_parser import MySQLQueryParser

create_sql = "CREATE TABLE IF NOT EXISTS table1 (title title, id int);"


@mark.parametrize(
"sql",
Expand All @@ -13,6 +19,7 @@
"DELETE FROM table1 WHERE column1 = 'value1';",
"SELECT * FROM table1 WHERE column1=1 AND column2='text' OR column3 IS NULL;",
"SELECT *, agg_list(column) FROM table GROUP BY column2 LIMIT 10 OFFSET 5;",
create_sql,
],
)
def test_sql_parser(sql: str, snapshot):
Expand All @@ -21,3 +28,13 @@ def test_sql_parser(sql: str, snapshot):
assert ok

snapshot.assert_match(parser.parse())


def test_notion(snapshot):
with MockRouter(base_url="https://api.notion.com/v1") as req:
call = req.post("/databases").respond(200, json={})
notion = NotionAPI("", table_parent_page="PARENT_PAGE")
notion.execute(create_sql)

assert call.called
assert snapshot == json.loads(req.calls.last.request.content)
112 changes: 112 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.