Skip to content

Commit

Permalink
LIN-351 Include lineapy.save() in code slice (#634)
Browse files Browse the repository at this point in the history
* Include lineapy.save() in code slice

* Do not include lineapy.save() in code slice by default; but DO include it in pipeline building

* Start slicing from .save() to include all ancestors including lineapy import

* Use consistent input param type

* If not applicable (e.g., lineapy.file_system), use original sink

* Update tests related to pipeline building

* Identify .save() statement more precisely

* Use a clearer param name

* Fix mypy issue

* require lineapy when dumping pipelines because we have that in now.

* Fix typos

* pull out the de-lineazing functions out into utils. (added a new file api_utils to prevent circular dependency thing. Utils is big enough now to warrant a subpackage)

* De-Lineate (i.e., use non-Linea serialization) code slices for pipeline building

* Add tutorials to docs

* mock path example

* Ignore cell output for CI

* Update pipeline tests

* Add tutorials to docs

* sample mock for ipython tests

* Update other IPython-related tests

* no dependency on lineapy

* add test case to run the script to check if it works or not!!

* fix the missing import pickle on top

* fix the snapshots that are affected.

Co-authored-by: Shardul Sardesai <shardul@linea.ai>
Co-authored-by: Yifan Wu <yifan1030@gmail.com>
  • Loading branch information
3 people committed May 17, 2022
1 parent 9c16686 commit 272e94a
Show file tree
Hide file tree
Showing 22 changed files with 288 additions and 146 deletions.
4 changes: 2 additions & 2 deletions airflow-requirements.txt
@@ -1,4 +1,4 @@
apache-airflow==2.2.4
pandas
scikit-learn==1.0.2
SQLAlchemy==1.3.24
sklearn
SQLAlchemy==1.3.24
13 changes: 12 additions & 1 deletion docs/source/tutorials/00_api_basics.ipynb
Expand Up @@ -1068,6 +1068,10 @@
" avg_length_setosa = df.query(\"variety == 'Setosa'\")[\"petal.length\"].mean()\r\n",
" avg_length_virginica = df.query(\"variety == 'Virginica'\")[\"petal.length\"].mean()\r\n",
" diff_avg_length = avg_length_setosa - avg_length_virginica\r\n",
" pickle.dump(\r\n",
" diff_avg_length,\r\n",
" open(\"/Users/sangyoonpark/.lineapy/linea_pickles/22oOFAm\", \"wb\"),\r\n",
" )\r\n",
"\r\n",
"\r\n",
"def iris_diff_avg_width():\r\n",
Expand All @@ -1078,11 +1082,16 @@
" )\r\n",
" avg_width_setosa = df.query(\"variety == 'Setosa'\")[\"petal.width\"].mean()\r\n",
" avg_width_virginica = df.query(\"variety == 'Virginica'\")[\"petal.width\"].mean()\r\n",
" diff_avg_width = avg_width_setosa - avg_width_virginica\r\n"
" diff_avg_width = avg_width_setosa - avg_width_virginica\r\n",
" width_artifact = pickle.dump(\r\n",
" diff_avg_width, open(\"/Users/sangyoonpark/.lineapy/linea_pickles/G25qpnr\", \"wb\")\r\n",
" )\r\n"
]
}
],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"\n",
"%cat output/00_api_basics/demo_pipeline/demo_pipeline.py"
]
},
Expand Down Expand Up @@ -1147,6 +1156,8 @@
}
],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"\n",
"%cat output/00_api_basics/demo_pipeline/demo_pipeline_dag.py"
]
},
Expand Down
13 changes: 12 additions & 1 deletion examples/tutorials/00_api_basics.ipynb
Expand Up @@ -1068,6 +1068,10 @@
" avg_length_setosa = df.query(\"variety == 'Setosa'\")[\"petal.length\"].mean()\r\n",
" avg_length_virginica = df.query(\"variety == 'Virginica'\")[\"petal.length\"].mean()\r\n",
" diff_avg_length = avg_length_setosa - avg_length_virginica\r\n",
" pickle.dump(\r\n",
" diff_avg_length,\r\n",
" open(\"/Users/sangyoonpark/.lineapy/linea_pickles/22oOFAm\", \"wb\"),\r\n",
" )\r\n",
"\r\n",
"\r\n",
"def iris_diff_avg_width():\r\n",
Expand All @@ -1078,11 +1082,16 @@
" )\r\n",
" avg_width_setosa = df.query(\"variety == 'Setosa'\")[\"petal.width\"].mean()\r\n",
" avg_width_virginica = df.query(\"variety == 'Virginica'\")[\"petal.width\"].mean()\r\n",
" diff_avg_width = avg_width_setosa - avg_width_virginica\r\n"
" diff_avg_width = avg_width_setosa - avg_width_virginica\r\n",
" width_artifact = pickle.dump(\r\n",
" diff_avg_width, open(\"/Users/sangyoonpark/.lineapy/linea_pickles/G25qpnr\", \"wb\")\r\n",
" )\r\n"
]
}
],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"\n",
"%cat output/00_api_basics/demo_pipeline/demo_pipeline.py"
]
},
Expand Down Expand Up @@ -1147,6 +1156,8 @@
}
],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"\n",
"%cat output/00_api_basics/demo_pipeline/demo_pipeline_dag.py"
]
},
Expand Down
14 changes: 7 additions & 7 deletions lineapy/data/graph.py
Expand Up @@ -110,25 +110,25 @@ def visit_order(self) -> Iterator[Node]:
# Then, we add all of its children to the queue, making sure to mark
# for each that we have seen one of its parents
yield node
for child_id in self.get_children(node):
for child_id in self.get_children(node.id):
remaining_parents[child_id] -= 1
if child_id in seen:
continue
child_node = self.ids[child_id]
queue.put(child_node)
seen.add(child_id)

def get_parents(self, node: Node) -> List[LineaID]:
return list(self.nx_graph.predecessors(node.id))
def get_parents(self, node_id: LineaID) -> List[LineaID]:
return list(self.nx_graph.predecessors(node_id))

def get_ancestors(self, node_id: LineaID) -> List[LineaID]:
return list(nx.ancestors(self.nx_graph, node_id))

def get_children(self, node: Node) -> List[LineaID]:
return list(self.nx_graph.successors(node.id))
def get_children(self, node_id: LineaID) -> List[LineaID]:
return list(self.nx_graph.successors(node_id))

def get_descendants(self, node: Node) -> List[LineaID]:
return list(nx.descendants(self.nx_graph, node.id))
def get_descendants(self, node_id: LineaID) -> List[LineaID]:
return list(nx.descendants(self.nx_graph, node_id))

def get_leaf_nodes(self) -> List[LineaID]:
return [
Expand Down
13 changes: 13 additions & 0 deletions lineapy/db/db.py
Expand Up @@ -440,6 +440,19 @@ def get_node_value_from_db(
)
return value_orm

def get_node_value_path(
self, node_id: LineaID, execution_id: LineaID
) -> Optional[str]:
"""
Get the path to the value of the artifact.
:param other: Additional argument to let you query another artifact's value path.
This is set to be optional and if its not set, we will use the current artifact
"""
value = self.get_node_value_from_db(node_id, execution_id)
if not value:
raise ValueError("No value saved for this node")
return value.value

def node_value_in_db(
self, node_id: LineaID, execution_id: LineaID
) -> bool:
Expand Down
52 changes: 52 additions & 0 deletions lineapy/graph_reader/api_utils.py
@@ -0,0 +1,52 @@
import re

from lineapy.db.db import RelationalLineaDB


def de_lineate_code(code: str, db: RelationalLineaDB) -> str:
"""
De-linealize the code by removing any lineapy api references
"""
lineapy_pattern = re.compile(
r"(lineapy.(save\(([\w]+),\s*[\"\']([\w\-\s]+)[\"\']\)|get\([\"\']([\w\-\s]+)[\"\']\).get_value\(\)))"
)
# init swapped version

def replace_fun(match):
if match.group(2).startswith("save"):
# FIXME - there is a potential issue here because we are looking up the artifact by name
# This does not ensure that the same version of current artifact is being looked up.
# We support passing a version number to the get_artifact_by_name but it needs to be parsed
# out in the regex somehow. This would be simpler when we support named versions when saving.
dep_artifact = db.get_artifact_by_name(match.group(4))
path_to_use = db.get_node_value_path(
dep_artifact.node_id, dep_artifact.execution_id
)
return f'pickle.dump({match.group(3)},open("{path_to_use}","wb"))'

elif match.group(2).startswith("get"):
# this typically will be a different artifact.
dep_artifact = db.get_artifact_by_name(match.group(5))
path_to_use = db.get_node_value_path(
dep_artifact.node_id, dep_artifact.execution_id
)
return f'pickle.load(open("{path_to_use}","rb"))'

swapped, replaces = lineapy_pattern.subn(replace_fun, code)
if replaces > 0:
# If we replaced something, pickle was used so add import pickle on top
# Conversely, if lineapy reference was removed, potentially the import lineapy line is not needed anymore.
remove_pattern = re.compile(r"import lineapy\n")
match_pattern = re.compile(r"lineapy\.(.*)")
swapped = "import pickle\n" + swapped
if match_pattern.search(swapped):
# we still are using lineapy.xxx functions
# so do nothing
pass
else:
swapped, lineareplaces = remove_pattern.subn("\n", swapped)
# logger.debug(f"Removed lineapy {lineareplaces} times")

# logger.debug("replaces made: %s", replaces)

return swapped
128 changes: 37 additions & 91 deletions lineapy/graph_reader/apis.py
Expand Up @@ -4,7 +4,6 @@
from __future__ import annotations

import logging
import re
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Optional, cast
Expand All @@ -17,6 +16,7 @@
from lineapy.db.relational import ArtifactORM
from lineapy.db.utils import FilePickler
from lineapy.execution.executor import Executor
from lineapy.graph_reader.api_utils import de_lineate_code
from lineapy.graph_reader.program_slice import (
get_slice_graph,
get_source_code_from_graph,
Expand Down Expand Up @@ -68,7 +68,7 @@ def get_value(self) -> object:
"""
Get and return the value of the artifact
"""
value = self._get_value_path()
value = self.db.get_node_value_path(self._node_id, self._execution_id)
if value is None:
return None
else:
Expand All @@ -77,45 +77,36 @@ def get_value(self) -> object:
with open(value, "rb") as f:
return FilePickler.load(f)

def _get_value_path(
self, other: Optional[ArtifactORM] = None
) -> Optional[str]:
"""
Get the path to the value of the artifact.
:param other: Additional argument to let you query another artifact's value path.
This is set to be optional and if its not set, we will use the current artifact
"""
if other is not None:
value = self.db.get_node_value_from_db(
other.node_id, other.execution_id
)
else:
value = self.db.get_node_value_from_db(
self._node_id, self._execution_id
)
if not value:
raise ValueError("No value saved for this node")
return value.value

# Note that I removed the @properties becuase they were not working
# well with the lru_cache
@lru_cache(maxsize=None)
def _get_subgraph(self) -> Graph:
def _get_subgraph(self, keep_lineapy_save: bool = False) -> Graph:
"""
Return the slice subgraph for the artifact
Return the slice subgraph for the artifact.
:param keep_lineapy_save: Whether to retain ``lineapy.save()`` in code slice.
Defaults to ``False``.
"""
return get_slice_graph(self._get_graph(), [self._node_id])
return get_slice_graph(
self._get_graph(), [self._node_id], keep_lineapy_save
)

@lru_cache(maxsize=None)
def get_code(self, use_lineapy_serialization=True) -> str:
def get_code(
self,
use_lineapy_serialization: bool = True,
keep_lineapy_save: bool = False,
) -> str:
"""
Return the slices code for the artifact
:param use_lineapy_serialization: If True, will use the lineapy serialization to get the code.
We will hide the serialization and the value pickler irrespective of the value type.
If False, will use remove all the lineapy references and instead use the underlying serializer directly.
Currently, we use the native `pickle` serializer.
:param use_lineapy_serialization: If ``True``, will use the lineapy serialization to get the code.
We will hide the serialization and the value pickler irrespective of the value type.
If ``False``, will use remove all the lineapy references and instead use the underlying serializer directly.
Currently, we use the native ``pickle`` serializer.
:param keep_lineapy_save: Whether to retain ``lineapy.save()`` in code slice.
Defaults to ``False``.
"""
# FIXME: this seems a little heavy to just get the slice?
Expand All @@ -125,24 +116,23 @@ def get_code(self, use_lineapy_serialization=True) -> str:
is_session_code=False,
)
)
return prettify(
self._de_linealize_code(
str(get_source_code_from_graph(self._get_subgraph())),
use_lineapy_serialization,
)
code = str(
get_source_code_from_graph(self._get_subgraph(keep_lineapy_save))
)
if not use_lineapy_serialization:
code = de_lineate_code(code, self.db)
return prettify(code)

@lru_cache(maxsize=None)
def get_session_code(self, use_lineapy_serialization=True) -> str:
"""
Return the raw session code for the artifact. This will include any
comments and non-code lines.
:param use_lineapy_serialization: If True, will use the lineapy serialization to get the code.
We will hide the serialization and the value pickler irrespective of the value type.
If False, will use remove all the lineapy references and instead use the underlying serializer directly.
Currently, we use the native `pickle` serializer.
:param use_lineapy_serialization: If ``True``, will use the lineapy serialization to get the code.
We will hide the serialization and the value pickler irrespective of the value type.
If ``False``, will use remove all the lineapy references and instead use the underlying serializer directly.
Currently, we use the native ``pickle`` serializer.
"""
# using this over get_source_code_from_graph because it will process the
Expand All @@ -153,56 +143,12 @@ def get_session_code(self, use_lineapy_serialization=True) -> str:
is_session_code=True,
)
)
return self._de_linealize_code(
self.db.get_source_code_for_session(self._session_id),
use_lineapy_serialization,
)

def _de_linealize_code(
self, code: str, use_lineapy_serialization: bool
) -> str:
"""
De-linealize the code by removing any lineapy api references
"""
if use_lineapy_serialization:
return code
else:
lineapy_pattern = re.compile(
r"(lineapy.(save\(([\w]+),\s*[\"\']([\w\-\s]+)[\"\']\)|get\([\"\']([\w\-\s]+)[\"\']\).get_value\(\)))"
)
# init swapped version

def replace_fun(match):
if match.group(2).startswith("save"):
# TODO - this can be another artifact. find it using the match.group(4)
# dep_artifact = self.db.get_artifact_by_name(match.group(4))
path_to_use = self._get_value_path()
return f'pickle.dump({match.group(3)},open("{path_to_use}","wb"))'

elif match.group(2).startswith("get"):
# this typically will be a different artifact.
dep_artifact = self.db.get_artifact_by_name(match.group(5))
path_to_use = self._get_value_path(dep_artifact)
return f'pickle.load(open("{path_to_use}","rb"))'

swapped, replaces = lineapy_pattern.subn(replace_fun, code)
if replaces > 0:
# If we replaced something, pickle was used so add import pickle on top
# Conversely, if lineapy reference was removed, potentially the import lineapy line is not needed anymore.
remove_pattern = re.compile(r"import lineapy\n")
match_pattern = re.compile(r"lineapy\.(.*)")
swapped = "import pickle\n" + swapped
if match_pattern.search(swapped):
# we still are using lineapy.xxx functions
# so do nothing
pass
else:
swapped, lineareplaces = remove_pattern.subn("", swapped)
logger.debug(f"Removed lineapy {lineareplaces} times")

logger.debug("replaces made: %s", replaces)

return swapped
code = self.db.get_source_code_for_session(self._session_id)
if not use_lineapy_serialization:
code = de_lineate_code(code, self.db)
# NOTE: we are not prettifying this code because we want to preserve what
# the user wrote originally, without processing
return code

@lru_cache(maxsize=None)
def _get_graph(self) -> Graph:
Expand Down

0 comments on commit 272e94a

Please sign in to comment.