Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pyprql/magic/prql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from IPython.core.magic import cell_magic, line_magic, magics_class, needs_local_scope
from IPython.core.magic_arguments import argument, magic_arguments
from prql_python import compile, CompileOptions
from sql import parse
from sql.magic import SqlMagic
from traitlets import Bool, Unicode
import re


@magics_class
Expand Down Expand Up @@ -76,7 +78,7 @@ class PrqlMagic(SqlMagic):
type=str,
help="specify dictionary of connection arguments to pass to SQL driver",
)
@argument("-f", "--file", type=str, help="Run SQL from file at this path")
@argument("-f", "--file", type=str, help="Run PRQL from file at this path")
def prql(
self, line: str = "", cell: str = "", local_ns: dict | None = None
) -> None:
Expand All @@ -95,6 +97,11 @@ def prql(
None
"""
local_ns = local_ns or {}
self.args = parse.magic_args(self.execute, line)
if self.args.file:
with open(self.args.file, "r") as infile:
cell = infile.read()
line = re.sub(r"(\-f|\-\-file)\s+" + self.args.file, "", line)
# If cell is occupied, parsed to SQL
if cell:
cell = compile(
Expand Down
11 changes: 6 additions & 5 deletions pyprql/tests/test_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,22 +308,23 @@ def test_csv_to_file(ip):
assert len(content.splitlines()) == 3


@pytest.mark.skip("We only support pandas")
def test_sql_from_file(ip):
ip.run_line_magic("config", "SqlMagic.autopandas = False")
ip.run_line_magic("config", "PrqlMagic.autopandas = False")
with tempfile.TemporaryDirectory() as tempdir:
fname = os.path.join(tempdir, "test.sql")
with open(fname, "w") as tempf:
tempf.write("from test")
result = ip.run_cell("%sql --file " + fname)
result = ip.run_cell("%prql --file " + fname)
assert result.result == [(1, "foo"), (2, "bar")]
result = ip.run_cell("%prql -f " + fname)
assert result.result == [(1, "foo"), (2, "bar")]


def test_sql_from_nonexistent_file(ip):
ip.run_line_magic("config", "SqlMagic.autopandas = False")
ip.run_line_magic("config", "PrqlMagic.autopandas = False")
with tempfile.TemporaryDirectory() as tempdir:
fname = os.path.join(tempdir, "nonexistent.sql")
result = ip.run_cell("%sql --file " + fname)
result = ip.run_cell("%prql --file " + fname)
assert isinstance(result.error_in_exec, FileNotFoundError)


Expand Down