forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RoBERTa for RLHF Stage 2 & 3 (test)
RoBERTa for RLHF Stage 2 & 3 (still in testing)
- Loading branch information
1 parent
1e1b9d2
commit 06741d8
Showing
6 changed files
with
134 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .roberta_actor import RoBERTaActor | ||
from .roberta_critic import RoBERTaCritic | ||
from .roberta_rm import RoBERTaRM | ||
|
||
__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM'] |
35 changes: 35 additions & 0 deletions
35
applications/ChatGPT/chatgpt/models/roberta/roberta_actor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Optional | ||
|
||
from transformers.models.roberta.configuration_roberta import RobertaConfig | ||
from transformers.models.roberta.modeling_roberta import RobertaForCausalLM | ||
|
||
from ..base import Actor | ||
|
||
class RoBERTaActor(Actor): | ||
""" | ||
RoBERTa Actor model. | ||
Args: | ||
pretrained (str): Pretrained model name or path. | ||
config (RoBERTaConfig): Model config. | ||
checkpoint (bool): Enable gradient checkpointing. | ||
lora_rank (int): Rank of the low-rank approximation. | ||
lora_train_bias (str): LoRA bias training mode. | ||
""" | ||
|
||
|
||
def __init__(self, | ||
pretrained: Optional[str] = None, | ||
config: Optional[RobertaConfig] = None, | ||
checkpoint: bool = False, | ||
lora_rank: int = 0, | ||
lora_train_bias: str = 'none') -> None: | ||
if pretrained is not None: | ||
model = RobertaForCausalLM.from_pretrained(pretrained) | ||
elif config is not None: | ||
model = RobertaForCausalLM(config) | ||
else: | ||
model = RobertaForCausalLM(RobertaConfig()) | ||
if checkpoint: | ||
model.gradient_checkpointing_enable() | ||
super().__init__(model, lora_rank, lora_train_bias) |
38 changes: 38 additions & 0 deletions
38
applications/ChatGPT/chatgpt/models/roberta/roberta_critic.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import Optional | ||
|
||
import torch.nn as nn | ||
from transformers.models.roberta.configuration_roberta import RobertaConfig | ||
from transformers.models.roberta.modeling_roberta import RobertaModel | ||
|
||
from ..base import Critic | ||
|
||
|
||
class RoBERTaCritic(Critic): | ||
""" | ||
RoBERTa Critic model. | ||
Args: | ||
pretrained (str): Pretrained model name or path. | ||
config (RoBERTa Config): Model config. | ||
checkpoint (bool): Enable gradient checkpointing. | ||
lora_rank (int): Rank of the low-rank approximation. | ||
lora_train_bias (str): LoRA bias training mode. | ||
""" | ||
|
||
def __init__(self, | ||
pretrained: Optional[str] = None, | ||
config: Optional[RobertaConfig] = None, | ||
checkpoint: bool = False, | ||
lora_rank: int = 0, | ||
lora_train_bias: str = 'none', | ||
**kwargs) -> None: | ||
if pretrained is not None: | ||
model = RobertaModel.from_pretrained(pretrained) | ||
elif config is not None: | ||
model = RobertaModel(config) | ||
else: | ||
model = RobertaModel(RobertaConfig()) | ||
if checkpoint: | ||
model.gradient_checkpointing_enable() | ||
value_head = nn.Linear(model.config.hidden_size, 1) | ||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Optional | ||
|
||
import torch.nn as nn | ||
from transformers import RobertaConfig, RobertaModel | ||
|
||
|
||
from ..base import RewardModel | ||
|
||
|
||
class RoBERTaRM(RewardModel): | ||
""" | ||
RoBERTa Reward model. | ||
Args: | ||
pretrained (str): Pretrained model name or path. | ||
config (RoBERTaConfig): Model config. | ||
checkpoint (bool): Enable gradient checkpointing. | ||
lora_rank (int): Rank of the low-rank approximation. | ||
lora_train_bias (str): LoRA bias training mode. | ||
""" | ||
|
||
def __init__(self, | ||
pretrained: Optional[str] = None, | ||
config: Optional[RobertaConfig] = None, | ||
checkpoint: bool = False, | ||
lora_rank: int = 0, | ||
lora_train_bias: str = 'none') -> None: | ||
if pretrained is not None: | ||
model = RobertaModel.from_pretrained(pretrained) | ||
elif config is not None: | ||
model = RobertaModel(config) | ||
else: | ||
model = RobertaModel(RobertaConfig()) | ||
if checkpoint: | ||
model.gradient_checkpointing_enable() | ||
|
||
value_head = nn.Linear(model.config.hidden_size, 1) | ||
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1)) | ||
super().__init__(model, value_head, lora_rank, lora_train_bias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters