Skip to content

Commit

Permalink
Shift from tqdm to rich
Browse files Browse the repository at this point in the history
  • Loading branch information
rasswanth-s committed Aug 4, 2024
1 parent 61a56cb commit f44c200
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 32 deletions.
1 change: 0 additions & 1 deletion packages/syft/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ syft =
requests==2.32.3
RestrictedPython==7.0
result==0.16.1
tqdm==4.66.4
typeguard==4.1.5
typing_extensions==4.12.0
sherlock[filelock]==0.4.1
Expand Down
74 changes: 53 additions & 21 deletions packages/syft/src/syft/client/datasite_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
from pathlib import Path
import re
from string import Template
import time
import traceback
from typing import TYPE_CHECKING
from typing import cast

# third party
import markdown
from result import Result
from tqdm import tqdm
from rich.console import Console
from rich.progress import BarColumn
from rich.progress import Progress
from rich.progress import SpinnerColumn
from rich.progress import TextColumn

# relative
from ..abstract_server import ServerSideType
Expand Down Expand Up @@ -119,19 +124,24 @@ def upload_model(self, model: CreateModel) -> SyftSuccess | SyftError:
model.code_action_id = model_code_res.id
model_ref_action_ids.append(model_code_res.id)

console = Console()
# Step 2. Upload Model Assets to Action Store

model_size: float = 0.0
with tqdm(
total=len(model.asset_list), colour="green", desc="Uploading"
) as pbar:
with Progress(
SpinnerColumn(),
"[progress.description]{task.description}",
TextColumn("[progress.remaining]{task.completed}/{task.total}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
) as progress:
task = progress.add_task("Uploading...", total=len(model.asset_list))

for asset in model.asset_list:
try:
contains_empty: bool = asset.contains_empty()
# TODO: Add mock model weights
twin = TwinObject(
private_obj=ActionObject.from_obj(
asset.data
), # same on both for now
private_obj=ActionObject.from_obj(asset.data),
mock_obj=ActionObject.from_obj(asset.data),
syft_server_location=self.id,
syft_client_verify_key=self.verify_key,
Expand All @@ -140,19 +150,23 @@ def upload_model(self, model: CreateModel) -> SyftSuccess | SyftError:
if isinstance(res, SyftError):
return res
except Exception as e:
tqdm.write(f"Failed to create twin for {asset.name}. {e}")
console.print(
f"Failed to create twin for {asset.name}. {e}", style="bold red"
)
return SyftError(message=f"Failed to create twin. {e}")

if isinstance(res, SyftWarning):
logger.debug(res.message)
# Clear Cache before saving

twin.private_obj._clear_cache()
twin.mock_obj._clear_cache()
response = self.api.services.action.set(
twin, ignore_detached_objs=contains_empty
)
if isinstance(response, SyftError):
tqdm.write(f"Failed to upload asset: {asset.name}")
console.print(
f"Failed to upload asset: {asset.name}", style="bold red"
)
return response

asset.action_id = twin.id
Expand All @@ -164,9 +178,12 @@ def upload_model(self, model: CreateModel) -> SyftSuccess | SyftError:
asset.data = None
asset.mock = None

# Update the progress bar and set the dynamic description
pbar.set_description(f"Uploading: {asset.name}")
pbar.update(1)
progress.update(
task, advance=1, description=f"Uploading : {asset.name}"
)
# When the last asset is uploaded, the progress bar is not updated
# so add a small delay to ensure the progress bar is updated
time.sleep(0.1)

# Step 3. Upload Model Ref to Action Store
# Model Ref is a reference to the model code and assets
Expand Down Expand Up @@ -225,9 +242,16 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError:
)
prompt_warning_message(message=message, confirm=True)

with tqdm(
total=len(dataset.asset_list), colour="green", desc="Uploading"
) as pbar:
console = Console()
with Progress(
SpinnerColumn(),
"[progress.description]{task.description}",
TextColumn("[progress.remaining]{task.completed}/{task.total}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
) as progress:
task = progress.add_task("Uploading...", total=len(dataset.asset_list))

for asset in dataset.asset_list:
try:
contains_empty: bool = asset.contains_empty()
Expand All @@ -241,7 +265,9 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError:
if isinstance(res, SyftError):
return res
except Exception as e:
tqdm.write(f"Failed to create twin for {asset.name}. {e}")
console.print(
f"Failed to create twin for {asset.name}. {e}", style="bold red"
)
return SyftError(message=f"Failed to create twin. {e}")

if isinstance(res, SyftWarning):
Expand All @@ -250,16 +276,22 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError:
twin, ignore_detached_objs=contains_empty
)
if isinstance(response, SyftError):
tqdm.write(f"Failed to upload asset: {asset.name}")
console.print(
f"Failed to upload asset: {asset.name}", style="bold red"
)
return response

asset.action_id = twin.id
asset.server_uid = self.id
dataset_size += get_mb_size(asset.data)

# Update the progress bar and set the dynamic description
pbar.set_description(f"Uploading: {asset.name}")
pbar.update(1)
progress.update(
task, advance=1, description=f"Uploading : {asset.name}"
)
# When the last asset is uploaded, the progress bar is not updated
# so add a small delay to ensure the progress bar is updated
time.sleep(0.1)

dataset.mb_size = dataset_size
valid = dataset.check()
Expand Down
1 change: 0 additions & 1 deletion packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,6 @@ def _save_to_blob_storage_(
serialized = serialize(data, to_bytes=True)
size = sys.getsizeof(serialized)
storage_entry = CreateBlobStorageEntry.from_obj(data, file_size=size)
print("storage entry id", storage_entry.id)
if not TraceResultRegistry.current_thread_is_tracing():
self.syft_action_data_cache = self.as_empty_data()
if self.syft_blob_storage_entry_id is not None:
Expand Down
2 changes: 0 additions & 2 deletions packages/syft/src/syft/service/model/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def add(
"""Add a model"""
model = model.to(Model, context=context)

print("got model", model)

result = self.stash.set(
context.credentials,
model,
Expand Down
22 changes: 15 additions & 7 deletions packages/syft/src/syft/store/blob_storage/seaweedfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
from botocore.client import Config
from botocore.exceptions import ConnectionError
import requests
from rich.progress import BarColumn
from rich.progress import Progress
from rich.progress import SpinnerColumn
from rich.progress import TextColumn
from tenacity import retry
from tenacity import retry_if_exception_type
from tenacity import stop_after_delay
from tenacity import wait_fixed
from tqdm import tqdm
from typing_extensions import Self

# relative
Expand Down Expand Up @@ -82,11 +85,16 @@ def write(self, data: BytesIO) -> SyftSuccess | SyftError:
# this is the total nr of chunks in all parts
total_iterations = math.ceil(part_size / chunk_size) * len(self.urls)

with tqdm(
total=total_iterations,
desc=f"Uploading progress", # noqa
colour="green",
) as pbar:
with Progress(
SpinnerColumn(),
"[progress.description]{task.description}",
TextColumn("[progress.remaining]{task.completed}/{task.total}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
) as progress:
task = progress.add_task(
"Uploading to Blob Storage ...", total=total_iterations
)
for part_no, url in enumerate(
self.urls,
start=1,
Expand Down Expand Up @@ -120,7 +128,7 @@ def async_generator(
item = item_queue.get()
while item != 0:
yield item
pbar.update(1)
progress.update(task, advance=1)
item = item_queue.get()

def add_chunks_to_queue(
Expand Down

0 comments on commit f44c200

Please sign in to comment.