Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions configs/alpaca/alpaca_standford_llama-7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
from .._base_.schedules.guanaco import * # noqa: F401,F403

pretrained_model_name_or_path = '/nvme/share_data/llama-7b'
model = dict(
type=SupervisedFinetune,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path))

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
use_fast=False,
padding_side='right')

model = dict(
type=SupervisedFinetune,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path),
tokenizer=tokenizer)

train_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405
12 changes: 7 additions & 5 deletions configs/alpaca/alpaca_standford_llama-7b_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
from .._base_.schedules.guanaco_deepspeed import * # noqa: F401,F403

pretrained_model_name_or_path = '/nvme/share_data/llama-7b'
model = dict(
type=SupervisedFinetune,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path))

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
use_fast=False,
padding_side='right')

model = dict(
type=SupervisedFinetune,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path),
tokenizer=tokenizer)

train_dataloader['collate_fn']['tokenizer'] = tokenizer # noqa: F405
16 changes: 9 additions & 7 deletions configs/alpaca/alpaca_standford_llama-7b_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
from .._base_.schedules.guanaco import * # noqa: F401,F403

pretrained_model_name_or_path = '/share/gaojianfei/merged_chinese_lora_7b'

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
use_fast=False,
padding_side='right')

model = dict(
type=SupervisedQloraFinetune,
llm=dict(
Expand All @@ -31,12 +38,7 @@
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
use_fast=False,
padding_side='right')
task_type='CAUSAL_LM'),
tokenizer=tokenizer)

train_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405
22 changes: 15 additions & 7 deletions configs/guanaco/gunaco_llama-7b_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)

from mmchat.engine import LogSampleHook
from mmchat.models import SupervisedQloraFinetune

with read_base():
Expand All @@ -14,6 +15,13 @@
from .._base_.schedules.guanaco import * # noqa: F401,F403

pretrained_model_name_or_path = '/nvme/share_data/llama-7b'

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
use_fast=False,
padding_side='right')

model = dict(
type=SupervisedQloraFinetune,
data_preprocessor=dict(type=BaseDataPreprocessor),
Expand All @@ -36,17 +44,17 @@
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
use_fast=False,
padding_side='right')
task_type='CAUSAL_LM'),
tokenizer=tokenizer)

train_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405
val_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405
test_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405

val_evaluator['tokenizer'] = tokenizer # noqa: F405
test_evaluator['tokenizer'] = tokenizer # noqa: F405

custom_hooks = [dict(
type=LogSampleHook,
tokenizer=tokenizer,
)]
18 changes: 10 additions & 8 deletions configs/guanaco/gunaco_llama-7b_qlora_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
from .._base_.schedules.guanaco_deepspeed import * # noqa: F401,F403

pretrained_model_name_or_path = '/nvme/share_data/llama-7b'

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
use_fast=False,
padding_side='right')

model = dict(
type=SupervisedQloraFinetune,
data_preprocessor=dict(type=BaseDataPreprocessor),
Expand All @@ -36,17 +43,12 @@
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
use_fast=False,
padding_side='right')
task_type='CAUSAL_LM'),
tokenizer=tokenizer)

train_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405
val_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405
test_dataloader['dataset']['tokenizer'] = tokenizer # noqa: F405

val_evaluator['tokenizer'] = tokenizer # noqa: F405
test_evaluator['tokenizer'] = tokenizer # noqa: F405
test_evaluator['tokenizer'] = tokenizer # noqa: F405
13 changes: 12 additions & 1 deletion mmchat/datasets/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import partial

from mmengine.config.lazy import LazyObject

from mmchat.registry import DATASETS, TOKENIZER
from .utils import Concatenator, encode_fn

Expand All @@ -17,7 +19,16 @@ def process_hf_dataset(dataset,
dataset = DATASETS.build(dataset)
if isinstance(map_fn, str):
map_fn = eval(map_fn)
dataset = dataset.map(map_fn, remove_columns=remove_columns)
if isinstance(map_fn, list):
assert all(
[callable(fn) and isinstance(fn, LazyObject) for fn in map_fn])
for fn in map_fn[:-1]:
fn = fn.build()
dataset = dataset.map(fn)
dataset = dataset.map(
map_fn[-1].build(), remove_columns=remove_columns)
elif map_fn is not None:
dataset = dataset.map(map_fn, remove_columns=remove_columns)
for old, new in rename_maps:
dataset = dataset.rename_column(old, new)
tokenizer = TOKENIZER.build(tokenizer)
Expand Down
10 changes: 6 additions & 4 deletions mmchat/datasets/map_fns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .alpaca_map_fn import alpaca_map_fn
from .alpaca_zh_map_fn import alpaca_zh_map_fn
from .oasst1_map_fn import oasst1_map_fn
from .dataset_map_fn import alpaca_map_fn, alpaca_zh_map_fn, oasst1_map_fn
from .model_map_fn import internlm_map_fn, llama2_map_fn

__all__ = ['alpaca_map_fn', 'alpaca_zh_map_fn', 'oasst1_map_fn']
__all__ = [
'alpaca_map_fn', 'alpaca_zh_map_fn', 'oasst1_map_fn', 'internlm_map_fn',
'llama2_map_fn'
]
5 changes: 5 additions & 0 deletions mmchat/datasets/map_fns/dataset_map_fn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .alpaca_map_fn import alpaca_map_fn
from .alpaca_zh_map_fn import alpaca_zh_map_fn
from .oasst1_map_fn import oasst1_map_fn

__all__ = ['alpaca_map_fn', 'alpaca_zh_map_fn', 'oasst1_map_fn']
4 changes: 4 additions & 0 deletions mmchat/datasets/map_fns/model_map_fn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .internlm_map_fn import internlm_map_fn
from .llama2_map_fn import llama2_map_fn

__all__ = ['internlm_map_fn', 'llama2_map_fn']
8 changes: 8 additions & 0 deletions mmchat/datasets/map_fns/model_map_fn/internlm_map_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def internlm_map_fn(example):
user = '<|User|>'
eoh = '<eoh>'
eoa = '<eoa>' # noqa:F841
assistant = '<|Bot|>'
instruction = example.get('input', '')
prompt = f'<BOS>{user}:{instruction}{eoh}\n{assistant}:'
return {'input': prompt}
15 changes: 15 additions & 0 deletions mmchat/datasets/map_fns/model_map_fn/llama2_map_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
def llama2_map_fn(example):
B_INST, E_INST = '[INST]', '[/INST]'
B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n'

DEFAULT_SYSTEM_PROMPT = \
'You are a helpful, respectful and honest assistant. Always answer ' \
'as helpfully as possible, while being safe. Your answers should not' \
' include any harmful, unethical, racist, sexist, toxic, dangerous, ' \
'or illegal content. Please ensure that your responses are socially ' \
'unbiased and positive in nature.'

instruction = example.get('input', '')
prompt = f'<BOS>{B_INST} {B_SYS} {DEFAULT_SYSTEM_PROMPT} {E_SYS}' \
f'{instruction} {E_INST}'
return {'input': prompt}
4 changes: 2 additions & 2 deletions mmchat/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .hooks import SampleGenerateHook
from .hooks import LogSampleHook, SampleGenerateHook

__all__ = ['SampleGenerateHook']
__all__ = ['SampleGenerateHook', 'LogSampleHook']
3 changes: 2 additions & 1 deletion mmchat/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .log_data_sample import LogSampleHook
from .sample_generate_hook import SampleGenerateHook

__all__ = ['SampleGenerateHook']
__all__ = ['SampleGenerateHook', 'LogSampleHook']
29 changes: 29 additions & 0 deletions mmchat/engine/hooks/log_data_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from mmengine.hooks import Hook

from mmchat.registry import HOOKS, TOKENIZER


@HOOKS.register_module()
class LogSampleHook(Hook):

def __init__(self, tokenizer):
self.tokenizer = TOKENIZER.build(tokenizer)

def log(self, runner, dataset, mode='train'):
runner.logger.info(f'Num {mode} samples {len(dataset)}')
runner.logger.info(f'{mode} example:')
runner.logger.info(self.tokenizer.decode(dataset[0]['input_ids']))

def before_run(self, runner) -> None:
do_train = runner.train_loop is not None
do_eval = runner.val_loop is not None
do_test = runner.test_loop is not None
if do_train:
train_dataset = runner.train_dataloader.dataset
self.log(runner, train_dataset, mode='train')
if do_eval:
eval_dataset = runner.val_dataloader.dataset
self.log(runner, eval_dataset, mode='eval')
if do_test:
test_dataset = runner.test_dataloader.dataset
self.log(runner, test_dataset, mode='test')
17 changes: 10 additions & 7 deletions mmchat/models/algorithms/sft.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import dataclasses
from typing import Dict

import torch
import transformers
from mmengine import print_log
from mmengine.model import BaseModel
from torch import nn

from mmchat.registry import LLM
from mmchat.registry import LLM, TOKENIZER


def traverse_dict(d):
Expand All @@ -28,7 +27,6 @@ def traverse_dict(d):


def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
Expand All @@ -37,8 +35,9 @@ def smart_tokenizer_and_embedding_resize(
Note: This is the unoptimized version that may make your embedding size
not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model_vocab_size = model.get_output_embeddings().weight.size(0)
model.resize_token_embeddings(len(tokenizer))
num_new_tokens = len(tokenizer) - model_vocab_size

if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
Expand All @@ -51,16 +50,20 @@ def smart_tokenizer_and_embedding_resize(

input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
elif num_new_tokens < 0:
raise RuntimeError


class SupervisedFinetune(BaseModel):

def __init__(self, llm, data_preprocessor=None):
def __init__(self, llm, data_preprocessor=None, tokenizer=None):
super().__init__(data_preprocessor)
self.llm = self._build_from_cfg_or_module(llm, LLM)
self.llm.config.use_cache = False
self.llm.config.torch_dtype = torch.float32

tokenizer = TOKENIZER.build(tokenizer)
smart_tokenizer_and_embedding_resize(tokenizer, self.llm)

def _build_from_cfg_or_module(self, cfg_or_mod, registry):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
Expand Down Expand Up @@ -97,7 +100,7 @@ def compute_loss(self, data, data_samples=None):
# import pdb;pdb.set_trace()
loss_dict = {'loss': outputs.loss}
return loss_dict

def __getattr__(self, name: str):
try:
return super().__getattr__(name)
Expand Down
Loading