diff --git a/src/aimcore/cli/package/commands.py b/src/aimcore/cli/package/commands.py index c4fe73cd65..50d914a61f 100644 --- a/src/aimcore/cli/package/commands.py +++ b/src/aimcore/cli/package/commands.py @@ -88,3 +88,12 @@ def sync_package(name, repo): package_watcher = PackageSourceWatcher(repo_inst, name, src_path) package_watcher.initialize() package_watcher.start() + + +@package.command('set-active') +@click.option('--name', '-n', required=True, type=str) +@click.option('--repo', default='', type=str) +def set_active_package(name, repo): + repo_inst = Repo.from_path(repo) if repo else Repo.default() + click.echo(f'Setting \'{name}\' as active package for Repo \'{repo_inst}\'') + repo_inst.set_active_package(pkg_name=name) diff --git a/src/aimcore/cli/package/watcher.py b/src/aimcore/cli/package/watcher.py index 497ee96eba..d191aedcc5 100644 --- a/src/aimcore/cli/package/watcher.py +++ b/src/aimcore/cli/package/watcher.py @@ -51,26 +51,32 @@ async def watch_events(self): fs_events.append(self.queue.get()) with self.repo.storage_engine.write_batch(0): for fs_event in fs_events: + file = pathlib.Path(fs_event.src_path) + file_path = str(file.relative_to(self.src_dir)) + if self.is_black_listed(file_path): + continue try: - file = pathlib.Path(fs_event.src_path) - file_path = str(file.relative_to(self.src_dir)) - click.echo(f'Detected change in file \'{file_path}\'. Syncing.') + click.echo(f'Detected change in: \'{file_path}\', type: \'{fs_event.event_type}\'. Syncing...') if fs_event.event_type in (events.EVENT_TYPE_CREATED, events.EVENT_TYPE_MODIFIED): with file.open('r') as fh: contents = fh.read() self.package.sync(str(file_path), contents) + click.echo(f'Updated file {file_path}.') elif fs_event.event_type == events.EVENT_TYPE_DELETED: self.package.remove(str(file_path)) - elif fs_event == events.EVENT_TYPE_MOVED: - dest_path = pathlib.Path(fs_event.dest_path).relative_to(self.src_dir) + click.echo(f'Removed file {file_path}.') + elif fs_event.event_type == events.EVENT_TYPE_MOVED: + dest_path = str(pathlib.Path(fs_event.dest_path).relative_to(self.src_dir)) self.package.move(file_path, dest_path) + click.echo(f'Moved file {file_path} to {dest_path}.') except Exception: pass await asyncio.sleep(5) def start(self): + click.echo(f'Watching for change in package \'{self.package_name}\' sources...') self.observer = Observer() event_hanlder = SourceFileChangeHandler(self.src_dir, self.queue) self.observer.schedule(event_hanlder, self.src_dir, recursive=True) @@ -88,7 +94,15 @@ def initialize(self): click.echo(f'Initializing package \'{self.package_name}\'.') for file_path in self.src_dir.glob('**/*'): file_name = file_path.relative_to(self.src_dir) - if file_path.is_file(): + if file_path.is_file() and not self.is_black_listed(str(file_name)): with file_path.open('r') as fh: self.package.sync(str(file_name), fh.read()) self.package.install() + click.echo(f'Package \'{self.package_name}\' initialized.') + + def is_black_listed(self, file_path: str) -> bool: + if '__pycache__' in file_path: + return True + if file_path.endswith('.pyc'): + return True + return False diff --git a/src/aimcore/cli/server/commands.py b/src/aimcore/cli/server/commands.py index d8f7fdb0c9..faf966c1ef 100644 --- a/src/aimcore/cli/server/commands.py +++ b/src/aimcore/cli/server/commands.py @@ -1,7 +1,7 @@ import os import click -from aimcore.cli.utils import set_log_level +from aimcore.cli.utils import set_log_level, start_uvicorn_app from aim._sdk.repo import Repo from aim._sdk.package_utils import Package from aimcore.transport.config import ( @@ -21,7 +21,6 @@ file_okay=False, dir_okay=True, writable=True)) -@click.option('--package', '--pkg', required=False, default='asp', type=str) @click.option('--ssl-keyfile', required=False, type=click.Path(exists=True, file_okay=True, dir_okay=False, @@ -35,7 +34,7 @@ @click.option('--dev', is_flag=True, default=False) @click.option('-y', '--yes', is_flag=True, help='Automatically confirm prompt') def server(host, port, - repo, package, ssl_keyfile, ssl_certfile, + repo, ssl_keyfile, ssl_certfile, base_path, log_level, dev, yes): # TODO [MV, AT] remove code duplication with aim up cmd implementation if not log_level: @@ -67,9 +66,6 @@ def server(host, port, return os.environ[AIM_SERVER_MOUNTED_REPO_PATH] = repo - if package not in Package.pool: - Package.load_package(package) - click.secho('Running Aim Server on repo `{}`'.format(repo_inst), fg='yellow') click.echo('Server is mounted on {}:{}'.format(host, port), err=True) click.echo('Press Ctrl+C to exit') @@ -79,17 +75,20 @@ def server(host, port, # delete the repo as it needs to be opened in a child process in dev mode del repo_inst - try: - from aimcore.transport.server import start_server + if dev: + import aim + import aimcore + + reload_dirs = [os.path.dirname(aim.__file__), os.path.dirname(aimcore.__file__), dev_package_dir] + else: + reload_dirs = [] - if dev: - import aim - import aimcore - reload_dirs = (os.path.dirname(aim.__file__), os.path.dirname(aim.__file__), dev_package_dir) - start_server(host, port, ssl_keyfile, ssl_certfile, log_level=log_level, reload=dev, reload_dirs=reload_dirs) - else: - start_server(host, port, ssl_keyfile, ssl_certfile, log_level=log_level) + try: + start_uvicorn_app('aimcore.transport.server:app', + host=host, port=port, + ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, log_level=log_level, + reload=dev, reload_dirs=reload_dirs) except Exception: click.echo('Failed to run Aim Tracking Server. ' - 'Please see the logs for details.') - return + 'Please see the logs above for details.') + exit(1) diff --git a/src/aimcore/cli/ui/commands.py b/src/aimcore/cli/ui/commands.py index 54302c4cb7..9ea5089b3d 100644 --- a/src/aimcore/cli/ui/commands.py +++ b/src/aimcore/cli/ui/commands.py @@ -1,14 +1,13 @@ import os import click -from aimcore.cli.utils import set_log_level -from aimcore.cli.ui.utils import build_db_upgrade_command, build_uvicorn_command, get_free_port_num +from aimcore.cli.utils import set_log_level, start_uvicorn_app +from aimcore.cli.ui.utils import build_db_upgrade_command, get_free_port_num from aimcore.web.configs import ( AIM_UI_BASE_PATH, AIM_UI_DEFAULT_HOST, AIM_UI_DEFAULT_PORT, AIM_UI_MOUNTED_REPO_PATH, - AIM_UI_PACKAGE_NAME, AIM_UI_TELEMETRY_KEY, AIM_PROXY_URL, AIM_PROFILER_KEY @@ -34,7 +33,7 @@ file_okay=False, dir_okay=True, writable=True)) -@click.option('--package', '--pkg', required=False, default='asp', type=str) +@click.option('--package', '--pkg', required=False, default='', show_default='asp', type=str) @click.option('--dev', is_flag=True, default=False) @click.option('--ssl-keyfile', required=False, type=click.Path(exists=True, file_okay=True, @@ -57,14 +56,11 @@ def ui(dev, host, port, workers, uds, """ Start Aim UI with the --repo repository. """ - if dev: - os.environ[AIM_ENV_MODE_KEY] = 'dev' - log_level = log_level or 'debug' - else: - os.environ[AIM_ENV_MODE_KEY] = 'prod' + if not log_level: + log_level = 'debug' if dev else 'warning' + set_log_level(log_level) - if log_level: - set_log_level(log_level) + os.environ[AIM_ENV_MODE_KEY] = 'dev' if dev else 'prod' if base_path: # process `base_path` as ui requires leading slash @@ -82,9 +78,11 @@ def ui(dev, host, port, workers, uds, return Repo.init(repo) repo_inst = Repo.from_path(repo, read_only=True) - os.environ[AIM_UI_MOUNTED_REPO_PATH] = repo - os.environ[AIM_UI_PACKAGE_NAME] = package + + dev_package_dir = repo_inst.dev_package_dir + if package: + repo_inst.set_active_package(pkg_name=package) try: db_cmd = build_db_upgrade_command() @@ -92,7 +90,7 @@ def ui(dev, host, port, workers, uds, except ShellCommandException: click.echo('Failed to initialize Aim DB. ' 'Please see the logs above for details.') - return + exit(1) if port == 0: try: @@ -123,12 +121,24 @@ def ui(dev, host, port, workers, uds, if profiler: os.environ[AIM_PROFILER_KEY] = '1' + if dev: + import aim + import aimstack + import aimcore + + reload_dirs = [os.path.dirname(aim.__file__), os.path.dirname(aimcore.__file__), os.path.dirname(aimstack.__file__), dev_package_dir] + else: + reload_dirs = [] + try: - server_cmd = build_uvicorn_command(host, port, workers, uds, ssl_keyfile, ssl_certfile, log_level, package) - exec_cmd(server_cmd, stream_output=True) - except ShellCommandException: - click.echo('Failed to run Aim UI. Please see the logs above for details.') - return + start_uvicorn_app('aimcore.web.run:app', + host=host, port=port, workers=workers, uds=uds, + ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, log_level=log_level, + reload=dev, reload_dirs=reload_dirs) + except Exception: + click.echo('Failed to run Aim UI. ' + 'Please see the logs above for details.') + exit(1) @click.command('up', context_settings={'ignore_unknown_options': True, 'allow_extra_args': True}, hidden=True) diff --git a/src/aimcore/cli/ui/utils.py b/src/aimcore/cli/ui/utils.py index 2b4ddeee07..d46bf9bff3 100644 --- a/src/aimcore/cli/ui/utils.py +++ b/src/aimcore/cli/ui/utils.py @@ -15,40 +15,6 @@ def build_db_upgrade_command(): return [sys.executable, '-m', 'alembic', '-c', ini_file, 'upgrade', 'head'] -def build_uvicorn_command(host, port, num_workers, uds_path, ssl_keyfile, ssl_certfile, log_level, pkg_name): - cmd = [sys.executable, '-m', 'uvicorn', - '--host', host, '--port', f'{port}', - '--workers', f'{num_workers}'] - if os.getenv(AIM_ENV_MODE_KEY, 'prod') == 'prod': - log_level = log_level or 'error' - else: - import aim - import aimstack - from aimcore import web as aim_web - - cmd += ['--reload'] - cmd += ['--reload-dir', os.path.dirname(aim.__file__)] - cmd += ['--reload-dir', os.path.dirname(aim_web.__file__)] - cmd += ['--reload-dir', os.path.dirname(aimstack.__file__)] - - from aim._sdk.package_utils import Package - if pkg_name not in Package.pool: - Package.load_package(pkg_name) - pkg = Package.pool[pkg_name] - cmd += ['--reload-dir', os.path.dirname(pkg._path)] - - log_level = log_level or 'debug' - if uds_path: - cmd += ['--uds', uds_path] - if ssl_keyfile: - cmd += ['--ssl-keyfile', ssl_keyfile] - if ssl_certfile: - cmd += ['--ssl-certfile', ssl_certfile] - cmd += ['--log-level', log_level.lower()] - cmd += ['aimcore.web.run:app'] - return cmd - - def get_free_port_num(): import socket s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) diff --git a/src/aimcore/cli/utils.py b/src/aimcore/cli/utils.py index 80b8a4210a..77bb2fe7ce 100644 --- a/src/aimcore/cli/utils.py +++ b/src/aimcore/cli/utils.py @@ -9,3 +9,8 @@ def set_log_level(log_level): raise ValueError('Invalid log level: %s' % log_level) os.environ[AIM_LOG_LEVEL_KEY] = str(numeric_level) logging.basicConfig(level=numeric_level) + + +def start_uvicorn_app(app: str, **uvicorn_args): + import uvicorn + uvicorn.run(app, **uvicorn_args) diff --git a/src/aimcore/transport/handlers.py b/src/aimcore/transport/handlers.py index b27096497f..43e4c0ea2e 100644 --- a/src/aimcore/transport/handlers.py +++ b/src/aimcore/transport/handlers.py @@ -5,7 +5,6 @@ from aim import Repo from aim._sdk.local_storage import LocalFileManager -from aim._sdk.dev_package import DevPackage from aimcore.cleanup import AutoClean diff --git a/src/aimcore/transport/server.py b/src/aimcore/transport/server.py index ffa4de6d80..c6177bf9ba 100644 --- a/src/aimcore/transport/server.py +++ b/src/aimcore/transport/server.py @@ -137,11 +137,3 @@ async def run_write_instructions(websocket: WebSocket, client_uri: str): app = create_app() - - -def start_server(host, port, ssl_keyfile=None, ssl_certfile=None, *, log_level='info', reload=False, reload_dirs=()): - import uvicorn - uvicorn.run('aimcore.transport.server:app', host=host, port=port, - ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, - log_level=log_level, - reload=reload, reload_dirs=reload_dirs) diff --git a/src/aimcore/transport/tracking.py b/src/aimcore/transport/tracking.py index 6d8f53bd24..1a46021758 100644 --- a/src/aimcore/transport/tracking.py +++ b/src/aimcore/transport/tracking.py @@ -1,5 +1,6 @@ import uuid import base64 +import logging from typing import Dict, Union from fastapi import WebSocket, Request, APIRouter @@ -8,6 +9,8 @@ import aimcore.transport.message_utils as utils from aim._core.storage.treeutils import encode_tree, decode_tree +logger = logging.getLogger(__name__) + def get_handler(): return str(uuid.uuid4()) @@ -109,6 +112,7 @@ async def get_resource(self, except KeyError: pass + logger.debug(f'Caught exception {e}. Sending response 400.') return JSONResponse({ 'exception': utils.build_exception(e), }, status_code=400) @@ -118,6 +122,7 @@ async def release_resource(self, client_uri, resource_handler): self._verify_resource_handler(resource_handler, client_uri) del self.resource_pool[resource_handler] except Exception as e: + logger.debug(f'Caught exception {e}. Sending response 400.') return JSONResponse({ 'exception': utils.build_exception(e), }, status_code=400) @@ -157,6 +162,7 @@ async def run_instruction(self, client_uri: str, return StreamingResponse(utils.pack_stream(encode_tree(result))) except Exception as e: + logger.debug(f'Caught exception {e}. Sending response 400.') return JSONResponse({ 'exception': utils.build_exception(e), }, status_code=400) diff --git a/src/aimcore/web/api/__init__.py b/src/aimcore/web/api/__init__.py index dc4b8346a9..9559232ce3 100644 --- a/src/aimcore/web/api/__init__.py +++ b/src/aimcore/web/api/__init__.py @@ -7,9 +7,8 @@ from fastapi.responses import JSONResponse from aim._sdk.configs import get_aim_repo_name -from aim._sdk.package_utils import Package -from aimcore.web.configs import AIM_PROFILER_KEY, AIM_UI_PACKAGE_NAME +from aimcore.web.configs import AIM_PROFILER_KEY from aimcore.web.middlewares.profiler import PyInstrumentProfilerMiddleware from aimcore.web.utils import get_root_path @@ -74,10 +73,6 @@ def create_app(): api_app.add_middleware(PyInstrumentProfilerMiddleware, repo_path=os.path.join(get_root_path(), get_aim_repo_name())) - ui_pkg_name = os.environ.get(AIM_UI_PACKAGE_NAME) - if ui_pkg_name not in Package.pool: - Package.load_package(ui_pkg_name) - api_app.include_router(dashboard_apps_router, prefix='/apps') api_app.include_router(dashboards_router, prefix='/dashboards') api_app.include_router(boards_router, prefix='/boards') diff --git a/src/aimcore/web/api/boards/views.py b/src/aimcore/web/api/boards/views.py index e529647307..71f18e2200 100644 --- a/src/aimcore/web/api/boards/views.py +++ b/src/aimcore/web/api/boards/views.py @@ -1,6 +1,6 @@ from fastapi import Depends, HTTPException from fastapi.responses import JSONResponse -from aimcore.web.utils import get_root_package +from aimcore.web.utils import get_root_package, get_root_package_name from aimcore.web.api.utils import APIRouter # wrapper for fastapi.APIRouter from aimcore.web.api.boards.pydantic_models import BoardOut, BoardListOut @@ -13,12 +13,18 @@ @boards_router.get('/', response_model=BoardOut) async def board_list_api(package=Depends(get_root_package)): + if package is None: + raise HTTPException(status_code=400, detail=f'Failed to load current active ' + f'package \'{get_root_package_name()}\'.') result = [board.as_posix() for board in package.boards] return JSONResponse(result) @boards_router.get('/{board_path:path}', response_model=BoardListOut) async def board_get_api(board_path: str, package=Depends(get_root_package)): + if package is None: + raise HTTPException(status_code=400, detail=f'Failed to load current active ' + f'package \'{get_root_package_name()}\'.') board: pathlib.Path = package.boards_directory / board_path if not board.exists(): raise HTTPException(status_code=404) diff --git a/src/aimcore/web/api/queries/views.py b/src/aimcore/web/api/queries/views.py index d4fb1d44ee..d71cd49480 100644 --- a/src/aimcore/web/api/queries/views.py +++ b/src/aimcore/web/api/queries/views.py @@ -1,10 +1,11 @@ from fastapi.responses import StreamingResponse -from fastapi import HTTPException +from fastapi import Depends, HTTPException from typing import Optional, Iterable, Dict, List, Iterator, TYPE_CHECKING from aim._sdk.uri_service import URIService from aim.utils import sequence_data, container_data +from aimcore.web.utils import get_root_package from aimcore.web.api.runs.pydantic_models import QuerySyntaxErrorOut from aimcore.web.api.utils import ( checked_query, @@ -76,7 +77,8 @@ async def data_fetch_api(type_: str, q: Optional[str] = '', p: Optional[int] = 500, start: Optional[int] = None, - stop: Optional[int] = None): + stop: Optional[int] = None, + package=Depends(get_root_package)): repo = get_project_repo() query = checked_query(q) if type_ in Container.registry: @@ -99,7 +101,8 @@ async def grouped_data_fetch_api(seq_type: Optional[str] = 'Sequence', q: Optional[str] = '', p: Optional[int] = 500, start: Optional[int] = None, - stop: Optional[int] = None): + stop: Optional[int] = None, + package=Depends(get_root_package)): repo = get_project_repo() query = checked_query(q) if seq_type not in Sequence.registry: @@ -134,7 +137,7 @@ async def fetch_blobs_api(uri_batch: URIBatchIn): @query_router.post('/run/') -async def run_function(func_name: str, request_data: Dict): +async def run_function(func_name: str, request_data: Dict, package=Depends(get_root_package)): repo = get_project_repo() # noqa from aim._sdk.function import Function diff --git a/src/aimcore/web/configs.py b/src/aimcore/web/configs.py index c91c6a3340..2760c5c5be 100644 --- a/src/aimcore/web/configs.py +++ b/src/aimcore/web/configs.py @@ -4,7 +4,6 @@ AIM_UI_MOUNTED_REPO_PATH = '__AIM_UI_MOUNT_REPO_PATH__' AIM_UI_TELEMETRY_KEY = 'AIM_UI_TELEMETRY_ENABLED' AIM_UI_BASE_PATH = '__AIM_UI_BASE_PATH__' -AIM_UI_PACKAGE_NAME = '__AIM_UI_PACKAGE_NAME__' AIM_PROXY_URL = '__AIM_PROXY_URL__' AIM_PROFILER_KEY = '__AIM_PROFILER_ENABLED__' AIM_PROGRESS_REPORT_INTERVAL = 0.5 diff --git a/src/aimcore/web/utils.py b/src/aimcore/web/utils.py index 81b8ae5cfd..5bea43e39c 100644 --- a/src/aimcore/web/utils.py +++ b/src/aimcore/web/utils.py @@ -1,8 +1,11 @@ -from importlib import import_module import os import subprocess -from aimcore.web.configs import AIM_UI_MOUNTED_REPO_PATH, AIM_UI_PACKAGE_NAME +from typing import Optional +from importlib import import_module +from cachetools.func import ttl_cache + +from aimcore.web.configs import AIM_UI_MOUNTED_REPO_PATH from aim._sdk.configs import get_aim_repo_name from aim._sdk.utils import clean_repo_path from aim._sdk.package_utils import Package @@ -94,10 +97,24 @@ def get_root_path(): return clean_repo_path(os.getenv(AIM_UI_MOUNTED_REPO_PATH, os.getcwd())) -def get_root_package(): - ui_pkg_name = os.environ.get(AIM_UI_PACKAGE_NAME) - assert ui_pkg_name in Package.pool - return Package.pool[ui_pkg_name] +@ttl_cache(ttl=5) +def get_root_package_name(): + package_file = os.path.join(get_root_path(), get_aim_repo_name(), 'active_pkg') + if not os.path.exists(package_file): + package_name = Package.default_package_name + else: + with open(package_file, 'r') as pf: + package_name = pf.read().strip() + return package_name + + +def get_root_package() -> Optional[Package]: + pkg_name = get_root_package_name() + pkgs_dir = os.path.join(get_root_path(), get_aim_repo_name(), 'pkgs') + if pkg_name not in Package.pool: + if not Package.load_package(pkg_name, pkgs_dir): + return None + return Package.pool[pkg_name] def get_db_url(): diff --git a/src/python/aim/_sdk/dev_package.py b/src/python/aim/_sdk/dev_package.py index afe244100b..dacef1757f 100644 --- a/src/python/aim/_sdk/dev_package.py +++ b/src/python/aim/_sdk/dev_package.py @@ -28,7 +28,10 @@ def sync(self, file, contents): def remove(self, file): logger.debug(f'Removing file \'{file}\'.') full_path = self._src_path / file - full_path.unlink(missing_ok=True) + try: + full_path.unlink() + except FileNotFoundError: + pass def move(self, src, dest): logger.debug(f'Moving file from \'{src}\' to \'{dest}\'.') diff --git a/src/python/aim/_sdk/package_utils.py b/src/python/aim/_sdk/package_utils.py index ac97f8c240..798af50b9f 100644 --- a/src/python/aim/_sdk/package_utils.py +++ b/src/python/aim/_sdk/package_utils.py @@ -2,14 +2,16 @@ import pkgutil import importlib import logging +import sys -from typing import Iterable, Dict, List +from typing import Iterable, Dict, List, Optional logger = logging.getLogger(__name__) class Package: pool: Dict[str, 'Package'] = {} + default_package_name = 'asp' def __init__(self, name, pkg): self.name = name @@ -55,11 +57,26 @@ def _load_packages(base_pkg, package_list: Iterable[str]): Package.pool[name] = Package(name, pkg) @staticmethod - def load_package(package_name: str): + def load_package(package_name: str, src_lookup_dir: Optional[str] = None) -> bool: if package_name in Package.pool: - return - pkg = importlib.import_module(package_name) - Package.pool[package_name] = Package(package_name, pkg) + return True + try: + pkg = importlib.import_module(package_name) + if package_name not in Package.pool: + Package.pool[package_name] = Package(package_name, pkg) + return True + except ModuleNotFoundError: + if src_lookup_dir is None: + return False + + sys.path.append(str(pathlib.Path(src_lookup_dir) / package_name / 'src')) + try: + pkg = importlib.import_module(package_name) + if package_name not in Package.pool: + Package.pool[package_name] = Package(package_name, pkg) + return True + except ModuleNotFoundError: + return False def register_aim_package_classes(self, name, pkg): if not hasattr(pkg, '__aim_types__'): diff --git a/src/python/aim/_sdk/remote_storage.py b/src/python/aim/_sdk/remote_storage.py index cf5ef1d809..b24cd307cd 100644 --- a/src/python/aim/_sdk/remote_storage.py +++ b/src/python/aim/_sdk/remote_storage.py @@ -559,3 +559,6 @@ def prune(self): def _close_container(self, hash_): return self._rpc_client.run_instruction(-1, self._handler, '_close_container', [hash_]) + + def set_active_package(self, pkg_name: str): + return self._rpc_client.run_instruction(-1, self._handler, 'set_active_package', [pkg_name]) diff --git a/src/python/aim/_sdk/repo.py b/src/python/aim/_sdk/repo.py index 1581b0f906..03e06ab016 100644 --- a/src/python/aim/_sdk/repo.py +++ b/src/python/aim/_sdk/repo.py @@ -13,6 +13,7 @@ from aim._sdk.container import Container from aim._sdk.sequence import Sequence from aim._sdk.function import Function +from aim._sdk.package_utils import Package from aim._sdk.collections import ContainerCollection, SequenceCollection from aim._sdk.query_utils import construct_query_expression from aim._sdk.constants import KeyNames @@ -84,6 +85,7 @@ def __init__(self, path: str, *, read_only: Optional[bool] = True): self.read_only = read_only self._is_remote_repo = False if self.is_remote_path(path): + self.path = os.path.join(path, get_aim_repo_name()) self._is_remote_repo = True self._storage_engine = RemoteStorage(path) self._remote_repo_proxy = RemoteRepoProxy(self._storage_engine._client) @@ -114,9 +116,15 @@ def exists(cls, path: str) -> bool: def init(cls, path: str): aim_repo_path = os.path.join(clean_repo_path(path), get_aim_repo_name()) os.makedirs(aim_repo_path, exist_ok=True) + version_file_path = os.path.join(aim_repo_path, 'VERSION') with open(version_file_path, 'w') as version_fh: version_fh.write('.'.join(map(str, get_data_version())) + '\n') + + active_pkg_file = os.path.join(aim_repo_path, 'active_pkg') + with open(active_pkg_file, 'w') as apf: + apf.write(Package.default_package_name) + return cls.from_path(aim_repo_path, read_only=False) @classmethod @@ -461,3 +469,10 @@ def prune(self): return self._remote_repo_proxy.prune() prune(self) + + def set_active_package(self, pkg_name): + if self._is_remote_repo: + return self._remote_repo_proxy.set_active_package(pkg_name) + active_pkg_file = os.path.join(self.path, 'active_pkg') + with open(active_pkg_file, 'w') as apf: + apf.write(pkg_name)