-
Notifications
You must be signed in to change notification settings - Fork 3
/
importer.py
147 lines (113 loc) · 5.15 KB
/
importer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Aqueduct module for importing data from instance."""
from __future__ import annotations
import tarfile
from tarfile import TarFile
from typing import Any, Callable, Dict, List, Optional, Sequence
from packaging.version import Version
from sqlalchemy import select
from sqlalchemy.orm import Session
from aqueductcore.backend.models import orm
from aqueductcore.cli.exporter import Exporter
from aqueductcore.cli.models import AqueductData, Experiment, Tag
class Importer:
"""Aqueduct importer class."""
@classmethod
def check_version_compatible(cls, aqueduct_version: str, metadata_version: str) -> bool:
"""Check if the metadata version is compatible with the current Aqueduct version.
Args:
metadata_version: Version string in semantic format.
Returns:
True if the version is compatible, False otherwise.
"""
version = Version(metadata_version)
cur_version = Version(aqueduct_version)
return cur_version.major == version.major
@classmethod
def get_conflicting_experiments(
cls, db_session: Session, experiments: List[Experiment]
) -> Sequence[orm.Experiment]:
"""Check if there are conflicting experiments with the same EID.
Args:
db_session: Database session to load metadata from.
experiments: Experiments to be checked.
Returns:
List of database experiments that are conflicting with the provided ones.
"""
conflict_statement = select(orm.Experiment).where(
orm.Experiment.eid.in_([item.eid for item in experiments])
)
conflict_result = db_session.execute(conflict_statement).scalars().all()
return conflict_result
@classmethod
def import_experiments_metadata(cls, db_session: Session, metadata: AqueductData) -> None:
"""Import metadata into the database. Conflicts raise database exceptions.
Args:
db_session: Database session to load metadata from.
metadata: Aqueduct metadata object.
"""
cur_users_statement = select(orm.User).where(
orm.User.uuid.in_([item.uuid for item in metadata.users])
)
cur_db_users = db_session.execute(cur_users_statement).scalars().all()
cur_db_users_dict = {item.uuid: item for item in cur_db_users}
metadata_experiments: List[Experiment] = []
for user in metadata.users:
metadata_experiments.extend(user.experiments)
metadata_tags: Dict[str, Tag] = {}
for experiment in metadata_experiments:
for tag in experiment.tags:
metadata_tags[tag.key] = tag
cur_db_tags_statement = select(orm.Tag).where(orm.Tag.key.in_(list(metadata_tags.keys())))
cur_db_tags = db_session.execute(cur_db_tags_statement).scalars().all()
cur_db_tags_dict = {item.key: item for item in cur_db_tags}
db_tags: Dict[str, orm.Tag] = {}
for key, value in metadata_tags.items():
if key in cur_db_tags_dict:
db_tags[key] = cur_db_tags_dict[key]
else:
db_tags[key] = orm.Tag(key=key, name=value.name)
for user in metadata.users:
db_user = cur_db_users_dict.get(user.uuid)
if not db_user:
db_user = orm.User(
uuid=user.uuid,
username=user.username,
)
cur_db_users_dict[user.uuid] = db_user
db_session.add(db_user)
for experiment in user.experiments:
db_experiment = orm.Experiment(
uuid=experiment.uuid,
title=experiment.title,
description=experiment.description,
tags=[db_tags[tag.key] for tag in experiment.tags],
eid=experiment.eid,
created_at=experiment.created_at,
updated_at=experiment.updated_at,
)
db_user.experiments.append(db_experiment)
@classmethod
def import_experiment_files(
cls,
tar: TarFile,
experiments_root: str,
progress: Optional[Callable[[int], Any]] = None,
) -> None:
"""Import experiments' files and metadata to the desired location as a tar file with
gzip compression.
Args:
metadata: Aqueduct metadata as bytes.
tar: Tar file to read the archive.
experiments_root: Experiments rood directory of the Aqueduct instance.
progress: Call back with processed data information to show progress.
"""
def experiments_filter(member: tarfile.TarInfo, _) -> Optional[tarfile.TarInfo]:
"""Extraction filter for progress bar."""
if progress:
progress(member.size)
if member.path == Exporter.METADATA_FILENAME:
return None
if member.path.startswith(Exporter.EXPERIMENTS_BASE_DIR_NAME):
member.path = member.path.replace(f"{Exporter.EXPERIMENTS_BASE_DIR_NAME}/", "")
return member
tar.extractall(path=experiments_root, filter=experiments_filter)