A thin PyTorch-Lightning wrapper for building configuration-based DL pipelines with Hydra.
- PyTorch Lightning
- Hydra
- Pydantic
- etc.
$ pip install -U hiraishin
Define a model class that has training components with type annotations.
import torch.nn as nn
import torch.optim as optim
from hiraishin.models import BaseModel
class ToyModel(BaseModel):
net: nn.Linear
criterion: nn.CrossEntropyLoss
optimizer: optim.Adam
scheduler: optim.lr_schedulers.ExponentialLR
def __init__(self, config: DictConfig) -> None:
super().__init__(config)
Modules with the following prefixes are instantiated by their own role-specific logic.
net
criterion
optimizer
scheduler
The same notation can be used to define components other than the learning components listed above (e.g., tokenizers). It is also possible to define built-in type constants that are YAML serializable.
class ToyModel(BaseModel):
net: nn.Linear
criterion: nn.CrossEntropyLoss
optimizer: optim.Adam
scheduler: optim.lr_schedulers.ExponentialLR
# additional components and constants
tokenizer: MyTokenizer
n_classes: int
def __init__(self, config: DictConfig) -> None:
super().__init__(config)
Hiraishin provides a CLI command that automatically generates a configuration file based on type annotations.
For example, if ToyModel
is defined in models.py
(i.e., from models import ToyModel
can be executed in the code), then the following command will generate the configuration file automatically.
$ hiraishin generate model.ToyModel --output_dir config/model
The config has been generated! --> config/model/ToyModel.yaml
Let's take a look at the generated file.
_target_: models.ToyModel
_recursive_: false
config:
networks:
net:
args:
_target_: torch.nn.Linear
out_features: ???
in_features: ???
weights:
initializer: null
path: null
losses:
criterion:
args:
_target_: torch.nn.CrossEntropyLoss
weight: 1.0
optimizers:
optimizer:
args:
_target_: torch.optim.Adam
params:
- ???
scheduler:
args:
_target_: torch.optim.lr_scheduler.ExponentialLR
gamma: ???
interval: epoch
frequency: 1
strict: true
monitor: null
tokenizer:
_target_: MyTokenizer
n_classes: ???
First of all, it is compliant with the instantiation by hydra.utils.instantiate
.
The positional arguments are filled with ???
that indicates mandatory parameters. They should be overridden by the values you want to set.
The rest of model definition is only defining your training routine along with the style of PyTorch Lightning.
class ToyModel(BaseModel):
...
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def training_step(self, batch, *args, **kwargs) -> torch.Tensor:
x, target = batch
pred = self.forward(x)
loss = self.criterion(pred, target)
self.log('loss/train', loss)
return loss
The defined model can be instantiated from configuration file. Let's train your models!
from hydra.utils import inatantiate
from omegeconf import OmegaConf
def app():
...
config = OmegaConf.load('config/model/toy.yaml')
model = inatantiate(config)
print(model)
# ToyModel(
# (net): Linear(in_features=1, out_features=1, bias=True)
# (criterion): CrossEntropyLoss()
# )
trainer.fit(model, ...)
You can easily load trained models by using the checkpoints generated by PyTorch Lightning's standard features. Let's test your models!
from hiraishin.utils import load_from_checkpoint
model = load_from_checkpoint('path/to/model.ckpt')
print(model)
# ToyModel(
# (net): Linear(in_features=1, out_features=1, bias=True)
# (criterion): CrossEntropyLoss()
# )
Hiraishin is licensed under the Apache License, Version 2.0. See LICENSE for the full license text.