Skip to content

Commit

Permalink
Add RoBERTa for RLHF Stage 2 & 3 (test)
Browse files Browse the repository at this point in the history
RoBERTa for RLHF Stage 2 & 3 (still in testing)
  • Loading branch information
Camille7777 committed Mar 22, 2023
1 parent 1e1b9d2 commit 06741d8
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 4 deletions.
5 changes: 5 additions & 0 deletions applications/ChatGPT/chatgpt/models/roberta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .roberta_actor import RoBERTaActor
from .roberta_critic import RoBERTaCritic
from .roberta_rm import RoBERTaRM

__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM']
35 changes: 35 additions & 0 deletions applications/ChatGPT/chatgpt/models/roberta/roberta_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional

from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.models.roberta.modeling_roberta import RobertaForCausalLM

from ..base import Actor

class RoBERTaActor(Actor):
"""
RoBERTa Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (RoBERTaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""


def __init__(self,
pretrained: Optional[str] = None,
config: Optional[RobertaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = RobertaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = RobertaForCausalLM(config)
else:
model = RobertaForCausalLM(RobertaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
38 changes: 38 additions & 0 deletions applications/ChatGPT/chatgpt/models/roberta/roberta_critic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Optional

import torch.nn as nn
from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.models.roberta.modeling_roberta import RobertaModel

from ..base import Critic


class RoBERTaCritic(Critic):
"""
RoBERTa Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (RoBERTa Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self,
pretrained: Optional[str] = None,
config: Optional[RobertaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None:
model = RobertaModel.from_pretrained(pretrained)
elif config is not None:
model = RobertaModel(config)
else:
model = RobertaModel(RobertaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
39 changes: 39 additions & 0 deletions applications/ChatGPT/chatgpt/models/roberta/roberta_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Optional

import torch.nn as nn
from transformers import RobertaConfig, RobertaModel


from ..base import RewardModel


class RoBERTaRM(RewardModel):
"""
RoBERTa Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (RoBERTaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self,
pretrained: Optional[str] = None,
config: Optional[RobertaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = RobertaModel.from_pretrained(pretrained)
elif config is not None:
model = RobertaModel(config)
else:
model = RobertaModel(RobertaConfig())
if checkpoint:
model.gradient_checkpointing_enable()

value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
11 changes: 9 additions & 2 deletions applications/ChatGPT/examples/train_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.models.opt import OPTActor, OPTCritic
from chatgpt.models.roberta import RoBERTaActor, RoBERTaCritic
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import SaveCheckpoint
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from colossalai.nn.optimizer import HybridAdam
Expand Down Expand Up @@ -46,6 +47,9 @@ def main(args):
elif args.model == 'opt':
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'roberta':
actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = RoBERTaCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{args.model}"')

Expand All @@ -69,6 +73,9 @@ def main(args):
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
tokenizer.pad_token = tokenizer.eos_token
else:
raise ValueError(f'Unsupported model "{args.model}"')

Expand Down Expand Up @@ -128,7 +135,7 @@ def main(args):
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
Expand Down
10 changes: 8 additions & 2 deletions applications/ChatGPT/examples/train_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from chatgpt.models.bloom import BLOOMRM
from chatgpt.models.gpt import GPTRM
from chatgpt.models.opt import OPTRM
from chatgpt.models.roberta import RoBERTaRM
from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset
from random import randint
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from colossalai.nn.optimizer import HybridAdam
Expand All @@ -39,6 +40,8 @@ def train(args):
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'roberta':
model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{args.model}"')

Expand All @@ -54,6 +57,9 @@ def train(args):
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
tokenizer.pad_token = tokenizer.eos_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
max_len = args.max_len
Expand Down Expand Up @@ -119,7 +125,7 @@ def train(args):
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'roberta'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
Expand Down

0 comments on commit 06741d8

Please sign in to comment.