Skip to content

Commit

Permalink
Integration init add mode param(#596)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyi-Lin committed May 31, 2024
1 parent e157f3d commit 935f1f7
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 22 deletions.
9 changes: 5 additions & 4 deletions swanlab/integration/fastai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"This module requires `fastai` to be installed. " "Please install it with command: \n pip install fastai"
)

from typing import Optional
from typing import Optional, Any
import swanlab
from swanlab.log import swanlog as swl

Expand All @@ -24,8 +24,9 @@ def __init__(
description: Optional[str] = None,
workspace: Optional[str] = None,
config: Optional[dict] = None,
cloud: Optional[bool] = True,
mode: Optional[str] = None,
logdir: Optional[str] = None,
**kwargs: Any,
):
store_attr()
self._experiment = swanlab
Expand All @@ -35,7 +36,7 @@ def __init__(
self.workspace = workspace
self.config = config
self.description = description
self.cloud = cloud
self.mode = mode
self.logdir = logdir
self.train_suffix = "train"
self.summary_suffix = "summary"
Expand All @@ -48,7 +49,7 @@ def setup_swanlab(self):
experiment_name=self.experiment_name,
description=self.description,
config=self.config,
cloud=self.cloud,
mode=self.mode,
logdir=self.logdir,
)

Expand Down
6 changes: 3 additions & 3 deletions swanlab/integration/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
experiment_name: Optional[str] = None,
description: Optional[str] = None,
logdir: Optional[str] = None,
cloud: Optional[bool] = True,
mode: Optional[str] = None,
**kwargs: Any,
):
self._initialized = False
Expand All @@ -50,7 +50,7 @@ def __init__(
"experiment_name": experiment_name,
"description": description,
"logdir": logdir,
"cloud": cloud,
"mode": mode,
}

self._swanlab_init.update(**kwargs)
Expand All @@ -60,7 +60,7 @@ def __init__(
self._experiment_name = self._swanlab_init.get("experiment_name")
self._description = self._swanlab_init.get("decsription")
self._logdir = self._swanlab_init.get("logdir")
self._cloud = self._swanlab_init.get("cloud")
self._mode = self._swanlab_init.get("mode")

def setup(self, args, state, model, **kwargs):
self._initialized = True
Expand Down
4 changes: 2 additions & 2 deletions swanlab/integration/integration_utils/autologging.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
resolver: ArgumentResponseResolver,
client=None,
lib_version=None,
cloud: bool = True,
mode: str = None,
) -> None:
"""Autolog API calls to SwanLab."""

Expand All @@ -188,7 +188,7 @@ def __init__(
)
self._name = self._patch_api.name
self._run: Optional[SwanLabRun] = None
self.cloud = cloud
self.mode = mode
self.client: openai.Client = client
self.lib_version = lib_version

Expand Down
2 changes: 1 addition & 1 deletion swanlab/integration/mmengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"experiment_name": "YourExperiment", # experiment name on swanlab
"description": "have fun", # experiment description (can be null)
"workspace": "YourOrganization", # Your Organization on swanlab
# "cloud": False, # Upload to cloud
# "mode": "cloud", # Upload to cloud
},
),
]
Expand Down
6 changes: 3 additions & 3 deletions swanlab/integration/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
experiment_name: Optional[str] = None,
description: Optional[str] = None,
logdir: Optional[str] = None,
cloud: Optional[bool] = True,
mode: Optional[str] = None,
save_dir: Union[str, Path] = ".",
**kwargs: Any,
):
Expand All @@ -69,7 +69,7 @@ def __init__(
"experiment_name": experiment_name,
"description": description,
"logdir": logdir,
"cloud": cloud,
"mode": mode,
}

self._swanlab_init.update(**kwargs)
Expand All @@ -79,7 +79,7 @@ def __init__(
self._experiment_name = self._swanlab_init.get("experiment_name")
self._description = self._swanlab_init.get("decsription")
self._logdir = self._swanlab_init.get("logdir")
self._cloud = self._swanlab_init.get("cloud")
self._mode = self._swanlab_init.get("mode")

if save_dir is not None:
save_dir = os.fspath(save_dir)
Expand Down
6 changes: 3 additions & 3 deletions swanlab/integration/sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
experiment_name: Optional[str] = None,
description: Optional[str] = None,
logdir: Optional[str] = None,
cloud: Optional[bool] = True,
mode: Optional[bool] = None,
verbose: int = 0,
**kwargs: Any,
):
Expand All @@ -83,7 +83,7 @@ def __init__(
"experiment_name": experiment_name,
"description": description,
"logdir": logdir,
"cloud": cloud,
"mode": mode,
}

self._swanlab_init.update(**kwargs)
Expand All @@ -93,7 +93,7 @@ def __init__(
self._experiment_name = self._swanlab_init.get("experiment_name")
self._description = self._swanlab_init.get("decsription")
self._logdir = self._swanlab_init.get("logdir")
self._cloud = self._swanlab_init.get("cloud")
self._mode = self._swanlab_init.get("mode")

def _init_callback(self) -> None:
args = {"algo": type(self.model).__name__}
Expand Down
71 changes: 65 additions & 6 deletions swanlab/integration/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,45 @@ def add_integration_callbacks(instance):
from ultralytics.models import YOLO
from ultralytics.utils.torch_utils import model_info_for_loggers
from collections import Counter
from typing import Optional, Dict, Any
import swanlab


_processed_plots = {}


class UltralyticsSwanlabCallback:
def __init__(self) -> None:
def __init__(
self,
project: Optional[str] = None,
workspace: Optional[str] = None,
experiment_name: Optional[str] = None,
description: Optional[str] = None,
logdir: Optional[str] = None,
mode: Optional[bool] = None,
**kwargs: Any,
) -> None:
self.step_counter = Counter()
self._run = None

self._swanlab_init: Dict[str, Any] = {
"project": project,
"workspace": workspace,
"experiment_name": experiment_name,
"description": description,
"logdir": logdir,
"mode": mode,
}

self._swanlab_init.update(**kwargs)

self._project = self._swanlab_init.get("project")
self._workspace = self._swanlab_init.get("workspace")
self._experiment_name = self._swanlab_init.get("experiment_name")
self._description = self._swanlab_init.get("decsription")
self._logdir = self._swanlab_init.get("logdir")
self._mode = self._swanlab_init.get("mode")

def _log_plots(self, plots: dict, step: int, tag: str):
"""记录指标绘图和推理图像"""
image_list = []
Expand All @@ -67,9 +95,13 @@ def on_pretrain_routine_start(self, trainer):
"""初始化实验记录器"""
if swanlab.get_run() is None:
self._run = swanlab.init(
project=trainer.args.project,
experiment_name=trainer.args.name,
project=trainer.args.project if self._project is None else self._project,
workspace=self._workspace,
experiment_name=trainer.args.name if self._experiment_name is None else self._experiment_name,
config=vars(trainer.args),
description=self._description,
logdir=self._logdir,
mode=self._mode,
)
else:
self._run = swanlab.get_run()
Expand Down Expand Up @@ -112,8 +144,21 @@ def on_train_end(self, trainer):

def add_swanlab_callback(
model: YOLO,
project: Optional[str] = None,
workspace: Optional[str] = None,
experiment_name: Optional[str] = None,
description: Optional[str] = None,
logdir: Optional[str] = None,
mode: Optional[bool] = None,
):
ultralytics_swanlabcallback = UltralyticsSwanlabCallback()
ultralytics_swanlabcallback = UltralyticsSwanlabCallback(
project=project,
workspace=workspace,
experiment_name=experiment_name,
description=description,
logdir=logdir,
mode=mode,
)

"""给Ultralytics模型添加swanlab回调函数"""
callbacks = {
Expand All @@ -129,8 +174,22 @@ def add_swanlab_callback(
return model


def return_swanlab_callback():
ultralytics_swanlabcallback = UltralyticsSwanlabCallback()
def return_swanlab_callback(
project: Optional[str] = None,
workspace: Optional[str] = None,
experiment_name: Optional[str] = None,
description: Optional[str] = None,
logdir: Optional[str] = None,
mode: Optional[bool] = None,
):
ultralytics_swanlabcallback = UltralyticsSwanlabCallback(
project=project,
workspace=workspace,
experiment_name=experiment_name,
description=description,
logdir=logdir,
mode=mode,
)

callbacks = (
{
Expand Down

0 comments on commit 935f1f7

Please sign in to comment.