Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support micro_batch in pipeline #199

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions mlora/dispatcher/pipeline_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,32 @@


class PipelineDispatcher(Dispatcher):
_adapter_lock_: Dict[str, bool] = {}
_adapter_backward_cnt_: Dict[str, int] = {}
_adapter_forward_cnt_: Dict[str, int] = {}
_adapter_accumulation_step_: Dict[str, int] = {}

def __init__(self,
config: MLoRAConfig,
tokenizer: Tokenizer) -> None:
super().__init__(config, tokenizer)

def activate_adapter(self, adapter_name: str):
self._adapter_lock_[adapter_name] = True

def deactivate_adapter(self, adapter_name: str):
self._adapter_lock_[adapter_name] = False
for lora_config in config.lora_configs_:
adapter_name = lora_config.adapter_name_
accumulation_step = lora_config.batch_size_ / lora_config.micro_batch_size_
self._adapter_forward_cnt_[adapter_name] = 0
self._adapter_backward_cnt_[adapter_name] = 0
self._adapter_accumulation_step_[adapter_name] = accumulation_step

def update_backward_cnt(self, adapter_name: str):
self._adapter_backward_cnt_[adapter_name] += 1
if self._adapter_backward_cnt_[adapter_name] == self._adapter_accumulation_step_[adapter_name]:
self._adapter_forward_cnt_[adapter_name] = 0
self._adapter_backward_cnt_[adapter_name] = 0

def update_forward_cnt(self, adapter_name: str):
self._adapter_forward_cnt_[adapter_name] += 1

def __check_adapter_available(self, adapter_name: str) -> bool:
if adapter_name in self._adapter_lock_:
return self._adapter_lock_[adapter_name]
return True
return self._adapter_forward_cnt_[adapter_name] < self._adapter_accumulation_step_[adapter_name]

def rigister_strategies(self):
self.rigister_strategy("pipe", self.pipe_dispatch_strategy)
Expand All @@ -36,7 +45,7 @@ def pipe_dispatch_strategy(self) -> Dict[str, List[TrainData]]:
# check the adapter is available
if not self.__check_adapter_available(task.adapter_name_):
continue
self.deactivate_adapter(task.adapter_name_)
self.update_forward_cnt(task.adapter_name_)
ret_train_data[task.adapter_name_] = task.get_train_data()
cnt += 1
if cnt >= self.train_lora_simultaneously_num_:
Expand Down
38 changes: 27 additions & 11 deletions mlora/pipeline/pipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from mlora.pipeline.queue import DeviceSwapQueue
from mlora.pipeline.transport import RpcTransport
from mlora.pipeline.stream import CudaStream
from mlora.pipeline.messages import PipeMessageType
from mlora.pipeline.messages import PipeMessage, PipeMessageType
from mlora.pipeline.function import RecvOperator, SendOperator
from mlora.model.model import LLMModel, precompute_mask
from mlora.model.modelargs import LoraBatchDataConfig, MultiLoraBatchData
Expand Down Expand Up @@ -44,6 +45,7 @@ class Pipe():
config_: MLoRAConfig = None

multi_trainer_context_: MultiTrainerContext = None
input_queue_: DeviceSwapQueue = None

def is_stop_signal(self, data: torch.tensor) -> bool:
return data.dtype == torch.long and torch.numel(data) == 1
Expand All @@ -55,7 +57,7 @@ def __init__(self,
device: torch.device,
rank: int,
balance: List[int]) -> None:
self.world_size_ = torch.cuda.device_count()
self.world_size_ = len(balance)
assert self.world_size_ == len(balance)

self.rank_ = rank
Expand All @@ -64,6 +66,8 @@ def __init__(self,

if rank == 0:
self.role_ = WorkerRole.HEAD
self.input_queue_ = DeviceSwapQueue(torch.device('cpu'), device, 4, 'input_data_queue')
self.input_queue_.start()
elif rank == self.world_size_ - 1:
self.role_ = WorkerRole.TAIL
else:
Expand Down Expand Up @@ -114,23 +118,34 @@ def stop(self):
if isinstance(transport, RpcTransport):
transport.stop()
logging.info("Transport stop.")
if self.input_queue_:
self.input_queue_.stop()

def process_input(self):
assert self.role_ == WorkerRole.HEAD
assert not self.input_stop_

if not self.dispatcher_.check_task_done():
def put_train_data():
train_input = self.dispatcher_.get_train_data()
if not train_input:
# avoid the busy loop
time.sleep(1 / 10000000)
return
for lora_config in train_input.lora_batch_data_config_:
logging.info(f'load lora: {lora_config.adapter_name_}')
tokens = torch.tensor(train_input.batch_tokens_,
dtype=torch.int64,
device=self.device_)
data = self.forward(tokens, train_input)
data = torch.tensor(train_input.batch_tokens_, dtype=torch.int64, device="cpu")
msg = PipeMessage(self.device_, self.device_, PipeMessageType.ACTIVATIONS,
0, data, train_input)
self.input_queue_.put(msg)

assert self.role_ == WorkerRole.HEAD
assert not self.input_stop_

if not self.dispatcher_.check_task_done():
put_train_data()
# fetch train data
msg = self.input_queue_.get_nowait()
if not msg:
return
train_input = msg.batch_data_
data = self.forward(msg.tensor_data_, msg.batch_data_)
self.forward_cnt_ += 1
else:
# stop
Expand Down Expand Up @@ -204,6 +219,7 @@ def process_forward(self):
if not self.forward_stop_ and not self.is_stop_signal(message.tensor_data_):
lora_configs = message.batch_data_.lora_batch_data_config_
total_loss = self.multi_trainer_context_.calc_loss(message.batch_data_, data)
message.batch_data_.batch_tokens_ = None # backward doesn't need to save batch_tokens
total_loss.backward()

self.trainer_step(lora_configs)
Expand All @@ -216,7 +232,7 @@ def trainer_step(self, lora_configs: List[LoraBatchDataConfig]):
if self.multi_trainer_context_.is_save_step(adapter_name):
self.save_model(adapter_name, f"{step_cnt}")
if self.role_ == WorkerRole.HEAD:
self.dispatcher_.activate_adapter(adapter_name)
self.dispatcher_.update_backward_cnt(adapter_name)

def save_all_model(self):
for adapter_name in self.multi_trainer_context_.trainer_context_:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"torch==2.0.1",
"torch==2.2.2",
"einops==0.6.1",
"datasets==2.14.5",
"accelerate==0.24.1",
Expand All @@ -23,7 +23,7 @@ dependencies = [
"sentencepiece==0.1.99",
"scipy==1.10.1",
"protobuf==3.20.2",
"xformers==0.0.21"
"xformers==0.0.25"
]

[project.urls]
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
torch==2.0.1
torch==2.2.2
einops==0.6.1
datasets==2.14.5
accelerate==0.24.1
transformers==4.38.2
bitsandbytes==0.41.1
sentencepiece==0.1.99
scipy==1.10.1
xformers==0.0.21
xformers==0.0.25
flask
peft==0.10.0
protobuf==3.20.2
Loading