Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Refactor huggingface/task/text_classification #35

Merged
merged 15 commits into from
Jan 12, 2021

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Jan 10, 2021

Follow-up on #24
Related to #37

TODO:

  • Fix: ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index); IndexError: Target 4 is out of bounds.
  • Discuss conf structure:
    • Removed conf/backbone, we can always add it if useful
    • Added conf/module. This links the module used by the task with the module arguments (optimizer, scheduler...)
    • conf/task/huggingface/*.yaml now contains the backbone definition and a reference to the module config
    • conf/tokenizer: uses the current task backbone by default

conf/trainer/default.yaml Outdated Show resolved Hide resolved
@SeanNaren
Copy link
Contributor

@carmocca usually:

ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index); IndexError: Target 4 is out of bounds.

means that the num classes is not being passed correctly to the input config when making the model.

train.py Outdated Show resolved Hide resolved
@carmocca carmocca marked this pull request as ready for review January 10, 2021 20:45
train.py Outdated Show resolved Hide resolved
Co-authored-by: SeanNaren <sean@grid.ai>
@SeanNaren
Copy link
Contributor

I just realised I don't see any specific code around Hydra for type-checking. Any suggestions for that? Would we need to introduce the config store?

I also don't know how defaults handling would be; i.e if a user adds a default to their dataclass, you shouldn't need to specify this in your Hydra config for example

@carmocca
Copy link
Contributor Author

carmocca commented Jan 11, 2021

I just realised I don't see any specific code around Hydra for type-checking. Any suggestions for that? Would we need to introduce the config store?

I guess. Currently is all duck-typing without any type-checking.

I'd have to play around with Hydra. Looking at https://hydra.cc/docs/tutorials/structured_config/minimal_example/

Is the ConfigStore a global instance shared for the process? Can we create it and register the dataclass within each Instantiator function? If so, we could add _target_ attributes to our config to instantiate their corresponding dataclasses as soon as possible instead of passing DictConfigs around. Would need some experimenting.

I also don't know how defaults handling would be; i.e if a user adds a default to their dataclass, you shouldn't need to specify this in your Hydra config for example

I totally agree. Ideally, the dataclasses would be the source of truth and the config files instances of them

Do you want to try to solve these challenges here or in a separate PR?

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great cleanup !

conf/dataset/text_classification/emotion.yaml Show resolved Hide resolved
lightning_transformers/core/data.py Show resolved Hide resolved
train.py Show resolved Hide resolved
lightning_transformers/huggingface/model.py Outdated Show resolved Hide resolved
train.py Show resolved Hide resolved
@SeanNaren
Copy link
Contributor

Do you want to try to solve these challenges here or in a separate PR?

Up to you, either way we should solve this ASAP to ensure that the foundation is neat before heading onwards.

I'd have to play around with Hydra. Looking at https://hydra.cc/docs/tutorials/structured_config/minimal_example/

Is the ConfigStore a global instance shared for the process? Can we create it and register the dataclass within each Instantiator function? If so, we could add _target_ attributes to our config to instantiate their corresponding dataclasses as soon as possible instead of passing DictConfigs around. Would need some experimenting.

It's not as simple hence why I wasn't sure the dataclass approach would work:

We would define in our ConfigStore (which I think is a global instance as you suggested):

cs = ConfigStore.instance()
# Registering the Config class with the name 'config'.
cs.store(name="hf_data_config", node=HFTransformerDataConfig)

And then within our dataset conf file:

# @package dataset
defaults:
    - /hf_data_config # Defines the structure config defaults we'd like to use
_target_: lightning_transformers.huggingface.task.text_classification.TextClassificationDataModule
cfg:
  dataset_name: emotion
  train_file: ''
  validation_file: ''
  padding: 'max_length'
  truncation: 'only_first'
  preprocessing_num_workers: 8
  load_from_cache_file: True
  max_length: 128

  # torch data-loader specific arguments
  batch_size: ${training.batch_size}
  num_workers: ${training.num_workers}

But this doesn't make sense. What this means is that you'd like to popular your config with defaults from the structure config "hf_data_config". This would throw an error saying _target_ or cfg is not defined in the data class.

If you were then to modify the dataclass to:

@dataclass
class HFTransformerDataConfig:
    _target_: str
    dataset_name: Optional[str] = None
    train_val_split: Optional[int] = None
    train_file: Optional[str] = None
    validation_file: Optional[str] = None
    padding: str = "max_length"
    truncation: str = "only_first"
    max_length: int = 128
    preprocessing_num_workers: int = 8
    load_from_cache_file: bool = True
    dataset_config_name: Optional[str] = None

and the config to not contain the original cfg:

# @package dataset
defaults:
    - /hf_data_config
_target_: lightning_transformers.huggingface.task.text_classification.TextClassificationDataModule

dataset_name: emotion
train_file: ''
validation_file: ''
padding: 'max_length'
truncation: 'only_first'
preprocessing_num_workers: 8
load_from_cache_file: True
max_length: 128

# torch data-loader specific arguments
batch_size: ${training.batch_size}
num_workers: ${training.num_workers}

This still won't work, as you'd be calling instantiate, sending all these parameters as positional/kwarg arguments, not as a config.

I think this would work however:

@dataclass
class ActualConfig:
dataset_name: Optional[str] = None
    train_val_split: Optional[int] = None
    train_file: Optional[str] = None
    validation_file: Optional[str] = None
    padding: str = "max_length"
    truncation: str = "only_first"
    max_length: int = 128
    preprocessing_num_workers: int = 8
    load_from_cache_file: bool = True
    dataset_config_name: Optional[str] = None

@dataclass
class HFTransformerDataConfig:
    _target_: str
    cfg: ActualConfig
# @package dataset
defaults:
    - /hf_data_config
_target_: lightning_transformers.huggingface.task.text_classification.TextClassificationDataModule
cfg:
  dataset_name: emotion
  train_file: ''
  validation_file: ''
  padding: 'max_length'
  truncation: 'only_first'
  preprocessing_num_workers: 8
  load_from_cache_file: True
  max_length: 128

  # torch data-loader specific arguments
  batch_size: ${training.batch_size}
  num_workers: ${training.num_workers}

But now we have to define an additional dataclass.

@SeanNaren
Copy link
Contributor

Hey @carmocca I landed here after hacking around with the minimal example that Hydra offer:

from dataclasses import dataclass

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf


@dataclass
class DBConfig:
    driver: str = MISSING
    port: int = MISSING
    host: str = "localhost"


@dataclass
class MySQLConfig(DBConfig):
    user: str = MISSING
    password: str = MISSING
    driver: str = "mysql"
    port: int = 3306


@dataclass
class PostGreSQLConfig(DBConfig):
    user: str = MISSING
    password: str = MISSING
    driver: str = "postgresql"
    port: int = 5432
    timeout: int = 10


class PostGreSQLClass:
    def __init__(self, cfg: PostGreSQLConfig):
        self.cfg = cfg


class MySqlClass:
    def __init__(self, cfg: MySQLConfig):
        self.cfg = cfg


@dataclass
class Config:
    db: DBConfig = MISSING


cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
cs.store(
    group="db",
    name="base_mysql",
    node={'_target_': '__main__.MySqlClass', 'cfg': MySQLConfig}
)
cs.store(
    group="db",
    name="base_postgresql",
    node={'_target': '__main__.PostGreSQLClass', 'cfg': PostGreSQLConfig}
)


@hydra.main(config_path="conf", config_name="config")
def my_app(cfg: Config) -> None:
    print(OmegaConf.to_yaml(cfg))
    print(hydra.utils.instantiate(cfg.db))


if __name__ == "__main__":
    my_app()

This is the primary thing, adding target + cfg explicitly in the config store.

cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
cs.store(
    group="db",
    name="base_mysql",
    node={'_target_': '__main__.MySqlClass', 'cfg': MySQLConfig}
)
cs.store(
    group="db",
    name="base_postgresql",
    node={'_target': '__main__.PostGreSQLClass', 'cfg': PostGreSQLConfig}
)

It allows you to add as a default to the conf:

defaults:
    - base_postgresql

And adds to your conf the dataclass defaults.

I don't know if this is the cleanest way to do it, but I've pinged Omry (the Hydra author) to see if there is a cleaner way to do this!

@SeanNaren
Copy link
Contributor

Really nice work @carmocca! On average I realised the instantiator is necessary (alternative being dropping the instantiator and hard coding hydra.utils.instantiate everywhere). Also like how it's injected into the codebase as a dependency.

@carmocca carmocca requested a review from tchaton January 12, 2021 11:05
@SeanNaren SeanNaren merged commit 0166b49 into master Jan 12, 2021
@SeanNaren SeanNaren deleted the refactor-text-classification branch January 12, 2021 12:20
@SeanNaren SeanNaren mentioned this pull request Jan 14, 2021
7 tasks
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants