Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 20 additions & 44 deletions fastdeploy/cache_manager/cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,19 @@

from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.cache_manager.ops import (
cuda_host_alloc,
cuda_host_free,
memory_allocated,
set_data_ipc,
set_device,
share_external_data_,
swap_cache_all_layers,
unset_data_ipc,
)
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.platforms import current_platform

if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
cuda_host_free,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
unset_data_ipc,
)
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
cuda_host_alloc,
cuda_host_free,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
)
from fastdeploy.utils import get_logger


Expand Down Expand Up @@ -194,10 +186,7 @@ def __init__(self, args):
suffix=args.engine_worker_queue_port,
create=False,
)

# TODO XPU support RL
if not current_platform.is_xpu():
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()

def _init_gpu_cache(self, args):

Expand All @@ -208,10 +197,7 @@ def _init_gpu_cache(self, args):
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")

logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
if current_platform.is_cuda():
paddle.set_device(f"gpu:{self.device}")
elif current_platform.is_xpu():
paddle.set_device(f"xpu:{self.device}")
set_device(self.device)
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
Expand All @@ -228,12 +214,8 @@ def _init_gpu_cache(self, args):
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
if current_platform.is_xpu():
key_cache = share_external_data(key_cache, key_name, cache_shape, True)
val_cache = share_external_data(val_cache, val_name, cache_shape, True)
else:
key_cache = share_external_data(key_cache, key_name, cache_shape)
val_cache = share_external_data(val_cache, val_name, cache_shape)
key_cache = share_external_data_(key_cache, key_name, cache_shape, True)
val_cache = share_external_data_(val_cache, val_name, cache_shape, True)

self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[val_name] = val_cache
Expand All @@ -247,10 +229,7 @@ def _init_gpu_cache(self, args):
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
if current_platform.is_cuda():
logger.info(
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
)
logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}")

def _init_cpu_cache(self, args):
if args.num_cpu_blocks == 0:
Expand Down Expand Up @@ -513,6 +492,9 @@ def _transfer_data(
)

def clear_or_update_caches(self, args):
# TODO XPU support RL
if unset_data_ipc is None:
return
logger.info("Start a thread to clear/restore kv cache when model weights are cleared/updated.")
logger.info(f"FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}")
kv_cache_status = np.zeros([1], dtype=np.int32)
Expand Down Expand Up @@ -544,10 +526,7 @@ def clear_or_update_caches(self, args):
time.sleep(0.1)

# clear gpu caches
if current_platform.is_cuda():
paddle.set_device(f"gpu:{self.device}")
elif current_platform.is_xpu():
paddle.set_device(f"xpu:{self.device}")
set_device(self.device)
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
self.gpu_cache_kvs.clear()
Expand Down Expand Up @@ -617,8 +596,5 @@ def main():
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log")
if current_platform.is_cuda():
paddle.set_device(f"gpu:{args.device_id}")
elif current_platform.is_xpu():
paddle.set_device(f"xpu:{args.device_id}")
set_device(args.device_id)
main()
60 changes: 60 additions & 0 deletions fastdeploy/cache_manager/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import paddle

from fastdeploy.platforms import current_platform

if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
cuda_host_free,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
unset_data_ipc,
)

memory_allocated = paddle.device.cuda.memory_allocated
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
cuda_host_alloc,
cuda_host_free,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
)

unset_data_ipc = None
memory_allocated = paddle.device.xpu.memory_allocated

else:
raise RuntimeError("Prefix cache ops only supported CUDA nor XPU platform ")


def set_device(device):
if current_platform.is_cuda():
paddle.set_device(f"gpu:{device}")
elif current_platform.is_xpu():
paddle.set_device(f"xpu:{device}")
else:
raise RuntimeError("No supported platform")


def share_external_data_(cache, cache_name, cache_shape, use_ipc):
if current_platform.is_cuda():
cache = share_external_data(cache, cache_name, cache_shape)
elif current_platform.is_xpu():
cache = share_external_data(cache, cache_name, cache_shape, use_ipc)
else:
raise RuntimeError("No supported platform")
return cache


__all__ = [
"cuda_host_alloc",
"cuda_host_free",
"set_data_ipc",
"share_external_data_",
"swap_cache_all_layers",
"unset_data_ipc", # XPU是 None
"set_device",
"memory_allocated",
]
22 changes: 12 additions & 10 deletions fastdeploy/worker/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def init_device(self):
self.device_ids = self.parallel_config.device_ids.split(",")
self.device = f"xpu:{self.local_rank % self.max_chips_per_node}"
paddle.device.set_device(self.device)
self.device_id = int(self.device_ids[self.local_rank % self.max_chips_per_node])
assert (
self.device_id is not None
), f"device_id is none for rank {self.local_rank % self.max_chips_per_node}"
assert len(self.device_ids) > (
self.local_rank % self.max_chips_per_node
), f"device number must be greater than local rank, but get device number is {len(self.device_ids)}, rank is {self.local_rank % self.max_chips_per_node}"
paddle.set_default_dtype(self.model_config.dtype)

gc.collect()
Expand All @@ -69,7 +76,7 @@ def init_device(self):
fd_config=self.fd_config,
device=self.device,
rank=self.rank,
device_id=int(self.device_ids[self.local_rank % self.max_chips_per_node]),
device_id=self.device_id,
local_rank=self.local_rank,
)

Expand Down Expand Up @@ -98,14 +105,9 @@ def determine_available_memory(self) -> int:
xpu_get_used_global_memory,
)

assert self.device_ids[self.local_rank] is not None, f"device_id is none for rank {self.local_rank}"
assert (
len(self.device_ids) > self.local_rank
), f"device number must be greater than local rank, but get device number is {len(self.device_ids)}, rank is {self.local_rank}"

total_memory = xpu_get_total_global_memory(int(self.device_ids[self.local_rank]))
used_memory = xpu_get_used_global_memory(int(self.device_ids[self.local_rank]))
free_memory = xpu_get_free_global_memory(int(self.device_ids[self.local_rank]))
total_memory = xpu_get_total_global_memory(self.device_id)
used_memory = xpu_get_used_global_memory(self.device_id)
free_memory = xpu_get_free_global_memory(self.device_id)

logger.info(
f"Before warm up, total_memory: {total_memory}, \
Expand All @@ -119,7 +121,7 @@ def determine_available_memory(self) -> int:
set_random_seed(self.fd_config.model_config.seed)

total_available_memory = int(total_memory * self.cache_config.gpu_memory_utilization)
used_memory = xpu_get_used_global_memory(int(self.device_ids[self.local_rank]))
used_memory = xpu_get_used_global_memory(self.device_id)
available_kv_cache_memory = total_available_memory - used_memory
model_block_memory_used = self.cal_theortical_kvcache()
available_kv_cache_memory += model_block_memory_used * self.cache_config.total_block_num
Expand Down
Loading