Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ module = [
"pytorch_lightning.loggers.comet",
"pytorch_lightning.loggers.mlflow",
"pytorch_lightning.loggers.neptune",
"pytorch_lightning.loggers.tensorboard",
"pytorch_lightning.loggers.wandb",
"pytorch_lightning.profilers.advanced",
"pytorch_lightning.profilers.base",
Expand Down
18 changes: 9 additions & 9 deletions src/pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
sub_dir: Optional[str] = None,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**kwargs,
**kwargs: Any,
):
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
self._save_dir = save_dir
Expand All @@ -108,8 +108,8 @@ def __init__(
self._prefix = prefix
self._fs = get_filesystem(save_dir)

self._experiment = None
self.hparams = {}
self._experiment: Optional["SummaryWriter"] = None
self.hparams: Union[Dict[str, Any], Namespace] = {}
self._kwargs = kwargs

@property
Expand Down Expand Up @@ -138,7 +138,7 @@ def log_dir(self) -> str:
return log_dir

@property
def save_dir(self) -> Optional[str]:
def save_dir(self) -> str:
"""Gets the save directory where the TensorBoard experiments are saved.

Returns:
Expand All @@ -155,7 +155,7 @@ def sub_dir(self) -> Optional[str]:
"""
return self._sub_dir

@property
@property # type: ignore[misc]
@rank_zero_experiment
def experiment(self) -> SummaryWriter:
r"""
Expand Down Expand Up @@ -236,7 +236,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
raise ValueError(m) from ex

@rank_zero_only
def log_graph(self, model: "pl.LightningModule", input_array=None):
def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None:
if self._log_graph:
if input_array is None:
input_array = model.example_input_array
Expand Down Expand Up @@ -281,7 +281,7 @@ def name(self) -> str:
return self._name

@property
def version(self) -> int:
def version(self) -> Union[int, str]:
"""Get the experiment version.

Returns:
Expand All @@ -291,7 +291,7 @@ def version(self) -> int:
self._version = self._get_next_version()
return self._version

def _get_next_version(self):
def _get_next_version(self) -> int:
root_dir = self.root_dir

try:
Expand All @@ -318,7 +318,7 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
# logging of arrays with dimension > 1 is not supported, sanitize as string
return {k: str(v) if isinstance(v, (Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()}

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["_experiment"] = None
return state