Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions dataci/models/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def copy(
dst_path = Path(dst)
dst_path.mkdir(parents=True, exist_ok=dirs_exist_ok)
for file in self.filelist:
# Create the parent directory if not exist
(dst_path / file).parent.mkdir(parents=True, exist_ok=True)
copy_function(self.dir / file, dst_path / file)

return dst
Expand Down Expand Up @@ -309,7 +311,7 @@ def replace_source_segment(source, nodes, replace_segments):
replace_segment = indent(dedent(replace_segment), indent_prefix).strip()
# Replace code segment
new_script += source[prev_end:start] + replace_segment
prev_end = end + 1
prev_end = end
break
new_script += source[prev_end:]
return new_script
Expand Down Expand Up @@ -364,7 +366,7 @@ def get_syntax_lines(syntax: Syntax) -> int:
Tuple[int, List[Renderable]]: The number of lines and syntax renderables.
"""
# convert to renderables
renderables = list(text_syntax.__rich_console__(console, console.options))
renderables = list(syntax.__rich_console__(console, console.options))
# counter # \n in renderables
segements = itertools.chain(*map(lambda x: x.segments, renderables))
num_lines = list(map(lambda x: x.text, segements)).count('\n')
Expand Down Expand Up @@ -412,7 +414,7 @@ def get_syntax_lines(syntax: Syntax) -> int:
table.add_column('old-new-sep', justify='right', width=2, style='blue')
table.add_column('new_lineno', justify='right', width=lineno_width, style='white')
table.add_column('new-text-sep', justify='right', width=2, style='blue')
table.add_column('text', justify='left', style='white')
table.add_column('text', justify='left', min_width=80, style='white')

old_lineno, new_lineno = 0, 0

Expand All @@ -436,6 +438,7 @@ def get_syntax_lines(syntax: Syntax) -> int:
'python',
line_numbers=False,
word_wrap=True,
theme='github-dark',
background_color='default',
line_range=(1, 1),
),
Expand All @@ -454,6 +457,7 @@ def get_syntax_lines(syntax: Syntax) -> int:
'python',
line_numbers=False,
word_wrap=True,
theme='github-dark',
background_color='default',
line_range=(old_lineno, old_lineno),
code_width=80,
Expand Down
51 changes: 37 additions & 14 deletions dataci/models/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import itertools
import json
import logging
import multiprocessing as mp
import shutil
import sys
from abc import ABC
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING

import cloudpickle
import networkx as nx

from dataci.db.workflow import (
Expand Down Expand Up @@ -132,20 +135,39 @@ def from_dict(cls, config: 'dict'):
@classmethod
def from_path(cls, script_dir: 'Union[str, os.PathLike]', entry_path: 'Union[str, os.PathLike]'):
# TODO: make the build process more secure with sandbox / allowed safe methods
local_dict = dict()
def _import_module(entry_module, shared_import_pickle):
import cloudpickle
import importlib
import os
from dataci.models import Workflow

mod = importlib.import_module(entry_module)
# get all variables from the module
for k, v in mod.__dict__.items():
if not k.startswith('__') and isinstance(v, Workflow):
shared_import_pickle['__return__'] = cloudpickle.dumps(v)
break
else:
raise ValueError(f'Workflow not found in directory: {os.getcwd()}')

with cwd(script_dir):
import sys
entry_file = Path(entry_path)
sys_path = sys.path.copy()
# Append the local dir to the sys path
sys.path.insert(0, '')
entry_module = '.'.join(entry_file.parts[:-1] + (entry_file.stem,))
exec(
f'import os, sys; sys.path.insert(0, os.getcwd()); from {entry_module} import *',
local_dict, local_dict
)
for v in local_dict.copy().values():
if isinstance(v, Workflow):
self = v
break
else:
raise ValueError(f'Workflow not found in directory: {script_dir}')
with mp.Manager() as manager:
import_pickle = manager.dict()
p = mp.Process(target=_import_module, args=(entry_module, import_pickle,))
p.start()
p.join()
try:
self = cloudpickle.loads(import_pickle['__return__'])
except KeyError:
raise ValueError(f'Workflow not found in directory: {script_dir}')
# restore sys path
sys.path = sys_path

return self

Expand Down Expand Up @@ -187,6 +209,7 @@ def reload(self, config=None):
self.create_date = datetime.fromtimestamp(config['timestamp']) if config['timestamp'] else None
self.trigger = [Event.from_str(evt) for evt in config['trigger']]
if 'script' in config:
# fixme: reload the object if the script hash is changed
self._script = Script.from_dict(config['script'])
if 'dag' in config:
self._stage_script_paths.clear()
Expand Down Expand Up @@ -287,12 +310,12 @@ def patch(self, verbose=True, **kwargs):
for k, stage in kwargs.items():
# Convert k to full stage name
full_k = f'{self.workspace.name}.{k}' if self.workspace else k
# Check if the stage is in the workflow
if full_k not in self.stages:
raise ValueError(f'Cannot find stage name={k} in workflow {self.name}')
if stage.name != k:
raise ValueError(f'Cannot patch stage {stage.name} to {k} in workflow {self.name}')
# TODO: Check if the stage has the same signature
# Warning if the new stage has different signature with the old stage
new_workflow = patch_func(self, source_name=full_k, target=stage, logger=self.logger, verbose=verbose)
new_workflow = self.reload(new_workflow.dict())

return new_workflow

Expand Down
Loading