Skip to content

Commit

Permalink
refactor: code refactor due to PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Dec 14, 2021
1 parent 2d52593 commit cbac091
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion embeddings/pipeline/lightning_classification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union, Sequence
from typing import Any, Dict, Optional, Sequence, Union

import datasets
import numpy as np
Expand Down
13 changes: 6 additions & 7 deletions embeddings/task/lightning_task/lightning_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,19 @@ def configure_optimizers(self) -> Tuple[List[Optimizer], List[Any]]:
)

if self.hparams.use_scheduler:
lr_scheduler = self.configure_scheduler(optimizer=optimizer)
lr_schedulers = self.configure_schedulers(optimizer=optimizer)
else:
lr_scheduler = []
lr_schedulers = []

return [optimizer], lr_scheduler
return [optimizer], lr_schedulers

def configure_scheduler(self, optimizer: Optimizer) -> List[Dict[str, Any]]:
def configure_schedulers(self, optimizer: Optimizer) -> List[Dict[str, Any]]:
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.total_steps,
)
self.lr_scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [self.lr_scheduler]
return [{"scheduler": scheduler, "interval": "step", "frequency": 1}]


class HuggingFaceLightningTask(LightningTask[AutoModel], abc.ABC):
Expand All @@ -143,7 +142,7 @@ def __init__(
self.config_kwargs = config_kwargs if config_kwargs else {}

def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit":
if stage in ["fit", None]:
self.configure_model()
self.configure_metrics()
if self.hparams.use_scheduler:
Expand Down
3 changes: 2 additions & 1 deletion embeddings/task/lightning_task/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
self.train_metrics(preds, batch["labels"])
self.log("train/Loss", loss)
if self.hparams.use_scheduler:
last_lr = self.lr_scheduler["scheduler"].get_last_lr()
assert self.trainer is not None
last_lr = self.trainer.lr_schedulers[0]["scheduler"].get_last_lr()
self.log("train/BaseLR", last_lr[0], prog_bar=True)
self.log("train/LambdaLR", last_lr[1], prog_bar=True)
return {"loss": loss}
Expand Down

0 comments on commit cbac091

Please sign in to comment.