-
Notifications
You must be signed in to change notification settings - Fork 77
Refactor huggingface/task/text_classification #35
Conversation
@carmocca usually:
means that the num classes is not being passed correctly to the input config when making the model. |
Co-authored-by: SeanNaren <sean@grid.ai>
I just realised I don't see any specific code around Hydra for type-checking. Any suggestions for that? Would we need to introduce the config store? I also don't know how defaults handling would be; i.e if a user adds a default to their dataclass, you shouldn't need to specify this in your Hydra config for example |
I guess. Currently is all duck-typing without any type-checking. I'd have to play around with Hydra. Looking at https://hydra.cc/docs/tutorials/structured_config/minimal_example/ Is the
I totally agree. Ideally, the dataclasses would be the source of truth and the config files instances of them Do you want to try to solve these challenges here or in a separate PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great cleanup !
lightning_transformers/huggingface/task/text_classification/model.py
Outdated
Show resolved
Hide resolved
Up to you, either way we should solve this ASAP to ensure that the foundation is neat before heading onwards.
It's not as simple hence why I wasn't sure the dataclass approach would work: We would define in our ConfigStore (which I think is a global instance as you suggested): cs = ConfigStore.instance()
# Registering the Config class with the name 'config'.
cs.store(name="hf_data_config", node=HFTransformerDataConfig) And then within our dataset conf file: # @package dataset
defaults:
- /hf_data_config # Defines the structure config defaults we'd like to use
_target_: lightning_transformers.huggingface.task.text_classification.TextClassificationDataModule
cfg:
dataset_name: emotion
train_file: ''
validation_file: ''
padding: 'max_length'
truncation: 'only_first'
preprocessing_num_workers: 8
load_from_cache_file: True
max_length: 128
# torch data-loader specific arguments
batch_size: ${training.batch_size}
num_workers: ${training.num_workers} But this doesn't make sense. What this means is that you'd like to popular your config with defaults from the structure config "hf_data_config". This would throw an error saying If you were then to modify the dataclass to: @dataclass
class HFTransformerDataConfig:
_target_: str
dataset_name: Optional[str] = None
train_val_split: Optional[int] = None
train_file: Optional[str] = None
validation_file: Optional[str] = None
padding: str = "max_length"
truncation: str = "only_first"
max_length: int = 128
preprocessing_num_workers: int = 8
load_from_cache_file: bool = True
dataset_config_name: Optional[str] = None and the config to not contain the original # @package dataset
defaults:
- /hf_data_config
_target_: lightning_transformers.huggingface.task.text_classification.TextClassificationDataModule
dataset_name: emotion
train_file: ''
validation_file: ''
padding: 'max_length'
truncation: 'only_first'
preprocessing_num_workers: 8
load_from_cache_file: True
max_length: 128
# torch data-loader specific arguments
batch_size: ${training.batch_size}
num_workers: ${training.num_workers} This still won't work, as you'd be calling instantiate, sending all these parameters as positional/kwarg arguments, not as a config. I think this would work however: @dataclass
class ActualConfig:
dataset_name: Optional[str] = None
train_val_split: Optional[int] = None
train_file: Optional[str] = None
validation_file: Optional[str] = None
padding: str = "max_length"
truncation: str = "only_first"
max_length: int = 128
preprocessing_num_workers: int = 8
load_from_cache_file: bool = True
dataset_config_name: Optional[str] = None
@dataclass
class HFTransformerDataConfig:
_target_: str
cfg: ActualConfig # @package dataset
defaults:
- /hf_data_config
_target_: lightning_transformers.huggingface.task.text_classification.TextClassificationDataModule
cfg:
dataset_name: emotion
train_file: ''
validation_file: ''
padding: 'max_length'
truncation: 'only_first'
preprocessing_num_workers: 8
load_from_cache_file: True
max_length: 128
# torch data-loader specific arguments
batch_size: ${training.batch_size}
num_workers: ${training.num_workers} But now we have to define an additional dataclass. |
Hey @carmocca I landed here after hacking around with the minimal example that Hydra offer: from dataclasses import dataclass
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf
@dataclass
class DBConfig:
driver: str = MISSING
port: int = MISSING
host: str = "localhost"
@dataclass
class MySQLConfig(DBConfig):
user: str = MISSING
password: str = MISSING
driver: str = "mysql"
port: int = 3306
@dataclass
class PostGreSQLConfig(DBConfig):
user: str = MISSING
password: str = MISSING
driver: str = "postgresql"
port: int = 5432
timeout: int = 10
class PostGreSQLClass:
def __init__(self, cfg: PostGreSQLConfig):
self.cfg = cfg
class MySqlClass:
def __init__(self, cfg: MySQLConfig):
self.cfg = cfg
@dataclass
class Config:
db: DBConfig = MISSING
cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
cs.store(
group="db",
name="base_mysql",
node={'_target_': '__main__.MySqlClass', 'cfg': MySQLConfig}
)
cs.store(
group="db",
name="base_postgresql",
node={'_target': '__main__.PostGreSQLClass', 'cfg': PostGreSQLConfig}
)
@hydra.main(config_path="conf", config_name="config")
def my_app(cfg: Config) -> None:
print(OmegaConf.to_yaml(cfg))
print(hydra.utils.instantiate(cfg.db))
if __name__ == "__main__":
my_app() This is the primary thing, adding target + cfg explicitly in the config store. cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
cs.store(
group="db",
name="base_mysql",
node={'_target_': '__main__.MySqlClass', 'cfg': MySQLConfig}
)
cs.store(
group="db",
name="base_postgresql",
node={'_target': '__main__.PostGreSQLClass', 'cfg': PostGreSQLConfig}
) It allows you to add as a default to the conf:
And adds to your conf the dataclass defaults. I don't know if this is the cleanest way to do it, but I've pinged Omry (the Hydra author) to see if there is a cleaner way to do this! |
Really nice work @carmocca! On average I realised the instantiator is necessary (alternative being dropping the instantiator and hard coding hydra.utils.instantiate everywhere). Also like how it's injected into the codebase as a dependency. |
Follow-up on #24
Related to #37
TODO:
ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index); IndexError: Target 4 is out of bounds.