Skip to content

Commit

Permalink
optim: async data fetch and speedup backward data transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
Vinkle-hzt committed Apr 7, 2024
1 parent 87a6961 commit 8c103b9
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions mlora/pipeline/pipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from mlora.pipeline.queue import DeviceSwapQueue
from mlora.pipeline.transport import RpcTransport
from mlora.pipeline.stream import CudaStream
from mlora.pipeline.messages import PipeMessageType
from mlora.pipeline.messages import PipeMessage, PipeMessageType
from mlora.pipeline.function import RecvOperator, SendOperator
from mlora.model.model import LLMModel, precompute_mask
from mlora.model.modelargs import LoraBatchDataConfig, MultiLoraBatchData
Expand Down Expand Up @@ -44,6 +45,7 @@ class Pipe():
config_: MLoRAConfig = None

multi_trainer_context_: MultiTrainerContext = None
input_queue_: DeviceSwapQueue = None

def is_stop_signal(self, data: torch.tensor) -> bool:
return data.dtype == torch.long and torch.numel(data) == 1
Expand All @@ -55,7 +57,7 @@ def __init__(self,
device: torch.device,
rank: int,
balance: List[int]) -> None:
self.world_size_ = torch.cuda.device_count()
self.world_size_ = len(balance)
assert self.world_size_ == len(balance)

self.rank_ = rank
Expand All @@ -64,6 +66,8 @@ def __init__(self,

if rank == 0:
self.role_ = WorkerRole.HEAD
self.input_queue_ = DeviceSwapQueue(torch.device('cpu'), device, 4, 'input_data_queue')
self.input_queue_.start()
elif rank == self.world_size_ - 1:
self.role_ = WorkerRole.TAIL
else:
Expand Down Expand Up @@ -114,23 +118,34 @@ def stop(self):
if isinstance(transport, RpcTransport):
transport.stop()
logging.info("Transport stop.")
if self.input_queue_:
self.input_queue_.stop()

def process_input(self):
assert self.role_ == WorkerRole.HEAD
assert not self.input_stop_

if not self.dispatcher_.check_task_done():
def put_train_data():
train_input = self.dispatcher_.get_train_data()
if not train_input:
# avoid the busy loop
time.sleep(1 / 10000000)
return
for lora_config in train_input.lora_batch_data_config_:
logging.info(f'load lora: {lora_config.adapter_name_}')
tokens = torch.tensor(train_input.batch_tokens_,
dtype=torch.int64,
device=self.device_)
data = self.forward(tokens, train_input)
data = torch.tensor(train_input.batch_tokens_, dtype=torch.int64, device="cpu")
msg = PipeMessage(self.device_, self.device_, PipeMessageType.ACTIVATIONS,
0, data, train_input)
self.input_queue_.put(msg)

assert self.role_ == WorkerRole.HEAD
assert not self.input_stop_

if not self.dispatcher_.check_task_done():
put_train_data()
# fetch train data
msg = self.input_queue_.get_nowait()
if not msg:
return
train_input = msg.batch_data_
data = self.forward(msg.tensor_data_, msg.batch_data_)
self.forward_cnt_ += 1
else:
# stop
Expand Down Expand Up @@ -204,6 +219,7 @@ def process_forward(self):
if not self.forward_stop_ and not self.is_stop_signal(message.tensor_data_):
lora_configs = message.batch_data_.lora_batch_data_config_
total_loss = self.multi_trainer_context_.calc_loss(message.batch_data_, data)
message.batch_data_.batch_tokens_ = None # backward doesn't need to save batch_tokens
total_loss.backward()

self.trainer_step(lora_configs)
Expand Down

0 comments on commit 8c103b9

Please sign in to comment.