Skip to content

Commit 1cb8f26

Browse files
committed
[LABIMP-8186] Added support to import users to a project from another project
1 parent 1f049e4 commit 1cb8f26

File tree

5 files changed

+89
-206
lines changed

5 files changed

+89
-206
lines changed

driver.py

Lines changed: 0 additions & 198 deletions
This file was deleted.

labellerr/core/connectors/connections.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,27 @@ def get_connection(client: "LabellerrClient", connection_id: str):
3939
"""Metaclass that combines ABC functionality with factory pattern"""
4040

4141
def __call__(cls, client, connection_id, **kwargs):
42-
# Only intercept calls to the base LabellerrConnection class
42+
# Fetch connection data if not already provided
43+
if "connection_data" not in kwargs:
44+
connection_data = cls.get_connection(client, connection_id)
45+
if connection_data is None:
46+
raise InvalidConnectionError(f"Connection not found: {connection_id}")
47+
kwargs["connection_data"] = connection_data
48+
49+
# Only intercept calls to the base LabellerrConnection class for factory behavior
4350
if cls.__name__ != "LabellerrConnection":
44-
# For subclasses, use normal instantiation
51+
# For subclasses, use normal instantiation with connection_data
4552
instance = cls.__new__(cls)
4653
if isinstance(instance, cls):
4754
instance.__init__(client, connection_id, **kwargs)
4855
return instance
49-
connection_data = cls.get_connection(client, connection_id)
50-
if connection_data is None:
51-
raise InvalidConnectionError(f"Connection not found: {connection_id}")
56+
57+
# Factory behavior for base class
58+
connection_data = kwargs["connection_data"]
5259
connector = connection_data.get("connector")
5360
connection_class = cls._registry.get(connector)
5461
if connection_class is None:
5562
raise InvalidConnectionError(f"Unknown connector type: {connector}")
56-
kwargs["connection_data"] = connection_data
5763
return connection_class(client, connection_id, **kwargs)
5864

5965

labellerr/core/connectors/gcs_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_connection(
6666
@staticmethod
6767
def create_connection(
6868
client: "LabellerrClient", params: GCSConnectionParams
69-
) -> str:
69+
) -> "LabellerrConnection":
7070
"""
7171
Sets up GCP connector for dataset creation (quick connection).
7272

labellerr/core/connectors/s3_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_connection(
6565
@staticmethod
6666
def create_connection(
6767
client: "LabellerrClient", params: AWSConnectionParams
68-
) -> dict:
68+
) -> "LabellerrConnection":
6969
"""
7070
Creates an AWS S3 connection.
7171

labellerr/core/projects/base.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,65 @@ def data_type(self):
102102
def attached_datasets(self):
103103
return self.__project_data.get("attached_datasets")
104104

105+
def status(self) -> Dict:
106+
"""
107+
Poll project status until completion or timeout.
108+
109+
Returns:
110+
Final project data with status information
111+
112+
Examples:
113+
# Get current project status
114+
final_status = project.status()
115+
"""
116+
from .utils import poll
117+
118+
def get_project_status():
119+
unique_id = str(uuid.uuid4())
120+
url = (
121+
f"{constants.BASE_URL}/projects/project/{self.project_id}?client_id={self.client.client_id}"
122+
f"&uuid={unique_id}"
123+
)
124+
125+
response = self.client.make_request(
126+
"GET",
127+
url,
128+
extra_headers={"content-type": "application/json"},
129+
request_id=unique_id,
130+
)
131+
project_data = response.get("response", {})
132+
if project_data:
133+
self.__project_data = project_data
134+
return project_data
135+
136+
def is_completed(project_data):
137+
status_code = project_data.get("status_code", 500)
138+
# Consider project complete when status_code is 300 (success) or >= 400 (error/failed)
139+
return status_code == 300 or status_code >= 400
140+
141+
def on_success(project_data):
142+
status_code = project_data.get("status_code", 500)
143+
if status_code == 300:
144+
logging.info(
145+
"Project %s processing completed successfully!", self.project_id
146+
)
147+
else:
148+
logging.warning(
149+
"Project %s processing finished with status code: %s",
150+
self.project_id,
151+
status_code,
152+
)
153+
return project_data
154+
155+
return poll(
156+
function=get_project_status,
157+
condition=is_completed,
158+
interval=2.0,
159+
timeout=None,
160+
max_retries=None,
161+
on_success=on_success,
162+
)
163+
105164
def detach_dataset_from_project(self, dataset_id=None, dataset_ids=None):
106165
"""
107166
Detaches one or more datasets from an existing project.
@@ -568,3 +627,19 @@ def __fetch_exports_download_url(self, project_id, uuid, export_id, client_id):
568627
return response.get("response")
569628
except Exception as e:
570629
raise LabellerrError(f"Failed to download export: {str(e)}")
630+
631+
def import_users(self, from_project: "LabellerrProject"):
632+
"""
633+
Imports users from a source project to the current project.
634+
635+
:param from_project: The source project to import users from
636+
:return: Dictionary containing import results
637+
:raises LabellerrError: If the import fails
638+
"""
639+
# Validate parameters using Pydantic
640+
unique_id = str(uuid.uuid4())
641+
url = f"{constants.BASE_URL}/users/projects/import_users?selected_project_id={self.project_id}&project_id={from_project.project_id}&client_id={self.client.client_id}&uuid={unique_id}"
642+
response = self.client.make_request(
643+
"POST", url, extra_headers={"Content-Type": "application/json"}
644+
)
645+
return response.get("response")

0 commit comments

Comments
 (0)