diff --git a/mergin/cli.py b/mergin/cli.py index 0f98489f..c5a98d88 100755 --- a/mergin/cli.py +++ b/mergin/cli.py @@ -26,6 +26,8 @@ from mergin.client_pull import ( download_project_async, download_project_cancel, + download_file_async, + download_file_finalize, download_project_finalize, download_project_is_running, ) @@ -260,6 +262,41 @@ def download(ctx, project, directory, version): _print_unhandled_exception() +@cli.command() +@click.argument("filepath") +@click.argument("output") +@click.option("--version", help="Project version tag, for example 'v3'") +@click.pass_context +def download_file(ctx, filepath, output, version): + """ + Download project file at specified version. `project` needs to be a combination of namespace/project. + If no version is given, the latest will be fetched. + """ + mc = ctx.obj["client"] + if mc is None: + return + mp = MerginProject(os.getcwd()) + project_path = mp.metadata["name"] + try: + job = download_file_async(mc, project_path, filepath, output, version) + with click.progressbar(length=job.total_size) as bar: + last_transferred_size = 0 + while download_project_is_running(job): + time.sleep(1 / 10) # 100ms + new_transferred_size = job.transferred_size + bar.update(new_transferred_size - last_transferred_size) # the update() needs increment only + last_transferred_size = new_transferred_size + download_file_finalize(job) + click.echo("Done") + except KeyboardInterrupt: + click.secho("Cancelling...") + download_project_cancel(job) + except ClientError as e: + click.secho("Error: " + str(e), fg="red") + except Exception as e: + _print_unhandled_exception() + + def num_version(name): return int(name.lstrip("v")) diff --git a/mergin/client.py b/mergin/client.py index 545ed9b7..7fc2bee7 100644 --- a/mergin/client.py +++ b/mergin/client.py @@ -11,12 +11,20 @@ import dateutil.parser import ssl -from .common import ClientError, LoginError +from .common import ClientError, LoginError, InvalidProject from .merginproject import MerginProject -from .client_pull import download_project_async, download_project_wait, download_project_finalize +from .client_pull import ( + download_file_finalize, + download_project_async, + download_file_async, + download_diffs_async, + download_project_finalize, + download_project_wait, + download_diffs_finalize, +) from .client_pull import pull_project_async, pull_project_wait, pull_project_finalize from .client_push import push_project_async, push_project_wait, push_project_finalize -from .utils import DateTimeEncoder +from .utils import DateTimeEncoder, get_versions_with_file_changes from .version import __version__ this_dir = os.path.dirname(os.path.realpath(__file__)) @@ -618,3 +626,38 @@ def get_projects_by_names(self, projects): resp = self.post("/v1/project/by_names", {"projects": projects}, {"Content-Type": "application/json"}) return json.load(resp) + + def download_file(self, project_dir, file_path, output_filename, version=None): + """ + Download project file at specified version. Get the latest if no version specified. + + :param project_dir: project local directory + :type project_dir: String + :param file_path: relative path of file to download in the project directory + :type file_path: String + :param output_filename: full destination path for saving the downloaded file + :type output_filename: String + :param version: optional version tag for downloaded file + :type version: String + """ + job = download_file_async(self, project_dir, file_path, output_filename, version=version) + pull_project_wait(job) + download_file_finalize(job) + + def get_file_diff(self, project_dir, file_path, output_diff, version_from, version_to): + """ Create concatenated diff for project file diffs between versions version_from and version_to. + + :param project_dir: project local directory + :type project_dir: String + :param file_path: relative path of file to download in the project directory + :type file_path: String + :param output_diff: full destination path for concatenated diff file + :type output_diff: String + :param version_from: starting project version tag for getting diff, for example 'v3' + :type version_from: String + :param version_to: ending project version tag for getting diff + :type version_to: String + """ + job = download_diffs_async(self, project_dir, file_path, version_from, version_to) + pull_project_wait(job) + download_diffs_finalize(job, output_diff) diff --git a/mergin/client_pull.py b/mergin/client_pull.py index c1e94b40..332269a7 100644 --- a/mergin/client_pull.py +++ b/mergin/client_pull.py @@ -14,12 +14,13 @@ import os import pprint import shutil +import tempfile import concurrent.futures from .common import CHUNK_SIZE, ClientError from .merginproject import MerginProject -from .utils import save_to_file +from .utils import save_to_file, get_versions_with_file_changes # status = download_project_async(...) @@ -35,7 +36,10 @@ class DownloadJob: - """ Keeps all the important data about a pending download job """ + """ + Keeps all the important data about a pending download job. + Used for downloading whole projects but also single files. + """ def __init__(self, project_path, total_size, version, update_tasks, download_queue_items, directory, mp, project_info): self.project_path = project_path @@ -48,7 +52,7 @@ def __init__(self, project_path, total_size, version, update_tasks, download_que self.mp = mp # MerginProject instance self.is_cancelled = False self.project_info = project_info # parsed JSON with project info returned from the server - + def dump(self): print("--- JOB ---", self.total_size, "bytes") for task in self.update_tasks: @@ -104,7 +108,6 @@ def download_project_async(mc, project_path, directory, project_version=None): """ Starts project download in background and returns handle to the pending project download. Using that object it is possible to watch progress or cancel the ongoing work. - """ if '/' not in project_path: @@ -226,27 +229,35 @@ def download_project_cancel(job): class UpdateTask: """ - Entry for each file that will be updated. At the end of a successful download of new data, all the tasks are executed. + Entry for each file that will be updated. + At the end of a successful download of new data, all the tasks are executed. """ # TODO: methods other than COPY - def __init__(self, file_path, download_queue_items): + def __init__(self, file_path, download_queue_items, destination_file=None): self.file_path = file_path + self.destination_file = destination_file self.download_queue_items = download_queue_items - + def apply(self, directory, mp): """ assemble downloaded chunks into a single file """ - basename = os.path.basename(self.file_path) #file['diff']['path']) if diff_only else os.path.basename(file['path']) - file_dir = os.path.dirname(os.path.normpath(os.path.join(directory, self.file_path))) - dest_file_path = os.path.join(file_dir, basename) + if self.destination_file is None: + basename = os.path.basename(self.file_path) + file_dir = os.path.dirname(os.path.normpath(os.path.join(directory, self.file_path))) + dest_file_path = os.path.join(file_dir, basename) + else: + file_dir = os.path.dirname(os.path.normpath(self.destination_file)) + dest_file_path = self.destination_file os.makedirs(file_dir, exist_ok=True) # merge chunks together (and delete them afterwards) file_to_merge = FileToMerge(dest_file_path, self.download_queue_items) file_to_merge.merge() - if mp.is_versioned_file(self.file_path): + # Make a copy of the file to meta dir only if there is no user-specified path for the file. + # destination_file is None for full project download and takes a meaningful value for a single file download. + if mp.is_versioned_file(self.file_path) and self.destination_file is None: mp.geodiff.make_copy_sqlite(mp.fpath(self.file_path), mp.fpath_meta(self.file_path)) @@ -287,7 +298,8 @@ def download_blocking(self, mc, mp, project_path): class PullJob: - def __init__(self, project_path, pull_changes, total_size, version, files_to_merge, download_queue_items, temp_dir, mp, project_info, basefiles_to_patch): + def __init__(self, project_path, pull_changes, total_size, version, files_to_merge, download_queue_items, + temp_dir, mp, project_info, basefiles_to_patch): self.project_path = project_path self.pull_changes = pull_changes # dictionary with changes (dict[str, list[dict]] - keys: "added", "updated", ...) self.total_size = total_size # size of data to download (in bytes) @@ -551,3 +563,190 @@ def pull_project_finalize(job): shutil.rmtree(job.temp_dir) return conflicts + + +def download_file_async(mc, project_dir, file_path, output_file, version): + """ + Starts background download project file at specified version. + Returns handle to the pending download. + """ + mp = MerginProject(project_dir) + project_path = mp.metadata["name"] + ver_info = f"at version {version}" if version is not None else "at latest version" + mp.log.info(f"Getting {file_path} {ver_info}") + project_info = mc.project_info(project_path, version=version) + mp.log.info(f"Got project info. version {project_info['version']}") + + # set temporary directory for download + temp_dir = tempfile.mkdtemp(prefix="mergin-py-client-") + + download_list = [] + update_tasks = [] + total_size = 0 + for file in project_info['files']: + if file["path"] == file_path: + file['version'] = version + items = _download_items(file, temp_dir) + task = UpdateTask(file['path'], items, output_file) + download_list.extend(task.download_queue_items) + for item in task.download_queue_items: + total_size += item.size + update_tasks.append(task) + break + if not download_list: + warn = f"No {file_path} exists at version {version}" + mp.log.warning(warn) + shutil.rmtree(temp_dir) + raise ClientError(warn) + + mp.log.info(f"will download file {file_path} in {len(download_list)} chunks, total size {total_size}") + job = DownloadJob( + project_path, total_size, version, update_tasks, download_list, temp_dir, mp, project_info + ) + job.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) + job.futures = [] + for item in download_list: + future = job.executor.submit(_do_download, item, mc, mp, project_path, job) + job.futures.append(future) + + return job + + +def download_file_finalize(job): + """ + To be called when download_file_async is finished + """ + job.executor.shutdown(wait=True) + + # make sure any exceptions from threads are not lost + for future in job.futures: + if future.exception() is not None: + raise future.exception() + + job.mp.log.info("--- download finished") + + temp_dir = None + for task in job.update_tasks: + task.apply(job.directory, job.mp) + if task.download_queue_items: + temp_dir = os.path.dirname(task.download_queue_items[0].download_file_path) + + # Remove temporary download directory + if temp_dir is not None: + shutil.rmtree(temp_dir) + + +def download_diffs_async(mc, project_directory, file_path, version_from, version_to): + """ + Starts background download project file diffs for specified versions. + Returns handle to the pending download. + + Args: + mc (MerginClient): MerginClient instance. + project_directory (str): local project directory. + file_path (str): file path relative to Mergin project root. + version_from (str): starting project version tag for getting diff, for example 'v3'. + version_to (str): ending project version tag for getting diff. + + Returns: + PullJob/None: a handle for the pending download. + """ + mp = MerginProject(project_directory) + project_path = mp.metadata["name"] + file_history = mc.project_file_history_info(project_path, file_path) + versions_to_fetch = get_versions_with_file_changes( + mc, project_path, file_path, version_from=version_from, version_to=version_to, file_history=file_history + ) + mp.log.info(f"--- version: {mc.user_agent_info()}") + mp.log.info(f"--- start download diffs for {file_path} of {project_path}, versions: {[v for v in versions_to_fetch]}") + + try: + server_info = mc.project_info(project_path) + if file_history is None: + file_history = mc.project_file_history_info(project_path, file_path) + except ClientError as err: + mp.log.error("Error getting project info: " + str(err)) + mp.log.info("--- downloading diffs aborted") + raise + + temp_dir = tempfile.mkdtemp(prefix="mergin-py-client-") + fetch_files = [] + + for version in versions_to_fetch[1:]: + version_data = file_history["history"][version] + diff_data = copy.deepcopy(version_data) + diff_data['version'] = version + diff_data['diff'] = version_data['diff'] + fetch_files.append(diff_data) + + files_to_merge = [] # list of FileToMerge instances + download_list = [] # list of all items to be downloaded + total_size = 0 + for file in fetch_files: + items = _download_items(file, temp_dir, diff_only=True) + dest_file_path = os.path.normpath(os.path.join(temp_dir, os.path.basename(file['diff']['path']))) + files_to_merge.append(FileToMerge(dest_file_path, items)) + download_list.extend(items) + for item in items: + total_size += item.size + + mp.log.info(f"will download {len(download_list)} chunks, total size {total_size}") + + job = PullJob(project_path, None, total_size, None, files_to_merge, download_list, temp_dir, mp, + server_info, {}) + + # start download + job.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) + job.futures = [] + for item in download_list: + future = job.executor.submit(_do_download, item, mc, mp, project_path, job) + job.futures.append(future) + + return job + + +def download_diffs_finalize(job, output_diff): + """ To be called after download_diffs_async """ + + job.executor.shutdown(wait=True) + + # make sure any exceptions from threads are not lost + for future in job.futures: + if future.exception() is not None: + job.mp.log.error("Error while pulling data: " + str(future.exception())) + job.mp.log.info("--- diffs download aborted") + raise future.exception() + + job.mp.log.info("finalizing diffs pull") + + # merge downloaded chunks + try: + for file_to_merge in job.files_to_merge: + file_to_merge.merge() + except ClientError as err: + job.mp.log.error("Error merging chunks of downloaded file: " + str(err)) + job.mp.log.info("--- diffs pull aborted") + raise + + job.mp.log.info("--- diffs pull finished") + + # Collect and finally concatenate diffs, if needed + diffs = [] + for file_to_merge in job.files_to_merge: + diffs.append(file_to_merge.dest_file) + + output_dir = os.path.dirname(output_diff) + temp_dir = None + if len(diffs) >= 1: + os.makedirs(output_dir, exist_ok=True) + temp_dir = os.path.dirname(diffs[0]) + if len(diffs) > 1: + job.mp.geodiff.concat_changes(diffs, output_diff) + elif len(diffs) == 1: + shutil.copy(diffs[0], output_diff) + for diff in diffs: + os.remove(diff) + + # remove the diffs download temporary directory + if temp_dir is not None: + shutil.rmtree(temp_dir) diff --git a/mergin/test/test_client.py b/mergin/test/test_client.py index 2a182eed..908d73f2 100644 --- a/mergin/test/test_client.py +++ b/mergin/test/test_client.py @@ -1,3 +1,4 @@ +import json import logging import os import tempfile @@ -10,7 +11,7 @@ from ..client import MerginClient, ClientError, MerginProject, LoginError, decode_token_data, TokenError from ..client_push import push_project_async, push_project_cancel -from ..utils import generate_checksum +from ..utils import generate_checksum, get_versions_with_file_changes from ..merginproject import pygeodiff @@ -810,12 +811,143 @@ def test_server_compatibility(mc): assert mc.is_server_compatible() +def create_versioned_project(mc, project_name, project_dir, updated_file, remove=True): + project = API_USER + '/' + project_name + cleanup(mc, project, [project_dir]) + + # create remote project + shutil.copytree(TEST_DATA_DIR, project_dir) + mc.create_project_and_push(project_name, project_dir) + + mp = MerginProject(project_dir) + + # create versions 2-4 + changes = ("inserted_1_A.gpkg", "inserted_1_A_mod.gpkg", "inserted_1_B.gpkg",) + for change in changes: + shutil.copy(mp.fpath(change), mp.fpath(updated_file)) + mc.push_project(project_dir) + # create version 5 with modified file removed + if remove: + os.remove(os.path.join(project_dir, updated_file)) + mc.push_project(project_dir) + + return mp + + +def test_get_versions_with_file_changes(mc): + """Test getting versions where the file was changed.""" + test_project = 'test_file_modified_versions' + project = API_USER + '/' + test_project + project_dir = os.path.join(TMP_DIR, test_project) + f_updated = "base.gpkg" + + mp = create_versioned_project(mc, test_project, project_dir, f_updated, remove=False) + + project_info = mc.project_info(project) + assert project_info["version"] == "v4" + file_history = mc.project_file_history_info(project, f_updated) + + with pytest.raises(ClientError) as e: + mod_versions = get_versions_with_file_changes( + mc, project, f_updated, version_from="v1", version_to="v5", file_history=file_history + ) + assert "Wrong version parameters: 1-5" in str(e.value) + assert "Available versions: [1, 2, 3, 4]" in str(e.value) + + mod_versions = get_versions_with_file_changes( + mc, project, f_updated, version_from="v2", version_to="v4", file_history=file_history + ) + assert mod_versions == [f"v{i}" for i in range(2, 5)] + + +def check_gpkg_same_content(mergin_project, gpkg_path_1, gpkg_path_2): + """Check if the two GeoPackages have equal content.""" + with tempfile.TemporaryDirectory() as temp_dir: + diff_path = os.path.join(temp_dir, "diff_file") + mergin_project.geodiff.create_changeset(gpkg_path_1, gpkg_path_2, diff_path) + return not mergin_project.geodiff.has_changes(diff_path) + + +def test_download_file(mc): + """Test downloading single file at specified versions.""" + test_project = 'test_download_file' + project = API_USER + '/' + test_project + project_dir = os.path.join(TMP_DIR, test_project) + f_updated = "base.gpkg" + + mp = create_versioned_project(mc, test_project, project_dir, f_updated) + + project_info = mc.project_info(project) + assert project_info["version"] == "v5" + + # Versioned file should have the following content at versions 2-4 + expected_content = ("inserted_1_A.gpkg", "inserted_1_A_mod.gpkg", "inserted_1_B.gpkg") + + # Download the base file at versions 2-4 and check the changes + f_downloaded = os.path.join(project_dir, f_updated) + for ver in range(2, 5): + mc.download_file(project_dir, f_updated, f_downloaded, version=f"v{ver}") + expected = os.path.join(TEST_DATA_DIR, expected_content[ver - 2]) # GeoPackage with expected content + assert check_gpkg_same_content(mp, f_downloaded, expected) + + # make sure there will be exception raised if a file doesn't exist in the version + with pytest.raises(ClientError, match=f"No {f_updated} exists at version v5"): + mc.download_file(project_dir, f_updated, f_downloaded, version=f"v5") + + +def test_download_diffs(mc): + """Test download diffs for a project file between specified project versions.""" + test_project = 'test_download_diffs' + project = API_USER + '/' + test_project + project_dir = os.path.join(TMP_DIR, test_project) + download_dir = os.path.join(project_dir, "diffs") # project for downloading files at various versions + f_updated = "base.gpkg" + diff_file = os.path.join(download_dir, f_updated + ".diff") + + mp = create_versioned_project(mc, test_project, project_dir, f_updated, remove=False) + + project_info = mc.project_info(project) + assert project_info["version"] == "v4" + + # Download diffs of updated file between versions 1 and 2 + mc.get_file_diff(project_dir, f_updated, diff_file, "v1", "v2") + assert os.path.exists(diff_file) + assert mp.geodiff.has_changes(diff_file) + assert mp.geodiff.changes_count(diff_file) == 1 + changes_file = diff_file + ".changes1-2" + mp.geodiff.list_changes_summary(diff_file, changes_file) + with open(changes_file, 'r') as f: + changes = json.loads(f.read())["geodiff_summary"][0] + assert changes["insert"] == 1 + assert changes["update"] == 0 + + # Download diffs of updated file between versions 2 and 4 + mc.get_file_diff(project_dir, f_updated, diff_file, "v2", "v4") + changes_file = diff_file + ".changes2-4" + mp.geodiff.list_changes_summary(diff_file, changes_file) + with open(changes_file, 'r') as f: + changes = json.loads(f.read())["geodiff_summary"][0] + assert changes["insert"] == 0 + assert changes["update"] == 1 + + with pytest.raises(ClientError) as e: + mc.get_file_diff(project_dir, f_updated, diff_file, "v4", "v1") + assert "Wrong version parameters" in str(e.value) + assert "version_from needs to be smaller than version_to" in str(e.value) + + with pytest.raises(ClientError) as e: + mc.get_file_diff(project_dir, f_updated, diff_file, "v4", "v5") + assert "Wrong version parameters" in str(e.value) + assert "Available versions: [1, 2, 3, 4]" in str(e.value) + + def _use_wal(db_file): """ Ensures that sqlite database is using WAL journal mode """ con = sqlite3.connect(db_file) cursor = con.cursor() cursor.execute('PRAGMA journal_mode=wal;') + def _create_test_table(db_file): """ Creates a table called 'test' in sqlite database. Useful to simulate change of database schema. """ con = sqlite3.connect(db_file) @@ -824,6 +956,7 @@ def _create_test_table(db_file): cursor.execute('INSERT INTO test VALUES (123, \'hello\');') cursor.execute('COMMIT;') + def _check_test_table(db_file): """ Checks whether the 'test' table exists and has one row - otherwise fails with an exception. """ #con_verify = sqlite3.connect(db_file) diff --git a/mergin/utils.py b/mergin/utils.py index 84045052..b6010258 100644 --- a/mergin/utils.py +++ b/mergin/utils.py @@ -5,6 +5,7 @@ import re import sqlite3 from datetime import datetime +from .common import ClientError def generate_checksum(file, chunk_size=4096): @@ -99,3 +100,48 @@ def do_sqlite_checkpoint(path, log=None): log.info("checkpoint - new size {} checksum {}".format(new_size, new_checksum)) return new_size, new_checksum + + +def get_versions_with_file_changes( + mc, project_path, file_path, version_from=None, version_to=None, file_history=None): + """ + Get the project version tags where the file was added, modified or deleted. + + Args: + mc (MerginClient): MerginClient instance + project_path (str): project full name (/) + file_path (str): relative path of file to download in the project directory + version_from (str): optional minimum version to fetch, for example "v3" + version_to (str): optional maximum version to fetch + file_history (dict): optional file history info, result of project_file_history_info(). + + Returns: + list of version tags, for example ["v4", "v7", "v8"] + """ + if file_history is None: + file_history = mc.project_file_history_info(project_path, file_path) + all_version_numbers = sorted([int(k[1:]) for k in file_history["history"].keys()]) + version_from = all_version_numbers[0] if version_from is None else int_version(version_from) + version_to = all_version_numbers[-1] if version_to is None else int_version(version_to) + if version_from is None or version_to is None: + err = f"Wrong version parameters: {version_from}-{version_to} while getting diffs for {file_path}. " + err += f"Version tags required in the form: 'v2', 'v11', etc." + raise ClientError(err) + if version_from >= version_to: + err = f"Wrong version parameters: {version_from}-{version_to} while getting diffs for {file_path}. " + err += f"version_from needs to be smaller than version_to." + raise ClientError(err) + if version_from not in all_version_numbers or version_to not in all_version_numbers: + err = f"Wrong version parameters: {version_from}-{version_to} while getting diffs for {file_path}. " + err += f"Available versions: {all_version_numbers}" + raise ClientError(err) + + # Find versions to fetch between the 'from' and 'to' versions + idx_from = idx_to = None + for idx, version in enumerate(all_version_numbers): + if version == version_from: + idx_from = idx + elif version == version_to: + idx_to = idx + break + return [f"v{ver_nr}" for ver_nr in all_version_numbers[idx_from:idx_to + 1]]