Skip to content

Commit

Permalink
feat: Support in RT Trainer for multiple entities (#166)
Browse files Browse the repository at this point in the history
* feat: Support in RT Trainer for multiple entities

Solving #143 by expanding the Regression Transformer trainer to support multi-entity discriminations, i.e., support the multientity_cg collator from the RT repo.

Signed-off-by: Nicolai Ree <nicolairee@hotmail.com>

* test: Add tests for multientity RT trainer

Signed-off-by: Nicolai Ree <nicolairee@hotmail.com>
Co-authored-by: jannisborn <jannis.born@gmx.de>
  • Loading branch information
NicolaiRee and jannisborn committed Nov 7, 2022
1 parent d25c38e commit efacf52
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 10 deletions.
34 changes: 34 additions & 0 deletions src/gt4sd/training_pipelines/regression_transformer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,40 @@ class RegressionTransformerTrainingArguments(
"generation task. Defaults to True."
},
)
cg_collator: str = field(
default="vanilla_cg",
metadata={
"help": "The collator class. Following options are implemented: "
"'vanilla_cg': Collator class that does not mask the properties but anything else as a regular DataCollatorForPermutationLanguageModeling. Can optionally replace the properties with sampled values. "
"NOTE: This collator can deal with multiple properties. "
"'multientity_cg': A training collator the conditional-generation task that can handle multiple entities. "
"Default: vanilla_cg."
},
)
entity_to_mask: int = field(
default=-1,
metadata={
"help": "Only applies if `cg_collator='multientity_cg'`. The entity that is being masked during training. 0 corresponds to first entity and so on. -1 corresponds to "
"a random sampling scheme where the entity-to-be-masked is determined "
"at runtime in the collator. NOTE: If 'mask_entity_separator' is true, "
"this argument will not have any effect. Defaults to -1."
},
)
entity_separator_token: str = field(
default=".",
metadata={
"help": "Only applies if `cg_collator='multientity_cg'`.The token that is used to separate "
"entities in the input. Defaults to '.' (applicable to SMILES & SELFIES)"
},
)
mask_entity_separator: bool = field(
default=False,
metadata={
"help": "Only applies if `cg_collator='multientity_cg'`. Whether or not the entity separator token can be masked. If True, *all** textual tokens can be masked and we "
"the collator behaves like the `vanilla_cg ` even though it is a `multientity_cg`. If False, the exact behavior "
"depends on the entity_to_mask argument. Defaults to False."
},
)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,15 @@ def train( # type: ignore
num_tokens_to_mask=None,
mask_token_order=None,
)
alternating_collator = TRAIN_COLLATORS["vanilla_cg"](
alternating_collator = TRAIN_COLLATORS[training_args["cg_collator"]](
tokenizer=self.tokenizer,
property_tokens=self.properties,
plm_probability=training_args["plm_probability"],
max_span_length=training_args["max_span_length"],
do_sample=False,
entity_separator_token=training_args["entity_separator_token"],
mask_entity_separator=training_args["mask_entity_separator"],
entity_to_mask=training_args["entity_to_mask"],
)

# Initialize our Trainer
Expand Down
4 changes: 4 additions & 0 deletions src/gt4sd/training_pipelines/regression_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,17 @@ def get_train_config_dict(
return {
"alternate_steps": training_args["alternate_steps"],
"reset_training_loss": True,
"cg_collator": training_args["cg_collator"],
"cc_loss": training_args["cc_loss"],
"property_tokens": list(properties),
"cg_collator_params": {
"do_sample": False,
"property_tokens": list(properties),
"plm_probability": training_args["plm_probability"],
"max_span_length": training_args["max_span_length"],
"entity_separator_token": training_args["entity_separator_token"],
"mask_entity_separator": training_args["mask_entity_separator"],
"entity_to_mask": training_args["entity_to_mask"],
},
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
text,prop0,prop1,prop2,prop3,prop4,prop5,prop6,prop7
CCCOC(=O)C(c1ccccc1C(c1ccccc1)c1ccccc1)C(F)(F)S(=O)(=O)[O-],460.115601244001,-6.9,2.63,-9.04,2.11,0.85,1.06,2800.0
C=C(COCCC(F)(F)S(=O)(=O)[O-])C(=O)OC1CC(C(O)(C(F)(F)F)C(F)(F)F)CC(C(O)(C(F)(F)F)C(F)(F)F)C1,660.049883616001,-2.64,2.84,-10.8,3.54,0.63,0.51,3946.0
O=C([O-])CC=Cc1ccccc1,148.052429496,-2.44,0.66,-7.51,4.04,0.61,0.6,2761.0
CCCCCCCS(=O)(=O)OC1CC2CC1CC2C(F)(F)C(F)(F)S(=O)(=O)[O-],440.09504336800103,-2.78,2.89,-5.99,4.41,1.14,0.73,5734.0
CO=S(=O)([O-])c1ccc(/N=N/c2ccc(Nc3ccccc3)cc2)cc1,353.08341234,-3.79,2.5,-7.26,4.12,0.63,1.33,5743.0
CO=S(=O)([O-])c1cccc(C(F)(F)F)c1,225.99114968,0.09,-5.12,-9.44,-1.08,0.65,0.81,450.0
CCCOC(=O)C(c1ccccc1C(c1ccccc1)c1ccccc1)C(F)(F)S(=O)(=O)[O-].CCO,460.115601244001,-6.9,2.63,-9.04,2.11,0.85,1.06,2800.0
C=C(COCCC(F)(F)S(=O)(=O)[O-])C(=O)OC1CC(C(O)(C(F)(F)F)C(F)(F)F)CC(C(O)(C(F)(F)F)C(F)(F)F)C1.CC,660.049883616001,-2.64,2.84,-10.8,3.54,0.63,0.51,3946.0
[Cs+].O=C([O-])CC=Cc1ccccc1,148.052429496,-2.44,0.66,-7.51,4.04,0.61,0.6,2761.0
[Li+].CCCCCCCS(=O)(=O)OC1CC2CC1CC2C(F)(F)C(F)(F)S(=O)(=O)[O-],440.09504336800103,-2.78,2.89,-5.99,4.41,1.14,0.73,5734.0
CO=S(=O)([O-])c1ccc(/N=N/c2ccc(Nc3ccccc3)cc2)cc1.[H-],353.08341234,-3.79,2.5,-7.26,4.12,0.63,1.33,5743.0
CO=S(=O)([O-])c1cccc(C(F)(F)F)c1.[H-],225.99114968,0.09,-5.12,-9.44,-1.08,0.65,0.81,450.0
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def test_train():

# Test the pretrained QED model
config["model_args"]["model_path"] = mol_path
config["dataset_args"]["train_data_path"] = raw_path
config["dataset_args"]["test_data_path"] = raw_path
config["dataset_args"]["train_data_path"] = str(raw_path)
config["dataset_args"]["test_data_path"] = str(raw_path)
config["dataset_args"]["augment"] = 2
input_config = combine_defaults_and_user_args(config)
test_pipeline.train(**input_config)
Expand All @@ -180,7 +180,18 @@ def test_train():
config["model_args"]["config_name"] = f_name
del config["model_args"]["model_path"]
config["model_args"]["tokenizer_name"] = mol_path
config["dataset_args"]["data_path"] = raw_path
config["dataset_args"]["data_path"] = str(raw_path)
config["dataset_args"]["augment"] = 2
input_config = combine_defaults_and_user_args(config)
test_pipeline.train(**input_config)

# Test the multientity collator (finetuning QED model)
config["training_args"].update(
{
"cg_collator": "multientity_cg",
"entity_to_mask": 0,
"entity_separator_token": ".",
}
)
input_config = combine_defaults_and_user_args(config)
test_pipeline.train(**input_config)

0 comments on commit efacf52

Please sign in to comment.