diff --git a/.github/workflows/cpp-linter.yml b/.github/workflows/cpp-linter.yml new file mode 100644 index 00000000..e013a62c --- /dev/null +++ b/.github/workflows/cpp-linter.yml @@ -0,0 +1,34 @@ +name: cpp-linter + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "dev*", "main", "*release" ] + + +jobs: + cpp-linter: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - uses: cpp-linter/cpp-linter-action@main + id: linter + continue-on-error: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + style: file + tidy-checks: '-*' + files-changed-only: true + lines-changed-only: diff + format-review: true + thread-comments: ${{ github.event_name == 'pull_request' && 'update' }} + + - name: Fail fast?! + if: steps.linter.outputs.checks-failed != 0 + run: | + echo "some linter checks failed. ${{ steps.linter.outputs.checks-failed }}" + exit 1 diff --git a/docs/source/index.md b/docs/source/index.md index 2352d399..69be815e 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -57,6 +57,7 @@ getting-started/installation_npu user-guide/prefix-cache/index user-guide/sparse-attention/index user-guide/pd-disaggregation/index +user-guide/metrics/metrics ::: :::{toctree} diff --git a/docs/source/user-guide/metrics/metrics.md b/docs/source/user-guide/metrics/metrics.md new file mode 100644 index 00000000..22b53268 --- /dev/null +++ b/docs/source/user-guide/metrics/metrics.md @@ -0,0 +1,193 @@ +# Observability + +UCM (Unified Cache Management) provides detailed metrics monitoring through Prometheus endpoints, allowing in-depth monitoring of cache performance and behavior. This document describes how to enable and configure observability from the embedded vLLM `/metrics` API endpoint. + +--- + +## Quick Start Guide + +### 1) On UCM Side + +First, set the `PROMETHEUS_MULTIPROC_DIR` environment variable. + +```bash +export PROMETHEUS_MULTIPROC_DIR=/vllm-workspace +``` + +Then, start the UCM service. + +```bash +export CUDA_VISIBLE_DEVICES=0 +vllm serve /home/models/Qwen2.5-14B-Instruct \ + --max-model-len 5000 \ + --tensor-parallel-size 1 \ + --gpu_memory_utilization 0.87 \ + --trust-remote-code \ + --disable-log-requests \ + --no-enable-prefix-caching \ + --enforce-eager \ + --max-num-batched-tokens 40000 \ + --max-num-seqs 10 \ + --host 0.0.0.0 \ + --port 8000 \ + --kv-transfer-config \ + '{ + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "UCM_CONFIG_FILE": "/vllm-workspace/unified-cache-management/examples/ucm_config.yaml" + } + }' +``` +**Note**: You can refer to the `ucm_config.yaml` file at https://github.com/ModelEngine-Group/unified-cache-management/tree/develop/examples to configure the `metrics_config_path` parameter. + +You can use the `vllm bench serve` command to run benchmarks: + +```bash +vllm bench serve \ + --backend vllm \ + --model /home/models/Qwen2.5-14B-Instruct \ + --host 127.0.0.1 \ + --port 8000 \ + --dataset-name random \ + --num-prompts 20 \ + --random-input-len 200 \ + --random-output-len 10 \ + --request-rate 1 \ + --ignore-eos +``` + +Once the HTTP server is running, you can access the UCM metrics at the `/metrics` endpoint. + +```bash +curl http://$:8000/metrics | grep ucm: +``` + +You will also find some `.db` files in the `$PROMETHEUS_MULTIPROC_DIR` directory, which are temporary files used by Prometheus. + +### 2) Start Prometheus and Grafana with Docker Compose + +#### Create Docker Compose Configuration Files + +First, create the `docker-compose.yaml` file: + +```yaml +# docker-compose.yaml +version: "3" + +services: + prometheus: + image: prom/prometheus:latest + extra_hosts: + - "host.docker.internal:host-gateway" + ports: + - "9090:9090" + volumes: + - ${PWD}/prometheus.yaml:/etc/prometheus/prometheus.yml + + grafana: + image: grafana/grafana:latest + depends_on: + - prometheus + ports: + - "3000:3000" +``` + +Then, create the `prometheus.yaml` configuration file: + +```yaml +# prometheus.yaml +global: + scrape_interval: 5s + evaluation_interval: 30s + +scrape_configs: + - job_name: vllm + static_configs: + - targets: + - 'host.docker.internal:8000' +``` + +**Note**: Make sure the port number in `prometheus.yaml` matches the port number used when starting the vLLM service. + +#### Start Services + +Run the following command in the directory containing `docker-compose.yaml` and `prometheus.yaml`: + +```bash +docker compose up +``` + +This will start Prometheus and Grafana services. + +### 3) Configure Grafana Dashboard + +#### Access Grafana + +Navigate to `http://:3000`. Log in with the default username (`admin`) and password (`admin`). You will be prompted to change the password on first login. + +#### Add Prometheus Data Source + +1. Navigate to `http://:3000/connections/datasources/new` and select **Prometheus**. + +2. On the Prometheus configuration page, add the Prometheus server URL in the **Connection** section. For this Docker Compose setup, Grafana and Prometheus run in separate containers, but Docker creates DNS names for each container. You can directly use `http://prometheus:9090`. + +3. Click **Save & Test**. You should see a green checkmark showing "Successfully queried the Prometheus API." + +#### Import Dashboard + +1. Navigate to `http://:3000/dashboard/import`. + +2. Click **Upload JSON file**, then upload the `unified-cache-management/examples/metrics/grafana.json` file. + +3. Select the Prometheus data source configured earlier. + +4. Click **Import** to complete the import. + +You should now be able to see the UCM monitoring dashboard with real-time visualization of all 9 metrics. + +## Available Metrics + +UCM exposes various metrics to monitor its performance. The following table lists all available metrics organized by category: + +| Metric Name | Type | Description | +|------------|------|-------------| +| **Load Operation Metrics** | | | +| `ucm:load_requests_num` | Histogram | Number of requests loaded per `start_load_kv` call | +| `ucm:load_blocks_num` | Histogram | Number of blocks loaded per `start_load_kv` call | +| `ucm:load_duration` | Histogram | Time to load KV cache from UCM (milliseconds) | +| `ucm:load_speed` | Histogram | Speed of loading from UCM (GB/s) | +| **Save Operation Metrics** | | | +| `ucm:save_requests_num` | Histogram | Number of requests saved per `wait_for_save` call | +| `ucm:save_blocks_num` | Histogram | Number of blocks saved per `wait_for_save` call | +| `ucm:save_duration` | Histogram | Time to save to UCM (milliseconds) | +| `ucm:save_speed` | Histogram | Speed of saving to UCM (GB/s) | +| **Lookup Hit Rate Metrics** | | | +| `ucm:interval_lookup_hit_rates` | Histogram | Hit rate of UCM lookup requests | + +## Prometheus Configuration + +Metrics configuration is defined in the `unified-cache-management/examples/metrics/metrics_configs.yaml` file: + +```yaml +log_interval: 5 # Interval in seconds for logging metrics + +prometheus: + multiproc_dir: "/vllm-workspace" # Prometheus directory + metric_prefix: "ucm:" # Metric name prefix + + enabled_metrics: + counters: true + gauges: true + histograms: true + + histograms: + - name: "load_requests_num" + documentation: "Number of requests loaded from ucm" + buckets: [1, 5, 10, 20, 50, 100, 200, 500, 1000] + # ... other metric configurations +``` + +--- + diff --git a/docs/source/user-guide/prefix-cache/dram_store.md b/docs/source/user-guide/prefix-cache/dram_store.md deleted file mode 100644 index 157e098e..00000000 --- a/docs/source/user-guide/prefix-cache/dram_store.md +++ /dev/null @@ -1,132 +0,0 @@ -# DRAM Store - -This document provides a usage example and configuration guide for the **DRAM Connector**. This connector enables offloading of KV cache from GPU HBM to CPU DRAM, helping reduce memory pressure and supporting larger models or batch sizes. - -## Performance - -### Overview -The following are the multi-concurrency performance test results of UCM in the Prefix Cache scenario under a CUDA environment, showing the performance improvements of UCM on two different models. -During the tests, HBM cache was disabled, and KV Cache was retrieved and matched only from DRAM. - -In the QwQ-32B model, the test used one H20 server with 2 GPUs. - -Here, Full Compute refers to pure VLLM inference, while DRAM80% indicates that after UCM pooling, the DRAM hit rate of the KV cache is 80%. - -The following table shows the results on the QwQ-32B model: -| **QwQ-32B** | | | | | -| ---------------: | -------------: | ------------------: | -------------: | :----------- | -| **Input length** | **Concurrent** | **Full Compute(s)** | **DRAM80%(s)** | **Speedup** | -| 4 000 | 1 | 1.0269 | 0.3102 | **+230.9 %** | -| 8 000 | 1 | 2.0902 | 0.5718 | **+265.5 %** | -| 16 000 | 1 | 4.4852 | 1.1914 | **+276.4 %** | -| 4 000 | 2 | 1.5383 | 0.4209 | **+265.4 %** | -| 8 000 | 2 | 3.1323 | 0.8231 | **+280.5 %** | -| 16 000 | 2 | 6.7984 | 1.7420 | **+290.2 %** | -| 4 000 | 4 | 2.8173 | 0.9444 | **+198.2 %** | -| 8 000 | 4 | 5.2643 | 1.8290 | **+187.8 %** | -| 16 000 | 4 | 11.3651 | 3.6706 | **+209.6 %** | -## Features - -The DRAM connector supports the following functionalities: - -- `dump`: Offload KV cache blocks from HBM to DRAM. -- `load`: Load KV cache blocks from DRAM back to HBM. -- `lookup`: Look up KV blocks stored in DRAM by block hash. -- `wait`: Ensure that all copy streams between CPU and GPU have completed. -- `commit`: Mark cache operations as complete and ready for reuse. - -## Configuration - -To use the DRAM connector, you need to configure the `connector_config` dictionary in your model's launch configuration. - -### Required Parameters - -- `max_cache_size` *(optional)*: - Specifies the maximum allowed DRAM memory usage (in **bytes**) for caching in `kv_connector_extra_config["ucm_connector_config"]`. - If not provided, it defaults to **5 GB**. -- `kv_block_size` *(optional)*: - Specifies the memory size (in **bytes**) of a single key or value cache block used in vLLM’s paged attention mechanism, which is calculated as : `block_size * head_size * total_num_kv_heads * element_size`. - -### Example: - -Create a config yaml like following and save it to your own directory: -```yaml -# UCM Configuration File Example -# Refer to file unified-cache-management/examples/ucm_config_example.yaml for more details -ucm_connector_name: "UcmDramStore" - -ucm_connector_config: - max_cache_size: 5368709120 - kv_block_size: 262144 -``` - -## Launching Inference - -### Offline Inference - -To start **offline inference** with the DRAM connector,modify the script `examples/offline_inference.py` to include the `kv_connector_extra_config` for DRAM connector usage: - -```python -# In examples/offline_inference.py -ktc = KVTransferConfig( - ... - kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} -) -``` - -Then run the script as follows: - -```bash -cd examples/ -python offline_inference.py -``` - -### Online Inference - -For **online inference** , vLLM with our connector can also be deployed as a server that implements the OpenAI API protocol. - -First, specify the python hash seed by: -```bash -export PYTHONHASHSEED=123456 -``` - -Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model: - -```bash -vllm serve /home/models/Qwen2.5-14B-Instruct \ ---max-model-len 20000 \ ---tensor-parallel-size 2 \ ---gpu_memory_utilization 0.87 \ ---trust-remote-code \ ---port 7800 \ ---kv-transfer-config \ -'{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", - "kv_role": "kv_both", - "kv_connector_extra_config": {"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} -}' -``` - -If you see log as below: - -```bash -INFO: Started server process [32890] -INFO: Waiting for application startup. -INFO: Application startup complete. -``` - -Congratulations, you have successfully started the vLLM server with DRAM Connector! - -After successfully started the vLLM server,You can interact with the API as following: - -```bash -curl http://localhost:7800/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "/home/models/Qwen2.5-14B-Instruct", - "prompt": "Shanghai is a", - "max_tokens": 7, - "temperature": 0 - }' -``` diff --git a/docs/source/user-guide/prefix-cache/index.md b/docs/source/user-guide/prefix-cache/index.md index defe27d3..ba3d16be 100644 --- a/docs/source/user-guide/prefix-cache/index.md +++ b/docs/source/user-guide/prefix-cache/index.md @@ -79,6 +79,5 @@ performance. :::{toctree} :maxdepth: 1 -dram_store nfs_store ::: \ No newline at end of file diff --git a/setup.py b/setup.py index 8c462dab..5a617c74 100644 --- a/setup.py +++ b/setup.py @@ -135,7 +135,7 @@ def _get_package_data_with_so(): setup( name="uc-manager", - version="0.1.0rc3", + version="0.1.0rc4", description="Unified Cache Management", author="Unified Cache Team", packages=find_packages(), diff --git a/test/common/capture_utils.py b/test/common/capture_utils.py index ee12ed2a..b12b7663 100644 --- a/test/common/capture_utils.py +++ b/test/common/capture_utils.py @@ -1,3 +1,4 @@ +import functools from typing import Any, Dict, List from common.db_utils import write_to_db @@ -44,6 +45,7 @@ def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]: # ---------------- decorator ---------------- def export_vars(func): + @functools.wraps(func) def wrapper(*args, **kwargs): result = func(*args, **kwargs) # If the function returns a dict containing '_data' or 'data', post-process it diff --git a/test/common/llmperf/__init__.py b/test/common/llmperf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/common/llmperf/run_inference.py b/test/common/llmperf/run_inference.py new file mode 100644 index 00000000..b04deb1e --- /dev/null +++ b/test/common/llmperf/run_inference.py @@ -0,0 +1,185 @@ +import json +import os +import random +from pathlib import Path +from typing import Any, Dict, List + +import yaml +from common.llmperf.utils.token_benchmark import run_token_benchmark +from common.llmperf.utils.utils import reset_prefill_cache + + +def run_test_cases( + llm_api, + model, + timeout, + max_num_completed_requests, + concurrent_requests, + mean_input_tokens, + stddev_input, + mean_output_tokens, + stddev_output, + additional_sampling_params, + timestamp_dir, + server_url, + tokenizer_path, + hit_rate, +): + print(f"[INFO] Total {len(mean_input_tokens)} test cases to be executed") + all_summaries = [] + failed_case = [] + + # Clear proxy environment variables + env = os.environ.copy() + env.pop("http_proxy", None) + env.pop("https_proxy", None) + + for i, ( + mean_input, + mean_output, + max_completed, + concurrent, + additional_sampling_params, + hit_rate_val, + ) in enumerate( + zip( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, + ), + start=1, + ): + # for i, case in enumerate(mean_input_tokens): + print(f"\n>>> Executing test case {i} <<<") + reset_prefill_cache(env, server_url) + # Use a fixed random_seed for each test to control PC hit_rate + random_seed = random.randint(1, 100000) + + try: + # Determine if two runs are needed (PC hit_rate test) + if hit_rate_val == 0: + summary = run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=mean_output, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "normal"}, + ) + else: + print( + f"[INFO] hit_rate > 0 detected, entering prefill mode, PC hit rate: {hit_rate_val} %" + ) + # hit_rate > 0: first prefill mode + prefill_mean_input = int(mean_input * hit_rate_val / 100) + print( + f"[INFO] Prefill execution: mean_input_tokens={prefill_mean_input}" + ) + run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=prefill_mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=2, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "prefill"}, + ) + reset_prefill_cache(env, server_url) + # Then run normal mode + print("[INFO] Prefill completed, switching to normal mode execution") + summary = run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=mean_output, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "normal"}, + ) + all_summaries.append(summary) + except Exception as e: + print(f"[Warning] {e}") + failed_case.append(i) + + return all_summaries, failed_case + + +def inference_results( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, +): + config_file = Path(__file__).parent.parent.parent / "config.yaml" + print("[INFO] Initialization complete, starting main process") + print(f"[INFO] Reading configuration file: {config_file}") + with open(config_file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + llm_api = config.get("llm_connection", {}).get("llm_api", "openai") + model = config.get("llm_connection", {}).get("model", "") + test_timeout_s = config.get("llm_connection", {}).get("test_timeout_s", 60000) + stddev_input_tokens = config.get("llm_connection", {}).get( + "stddev_input_tokens", 0 + ) + stddev_output_tokens = config.get("llm_connection", {}).get( + "stddev_output_tokens", 0 + ) + timestamp_dir = Path("results") + timestamp_dir.mkdir(parents=True, exist_ok=True) + server_url = config.get("llm_connection", {}).get("server_url", "") + tokenizer_path = config.get("llm_connection", {}).get("tokenizer_path", "") + print(f"[INFO] Created results directory: {timestamp_dir}") + + all_summaries, failed_cases = run_test_cases( + llm_api, + model, + test_timeout_s, + max_num_completed_requests, + concurrent_requests, + mean_input_tokens, + stddev_input_tokens, + mean_output_tokens, + stddev_output_tokens, + additional_sampling_params, + timestamp_dir, + server_url, + tokenizer_path, + hit_rate, + ) + total = len(mean_input_tokens) + print( + f"\n[INFO] All tests completed! Success: {total - len(failed_cases)}/{total}" + ) + if failed_cases: + print(f"[WARN] Failed case indices: {failed_cases}") + return all_summaries diff --git a/test/common/llmperf/utils/__init__.py b/test/common/llmperf/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/common/llmperf/utils/common_metrics.py b/test/common/llmperf/utils/common_metrics.py new file mode 100644 index 00000000..40e21124 --- /dev/null +++ b/test/common/llmperf/utils/common_metrics.py @@ -0,0 +1,17 @@ +# TODO (Avnishn): compute metrics in class +INTER_TOKEN_LAT = "inter_token_latency_s" +TTFT = "ttft_s" +E2E_LAT = "end_to_end_latency_s" +NUM_INPUT_TOKENS = "number_input_tokens" +NUM_OUTPUT_TOKENS = "number_output_tokens" +NUM_TOTAL_TOKENS = "number_total_tokens" +REQ_OUTPUT_THROUGHPUT = "request_output_throughput_token_per_s" +ERROR_MSG = "error_msg" +ERROR_CODE = "error_code" +ERROR_CODE_FREQ = "error_code_frequency" +NUM_ERRORS = "number_errors" +OUTPUT_THROUGHPUT = "mean_output_throughput_token_per_s" +NUM_COMPLETED_REQUESTS = "num_completed_requests" +COMPLETED_REQUESTS_PER_MIN = "num_completed_requests_per_min" +ERROR_RATE = "error_rate" +NUM_REQ_STARTED = "num_requests_started" diff --git a/test/common/llmperf/utils/models.py b/test/common/llmperf/utils/models.py new file mode 100644 index 00000000..1cbab628 --- /dev/null +++ b/test/common/llmperf/utils/models.py @@ -0,0 +1,23 @@ +from typing import Any, Dict, Optional, Tuple + +from pydantic import BaseModel + + +class RequestConfig(BaseModel): + """The configuration for a request to the LLM API. + + Args: + model: The model to use. + prompt: The prompt to provide to the LLM API. + sampling_params: Additional sampling parameters to send with the request. + For more information see the Router app's documentation for the completions + llm_api: The name of the LLM API to send the request to. + metadata: Additional metadata to attach to the request for logging or validation purposes. + """ + + model: str + prompt: Tuple[str, int] + sampling_params: Optional[Dict[str, Any]] = None + llm_api: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + openai_api_base: Optional[str] = "" diff --git a/test/common/llmperf/utils/openai_chat_completions_client.py b/test/common/llmperf/utils/openai_chat_completions_client.py new file mode 100644 index 00000000..5023bfa1 --- /dev/null +++ b/test/common/llmperf/utils/openai_chat_completions_client.py @@ -0,0 +1,136 @@ +import json +import os +import time +from asyncio import timeout +from pathlib import Path +from typing import Any, Dict, Tuple + +import requests +import yaml +from common.llmperf.utils import common_metrics +from common.llmperf.utils.models import RequestConfig + +config_file = Path(__file__).parent.parent.parent.parent / "config.yaml" +with open(config_file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) +stream = config.get("llm_connection", {}).get("stream", True) +ignore_eos = config.get("llm_connection", {}).get("ignore_eos", True) +timeout = config.get("llm_connection", {}).get("timeout", 180) + + +class OpenAIChatCompletionsClient: + """ + used for sending HTTP requests, receiving token streams, measuring latency, etc. + """ + + def llm_request( + self, request_config: RequestConfig + ) -> Tuple[Dict[str, Any], str, RequestConfig]: + prompt, prompt_len = request_config.prompt + + message = [ + {"role": "user", "content": prompt}, + ] + model = request_config.model + body = { + "model": model, + "messages": message, + "stream": stream, + "ignore_eos": ignore_eos, + } + sampling_params = request_config.sampling_params + body.update(sampling_params or {}) + + time_to_next_token = [] + tokens_received = 0 + ttft = 0.0 + error_response_code = None + generated_text = "" + error_msg = "" + output_throughput = 0.0 + total_request_time = 0.0 + flag = False + + metrics: Dict[str, Any] = {} + + metrics[common_metrics.ERROR_CODE] = None + metrics[common_metrics.ERROR_MSG] = "" + + start_time = time.monotonic() + most_recent_received_token_time = start_time + + address = request_config.openai_api_base + + if not address: + raise ValueError("the environment variable OPENAI_API_BASE must be set.") + key = os.environ.get("OPENAI_API_KEY", "secret_abcdefg") + if not key: + raise ValueError("the environment variable OPENAI_API_KEY must be set.") + headers = {"Authorization": f"Bearer {key}"} + if not address.endswith("/"): + address = address + "/" + address += "chat/completions" + try: + with requests.post( + address, + json=body, + stream=stream, + timeout=timeout, + headers=headers, + ) as response: + if response.status_code != 200: + error_msg = response.text + error_response_code = response.status_code + response.raise_for_status() + + for chunk in response.iter_lines(chunk_size=None): + if not chunk: + continue + stem = b"data: " + if chunk.startswith(stem): + chunk = chunk[len(stem) :] + # Data might already be bytes or str + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8", errors="ignore") + if chunk.strip() == "[DONE]": + continue + tokens_received += 1 + data = json.loads(chunk) + if "error" in data: + error_msg = data["error"]["message"] + error_response_code = data["error"]["code"] + raise RuntimeError(error_msg) + delta = data["choices"][0]["delta"] + content = delta.get("content", None) or delta.get( + "reasoning_content", "" + ) + if content: + if tokens_received != 0 and flag == False: + ttft = time.monotonic() - start_time + flag = True + else: + time_to_next_token.append( + time.monotonic() - most_recent_received_token_time + ) + most_recent_received_token_time = time.monotonic() + generated_text += content + + total_request_time = time.monotonic() - start_time + if total_request_time > 0: + output_throughput = tokens_received / total_request_time + + except Exception as e: + metrics[common_metrics.ERROR_MSG] = error_msg + metrics[common_metrics.ERROR_CODE] = error_response_code + print(f"Warning Or Error: {e}") + print(error_response_code) + + metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token) + metrics[common_metrics.TTFT] = ttft + metrics[common_metrics.E2E_LAT] = total_request_time + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput + metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len + metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received + metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len + + return metrics, generated_text, request_config diff --git a/test/common/llmperf/utils/sonnet.txt b/test/common/llmperf/utils/sonnet.txt new file mode 100644 index 00000000..9f13ead4 --- /dev/null +++ b/test/common/llmperf/utils/sonnet.txt @@ -0,0 +1,84 @@ +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Where art thou, Muse, that thou forget'st so long +To speak of that which gives thee all thy might? +Spend'st thou thy fury on some worthless song, +Darkening thy power to lend base subjects light? +Return, forgetful Muse, and straight redeem +In gentle numbers time so idly spent; +Sing to the ear that doth thy lays esteem +And gives thy pen both skill and argument. +Rise, resty Muse, my love's sweet face survey, +If Time have any wrinkle graven there; +If any, be a satire to decay, +And make Time's spoils despised every where. +Give my love fame faster than Time wastes life; +So thou prevent'st his scythe and crooked knife. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +So am I as the rich, whose blessed key +Can bring him to his sweet up-locked treasure, +The which he will not every hour survey, +For blunting the fine point of seldom pleasure. +Therefore are feasts so solemn and so rare, +Since, seldom coming, in the long year set, +Like stones of worth they thinly placed are, +Or captain jewels in the carcanet. +So is the time that keeps you as my chest, +Or as the wardrobe which the robe doth hide, +To make some special instant special blest, +By new unfolding his imprison'd pride. +Blessed are you, whose worthiness gives scope, +Being had, to triumph, being lack'd, to hope. +If there be nothing new, but that which is +Hath been before, how are our brains beguiled, +Which, labouring for invention, bear amiss +The second burden of a former child! +O, that record could with a backward look, +Even of five hundred courses of the sun, +Show me your image in some antique book, +Since mind at first in character was done! +That I might see what the old world could say +To this composed wonder of your frame; +Whether we are mended, or whether better they, +Or whether revolution be the same. +O, sure I am, the wits of former days +To subjects worse have given admiring praise. \ No newline at end of file diff --git a/test/common/llmperf/utils/token_benchmark.py b/test/common/llmperf/utils/token_benchmark.py new file mode 100644 index 00000000..67553cf1 --- /dev/null +++ b/test/common/llmperf/utils/token_benchmark.py @@ -0,0 +1,386 @@ +import json +import logging +import random +import re +import time +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +from common.llmperf.utils import common_metrics +from common.llmperf.utils.models import RequestConfig +from common.llmperf.utils.openai_chat_completions_client import ( + OpenAIChatCompletionsClient, +) +from common.llmperf.utils.utils import ( + LLMPerfResults, + randomly_sample_sonnet_lines_prompt, + sample_random_positive_int, +) +from transformers import AutoTokenizer + + +def get_token_throughput_latencies( + model: str, + mean_input_tokens: int, + stddev_input_tokens: int, + mean_output_tokens: int, + stddev_output_tokens: int, + additional_sampling_params: Optional[Dict[str, Any]] = None, + concurrent_requests: int = 1, + max_num_completed_requests: int = 500, + test_timeout_s=90, + llm_api="openai", + random_seed: int = None, + openai_api_base: str = "", + tokenizer_path: str = None, +) -> Tuple[Dict[str, Any], List[Dict[str, Any]], float, float]: + """Get the token throughput and latencies for the given model. + + Args: + model: The name of the model to query. + mean_input_tokens: The mean number of tokens to send in the prompt for the request. + stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. + mean_output_tokens: The mean number of tokens to generate per request. + stddev_output_tokens: The standard deviation of the number of tokens to generate per request. + additional_sampling_params: Additional sampling parameters to send with the request. + For more information see the LLM APIs documentation for the completions + concurrent_requests: The number of concurrent requests to make. Increase + this to increase the amount of load and vice versa. + test_timeout_s: The amount of time to run the test for before reporting results. + llm_api: The name of the llm api to use. Either "openai" or "litellm". + + Returns: + A summary of the performance metrics collected across all completed requests + (e.g. throughput, latencies, etc.) + The individual metrics for each request. + """ + random.seed(random_seed) + + print(f"Using tokenizer:{tokenizer_path}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + get_token_length = lambda text: len(tokenizer.encode(text)) + + if not additional_sampling_params: + additional_sampling_params = {} + + # 1. create prompts + prompts: List[Tuple[str, int]] = [] + num_output_tokens_list: List[int] = [] + for i in range(max_num_completed_requests): + num_output = sample_random_positive_int( + mean_output_tokens, stddev_output_tokens + ) + num_output_tokens_list.append(num_output) + prompts.append( + randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean=mean_input_tokens, + prompt_tokens_stddev=stddev_input_tokens, + tokenizer=tokenizer, + ) + ) + start_time = time.monotonic() + completed_requests: List[Dict[str, Any]] = [] + incremental_time_delay = 0.0 + client = OpenAIChatCompletionsClient() + futures = [] + + # 2. Submitting tasks using a thread pool + with ThreadPoolExecutor(max_workers=concurrent_requests) as executor: + for idx in range(max_num_completed_requests): + sampling = {"max_tokens": num_output_tokens_list[idx]} + sampling.update(additional_sampling_params) + cfg = RequestConfig( + model=model, + prompt=prompts[idx], + sampling_params=sampling, + llm_api=llm_api, + openai_api_base=openai_api_base, + ) + futures.append(executor.submit(client.llm_request, cfg)) + # 3. Waiting for completion or timeout + for future in as_completed(futures, timeout=test_timeout_s): + try: + metrics, gen_text, req_cfg = future.result() + except Exception as e: + logging.warning(f"[WARN] Future raised exception: {e}") + continue + num_output_tokens = get_token_length(gen_text) + if num_output_tokens: + metrics[common_metrics.INTER_TOKEN_LAT] /= ( + (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1) + if (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1) + else 1 + ) + metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens + metrics[common_metrics.NUM_TOTAL_TOKENS] = ( + metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens + ) + try: + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = ( + num_output_tokens / metrics[common_metrics.E2E_LAT] + ) + except ZeroDivisionError: + logging.error("Division by zero in throughput calculation.") + + completed_requests.append(metrics) + + incremental_time_delay += metrics.get( + common_metrics.INTER_TOKEN_LAT, 0.0 + ) + + end_time = time.monotonic() + + print(f"Results for token benchmark for {model} queried with the {llm_api} api.\n") + if mean_output_tokens == 2: + print(f"[INFO] First token sending pre-embedding completed\n") + return {}, [], 0.0, 0.0 + + ret = metrics_summary(completed_requests, start_time, end_time) + + metadata = { + "model": model, + "mean_input_tokens": mean_input_tokens, + "stddev_input_tokens": stddev_input_tokens, + "mean_output_tokens": mean_output_tokens, + "stddev_output_tokens": stddev_output_tokens, + "concurrent_requests": concurrent_requests, + "additional_sampling_params": additional_sampling_params, + } + + metadata["results"] = ret + elapsed_time = end_time - start_time + return metadata, completed_requests, elapsed_time, incremental_time_delay + + +def compute_throughput( + summary: Dict[str, Any], + completed_requests: List[Dict[str, Any]], + elapsed_time: float, + incremental_time_delay: float, +) -> Tuple[float, float]: + """ + Compute total_throughput (token/s) based on the metrics in summary. + + Formula: (mean_output_tokens * num_completed_requests) / total_e2e_latency_s + + Args: + summary (Dict[str, Any]): A dictionary containing performance metrics. + + Returns: + float: The computed total throughput in tokens per second. Returns 0.0 if latency is zero. + """ + mean_output_tokens = summary.get("mean_output_tokens", 0) + + total_throughput = ( + (mean_output_tokens * len(completed_requests)) / elapsed_time + if elapsed_time > 0 + else 0.0 + ) + incremental_throughput = ( + (mean_output_tokens * len(completed_requests)) / incremental_time_delay + if incremental_time_delay > 0 + else 0.0 + ) + return round(total_throughput, 4), round(incremental_throughput, 4) + + +def metrics_summary( + metrics: List[Dict[str, Any]], start_time: int, end_time: int +) -> Dict[str, Any]: + """Generate a summary over metrics generated from potentially multiple instances of this client. + + Args: + metrics: The metrics to summarize. + start_time: The time the test started. + end_time: The time the test ended. + + Returns: + A summary with the following information: + - Overall throughput (generated tokens / total test time) + - Number of completed requests + - Error rate + - Error code frequency + - Quantiles (p25-p99) for the following metrics: + - Inter token latency + - Time to first token + - User total request time + - Number of tokens processed per request + - Number of tokens generated per request + - User throughput (tokens / s) + """ + ret = {} + + def flatten(item): + for sub_item in item: + if isinstance(sub_item, Iterable) and not isinstance(sub_item, str): + yield from flatten(sub_item) + else: + yield sub_item + + df = pd.DataFrame(metrics) + df_without_errored_req = df[df[common_metrics.ERROR_CODE].isna()] + + for key in [ + common_metrics.INTER_TOKEN_LAT, + common_metrics.TTFT, + common_metrics.E2E_LAT, + common_metrics.REQ_OUTPUT_THROUGHPUT, + common_metrics.NUM_INPUT_TOKENS, + common_metrics.NUM_OUTPUT_TOKENS, + ]: + print(key) + ret[key] = {} + series = pd.Series(list(flatten(df_without_errored_req[key]))).dropna() + series = series[series > 0] # Calculate non-zero values + quantiles = series.quantile([0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).to_dict() + quantiles_reformatted_keys = {} + for quantile, value in quantiles.items(): + reformatted_key = f"p{int(quantile * 100)}" + print(f" {reformatted_key} = {value}") + quantiles_reformatted_keys[reformatted_key] = value + ret[key]["quantiles"] = quantiles_reformatted_keys + mean = series.mean() + print(f" mean = {mean}") + ret[key]["mean"] = mean + print(f" min = {series.min()}") + ret[key]["min"] = series.min() + print(f" max = {series.max()}") + ret[key]["max"] = series.max() + print(f" stddev = {series.std()}") + ret[key]["stddev"] = series.std() + + ret[common_metrics.NUM_REQ_STARTED] = len(metrics) + + error_codes = df[common_metrics.ERROR_CODE].dropna() + num_errors = len(error_codes) + ret[common_metrics.ERROR_RATE] = num_errors / len(metrics) if len(metrics) else 0 + ret[common_metrics.NUM_ERRORS] = num_errors + print(f"Number Of Errored Requests: {num_errors}") + error_code_frequency = dict(error_codes.value_counts()) + if num_errors: + error_code_frequency = dict(error_codes.value_counts()) + print("Error Code Frequency") + print(error_code_frequency) + ret[common_metrics.ERROR_CODE_FREQ] = str(error_code_frequency) + + overall_output_throughput = df_without_errored_req[ + common_metrics.NUM_OUTPUT_TOKENS + ].sum() / (end_time - start_time) + + print(f"Overall Output Throughput: {overall_output_throughput}") + ret[common_metrics.OUTPUT_THROUGHPUT] = overall_output_throughput + + num_completed_requests = len(df_without_errored_req) + num_completed_requests_per_min = ( + num_completed_requests / (end_time - start_time) * 60 + ) + print(f"Number Of Completed Requests: {num_completed_requests}") + print(f"Completed Requests Per Minute: {num_completed_requests_per_min}") + + ret[common_metrics.NUM_COMPLETED_REQUESTS] = num_completed_requests + ret[common_metrics.COMPLETED_REQUESTS_PER_MIN] = num_completed_requests_per_min + + return ret + + +def run_token_benchmark( + llm_api: str, + model: str, + test_timeout_s: int, + max_num_completed_requests: int, + concurrent_requests: int, + mean_input_tokens: int, + stddev_input_tokens: int, + mean_output_tokens: int, + stddev_output_tokens: int, + additional_sampling_params: str, + results_dir: str, + random_seed: int, + openai_api_base: str, + tokenizer_path: str, + user_metadata: Dict[str, Any], +): + """ + Args: + llm_api: The name of the llm api to use. + model: The name of the model to query. + max_num_completed_requests: The number of requests to complete before finishing the test. + test_timeout_s: The amount of time to run the test for before reporting results. + concurrent_requests: The number of concurrent requests to make. Increase + this to increase the amount of load and vice versa. + mean_input_tokens: The mean number of tokens to send in the prompt for the request. + stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. + mean_output_tokens: The mean number of tokens to generate per request. + stddev_output_tokens: The standard deviation of the number of tokens to generate per request. + additional_sampling_params: Additional sampling parameters to send with the request. + For more information see the LLM APIs documentation for the completions. + results_dir: The directory to save the results to. + user_metadata: Additional metadata to include in the results. + """ + if mean_input_tokens < 40: + print( + "the minimum number of input tokens that will be sent is 41" + " because of the prompting logic right now" + ) + + summary, completed_requests, elapsed_time, incremental_time_delay = ( + get_token_throughput_latencies( + model=model, + llm_api=llm_api, + test_timeout_s=test_timeout_s, + max_num_completed_requests=max_num_completed_requests, + mean_input_tokens=mean_input_tokens, + stddev_input_tokens=stddev_input_tokens, + mean_output_tokens=mean_output_tokens, + stddev_output_tokens=stddev_output_tokens, + concurrent_requests=concurrent_requests, + additional_sampling_params=json.loads(additional_sampling_params), + random_seed=random_seed, + openai_api_base=openai_api_base, + tokenizer_path=tokenizer_path, + ) + ) + if mean_output_tokens == 2: + return summary, completed_requests, elapsed_time, incremental_time_delay + + timestamp = int(time.time() * 1000) + if results_dir: + filename = f"{model}_{mean_input_tokens}_{mean_output_tokens}_{timestamp}" + filename = re.sub(r"[^\w\d-]+", "-", filename) + filename = re.sub(r"-{2,}", "-", filename) + summary_filename = f"{filename}_summary" + + # Update to metadata. + summary.update(user_metadata) + total_tp, req_tp = compute_throughput( + summary, completed_requests, elapsed_time, incremental_time_delay + ) + summary["num_completed_requests"] = len(completed_requests) + summary["elapsed_time"] = elapsed_time + summary["incremental_time_delay"] = incremental_time_delay + summary["total_throughput"] = total_tp + summary["incremental_throughput"] = req_tp + + results = LLMPerfResults(name=summary_filename, metadata=summary) + results_dir = Path(results_dir) + if not results_dir.exists(): + results_dir.mkdir(parents=True) + elif not results_dir.is_dir(): + raise ValueError(f"{results_dir} is not a directory") + + llmperf_dir = results_dir / "llmperf" + if not llmperf_dir.exists(): + llmperf_dir.mkdir(parents=True) + elif not llmperf_dir.is_dir(): + raise ValueError(f"{llmperf_dir} is not a directory") + + try: + with open(llmperf_dir / f"{summary_filename}.json", "w") as f: + json.dump(results.to_dict(), f, indent=4, default=str) + except Exception as e: + print(results.to_dict()) + raise e + return summary diff --git a/test/common/llmperf/utils/utils.py b/test/common/llmperf/utils/utils.py new file mode 100644 index 00000000..e2c27087 --- /dev/null +++ b/test/common/llmperf/utils/utils.py @@ -0,0 +1,171 @@ +import hashlib +import json +import math +import os +import pathlib +import random +import subprocess +import time +from typing import Any, Dict, Tuple + +from transformers import LlamaTokenizerFast + +RESULTS_VERSION = "2025-10-30" + + +class LLMPerfResults: + def __init__( + self, + name: str, + metadata: Dict[str, Any] = None, + ): + self.name = name + self.metadata = metadata or {} + self.timestamp = int(time.time()) + self.metadata["timestamp"] = self.timestamp + self.version = RESULTS_VERSION + + def to_dict(self): + data = { + "version": self.version, + "name": self.name, + } + data.update(self.metadata) + data = flatten_dict(data) + return data + + def json(self): + data = self.to_dict() + return json.dumps(data) + + +def upload_to_s3(results_path: str, s3_path: str) -> None: + """Upload the results to s3. + + Args: + results_path: The path to the results file. + s3_path: The s3 path to upload the results to. + + """ + + command = ["aws", "s3", "sync", results_path, f"{s3_path}/"] + result = subprocess.run(command) + if result.returncode == 0: + print("Files uploaded successfully!") + else: + print("An error occurred:") + print(result.stderr) + + +def randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean: int = 550, + prompt_tokens_stddev: int = 250, + tokenizer: LlamaTokenizerFast = None, +) -> Tuple[str, int]: + """Generate a prompt that randomly samples lines from a the shakespeare sonnet at sonnet.txt. + + Args: + prompt_length_mean: The mean length of the prompt to generate. + prompt_len_stddev: The standard deviation of the length of the prompt to generate. + expect_output_tokens: The number of tokens to expect in the output. This is used to + determine the length of the prompt. The prompt will be generated such that the output + will be approximately this many tokens. + + Note: + tokens will be counted from the sonnet using the Llama tokenizer. Using one tokenizer + ensures a fairer comparison across different LLMs. For example, if gpt 3.5 tokenizes + a prompt in less tokens than Llama2, then this will be reflected in the results since + they will be fed identical prompts. + + Returns: + A tuple of the prompt and the length of the prompt. + """ + get_token_length = lambda text: len(tokenizer.encode(text)) + + prompt = ( + "Randomly stream lines from the following text " + "Don't generate eos tokens:\n\n" + ) + # get a prompt length that is at least as long as the base + num_prompt_tokens = sample_random_positive_int( + prompt_tokens_mean, prompt_tokens_stddev + ) + while num_prompt_tokens < get_token_length(prompt): + num_prompt_tokens = sample_random_positive_int( + prompt_tokens_mean, prompt_tokens_stddev + ) + remaining_prompt_tokens = num_prompt_tokens - get_token_length(prompt) + sonnet_path = pathlib.Path(__file__).parent.resolve() / "sonnet.txt" + with open(sonnet_path, "r") as f: + sonnet_lines = f.readlines() + random.shuffle(sonnet_lines) + sampling_lines = True + while sampling_lines: + for line in sonnet_lines: + line_to_add = line + if remaining_prompt_tokens - get_token_length(line_to_add) < 0: + # This will cut off a line in the middle of a word, but that's ok since an + # llm should be able to handle that. + line_to_add = line_to_add[: int(math.ceil(remaining_prompt_tokens))] + sampling_lines = False + prompt += line_to_add + break + prompt += line_to_add + remaining_prompt_tokens -= get_token_length(line_to_add) + print(hashlib.sha256(prompt.encode("utf-8")).hexdigest()) + return (prompt, num_prompt_tokens) + + +def sample_random_positive_int(mean: int, stddev: int) -> int: + """Sample random numbers from a gaussian distribution until a positive number is sampled. + + Args: + mean: The mean of the gaussian distribution to sample from. + stddev: The standard deviation of the gaussian distribution to sample from. + + Returns: + A random positive integer sampled from the gaussian distribution. + """ + ret = -1 + while ret <= 0: + ret = int(random.gauss(mean, stddev)) + return ret + + +def flatten_dict(d, parent_key="", sep="_"): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def reset_prefill_cache(env, server_url): + """ + prefix cache / HBM + Param: + env + server_url + """ + reset_url = f"{server_url}/reset_prefix_cache" + print(f"[INFO] Resetting prefix cache: {reset_url}") + try: + result = subprocess.run( + ["curl", "-X", "POST", reset_url, "-s", "-f"], + env=env, + check=False, + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + print("[INFO] Prefix cache successfully reset") + else: + print( + f"[ERROR] Unsuccessfully reset prefix cache,error code: {result.returncode}" + ) + except Exception as e: + print(f"[ERROR] Exception in resetting prefix cache: {e}") diff --git a/test/config.yaml b/test/config.yaml index 88d00a61..7ac32f48 100644 --- a/test/config.yaml +++ b/test/config.yaml @@ -15,4 +15,13 @@ database: name: "ucm_pytest" user: "root" password: "123456" - charset: "utf8mb4" \ No newline at end of file + charset: "utf8mb4" + +# LLM Connection Configuration +llm_connection: + model: "qwen3" + server_url: "http://141.111.32.70:9382" + tokenizer_path: "/home/models/QwQ-32B" + stream: true # stream output + ignore_eos: true # Ignore the returned terminator + timeout: 180 # request time out \ No newline at end of file diff --git a/test/suites/E2E/test_demo_performance.py b/test/suites/E2E/test_demo_performance.py deleted file mode 100644 index 1b76818f..00000000 --- a/test/suites/E2E/test_demo_performance.py +++ /dev/null @@ -1,66 +0,0 @@ -import pytest -from common.config_utils import config_utils as config_instance - - -# ---------------- Fixture Example ---------------- -class Calculator: - def __init__(self): - print("[Calculator Initialization]") - pass - - def add(self, a, b): - return a + b - - def divide(self, a, b): - if b == 0: - raise ZeroDivisionError("Cannot divide by zero") - return a / b - - -@pytest.fixture(scope="module", name="calc") -def calculator(): - return Calculator() - - -@pytest.mark.feature("mark") -class TestCalculator: - # The calc instance will only be initialized on the first call, see the pytest documentation for more usage - def test_add(self, calc): - assert calc.add(1, 2) == 3 - - def test_divide(self, calc): - assert calc.divide(6, 2) == 3 - - def test_divide_by_zero(self, calc): - with pytest.raises(ZeroDivisionError): - calc.divide(6, 0) - - -# ---------------- Write to DB Example ---------------- -from common.capture_utils import * - - -@pytest.mark.feature("capture") # pytest must be the top -@export_vars -def test_capture_mix(): - """Mixed single + lists via '_name' + '_data'""" - assert 1 == 1 - return { - "_name": "demo", - "_data": { - "length": 10086, # single value - "accuracy": [0.1, 0.2, 0.3], # list - "loss": [0.1, 0.2, 0.3], # list - }, - } - - -# ---------------- Read Config Example ---------------- -from common.config_utils import config_utils as config_instance - - -@pytest.mark.feature("config") -def test_config(): - assert ( - config_instance.get_nested_config("database.host", "localhost") == "127.0.0.1" - ) diff --git a/test/suites/E2E/test_uc_performance.py b/test/suites/E2E/test_uc_performance.py new file mode 100644 index 00000000..dbec0318 --- /dev/null +++ b/test/suites/E2E/test_uc_performance.py @@ -0,0 +1,158 @@ +import pytest +from common.capture_utils import export_vars +from common.llmperf.run_inference import inference_results + + +@pytest.mark.parametrize("mean_input_tokens", [[2000, 3000]]) +@pytest.mark.parametrize("mean_output_tokens", [[200, 500]]) +@pytest.mark.parametrize("max_num_completed_requests", [[8, 4]]) +@pytest.mark.parametrize("concurrent_requests", [[8, 4]]) +@pytest.mark.parametrize("additional_sampling_params", [["{}", "{}"]]) +@pytest.mark.parametrize("hit_rate", [[0, 50]]) +@pytest.mark.feature("uc_performance_test") +@export_vars +def test_performance( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, +): + all_summaries = inference_results( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, + ) + failed_cases = [] + + value_lists = { + "mean_input_tokens": [], + "mean_output_tokens": [], + "results_inter_token_latency_s_quantiles_p50": [], + "results_inter_token_latency_s_quantiles_p90": [], + "results_inter_token_latency_s_quantiles_p99": [], + "results_inter_token_latency_s_mean": [], + "results_ttft_s_quantiles_p50": [], + "results_ttft_s_quantiles_p90": [], + "results_ttft_s_quantiles_p99": [], + "results_ttft_s_mean": [], + "results_end_to_end_latency_s_quantiles_p50": [], + "results_end_to_end_latency_s_quantiles_p90": [], + "results_end_to_end_latency_s_quantiles_p99": [], + "results_end_to_end_latency_s_mean": [], + "num_completed_requests": [], + "elapsed_time": [], + "incremental_time_delay": [], + "total_throughput": [], + "incremental_throughput": [], + } + + for i, summary in enumerate(all_summaries): + mean_input_tokens = summary["mean_input_tokens"] + mean_output_tokens = summary["mean_output_tokens"] + + results_inter_token_latency_s_quantiles_p50 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p50"] + results_inter_token_latency_s_quantiles_p90 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p90"] + results_inter_token_latency_s_quantiles_p99 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p99"] + results_inter_token_latency_s_mean = summary["results"][ + "inter_token_latency_s" + ]["mean"] + + results_ttft_s_quantiles_p50 = summary["results"]["ttft_s"]["quantiles"]["p50"] + results_ttft_s_quantiles_p90 = summary["results"]["ttft_s"]["quantiles"]["p90"] + results_ttft_s_quantiles_p99 = summary["results"]["ttft_s"]["quantiles"]["p99"] + results_ttft_s_mean = summary["results"]["ttft_s"]["mean"] + + results_end_to_end_latency_s_quantiles_p50 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p50"] + results_end_to_end_latency_s_quantiles_p90 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p90"] + results_end_to_end_latency_s_quantiles_p99 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p99"] + results_end_to_end_latency_s_mean = summary["results"]["end_to_end_latency_s"][ + "mean" + ] + + num_completed_requests = summary["num_completed_requests"] + elapsed_time = summary["elapsed_time"] + incremental_time_delay = summary["incremental_time_delay"] + total_throughput = summary["total_throughput"] + incremental_throughput = summary["incremental_throughput"] + + values = [ + mean_input_tokens, + mean_output_tokens, + results_inter_token_latency_s_quantiles_p50, + results_inter_token_latency_s_quantiles_p90, + results_inter_token_latency_s_quantiles_p99, + results_inter_token_latency_s_mean, + results_ttft_s_quantiles_p50, + results_ttft_s_quantiles_p90, + results_ttft_s_quantiles_p99, + results_ttft_s_mean, + results_end_to_end_latency_s_quantiles_p50, + results_end_to_end_latency_s_quantiles_p90, + results_end_to_end_latency_s_quantiles_p99, + results_end_to_end_latency_s_mean, + num_completed_requests, + elapsed_time, + incremental_time_delay, + total_throughput, + incremental_throughput, + ] + + for var_name, val in zip( + [ + "mean_input_tokens", + "mean_output_tokens", + "results_inter_token_latency_s_quantiles_p50", + "results_inter_token_latency_s_quantiles_p90", + "results_inter_token_latency_s_quantiles_p99", + "results_inter_token_latency_s_mean", + "results_ttft_s_quantiles_p50", + "results_ttft_s_quantiles_p90", + "results_ttft_s_quantiles_p99", + "results_ttft_s_mean", + "results_end_to_end_latency_s_quantiles_p50", + "results_end_to_end_latency_s_quantiles_p90", + "results_end_to_end_latency_s_quantiles_p99", + "results_end_to_end_latency_s_mean", + "num_completed_requests", + "elapsed_time", + "incremental_time_delay", + "total_throughput", + "incremental_throughput", + ], + values, + ): + value_lists[var_name].append(val) + if val is None: + failed_cases.append((i, var_name, "missing")) + + try: + assert val > 0, f"value <= 0" + except AssertionError as e: + failed_cases.append((i, var_name, str(e))) + + # Output final result + if failed_cases: + print(f"\n[WARNING] Assertion failed: {len(failed_cases)} abnormal cases found") + for i, key, reason in failed_cases: + print(f" Iteration={i + 1}, key='{key}' -> {reason}") + else: + print("\n[INFO] All values are greater than 0. Assertion passed!") + + return {"_name": "llmperf", "_data": value_lists} diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py index 738670ee..2b63838b 100644 --- a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_patch.py @@ -79,6 +79,13 @@ def maybe_execute_sparse_attention_finished( ): if not has_ucm_sparse(): return + ucm_sparse = get_ucm_sparse() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + ucm_sparse.attention_finished( + query, key, value, attn_output, layer_name, forward_context + ) attention_v1.maybe_execute_sparse_attention_finished = ( maybe_execute_sparse_attention_finished diff --git a/ucm/shared/CMakeLists.txt b/ucm/shared/CMakeLists.txt index f44b8522..1f73d1e8 100644 --- a/ucm/shared/CMakeLists.txt +++ b/ucm/shared/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(vendor) +add_subdirectory(infra) add_subdirectory(trans) add_subdirectory(metrics) add_subdirectory(test) diff --git a/ucm/shared/infra/CMakeLists.txt b/ucm/shared/infra/CMakeLists.txt new file mode 100644 index 00000000..ba4345dc --- /dev/null +++ b/ucm/shared/infra/CMakeLists.txt @@ -0,0 +1,22 @@ +file(GLOB_RECURSE UCMINFRA_STATUS_SOURCE_FILES "status/*.*") +add_library(infra_status OBJECT ${UCMINFRA_STATUS_SOURCE_FILES}) +target_include_directories(infra_status PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(infra_status PUBLIC fmt) + +file(GLOB UCMINFRA_LOGGER_SOURCE_FILES "logger/*.*") +file(GLOB_RECURSE UCMINFRA_LOGGER_DETAIL_SOURCE_FILES "logger/${LOGGER_BACKEND}/*.cc") +add_library(infra_logger OBJECT ${UCMINFRA_LOGGER_SOURCE_FILES} ${UCMINFRA_LOGGER_DETAIL_SOURCE_FILES}) +target_include_directories(infra_logger PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(infra_logger PUBLIC fmt spdlog) + +file(GLOB_RECURSE UCMINFRA_TEMPLATE_SOURCE_FILES "template/*.*") +add_library(infra_template OBJECT ${UCMINFRA_TEMPLATE_SOURCE_FILES}) +target_include_directories(infra_template PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +file(GLOB_RECURSE UCMINFRA_THREAD_SOURCE_FILES "thread/*.*") +add_library(infra_thread OBJECT ${UCMINFRA_THREAD_SOURCE_FILES}) +target_include_directories(infra_thread PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +file(GLOB_RECURSE UCMINFRA_TIME_SOURCE_FILES "time/*.*") +add_library(infra_time OBJECT ${UCMINFRA_TIME_SOURCE_FILES}) +target_include_directories(infra_time PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/ucm/store/infra/logger/flux/flux_logger.cc b/ucm/shared/infra/logger/flux/flux_logger.cc similarity index 100% rename from ucm/store/infra/logger/flux/flux_logger.cc rename to ucm/shared/infra/logger/flux/flux_logger.cc diff --git a/ucm/store/infra/logger/logger.h b/ucm/shared/infra/logger/logger.h similarity index 97% rename from ucm/store/infra/logger/logger.h rename to ucm/shared/infra/logger/logger.h index f27dd23d..516b9e66 100644 --- a/ucm/store/infra/logger/logger.h +++ b/ucm/shared/infra/logger/logger.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_LOGGER_H -#define UNIFIEDCACHE_LOGGER_H +#ifndef UNIFIEDCACHE_INFRA_LOGGER_H +#define UNIFIEDCACHE_INFRA_LOGGER_H #include #include diff --git a/ucm/store/infra/logger/spdlog/spdlog_logger.cc b/ucm/shared/infra/logger/spdlog/spdlog_logger.cc similarity index 100% rename from ucm/store/infra/logger/spdlog/spdlog_logger.cc rename to ucm/shared/infra/logger/spdlog/spdlog_logger.cc diff --git a/ucm/shared/trans/status.h b/ucm/shared/infra/status/status.h similarity index 51% rename from ucm/shared/trans/status.h rename to ucm/shared/infra/status/status.h index cab27179..3711de84 100644 --- a/ucm/shared/trans/status.h +++ b/ucm/shared/infra/status/status.h @@ -21,17 +21,34 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_TRANS_STATUS_H -#define UNIFIEDCACHE_TRANS_STATUS_H +#ifndef UNIFIEDCACHE_INFRA_STATUS_H +#define UNIFIEDCACHE_INFRA_STATUS_H +#include #include #include -namespace UC::Trans { +namespace UC { + +template +static inline constexpr int32_t __MakeStatusCode() +{ + return -50000 - i; +} class Status { static constexpr int32_t OK_ = 0; static constexpr int32_t ERROR_ = -1; + static constexpr int32_t EPARAM_ = __MakeStatusCode<0>(); + static constexpr int32_t EOOM_ = __MakeStatusCode<1>(); + static constexpr int32_t EOSERROR_ = __MakeStatusCode<2>(); + static constexpr int32_t EDUPLICATE_ = __MakeStatusCode<3>(); + static constexpr int32_t ERETRY_ = __MakeStatusCode<4>(); + static constexpr int32_t ENOOBJ_ = __MakeStatusCode<5>(); + static constexpr int32_t ESERIALIZE_ = __MakeStatusCode<6>(); + static constexpr int32_t EDESERIALIZE_ = __MakeStatusCode<7>(); + static constexpr int32_t EUNSUPPORTED_ = __MakeStatusCode<8>(); + static constexpr int32_t ENOSPACE_ = __MakeStatusCode<9>(); int32_t code_; std::string message_; explicit Status(int32_t code) : code_(code) {} @@ -39,7 +56,13 @@ class Status { public: bool operator==(const Status& other) const noexcept { return code_ == other.code_; } bool operator!=(const Status& other) const noexcept { return !(*this == other); } - std::string ToString() const { return fmt::format("({}) {}", code_, message_); } + int32_t Underlying() const { return code_; } + std::string ToString() const + { + auto str = std::to_string(code_); + if (message_.empty()) { return str; } + return fmt::format("{}, {}", str, message_); + } constexpr bool Success() const noexcept { return code_ == OK_; } constexpr bool Failure() const noexcept { return !Success(); } @@ -47,8 +70,21 @@ class Status { Status(int32_t code, std::string message) : code_{code}, message_{std::move(message)} {} static Status OK() { return Status{OK_}; } static Status Error(std::string message) { return {ERROR_, std::move(message)}; } + static Status Error() { return Status{ERROR_}; } + static Status InvalidParam() { return Status{EPARAM_}; } + static Status OutOfMemory() { return Status{EOOM_}; } + static Status OsApiError() { return Status{EOSERROR_}; } + static Status DuplicateKey() { return Status{EDUPLICATE_}; } + static Status Retry() { return Status{ERETRY_}; } + static Status NotFound() { return Status{ENOOBJ_}; } + static Status SerializeFailed() { return Status{ESERIALIZE_}; } + static Status DeserializeFailed() { return Status{EDESERIALIZE_}; } + static Status Unsupported() { return Status{EUNSUPPORTED_}; } + static Status NoSpace() { return Status{ENOSPACE_}; } }; -} // namespace UC::Trans +inline std::string format_as(const Status& status) { return status.ToString(); } + +} // namespace UC #endif diff --git a/ucm/store/infra/template/hashset.h b/ucm/shared/infra/template/hashset.h similarity index 98% rename from ucm/store/infra/template/hashset.h rename to ucm/shared/infra/template/hashset.h index b09692bc..102f69b6 100644 --- a/ucm/store/infra/template/hashset.h +++ b/ucm/shared/infra/template/hashset.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_HASHSET_H -#define UNIFIEDCACHE_HASHSET_H +#ifndef UNIFIEDCACHE_INFRA_HASHSET_H +#define UNIFIEDCACHE_INFRA_HASHSET_H #include #include diff --git a/ucm/store/infra/template/singleton.h b/ucm/shared/infra/template/singleton.h similarity index 94% rename from ucm/store/infra/template/singleton.h rename to ucm/shared/infra/template/singleton.h index fda4957b..f667288e 100644 --- a/ucm/store/infra/template/singleton.h +++ b/ucm/shared/infra/template/singleton.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_SINGLETON_H -#define UNIFIEDCACHE_SINGLETON_H +#ifndef UNIFIEDCACHE_INFRA_SINGLETON_H +#define UNIFIEDCACHE_INFRA_SINGLETON_H namespace UC { diff --git a/ucm/store/infra/template/timer.h b/ucm/shared/infra/template/timer.h similarity index 71% rename from ucm/store/infra/template/timer.h rename to ucm/shared/infra/template/timer.h index 1963faa8..0c9db149 100644 --- a/ucm/store/infra/template/timer.h +++ b/ucm/shared/infra/template/timer.h @@ -21,16 +21,14 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_TIMER_H -#define UNIFIEDCACHE_TIMER_H +#ifndef UNIFIEDCACHE_INFRA_TIMER_H +#define UNIFIEDCACHE_INFRA_TIMER_H -#include -#include -#include #include +#include #include -#include "logger/logger.h" -#include "status/status.h" +#include +#include namespace UC { @@ -38,29 +36,30 @@ template class Timer { public: Timer(const std::chrono::seconds& interval, Callable&& callable) - : interval_(interval), callable_(callable), running_(false) {} - ~Timer() { + : interval_(interval), callable_(callable), running_(false) + { + } + ~Timer() + { { std::lock_guard lg(this->mutex_); this->running_ = false; + this->cv_.notify_one(); } - - this->cv_.notify_one(); if (this->thread_.joinable()) { this->thread_.join(); } } - Status Start() + bool Start() { { std::lock_guard lg(this->mutex_); - if (this->running_) { return Status::OK(); } + if (this->running_) { return true; } } try { this->running_ = true; this->thread_ = std::thread(&Timer::Runner, this); - return Status::OK(); - } catch (const std::exception& e) { - UC_ERROR("Failed({}) to start timer.", e.what()); - return Status::OutOfMemory(); + return true; + } catch (...) { + return false; } } @@ -68,14 +67,12 @@ class Timer { void Runner() { while (this->running_) { - try { - { - std::unique_lock lg(this->mutex_); - this->cv_.wait_for(lg, this->interval_, [this] { return !this->running_; }); - if (!this->running_) { break; } - } - this->callable_(); - } catch (const std::exception& e) { UC_ERROR("Failed({}) to run timer.", e.what()); } + { + std::unique_lock lg(this->mutex_); + this->cv_.wait_for(lg, this->interval_, [this] { return !this->running_; }); + if (!this->running_) { break; } + } + this->callable_(); } } @@ -88,6 +85,6 @@ class Timer { std::atomic running_; }; -} // namespace UC +} // namespace UC -#endif \ No newline at end of file +#endif diff --git a/ucm/store/infra/template/topn_heap.h b/ucm/shared/infra/template/topn_heap.h similarity index 97% rename from ucm/store/infra/template/topn_heap.h rename to ucm/shared/infra/template/topn_heap.h index 98884c23..737d0b19 100644 --- a/ucm/store/infra/template/topn_heap.h +++ b/ucm/shared/infra/template/topn_heap.h @@ -22,11 +22,12 @@ * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_TOP_N_HEAP_H -#define UNIFIEDCACHE_TOP_N_HEAP_H +#ifndef UNIFIEDCACHE_INFRA_TOP_N_HEAP_H +#define UNIFIEDCACHE_INFRA_TOP_N_HEAP_H #include #include +#include namespace UC { @@ -117,4 +118,4 @@ class TopNFixedHeap : public TopNHeap { } // namespace UC -#endif \ No newline at end of file +#endif diff --git a/ucm/store/infra/thread/index_pool.h b/ucm/shared/infra/thread/index_pool.h similarity index 97% rename from ucm/store/infra/thread/index_pool.h rename to ucm/shared/infra/thread/index_pool.h index 225ee884..4217b7a0 100644 --- a/ucm/store/infra/thread/index_pool.h +++ b/ucm/shared/infra/thread/index_pool.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_INDEX_POOL_H -#define UNIFIEDCACHE_INDEX_POOL_H +#ifndef UNIFIEDCACHE_INFRA_INDEX_POOL_H +#define UNIFIEDCACHE_INFRA_INDEX_POOL_H #include #include diff --git a/ucm/store/infra/thread/latch.h b/ucm/shared/infra/thread/latch.h similarity index 95% rename from ucm/store/infra/thread/latch.h rename to ucm/shared/infra/thread/latch.h index 1837ca27..fb1dcf58 100644 --- a/ucm/store/infra/thread/latch.h +++ b/ucm/shared/infra/thread/latch.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_LATCH_H -#define UNIFIEDCACHE_LATCH_H +#ifndef UNIFIEDCACHE_INFRA_LATCH_H +#define UNIFIEDCACHE_INFRA_LATCH_H #include #include @@ -66,4 +66,4 @@ class Latch { } // namespace UC -#endif // UNIFIEDCACHE_LATCH_H +#endif // UNIFIEDCACHE_INFRA_LATCH_H diff --git a/ucm/store/infra/thread/thread_pool.h b/ucm/shared/infra/thread/thread_pool.h similarity index 98% rename from ucm/store/infra/thread/thread_pool.h rename to ucm/shared/infra/thread/thread_pool.h index c33a0c28..baa514ed 100644 --- a/ucm/store/infra/thread/thread_pool.h +++ b/ucm/shared/infra/thread/thread_pool.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_THREAD_POOL_H -#define UNIFIEDCACHE_THREAD_POOL_H +#ifndef UNIFIEDCACHE_INFRA_THREAD_POOL_H +#define UNIFIEDCACHE_INFRA_THREAD_POOL_H #include #include diff --git a/ucm/store/infra/time/stopwatch.h b/ucm/shared/infra/time/stopwatch.h similarity index 95% rename from ucm/store/infra/time/stopwatch.h rename to ucm/shared/infra/time/stopwatch.h index 2386f394..c2a5bb33 100644 --- a/ucm/store/infra/time/stopwatch.h +++ b/ucm/shared/infra/time/stopwatch.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_STOPWATCH_H -#define UNIFIEDCACHE_STOPWATCH_H +#ifndef UNIFIEDCACHE_INFRA_STOPWATCH_H +#define UNIFIEDCACHE_INFRA_STOPWATCH_H #include diff --git a/ucm/shared/metrics/cc/stats/conn_stats.cc b/ucm/shared/metrics/cc/stats/conn_stats.cc index a2aafa70..edf18ac2 100644 --- a/ucm/shared/metrics/cc/stats/conn_stats.cc +++ b/ucm/shared/metrics/cc/stats/conn_stats.cc @@ -24,18 +24,18 @@ #include "conn_stats.h" namespace UC::Metrics { - + ConnStats::ConnStats() = default; -std::string ConnStats::Name() const { - return "ConnStats"; -} +std::string ConnStats::Name() const { return "ConnStats"; } -void ConnStats::Reset() { +void ConnStats::Reset() +{ for (auto& v : data_) v.clear(); } -void ConnStats::Update(const std::unordered_map& params) { +void ConnStats::Update(const std::unordered_map& params) +{ for (const auto& [k, v] : params) { Key id = KeyFromString(k); if (id == Key::COUNT) continue; @@ -43,7 +43,8 @@ void ConnStats::Update(const std::unordered_map& params) { } } -std::unordered_map> ConnStats::Data() { +std::unordered_map> ConnStats::Data() +{ std::unordered_map> result; result["save_requests_num"] = data_[static_cast(Key::save_requests_num)]; result["save_blocks_num"] = data_[static_cast(Key::save_blocks_num)]; @@ -53,27 +54,30 @@ std::unordered_map> ConnStats::Data() { result["load_blocks_num"] = data_[static_cast(Key::load_blocks_num)]; result["load_duration"] = data_[static_cast(Key::load_duration)]; result["load_speed"] = data_[static_cast(Key::load_speed)]; - result["interval_lookup_hit_rates"] = data_[static_cast(Key::interval_lookup_hit_rates)]; + result["interval_lookup_hit_rates"] = + data_[static_cast(Key::interval_lookup_hit_rates)]; return result; } -Key ConnStats::KeyFromString(const std::string& k) { - if (k == "save_requests_num") return Key::save_requests_num; - if (k == "save_blocks_num") return Key::save_blocks_num; - if (k == "save_duration") return Key::save_duration; - if (k == "save_speed") return Key::save_speed; - if (k == "load_requests_num") return Key::load_requests_num; - if (k == "load_blocks_num") return Key::load_blocks_num; - if (k == "load_duration") return Key::load_duration; - if (k == "load_speed") return Key::load_speed; - if (k == "interval_lookup_hit_rates")return Key::interval_lookup_hit_rates; +Key ConnStats::KeyFromString(const std::string& k) +{ + if (k == "save_requests_num") return Key::save_requests_num; + if (k == "save_blocks_num") return Key::save_blocks_num; + if (k == "save_duration") return Key::save_duration; + if (k == "save_speed") return Key::save_speed; + if (k == "load_requests_num") return Key::load_requests_num; + if (k == "load_blocks_num") return Key::load_blocks_num; + if (k == "load_duration") return Key::load_duration; + if (k == "load_speed") return Key::load_speed; + if (k == "interval_lookup_hit_rates") return Key::interval_lookup_hit_rates; return Key::COUNT; } -void ConnStats::EmplaceBack(Key id, double value) { +void ConnStats::EmplaceBack(Key id, double value) +{ data_[static_cast(id)].push_back(value); } static Registrar registrar; -} \ No newline at end of file +} // namespace UC::Metrics \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats/conn_stats.h b/ucm/shared/metrics/cc/stats/conn_stats.h index 34a76a68..e8cc9455 100644 --- a/ucm/shared/metrics/cc/stats/conn_stats.h +++ b/ucm/shared/metrics/cc/stats/conn_stats.h @@ -24,21 +24,21 @@ #ifndef UNIFIEDCACHE_CONNSTATS_H #define UNIFIEDCACHE_CONNSTATS_H -#include "istats.h" -#include "stats_registry.h" #include -#include -#include -#include #include +#include +#include +#include +#include "istats.h" +#include "stats_registry.h" -namespace UC::Metrics { +namespace UC::Metrics { enum class Key : uint8_t { interval_lookup_hit_rates = 0, save_requests_num, save_blocks_num, - save_duration , + save_duration, save_speed, load_requests_num, load_blocks_num, @@ -66,13 +66,13 @@ class ConnStats : public IStats { }; struct Registrar { - Registrar() { - StatsRegistry::RegisterStats("ConnStats", []()->std::unique_ptr { - return std::make_unique(); - }); + Registrar() + { + StatsRegistry::RegisterStats( + "ConnStats", []() -> std::unique_ptr { return std::make_unique(); }); } }; -} +} // namespace UC::Metrics -#endif // UNIFIEDCACHE_CONNSTATS_H \ No newline at end of file +#endif // UNIFIEDCACHE_CONNSTATS_H \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats/istats.h b/ucm/shared/metrics/cc/stats/istats.h index 56a6e8e1..6e8de7b3 100644 --- a/ucm/shared/metrics/cc/stats/istats.h +++ b/ucm/shared/metrics/cc/stats/istats.h @@ -24,8 +24,8 @@ #ifndef UNIFIEDCACHE_ISTATS_H #define UNIFIEDCACHE_ISTATS_H -#include #include +#include #include #include @@ -40,6 +40,6 @@ class IStats { virtual std::unordered_map> Data() = 0; }; -} +} // namespace UC::Metrics -#endif \ No newline at end of file +#endif \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats_monitor.cc b/ucm/shared/metrics/cc/stats_monitor.cc index 8b83920c..2d3d8026 100644 --- a/ucm/shared/metrics/cc/stats_monitor.cc +++ b/ucm/shared/metrics/cc/stats_monitor.cc @@ -21,60 +21,62 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include "stats/istats.h" -#include "stats_registry.h" #include "stats_monitor.h" #include #include +#include "stats/istats.h" +#include "stats_registry.h" namespace UC::Metrics { -StatsMonitor::StatsMonitor() { +StatsMonitor::StatsMonitor() +{ auto& registry = StatsRegistry::GetInstance(); for (const auto& name : registry.GetRegisteredStatsNames()) { stats_map_[name] = registry.CreateStats(name); } } -void StatsMonitor::CreateStats(const std::string& name) { +void StatsMonitor::CreateStats(const std::string& name) +{ std::lock_guard lock(mutex_); auto& registry = StatsRegistry::GetInstance(); stats_map_[name] = registry.CreateStats(name); } -std::unordered_map> StatsMonitor::GetStats(const std::string& name) { +std::unordered_map> StatsMonitor::GetStats(const std::string& name) +{ std::lock_guard lock(mutex_); return stats_map_[name]->Data(); } -void StatsMonitor::ResetStats(const std::string& name) { +void StatsMonitor::ResetStats(const std::string& name) +{ std::lock_guard lock(mutex_); stats_map_[name]->Reset(); } -std::unordered_map> StatsMonitor::GetStatsAndClear(const std::string& name) { +std::unordered_map> +StatsMonitor::GetStatsAndClear(const std::string& name) +{ std::lock_guard lock(mutex_); auto result = stats_map_[name]->Data(); stats_map_[name]->Reset(); return result; } -void StatsMonitor::UpdateStats( - const std::string& name, - const std::unordered_map& params) +void StatsMonitor::UpdateStats(const std::string& name, + const std::unordered_map& params) { std::lock_guard lock(mutex_); auto it = stats_map_.find(name); - if (it != stats_map_.end()) { - it->second->Update(params); - } + if (it != stats_map_.end()) { it->second->Update(params); } } -void StatsMonitor::ResetAllStats() { +void StatsMonitor::ResetAllStats() +{ std::lock_guard lock(mutex_); - for (auto& [n, ptr] : stats_map_) { - ptr->Reset(); - } + for (auto& [n, ptr] : stats_map_) { ptr->Reset(); } } -} \ No newline at end of file +} // namespace UC::Metrics \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats_monitor.h b/ucm/shared/metrics/cc/stats_monitor.h index b8bd688a..1545d4b5 100644 --- a/ucm/shared/metrics/cc/stats_monitor.h +++ b/ucm/shared/metrics/cc/stats_monitor.h @@ -24,18 +24,19 @@ #ifndef UNIFIEDCACHE_MONITOR_H #define UNIFIEDCACHE_MONITOR_H -#include "stats/istats.h" -#include #include +#include #include +#include #include +#include "stats/istats.h" namespace UC::Metrics { class StatsMonitor { public: - - static StatsMonitor& GetInstance() { + static StatsMonitor& GetInstance() + { static StatsMonitor inst; return inst; } @@ -44,16 +45,14 @@ class StatsMonitor { void CreateStats(const std::string& name); - std::unordered_map> - GetStats(const std::string& name); - + std::unordered_map> GetStats(const std::string& name); + void ResetStats(const std::string& name); - std::unordered_map> - GetStatsAndClear(const std::string& name); + std::unordered_map> GetStatsAndClear(const std::string& name); void UpdateStats(const std::string& name, - const std::unordered_map& params); + const std::unordered_map& params); void ResetAllStats(); @@ -66,6 +65,6 @@ class StatsMonitor { StatsMonitor& operator=(const StatsMonitor&) = delete; }; -} +} // namespace UC::Metrics -#endif // UNIFIEDCACHE_MONITOR_H \ No newline at end of file +#endif // UNIFIEDCACHE_MONITOR_H \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats_registry.cc b/ucm/shared/metrics/cc/stats_registry.cc index 4a64516b..c2551d9a 100644 --- a/ucm/shared/metrics/cc/stats_registry.cc +++ b/ucm/shared/metrics/cc/stats_registry.cc @@ -25,26 +25,29 @@ namespace UC::Metrics { -StatsRegistry& StatsRegistry::GetInstance() { - static StatsRegistry inst; +StatsRegistry& StatsRegistry::GetInstance() +{ + static StatsRegistry inst; return inst; } -void StatsRegistry::RegisterStats(std::string name, Creator creator) { - auto& reg = GetInstance(); +void StatsRegistry::RegisterStats(std::string name, Creator creator) +{ + auto& reg = GetInstance(); std::lock_guard lk(reg.mutex_); reg.registry_[name] = creator; } -std::unique_ptr StatsRegistry::CreateStats(const std::string& name) { +std::unique_ptr StatsRegistry::CreateStats(const std::string& name) +{ auto& reg = GetInstance(); std::lock_guard lk(reg.mutex_); - if (auto it = reg.registry_.find(name); it != reg.registry_.end()) - return it->second(); + if (auto it = reg.registry_.find(name); it != reg.registry_.end()) return it->second(); return nullptr; } -std::vector StatsRegistry::GetRegisteredStatsNames() { +std::vector StatsRegistry::GetRegisteredStatsNames() +{ auto& reg = GetInstance(); std::lock_guard lk(reg.mutex_); std::vector names; @@ -53,4 +56,4 @@ std::vector StatsRegistry::GetRegisteredStatsNames() { return names; } -} // namespace UC::Metrics \ No newline at end of file +} // namespace UC::Metrics \ No newline at end of file diff --git a/ucm/shared/metrics/cc/stats_registry.h b/ucm/shared/metrics/cc/stats_registry.h index f7bb32fb..c22b6617 100644 --- a/ucm/shared/metrics/cc/stats_registry.h +++ b/ucm/shared/metrics/cc/stats_registry.h @@ -24,14 +24,14 @@ #ifndef UNIFIEDCACHE_REGISTRY_H #define UNIFIEDCACHE_REGISTRY_H -#include "stats/istats.h" -#include #include #include +#include +#include "stats/istats.h" namespace UC::Metrics { -using Creator = std::unique_ptr(*)(); +using Creator = std::unique_ptr (*)(); class StatsRegistry { public: @@ -53,6 +53,6 @@ class StatsRegistry { std::unordered_map registry_; }; -} +} // namespace UC::Metrics -#endif // UNIFIEDCACHE_REGISTRY_H \ No newline at end of file +#endif // UNIFIEDCACHE_REGISTRY_H \ No newline at end of file diff --git a/ucm/shared/metrics/cpy/metrics.py.cc b/ucm/shared/metrics/cpy/metrics.py.cc index fd0a3aea..10bfc2f9 100644 --- a/ucm/shared/metrics/cpy/metrics.py.cc +++ b/ucm/shared/metrics/cpy/metrics.py.cc @@ -28,19 +28,20 @@ namespace py = pybind11; namespace UC::Metrics { -void bind_monitor(py::module_& m) { +void bind_monitor(py::module_& m) +{ py::class_(m, "StatsMonitor") - .def_static("get_instance", &StatsMonitor::GetInstance, - py::return_value_policy::reference) + .def_static("get_instance", &StatsMonitor::GetInstance, py::return_value_policy::reference) .def("update_stats", &StatsMonitor::UpdateStats) .def("reset_all", &StatsMonitor::ResetAllStats) .def("get_stats", &StatsMonitor::GetStats) .def("get_stats_and_clear", &StatsMonitor::GetStatsAndClear); } -} // namespace UC +} // namespace UC::Metrics -PYBIND11_MODULE(ucmmonitor, module) { +PYBIND11_MODULE(ucmmonitor, module) +{ module.attr("project") = UCM_PROJECT_NAME; module.attr("version") = UCM_PROJECT_VERSION; module.attr("commit_id") = UCM_COMMIT_ID; diff --git a/ucm/store/test/case/infra/hashset_test.cc b/ucm/shared/test/case/infra/hashset_test.cc similarity index 100% rename from ucm/store/test/case/infra/hashset_test.cc rename to ucm/shared/test/case/infra/hashset_test.cc diff --git a/ucm/shared/test/case/trans/trans_test.cc b/ucm/shared/test/case/trans/trans_test.cc index f38769ce..4f2415b1 100644 --- a/ucm/shared/test/case/trans/trans_test.cc +++ b/ucm/shared/test/case/trans/trans_test.cc @@ -28,7 +28,7 @@ class UCTransUnitTest : public ::testing::Test {}; TEST_F(UCTransUnitTest, CopyDataWithCE) { - const auto ok = UC::Trans::Status::OK(); + const auto ok = UC::Status::OK(); constexpr int32_t deviceId = 0; constexpr size_t size = 36 * 1024; constexpr size_t number = 64 * 61; @@ -60,7 +60,7 @@ TEST_F(UCTransUnitTest, CopyDataWithCE) TEST_F(UCTransUnitTest, CopyDataWithSM) { - const auto ok = UC::Trans::Status::OK(); + const auto ok = UC::Status::OK(); constexpr int32_t deviceId = 0; constexpr size_t size = 36 * 1024; constexpr size_t number = 64 * 61; diff --git a/ucm/shared/trans/CMakeLists.txt b/ucm/shared/trans/CMakeLists.txt index bbf001fc..57a1bd0a 100644 --- a/ucm/shared/trans/CMakeLists.txt +++ b/ucm/shared/trans/CMakeLists.txt @@ -8,6 +8,7 @@ if(RUNTIME_ENVIRONMENT STREQUAL "simu") add_subdirectory(simu) endif() target_include_directories(trans PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..) +target_link_libraries(trans PUBLIC infra_status) file(GLOB_RECURSE UCMTRANS_CPY_SOURCE_FILES "./cpy/*.cc") pybind11_add_module(ucmtrans ${UCMTRANS_CPY_SOURCE_FILES}) diff --git a/ucm/shared/trans/buffer.h b/ucm/shared/trans/buffer.h index 918b067f..a7375251 100644 --- a/ucm/shared/trans/buffer.h +++ b/ucm/shared/trans/buffer.h @@ -25,7 +25,7 @@ #define UNIFIEDCACHE_TRANS_BUFFER_H #include -#include "status.h" +#include "status/status.h" namespace UC::Trans { diff --git a/ucm/shared/trans/stream.h b/ucm/shared/trans/stream.h index 3cb0c368..42561796 100644 --- a/ucm/shared/trans/stream.h +++ b/ucm/shared/trans/stream.h @@ -25,7 +25,7 @@ #define UNIFIEDCACHE_TRANS_STREAM_H #include -#include "status.h" +#include "status/status.h" namespace UC::Trans { diff --git a/ucm/sparse/esa/esa.py b/ucm/sparse/esa/esa.py index c7047f87..d8316cc6 100644 --- a/ucm/sparse/esa/esa.py +++ b/ucm/sparse/esa/esa.py @@ -307,7 +307,7 @@ def maybe_register_static_data(self, forward_context: ForwardContext): self.init_static_flag = True def wait_transfer_task_done(self): - assert len(self.tasks) > 0 + # assert len(self.tasks) > 0 for task_hash, task in self.tasks.items(): # TODO: handle exceptions ret = self.store_instance.wait(task) @@ -352,9 +352,10 @@ def wait_retrieval_and_start_load(self): self.pre_topk_block_hashes, diff_blocks = diff_two_map( self.pre_topk_block_hashes, target_map ) - self.launch_transfer_task( - "load", list(diff_blocks.values()), list(diff_blocks.keys()) - ) + if diff_blocks: + self.launch_transfer_task( + "load", list(diff_blocks.values()), list(diff_blocks.keys()) + ) ## 2. load all # self.launch_transfer_task( @@ -438,7 +439,8 @@ def attention_begin( self.k_cache[vllm_block_ids[-local_window_sz:]] = self.local_window self.start_retrieval(query, forward_context) self.wait_retrieval_and_start_load() - self.wait_transfer_task_done() + if len(self.tasks) > 0: + self.wait_transfer_task_done() def attention_finished( self, diff --git a/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp b/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp index dacf3868..17d1e92a 100644 --- a/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp +++ b/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp @@ -1,16 +1,16 @@ // retrieval_backend.cpp -#include -#include +#include #include #include #include +#include +#include #include +#include #include #include #include -#include -#include #ifdef NUMA_ENABLED #include #endif @@ -20,8 +20,7 @@ namespace py = pybind11; class RetrievalWorkerBackend { public: - RetrievalWorkerBackend(py::array_t data, - py::dict cpu_idx_tbl) + RetrievalWorkerBackend(py::array_t data, py::dict cpu_idx_tbl) : data_array_(data), stop_workers_(false), next_req_id_(0) { py::buffer_info info = data_array_.request(); @@ -40,17 +39,18 @@ class RetrievalWorkerBackend { // 核心绑定代码 cpu_set_t cpuset; CPU_ZERO(&cpuset); - CPU_SET(core_id, &cpuset); // 绑定每个线程到指定的核心 + CPU_SET(core_id, &cpuset); // 绑定每个线程到指定的核心 pthread_t thread = worker_threads_.back().native_handle(); - + // 设置 CPU 亲和性 int rc = pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset); if (rc != 0) { - std::cerr << "Error binding thread " << i << " to CPU core " << core_id << std::endl; + std::cerr << "Error binding thread " << i << " to CPU core " << core_id + << std::endl; } - #ifdef NUMA_ENABLED +#ifdef NUMA_ENABLED int numaId = cpu_idx.first.cast(); // 设置内存亲和性 unsigned long nodeMask = 1UL << numaId; @@ -58,28 +58,26 @@ class RetrievalWorkerBackend { if (rc != 0) { std::cerr << "Error binding memory to NUMA node " << numaId << std::endl; } - #else - std::cerr << "NUMA support is disabled." << std::endl; - #endif +#endif } - } } - ~RetrievalWorkerBackend() { + ~RetrievalWorkerBackend() + { { std::lock_guard lock(mutex_); stop_workers_ = true; cond_.notify_all(); } - for (auto& t: worker_threads_) t.join(); + for (auto& t : worker_threads_) t.join(); } - int submit(py::array_t query, int topk, py::array_t indexes) { + int submit(py::array_t query, int topk, py::array_t indexes) + { py::buffer_info qinfo = query.request(); py::buffer_info iinfo = indexes.request(); - if (qinfo.shape[1] != dim_) - throw std::runtime_error("Query dim mismatch"); + if (qinfo.shape[1] != dim_) throw std::runtime_error("Query dim mismatch"); if ((size_t)iinfo.shape[0] != (size_t)qinfo.shape[0]) throw std::runtime_error("Query and indexes batch mismatch"); @@ -108,12 +106,14 @@ class RetrievalWorkerBackend { return req_id; } - bool poll(int req_id) { + bool poll(int req_id) + { std::lock_guard lock(mutex_); return results_.find(req_id) != results_.end(); } - void wait(int req_id) { + void wait(int req_id) + { std::shared_ptr s; { std::lock_guard lock(mutex_); @@ -125,7 +125,8 @@ class RetrievalWorkerBackend { s->cv.wait(lk2, [&] { return s->done; }); } - py::dict get_result(int req_id) { + py::dict get_result(int req_id) + { std::lock_guard lock(mutex_); auto it = results_.find(req_id); if (it == results_.end()) throw std::runtime_error("Result not ready"); @@ -167,12 +168,13 @@ class RetrievalWorkerBackend { bool done = false; }; - void worker_loop() { + void worker_loop() + { while (true) { Request req; { std::unique_lock lock(mutex_); - cond_.wait(lock, [&]{ return stop_workers_ || !requests_.empty(); }); + cond_.wait(lock, [&] { return stop_workers_ || !requests_.empty(); }); if (stop_workers_ && requests_.empty()) return; req = std::move(requests_.front()); requests_.pop(); @@ -216,7 +218,7 @@ class RetrievalWorkerBackend { } int curr_topk = std::min((int)heap.size(), req.topk); std::partial_sort(heap.begin(), heap.begin() + curr_topk, heap.end(), - [](const auto& a, const auto& b){ return a.first > b.first; }); + [](const auto& a, const auto& b) { return a.first > b.first; }); for (int k = 0; k < curr_topk; ++k) { res.scores[b].push_back(heap[k].first); @@ -250,7 +252,8 @@ class RetrievalWorkerBackend { std::atomic next_req_id_; }; -PYBIND11_MODULE(retrieval_backend, m) { +PYBIND11_MODULE(retrieval_backend, m) +{ py::class_(m, "RetrievalWorkerBackend") .def(py::init, py::dict>()) .def("submit", &RetrievalWorkerBackend::submit) diff --git a/ucm/sparse/esa/retrieval/retrieval_worker.py b/ucm/sparse/esa/retrieval/retrieval_worker.py index ebed1ed1..7209d604 100644 --- a/ucm/sparse/esa/retrieval/retrieval_worker.py +++ b/ucm/sparse/esa/retrieval/retrieval_worker.py @@ -1,10 +1,11 @@ import time +from collections import defaultdict import numpy as np import torch -# import retrieval_backend from ucm.sparse.esa.retrieval import retrieval_backend +from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank class RetrievalWorker: @@ -42,7 +43,19 @@ def wait(self, req_id): data = torch.rand(kv_cache_blocks, dim).to(torch.float32) print("data created", data.shape) - backend = retrieval_backend.RetrievalWorkerBackend(data) + ratio = 0.75 + total_tp_size = 4 + local_tp_rank = 0 + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( + total_tp_size, local_tp_rank, ratio=ratio + ) + + bind_info_dict = defaultdict(list) + for item in bind_info_list: + bind_info_dict[item[1]].append(item[0]) + bind_info_dict = dict(bind_info_dict) + + backend = retrieval_backend.RetrievalWorkerBackend(data, bind_info_dict) worker = RetrievalWorker(backend) topk = 3000 search_blocks_range = 8000 diff --git a/ucm/sparse/kvcomp/hash_encoder.py b/ucm/sparse/kvcomp/hash_encoder.py index db76079d..7546aa71 100644 --- a/ucm/sparse/kvcomp/hash_encoder.py +++ b/ucm/sparse/kvcomp/hash_encoder.py @@ -31,6 +31,124 @@ logger = init_logger(__name__) +if hasattr(torch, "cuda") and torch.cuda.is_available(): + from vllm.triton_utils import tl, triton + + @triton.jit + def triton_hash_code_kernel( + x_ptr, + code_ptr, + pack_w_ptr, + hash_out_ptr, + M, + K, + N, + stride_xm, + stride_xk, + stride_codek, + stride_coden, + stride_pack_w, + stride_om, + stride_on, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # sample dimension + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # hash_rbits dimension + offs_k = tl.arange(0, BLOCK_K) # input_dim dimension + + # Matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x = tl.load( + x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, + mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), + other=0.0, + ) + code = tl.load( + code_ptr + + offs_k[:, None] * stride_codek + + offs_n[None, :] * stride_coden, + mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), + other=0.0, + ) + acc += tl.dot(x, code) + offs_k += BLOCK_K + + # Binarize and pack + bits = (acc > 0).to(tl.uint8) # Binarize + bits = tl.reshape(bits, (BLOCK_M, BLOCK_N // 8, 8)) # Reshape for packing + + # Load the packing weights (ensure it has the correct shape) + pack_w = tl.load(pack_w_ptr + tl.arange(0, 8) * stride_pack_w) + packed = tl.sum(bits * pack_w[None, None, :], axis=-1).to(tl.uint8) + + # Store results + offs_n = pid_n * (BLOCK_N // 8) + tl.arange(0, BLOCK_N // 8) + hash_out_ptrs = ( + hash_out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + ) + tl.store( + hash_out_ptrs, + packed, + mask=(offs_m[:, None] < M) & (offs_n[None, :] < (N // 8)), + ) + + def triton_hash_code(x, code, pack_weight): + input_dim = x.shape[-1] + samples = x.shape[0] + hash_bits = code.shape[-1] + assert (pack_weight.shape[0] == 8) and (hash_bits % 8 == 0) + hash_out = torch.empty( + (samples, hash_bits // 8), dtype=pack_weight.dtype, device=x.device + ) + + grid = lambda opts: ( + triton.cdiv(samples, opts["BLOCK_M"]), + triton.cdiv(input_dim, opts["BLOCK_N"]), + ) + + triton_hash_code_kernel[grid]( + x, + code, + pack_weight, + hash_out, + samples, + input_dim, + hash_bits, + x.stride(0), + x.stride(1), + code.stride(0), + code.stride(1), + pack_weight.stride(0), + hash_out.stride(0), + hash_out.stride(1), + BLOCK_M=32, + BLOCK_K=64, + BLOCK_N=16, + ) + + return hash_out.view(-1) # [samples * hash_numbers] + + +@torch.compile() +def torch_hash_code(x, code, pack_weight): + # [N, hash_bits] + x = x @ code + m = x.shape[:-1] + # [N, hash_bits] -- > [N, hash_bits // 8, 8] + x = (x > 0).to(torch.uint8).view(*m, -1, 8) + # 8bit -> 1bit + # binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8] + # then sum along the last dimension to get [N, hash_numbers] + x = torch.sum(x * pack_weight, dim=-1, dtype=torch.uint8) + x = x.view(-1) # [N * hash_numbers] + return x + class HashEncoder: """ @@ -105,8 +223,6 @@ def _init_bit_masks(self) -> None: self.bit_masks = torch.pow( 2, torch.arange(8, dtype=torch.uint8, device=self.device) ) - # shape (1, 1, 8) - self.bit_masks = self.bit_masks.unsqueeze(0).unsqueeze(0) def compute_hash(self, x: torch.Tensor) -> torch.Tensor: """ @@ -136,29 +252,24 @@ def compute_hash(self, x: torch.Tensor) -> torch.Tensor: if x_flat.dtype != self.dtype: x_flat = x_flat.to(self.dtype) - # [N, hash_bits] - xW = torch.matmul(x_flat, self.hash_weights) - - # [N * hash_bits] - xW_flat = xW.view(-1) - if self.device.type == "npu": + # [N, hash_bits] + xW = torch.matmul(x_flat, self.hash_weights) + # [N * hash_bits] + xW_flat = xW.view(-1) # [N*hash_numbers], where hash_numbers = hash_bits // 8 packed_codes_flat = torch_npu.npu_sign_bits_pack(xW_flat, size=1) - elif self.device.type == "cuda" or self.device.type == "cpu": - # (TODO) improve performance later on CUDA ops and CPU SIMD instructions - # [N, hash_bits] - projected = (xW > 0).to(torch.uint8) - # [N, hash_numbers, 8] - binary_codes = projected.view(-1, self.hash_numbers, 8) - - # binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8] - # then sum along the last dimension to get [N, hash_numbers] - packed_codes_flat = torch.sum( - binary_codes * self.bit_masks, dim=-1, dtype=torch.uint8 - ) # [N, hash_numbers] - packed_codes_flat = packed_codes_flat.view(-1) # [N * hash_numbers] + elif self.device.type == "cuda": + packed_codes_flat = triton_hash_code( + x_flat, self.hash_weights, self.bit_masks + ) # [N * hash_numbers] + + elif self.device.type == "cpu": + packed_codes_flat = torch_hash_code( + x_flat, self.hash_weights, self.bit_masks + ) # [N * hash_numbers] + else: raise ValueError(f"Unsupported device type: {self.device.type}") @@ -213,7 +324,7 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor: ) # expand last dim to 8 # (expanded & self.bit_masks) > 0 -> [N, hash_numbers, 8] - unpacked_bits = (expanded & self.bit_masks) > 0 + unpacked_bits = (expanded & self.bit_masks.unsqueeze(0).unsqueeze(0)) > 0 # 0 -> -1, 1 -> 1 unpacked_bits = unpacked_bits * 2 - 1 @@ -232,20 +343,22 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": + torch.manual_seed(42) + + print("test HashEncoder...") + dtype = torch.float16 if hasattr(torch, "npu") and torch.npu.is_available(): device = torch.device("npu:0") elif hasattr(torch, "cuda") and torch.cuda.is_available(): device = torch.device("cuda:0") + dtype = torch.float32 else: device = torch.device("cpu") print("Using device:", device) + encoder = HashEncoder(input_dim=8, hash_bits=8, dtype=dtype, device=device) - torch.manual_seed(42) - - encoder = HashEncoder(input_dim=8, hash_bits=8, dtype=torch.float16, device=device) - - x = torch.randn(2, 8, device=device, dtype=torch.float16) + x = torch.randn(2, 8, device=device, dtype=dtype) print("x:", x) hash_codes = encoder.compute_hash(x) @@ -262,3 +375,31 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor: print( f"hash_codes[1].item()={hash_codes[1].item()}, 8-bit binary form:{hash_codes[1].item():08b}" ) + + if hasattr(torch, "cuda") and torch.cuda.is_available(): + print("test cuda triton and torch hash code functions...") + x = torch.randn((1024, 512), device="cuda:0", dtype=torch.bfloat16) + code = torch.randn((512, 512), device="cuda:0", dtype=torch.bfloat16) + pack_weight = torch.tensor( + [128, 64, 32, 16, 8, 4, 2, 1], device="cuda:0", dtype=torch.uint8 + ) + + torch_output = torch_hash_code(x, code, pack_weight) + triton_output = triton_hash_code(x, code, pack_weight) + assert torch_output.shape == triton_output.shape + print(f"x_shape: {x.shape} code_shape: {code.shape}") + print("torch_output", torch_output) + print("triton_output", triton_output) + print( + f"The maximum difference between Torch and Triton is" + f" {torch.max(torch.abs(torch_output.to(torch.int32) - triton_output.to(torch.int32)))}" + ) + # benchmark + print( + "torch:", + triton.testing.do_bench(lambda: torch_hash_code(x, code, pack_weight)), + ) + print( + "triton:", + triton.testing.do_bench(lambda: triton_hash_code(x, code, pack_weight)), + ) diff --git a/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp b/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp index d1eedf05..cbeced42 100644 --- a/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp +++ b/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp @@ -1,39 +1,148 @@ // hash_retrieval_backend.cpp -#include -#include -#include -#include +#include #include +#include // 用于UINT16_MAX #include +#include +#include +#include #include +#include +#include +#include #include +#include #include #include #include -#include -#include -#include -#include // 用于UINT16_MAX -#include #ifdef NUMA_ENABLED +#include #include #endif -#ifdef __ARM_NEON -#include // ARM NEON SIMD 指令集头文件 -#elif defined(__x86_64__) || defined(_M_X64) -#include // x86_64 SSE SIMD 指令集头文件 +#include +#include +#include +#include + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) +#include +#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) +#include // SSE/AVX +#include // POPCNT (SSE4.2) #endif - #define VEC_SIZE 16 +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + +using vec16u = uint8x16_t; + +static inline vec16u vec_loadu16(const uint8_t* p) { return vld1q_u8(p); } + +static inline vec16u vec_xor(vec16u a, vec16u b) { return veorq_u8(a, b); } + +static inline uint16_t vec_sum_u8(vec16u v) +{ +#if defined(__aarch64__) || defined(_M_ARM64) + return vaddvq_u8(v); +#else + uint16x8_t s16 = vpaddlq_u8(v); + uint32x4_t s32 = vpaddlq_u16(s16); + uint64x2_t s64 = vpaddlq_u32(s32); + return (uint16_t)(vgetq_lane_u64(s64, 0) + vgetq_lane_u64(s64, 1)); +#endif +} + +static inline uint16_t vec_popcnt_xor_sum16(const uint8_t* a, const uint8_t* b) +{ + vec16u va = vec_loadu16(a); + vec16u vb = vec_loadu16(b); + vec16u vx = vec_xor(va, vb); + vec16u pc = vcntq_u8(vx); + return vec_sum_u8(pc); +} + +static inline uint16_t vec_popcnt_xor_sum16_vec(vec16u qa, const uint8_t* b) +{ + vec16u vb = vec_loadu16(b); + vec16u vx = vec_xor(qa, vb); + vec16u pc = vcntq_u8(vx); + return vec_sum_u8(pc); +} + +void print_uint8x16(uint8x16_t vec) +{ + uint8_t array[16]; + vst1q_u8(array, vec); + for (int i = 0; i < 16; ++i) { std::cout << static_cast(array[i]) << " "; } + std::cout << std::endl; +} + +#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) + +using vec16u = __m128i; + +static inline vec16u vec_loadu16(const uint8_t* p) +{ + return _mm_loadu_si128(reinterpret_cast(p)); +} + +static inline vec16u vec_xor(vec16u a, vec16u b) { return _mm_xor_si128(a, b); } + +static inline uint16_t vec_popcnt_xor_sum16(const uint8_t* a, const uint8_t* b) +{ + __m128i va = _mm_loadu_si128(reinterpret_cast(a)); + __m128i vb = _mm_loadu_si128(reinterpret_cast(b)); + __m128i vx = _mm_xor_si128(va, vb); + + uint64_t lo, hi; +#if defined(__SSE4_1__) + lo = static_cast(_mm_extract_epi64(vx, 0)); + hi = static_cast(_mm_extract_epi64(vx, 1)); +#else + alignas(16) uint64_t tmp[2]; + _mm_storeu_si128(reinterpret_cast<__m128i*>(tmp), vx); + lo = tmp[0]; + hi = tmp[1]; +#endif + return (uint16_t)(__builtin_popcountll(lo) + __builtin_popcountll(hi)); +} + +static inline uint16_t vec_popcnt_xor_sum16_vec(vec16u qa, const uint8_t* b) +{ + __m128i vb = _mm_loadu_si128(reinterpret_cast(b)); + __m128i vx = _mm_xor_si128(qa, vb); + + uint64_t lo, hi; +#if defined(__SSE4_1__) + lo = static_cast(_mm_extract_epi64(vx, 0)); + hi = static_cast(_mm_extract_epi64(vx, 1)); +#else + alignas(16) uint64_t tmp[2]; + _mm_storeu_si128(reinterpret_cast<__m128i*>(tmp), vx); + lo = tmp[0]; + hi = tmp[1]; +#endif + return (uint16_t)(__builtin_popcountll(lo) + __builtin_popcountll(hi)); +} + +#else + +static inline uint16_t vec_popcnt_xor_sum16(const uint8_t* a, const uint8_t* b) +{ + uint16_t s = 0; + for (int i = 0; i < 16; ++i) s += __builtin_popcount((unsigned)(a[i] ^ b[i])); + return s; +} + +#endif + namespace py = pybind11; class HashRetrievalWorkerBackend { public: - HashRetrievalWorkerBackend(py::array_t data, - py::dict cpu_idx_tbl) + HashRetrievalWorkerBackend(py::array_t data, py::dict cpu_idx_tbl) : data_array_(data), stop_workers_(false), next_req_id_(0) { py::buffer_info info = data_array_.request(); @@ -41,6 +150,8 @@ class HashRetrievalWorkerBackend { block_size_ = info.shape[1]; dim_ = info.shape[2]; vec_per_dim_ = dim_ / VEC_SIZE; // data_每个值类型uint8_t,组成8*16_t进行simd加速 + tail_dim_ = dim_ % VEC_SIZE; + tail_start_ = vec_per_dim_ * VEC_SIZE; data_ = static_cast(info.ptr); // Start worker threads @@ -54,17 +165,18 @@ class HashRetrievalWorkerBackend { // 核心绑定代码 cpu_set_t cpuset; CPU_ZERO(&cpuset); - CPU_SET(core_id, &cpuset); // 绑定每个线程到指定的核心 + CPU_SET(core_id, &cpuset); // 绑定每个线程到指定的核心 pthread_t thread = worker_threads_.back().native_handle(); - + // 设置 CPU 亲和性 int rc = pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset); if (rc != 0) { - std::cerr << "Error binding thread " << i << " to CPU core " << core_id << std::endl; + std::cerr << "Error binding thread " << i << " to CPU core " << core_id + << std::endl; } - #ifdef NUMA_ENABLED +#ifdef NUMA_ENABLED int numaId = cpu_idx.first.cast(); // 设置内存亲和性 unsigned long nodeMask = 1UL << numaId; @@ -72,35 +184,33 @@ class HashRetrievalWorkerBackend { if (rc != 0) { std::cerr << "Error binding memory to NUMA node " << numaId << std::endl; } - #else - std::cerr << "NUMA support is disabled." << std::endl; - #endif - +#endif } - } } - ~HashRetrievalWorkerBackend() { + ~HashRetrievalWorkerBackend() + { { std::lock_guard lock(mutex_); stop_workers_ = true; cond_.notify_all(); } - for (auto& t: worker_threads_) t.join(); + for (auto& t : worker_threads_) t.join(); } - int submit(py::array_t query, int topk, py::array_t indexes) { + int submit(py::array_t query, int topk, py::array_t indexes) + { py::buffer_info qinfo = query.request(); py::buffer_info iinfo = indexes.request(); - if (qinfo.shape[1] != dim_) - throw std::runtime_error("Query dim mismatch"); + if (qinfo.shape[1] != dim_) throw std::runtime_error("Query dim mismatch"); if ((size_t)iinfo.shape[0] != (size_t)qinfo.shape[0]) throw std::runtime_error("Query and indexes batch mismatch"); int req_id = next_req_id_.fetch_add(1); - auto q = std::vector((uint8_t*)qinfo.ptr, (uint8_t*)qinfo.ptr + qinfo.shape[0] * dim_); + auto q = + std::vector((uint8_t*)qinfo.ptr, (uint8_t*)qinfo.ptr + qinfo.shape[0] * dim_); // Parse indexes to vector> size_t n_requests = iinfo.shape[0], max_index_number = iinfo.shape[1]; @@ -123,12 +233,14 @@ class HashRetrievalWorkerBackend { return req_id; } - bool poll(int req_id) { + bool poll(int req_id) + { std::lock_guard lock(mutex_); return results_.find(req_id) != results_.end(); } - void wait(int req_id) { + void wait(int req_id) + { std::shared_ptr s; { std::lock_guard lock(mutex_); @@ -140,7 +252,8 @@ class HashRetrievalWorkerBackend { s->cv.wait(lk2, [&] { return s->done; }); } - py::dict get_result(int req_id) { + py::dict get_result(int req_id) + { std::lock_guard lock(mutex_); auto it = results_.find(req_id); if (it == results_.end()) throw std::runtime_error("Result not ready"); @@ -183,55 +296,13 @@ class HashRetrievalWorkerBackend { bool done = false; }; -#ifdef __ARM_NEON - static inline uint16_t vaddvq_u8_compat(uint8x16_t v) { - #if defined(__aarch64__) || defined(_M_ARM64) - return vaddvq_u8(v); - #else - uint16x8_t s16 = vpaddlq_u8(v); - uint32x4_t s32 = vpaddlq_u16(s16); - uint64x2_t s64 = vpaddlq_u32(s32); - return (uint16_t)(vgetq_lane_u64(s64, 0) + vgetq_lane_u64(s64, 1)); - #endif - } - - void print_uint8x16(uint8x16_t vec) { - uint8_t array[16]; - vst1q_u8(array, vec); - for (int i = 0; i < 16; ++i) { - std::cout << static_cast(array[i]) << " "; - } - std::cout << std::endl; - } - -#elif defined(__x86_64__) || defined(_M_X64) - // 采用 Brian Kernighan's 算法计算 64 位数的 Hamming Weight - unsigned int popcnt64(uint64_t x) { - x -= (x >> 1) & 0x5555555555555555; // 将相邻的两位合并 - x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333); // 合并四位 - x = (x + (x >> 4)) & 0x0F0F0F0F0F0F0F0F; // 合并八位 - x = x + (x >> 8); // 合并十六位 - x = x + (x >> 16); // 合并三十二位 - x = x + (x >> 32); // 合并六十四位 - return x & 0x7F; // 返回最后的1的个数,0x7F表示最多返回 7 位 - } - - // 计算 128 位向量中 1 的个数 - int popcount_128(__m128i xor_result) { - // 将 128 位数据拆成两个 64 位整数 - uint64_t* result = (uint64_t*)&xor_result; - - // 分别计算每个 64 位的 Hamming 权重并返回结果之和 - return popcnt64(result[0]) + popcnt64(result[1]); - } -#endif - - void worker_loop() { + void worker_loop() + { while (true) { Request req; { std::unique_lock lock(mutex_); - cond_.wait(lock, [&]{ return stop_workers_ || !requests_.empty(); }); + cond_.wait(lock, [&] { return stop_workers_ || !requests_.empty(); }); if (stop_workers_ && requests_.empty()) return; req = std::move(requests_.front()); requests_.pop(); @@ -248,24 +319,20 @@ class HashRetrievalWorkerBackend { std::vector> heap; heap.reserve(allowed.size()); +#if defined(__ARM_NEON) || defined(__ARM_NEON__) || defined(__x86_64__) || defined(_M_X64) || \ + defined(__i386) || defined(_M_IX86) // 1.预加载 query 向量 - #ifdef __ARM_NEON - uint8x16_t q_vecs[vec_per_dim_]; // 存储 query 向量 - for (size_t v = 0; v < vec_per_dim_; ++v) { - q_vecs[v] = vld1q_u8(q_ptr + v * VEC_SIZE); - } - #elif defined(__x86_64__) || defined(_M_X64) - __m128i q_vecs[vec_per_dim_]; // 存储 query 向量 + vec16u q_vecs[vec_per_dim_]; // 存储query向量 for (size_t v = 0; v < vec_per_dim_; ++v) { - q_vecs[v] = _mm_loadu_si128(reinterpret_cast(q_ptr + v * VEC_SIZE)); + q_vecs[v] = vec_loadu16(q_ptr + v * VEC_SIZE); } - #endif +#endif // 2.遍历允许的索引 for (auto idx : allowed) { const uint8_t* base_idx_ptr = data_ + idx * block_size_ * dim_; - int score = UINT16_MAX; // 初始化为最大值 + int score = UINT16_MAX; // 初始化为最大值 // 3.内层向量化计算 // #pragma omp parallel for @@ -274,45 +341,28 @@ class HashRetrievalWorkerBackend { const uint8_t* k_base = base_idx_ptr + t_idx * dim_; // 计算每个向量的相似度 +#if defined(__ARM_NEON) || defined(__ARM_NEON__) || defined(__x86_64__) || defined(_M_X64) || \ + defined(__i386) || defined(_M_IX86) for (size_t v = 0; v < vec_per_dim_; ++v) { - #ifdef __ARM_NEON - uint8x16_t k = vld1q_u8(k_base + v * VEC_SIZE); - sum += vaddvq_u8_compat(vcntq_u8(veorq_u8(q_vecs[v], k))); - #elif defined(__x86_64__) || defined(_M_X64) - __m128i k = _mm_loadu_si128(reinterpret_cast(k_base + v * VEC_SIZE)); - __m128i xor_result = _mm_xor_si128(q_vecs[v], k); // 16 * 8 - int popcount_result = popcount_128(xor_result); // 计算128位 xor_result 中所有位为 1 的个数 - sum += popcount_result; // 获取每个字节的累计值 - #endif + sum += vec_popcnt_xor_sum16_vec(q_vecs[v], k_base + v * VEC_SIZE); } - - // 处理不足16字节的部分 - ssize_t tail_dim = dim_ % VEC_SIZE; - if (tail_dim != 0) { - uint8_t q_tmp[16] = { 0 }; // 初始化填充为0 - uint8_t k_tmp[16] = { 0 }; - memcpy(q_tmp, q_ptr, dim_); - memcpy(k_tmp, k_base, dim_); - - #ifdef __ARM_NEON - uint8x16_t q = vld1q_u8(q_tmp); - uint8x16_t k = vld1q_u8(k_tmp); - sum += vaddvq_u8_compat(vcntq_u8(veorq_u8(q, k))); - #elif defined(__x86_64__) || defined(_M_X64) - __m128i q = _mm_loadu_si128(reinterpret_cast(q_tmp)); - __m128i k = _mm_loadu_si128(reinterpret_cast(k_tmp)); - __m128i xor_result = _mm_xor_si128(q, k); - int popcount_result = popcount_128(xor_result); // 计算128位 xor_result 中所有位为 1 的个数 - sum += popcount_result; // 获取每个字节的累计值 - #endif +#else + for (size_t v = 0; v < vec_per_dim_; ++v) { + sum += + vec_popcnt_xor_sum16(q_ptr + v * VEC_SIZE, k_base + v * VEC_SIZE); + } +#endif + if (tail_dim_ != 0) { + for (size_t t = 0; t < tail_dim_; ++t) { + uint8_t x = q_ptr[tail_start_ + t] ^ k_base[tail_start_ + t]; + sum += __builtin_popcount((unsigned)x); + } } // 如果得分为0,则跳出循环 if (sum < score) { score = sum; - if (score == 0) { - break; - } + if (score == 0) { break; } } } @@ -325,7 +375,7 @@ class HashRetrievalWorkerBackend { // 对堆进行部分排序,获取TopK std::partial_sort(heap.begin(), heap.begin() + curr_topk, heap.end(), - [](const auto& a, const auto& b) { return a.first < b.first; }); + [](const auto& a, const auto& b) { return a.first < b.first; }); // 保存TopK结果 for (int k = 0; k < curr_topk; ++k) { @@ -350,7 +400,7 @@ class HashRetrievalWorkerBackend { py::array_t data_array_; const uint8_t* data_ = nullptr; ssize_t dim_; - size_t num_blocks_, block_size_, vec_per_dim_; + size_t num_blocks_, block_size_, vec_per_dim_, tail_dim_, tail_start_; std::queue requests_; std::unordered_map results_; std::vector worker_threads_; @@ -361,11 +411,12 @@ class HashRetrievalWorkerBackend { std::atomic next_req_id_; }; -PYBIND11_MODULE(hash_retrieval_backend, m) { +PYBIND11_MODULE(hash_retrieval_backend, m) +{ py::class_(m, "HashRetrievalWorkerBackend") .def(py::init, py::dict>()) .def("submit", &HashRetrievalWorkerBackend::submit) .def("poll", &HashRetrievalWorkerBackend::poll) .def("get_result", &HashRetrievalWorkerBackend::get_result) .def("wait", &HashRetrievalWorkerBackend::wait); -} \ No newline at end of file +} diff --git a/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py b/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py index 7a77b05a..5faf83dc 100644 --- a/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py +++ b/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py @@ -1,10 +1,12 @@ import time +from collections import defaultdict import numpy as np import torch from ucm.sparse.kvcomp.hash_encoder import HashEncoder from ucm.sparse.kvcomp.hash_retrieval import hash_retrieval_backend +from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank class HashRetrievalWorker: @@ -37,15 +39,16 @@ def wait(self, req_id): if __name__ == "__main__": ################# data batch_size = 2 - dim = 1024 - kv_cache_blocks = 25600 - data = torch.rand(kv_cache_blocks, dim).to(torch.float32) + block_size = 2 + head_dim = 128 + head_num = 1 + dim = head_dim * head_num + kv_cache_blocks = 2560 + data = torch.rand(kv_cache_blocks, block_size, dim).to(torch.float32) print("data created", data.shape) - backend = hash_retrieval_backend.HashRetrievalWorkerBackend(data) - worker = HashRetrievalWorker(backend) - topk = 3000 - search_blocks_range = 8000 + topk = 10 + search_blocks_range = 100 tpot = 30 / 1000 indexes = np.arange(batch_size * search_blocks_range).reshape( @@ -54,8 +57,35 @@ def wait(self, req_id): query = torch.rand(batch_size, dim).to(torch.float32) + hash_encoder = HashEncoder( + input_dim=dim, + hash_bits=dim, + dtype=torch.float32, + device=torch.device("cpu"), + ) + + hash_query = hash_encoder.compute_hash(query) + hash_key_cache = hash_encoder.compute_hash(data) + + ratio = 0.75 + total_tp_size = 4 + local_tp_rank = 0 + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( + total_tp_size, local_tp_rank, ratio=ratio + ) + + bind_info_dict = defaultdict(list) + for item in bind_info_list: + bind_info_dict[item[1]].append(item[0]) + bind_info_dict = dict(bind_info_dict) + + backend = hash_retrieval_backend.HashRetrievalWorkerBackend( + hash_key_cache, bind_info_dict + ) + worker = HashRetrievalWorker(backend) + #################### cpp async version - req_id = worker.submit(query, topk=topk, indexes=indexes) + req_id = worker.submit(hash_query, topk=topk, indexes=indexes) #################### LLM decode begin time.sleep(tpot * 3) @@ -66,28 +96,24 @@ def wait(self, req_id): worker.wait(req_id) result = worker.get_result(req_id) print("cpp spent:", time.time() - begin) + cpp_indices = np.sort(result["indices"], 1) + print(f"cpp indices={cpp_indices}") ################### numpy version + unpacked_hash_query = hash_encoder._unpack_hash(hash_query) + unpacked_hash_key_cache = hash_encoder._unpack_hash(hash_key_cache) begin = time.time() - data_indexed = ( - data[indexes.flatten()].reshape(indexes.shape[0], indexes.shape[1], dim).numpy() + data_indexed = unpacked_hash_key_cache[indexes.flatten()].reshape( + indexes.shape[0], indexes.shape[1], block_size, dim ) - query = HashRetrievalWorker.handle_input(query) - scores = np.matmul(query[:, None, :], data_indexed.transpose((0, 2, 1))) - scores = scores[:, 0, :] - topk_elements = np.partition(scores, -topk, -1)[:, -topk:] - topk_indices = np.argpartition(scores, -topk, -1)[:, -topk:] - topk_indices = indexes[np.arange(indexes.shape[0])[:, None], topk_indices] - print("numpy spent: ", time.time() - begin) + scores = torch.einsum("td, tnjd->tnj", unpacked_hash_query, data_indexed) - ## compare - cpp_elements = np.sort(result["scores"], 1) - cpp_indices = np.sort(result["indices"], 1) - - np_elements = np.sort(topk_elements, 1) - np_indices = np.sort(topk_indices, 1) + block_scores_ret = torch.max(scores, dim=-1) + blocks_scores = block_scores_ret.values - diff_elements = np.abs(np_elements - cpp_elements) - diff_indices = np.abs(np_indices - cpp_indices) - - print(f"diff topk: {diff_indices.max()}") + topk_ret = torch.topk(blocks_scores, topk, dim=-1) + topk_index = topk_ret.indices + topk_index = topk_index.sort(dim=-1).values + topk_index = indexes[np.arange(indexes.shape[0])[:, None], topk_index] + print("numpy spent: ", time.time() - begin) + print(f"numpy indices={topk_index}") diff --git a/ucm/sparse/kvcomp/kvcomp.py b/ucm/sparse/kvcomp/kvcomp.py index c1713300..8a1f6123 100644 --- a/ucm/sparse/kvcomp/kvcomp.py +++ b/ucm/sparse/kvcomp/kvcomp.py @@ -186,10 +186,10 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): if hasattr(torch, "npu") and torch.npu.is_available(): device = torch.device(f"npu:{self.rank}") - elif torch.cuda.is_available(): + elif hasattr(torch, "cuda") and torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") else: - device = torch.device("npu") + device = torch.device("cpu") self.hash_encoder = HashEncoder( input_dim=self.kvcomp_config.head_dim, diff --git a/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp b/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp index a8fc080a..b504b106 100644 --- a/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp +++ b/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp @@ -5,19 +5,22 @@ #include "retrieve_task_runner.h" namespace KVStar { -RetrieveTaskQueue::~RetrieveTaskQueue() { +RetrieveTaskQueue::~RetrieveTaskQueue() +{ { std::unique_lock lk(this->_mutex); if (!this->_running) { return; } this->_running = false; } - if (this->_worker.joinable()){ + if (this->_worker.joinable()) { this->_cv.notify_all(); this->_worker.join(); } } -void RetrieveTaskQueue::Worker(const int numaId, const int bindCoreId, std::promise& started) { +void RetrieveTaskQueue::Worker(const int numaId, const int bindCoreId, + std::promise& started) +{ cpu_set_t cpuset; CPU_ZERO(&cpuset); CPU_SET(bindCoreId, &cpuset); @@ -37,18 +40,17 @@ void RetrieveTaskQueue::Worker(const int numaId, const int bindCoreId, std::prom started.set_value(Status::OsApiError()); return; } -#else - KVSTAR_DEBUG("NUMA support is disabled."); #endif - KVSTAR_DEBUG("Bind current thread {} to numa {} core {} and set memory affinity success.", thread, numaId, bindCoreId); + KVSTAR_DEBUG("Bind current thread {} to numa {} core {} and set memory affinity success.", + thread, numaId, bindCoreId); RetrieveTaskRunner runner; started.set_value(Status::OK()); Status status = Status::OK(); - for(;;){ + for (;;) { std::unique_lock lk(this->_mutex); this->_cv.wait(lk, [this] { return !this->_taskQ.empty() || !this->_running; }); if (!this->_running) { return; } @@ -62,22 +64,23 @@ void RetrieveTaskQueue::Worker(const int numaId, const int bindCoreId, std::prom if (!_failureSet->Exist(workItem.task.allocTaskId)) { if ((status = runner.Run(workItem.task, *workItem.result)).Failure()) { - KVSTAR_ERROR("Failed({}) to run retrieve task({}).", status, workItem.task.allocTaskId); + KVSTAR_ERROR("Failed({}) to run retrieve task({}).", status, + workItem.task.allocTaskId); this->_failureSet->Insert(workItem.task.allocTaskId); workItem.result->status = TaskStatus::FAILURE; } else { - KVSTAR_DEBUG("Process current task success, task id: {}.", workItem.task.allocTaskId); + KVSTAR_DEBUG("Process current task success, task id: {}.", + workItem.task.allocTaskId); workItem.result->status = TaskStatus::SUCCESS; } } workItem.task.waiter->Done(); } - } - -Status RetrieveTaskQueue::Setup(const int numaId, const int bindCoreId, RetrieveTaskSet* failureSet) { +Status RetrieveTaskQueue::Setup(const int numaId, const int bindCoreId, RetrieveTaskSet* failureSet) +{ this->_failureSet = failureSet; { std::unique_lock lk(this->_mutex); @@ -85,11 +88,12 @@ Status RetrieveTaskQueue::Setup(const int numaId, const int bindCoreId, Retrieve } std::promise started; auto fut = started.get_future(); - this->_worker = std::thread([&]{ this->Worker(numaId, bindCoreId, started); }); + this->_worker = std::thread([&] { this->Worker(numaId, bindCoreId, started); }); return fut.get(); } -void RetrieveTaskQueue::Push(WorkItem&& item) { +void RetrieveTaskQueue::Push(WorkItem&& item) +{ { std::unique_lock lk(this->_mutex); this->_taskQ.push_back(std::move(item)); @@ -97,5 +101,4 @@ void RetrieveTaskQueue::Push(WorkItem&& item) { this->_cv.notify_one(); } - -} \ No newline at end of file +} // namespace KVStar \ No newline at end of file diff --git a/ucm/store/dramstore/CMakeLists.txt b/ucm/store/dramstore/CMakeLists.txt index 15295544..e69de29b 100644 --- a/ucm/store/dramstore/CMakeLists.txt +++ b/ucm/store/dramstore/CMakeLists.txt @@ -1,12 +0,0 @@ -file(GLOB_RECURSE UCMSTORE_DRAM_CC_SOURCE_FILES "./cc/*.cc") -add_library(dramstore STATIC ${UCMSTORE_DRAM_CC_SOURCE_FILES}) -target_include_directories(dramstore PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/cc/api - ${CMAKE_CURRENT_SOURCE_DIR}/cc/domain -) -target_link_libraries(dramstore PUBLIC storeinfra storedevice storetask) - -file(GLOB_RECURSE UCMSTORE_DRAM_CPY_SOURCE_FILES "./cpy/*.cc") -pybind11_add_module(ucmdramstore ${UCMSTORE_DRAM_CPY_SOURCE_FILES}) -target_link_libraries(ucmdramstore PRIVATE dramstore) -set_target_properties(ucmdramstore PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/ucm/store/dramstore/cc/api/dramstore.cc b/ucm/store/dramstore/cc/api/dramstore.cc deleted file mode 100644 index c59b7f2b..00000000 --- a/ucm/store/dramstore/cc/api/dramstore.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#include "dramstore.h" -#include "logger/logger.h" -#include "status/status.h" -#include "trans/dram_trans_manager.h" -#include "memory/memory_pool.h" - -namespace UC { - -class DRAMStoreImpl : public DRAMStore { -public: - int32_t Setup(const Config& config) { - auto status = this->memPool_.Setup(config.deviceId, config.capacity, config.blockSize); - if (status.Failure()) { - UC_ERROR("Failed({}) to setup MemoryPool.", status); - return status.Underlying(); - } - status = this->transMgr_.Setup(config.deviceId, config.streamNumber, &this->memPool_, config.timeoutMs); - if (status.Failure()) { - UC_ERROR("Failed({}) to setup TsfTaskManager.", status); - return status.Underlying(); - } - return Status::OK().Underlying(); - } - int32_t Alloc(const std::string& block) override { return this->memPool_.NewBlock(block).Underlying(); } - bool Lookup(const std::string& block) override { return this->memPool_.LookupBlock(block); } - void Commit(const std::string& block, const bool success) override { this->memPool_.CommitBlock(block, success).Underlying(); } - std::list Alloc(const std::list& blocks) override - { - std::list results; - for (const auto &block : blocks) { - results.emplace_back(this->Alloc(block)); - } - return results; - } - std::list Lookup(const std::list& blocks) override - { - std::list founds; - for (const auto &block : blocks) { - founds.emplace_back(this->Lookup(block)); - } - return founds; - } - void Commit(const std::list& blocks, const bool success) override { - for (const auto &block : blocks) { - this->Commit(block, success); - } - } - size_t Submit(Task&& task) override { - auto taskId = Task::invalid; - auto status = this->transMgr_.Submit(std::move(task), taskId); - if (status.Failure()) { taskId = Task::invalid; } - return taskId; } - - int32_t Wait(const size_t task) override { - return this->transMgr_.Wait(task).Underlying(); - } - - int32_t Check(const size_t task, bool& finish) override { - return this->transMgr_.Check(task, finish).Underlying(); - } - - -private: - - DramTransManager transMgr_; - MemoryPool memPool_; - -}; - -int32_t DRAMStore::Setup(const Config& config) -{ - auto impl = new (std::nothrow) DRAMStoreImpl(); - if (!impl) { - UC_ERROR("Out of memory."); - return Status::OutOfMemory().Underlying(); - } - this->impl_ = impl; - return impl->Setup(config); -} - -} // namespace UC diff --git a/ucm/store/dramstore/cc/api/dramstore.h b/ucm/store/dramstore/cc/api/dramstore.h deleted file mode 100644 index 25d72612..00000000 --- a/ucm/store/dramstore/cc/api/dramstore.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_DRAMSTORE_H -#define UNIFIEDCACHE_DRAMSTORE_H - -#include "ucmstore.h" - -namespace UC { - -class DRAMStore : public CCStore<> { -public: - struct Config { - size_t capacity; - size_t blockSize; - int32_t deviceId; - size_t streamNumber; - size_t timeoutMs; - Config(const size_t capacity, const size_t blockSize, const int32_t deviceId, const size_t streamNumber, const size_t timeoutMs) - : capacity{capacity}, blockSize{blockSize}, deviceId{deviceId}, streamNumber{streamNumber}, timeoutMs{timeoutMs} - { - } - }; - -public: - DRAMStore() : impl_{nullptr} {} - ~DRAMStore() override - { - if (this->impl_) { delete this->impl_; } - } - int32_t Setup(const Config& config); - int32_t Alloc(const std::string& block) override { return this->impl_->Alloc(block); } - bool Lookup(const std::string& block) override { return this->impl_->Lookup(block); } - void Commit(const std::string& block, const bool success) override - { - this->impl_->Commit(block, success); - } - std::list Alloc(const std::list& blocks) override - { - return this->impl_->Alloc(blocks); - } - std::list Lookup(const std::list& blocks) override - { - return this->impl_->Lookup(blocks); - } - void Commit(const std::list& blocks, const bool success) override - { - this->impl_->Commit(blocks, success); - } - size_t Submit(Task&& task) override { return this->impl_->Submit(std::move(task)); } - int32_t Wait(const size_t task) override { return this->impl_->Wait(task); } - int32_t Check(const size_t task, bool& finish) override - { - return this->impl_->Check(task, finish); - } - -private: - DRAMStore* impl_; -}; - -} // namespace UC - -#endif diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.cc b/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.cc deleted file mode 100644 index f9835612..00000000 --- a/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.cc +++ /dev/null @@ -1,42 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ - -#include "dram_trans_manager.h" - -namespace UC { - -Status DramTransManager::Setup(const int32_t deviceId, const size_t streamNumber, const MemoryPool* memPool, size_t timeoutMs) { - this->timeoutMs_ = timeoutMs; - auto status = Status::OK(); - for (size_t i = 0; i < streamNumber; i++) { - auto q = std::make_shared(); - status = - q->Setup(deviceId, &this->failureSet_, memPool, timeoutMs); - if (status.Failure()) { break; } - this->queues_.emplace_back(std::move(q)); - } - return status; -} - -} \ No newline at end of file diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.h b/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.h deleted file mode 100644 index 7f9ef51b..00000000 --- a/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_DRAM_TRANS_MANAGER_H -#define UNIFIEDCACHE_DRAM_TRANS_MANAGER_H - -#include "task_manager.h" -#include "dram_trans_queue.h" - -namespace UC { - -class DramTransManager : public TaskManager { -public: - Status Setup(const int32_t deviceId, const size_t streamNumber, const MemoryPool* memPool, size_t timeoutMs); -}; - -} // namespace UC - -#endif diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.cc b/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.cc deleted file mode 100644 index cf7a3577..00000000 --- a/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.cc +++ /dev/null @@ -1,126 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ - -#include "dram_trans_queue.h" - -namespace UC { - -Status DramTransQueue::Setup(const int32_t deviceId, TaskSet* failureSet, - const MemoryPool* memPool, const size_t timeoutMs) { - this->deviceId_ = deviceId; - this->failureSet_ = failureSet; - this->memPool_ = memPool; - auto success = - this->backend_.SetWorkerInitFn([this](auto& device) { return this->Init(device); }) - .SetWorkerFn([this](auto& shards, const auto& device) { this->Work(shards, device); }) - .SetWorkerExitFn([this](auto& device) { this->Exit(device); }) - .Run(); - return success ? Status::OK() : Status::Error(); -} - -void DramTransQueue::Push(std::list& shards) noexcept { - this->backend_.Push(std::move(shards)); -} - -bool DramTransQueue::Init(Device& device) { - if (this->deviceId_ < 0) { return true; } - device = DeviceFactory::Make(this->deviceId_, 262144, 512); - if (!device) { - return false; - } - return device->Setup().Success(); -} - -void DramTransQueue::Exit(Device& device) { - device.reset(); -} - -void DramTransQueue::Work(std::list& shards, const Device& device) { - auto it = shards.begin(); - if (this->failureSet_->Contains(it->owner)) { - this->Done(shards, device, true); - } - auto status = Status::OK(); - if (it->type == Task::Type::DUMP) { - status = this->D2H(shards, device); - } else { - status = this->H2D(shards, device); - } - this->Done(shards, device, status.Success()); -} - -Status DramTransQueue::H2D(std::list& shards, const Device& device) { - size_t pool_offset = 0; - std::vector host_addrs(shards.size()); - std::vector device_addrs(shards.size()); - int shard_index = 0; - for (auto& shard : shards) { - bool found = this->memPool_->GetOffset(shard.block, &pool_offset); - if (!found) { - return Status::Error(); - } - auto host_addr = this->memPool_->GetStartAddr().get() + pool_offset + shard.offset; - auto device_addr = shard.address; - host_addrs[shard_index] = host_addr; - device_addrs[shard_index] = reinterpret_cast(device_addr); - shard_index++; - } - auto it = shards.begin(); - return device->H2DBatchSync(device_addrs.data(), const_cast(host_addrs.data()), shards.size(), it->length); -} - -Status DramTransQueue::D2H(std::list& shards, const Device& device) { - size_t pool_offset = 0; - std::vector host_addrs(shards.size()); - std::vector device_addrs(shards.size()); - int shard_index = 0; - for (auto& shard : shards) { - bool found = this->memPool_->GetOffset(shard.block, &pool_offset); - if (!found) { - return Status::Error(); - } - auto host_addr = this->memPool_->GetStartAddr().get() + pool_offset + shard.offset; - auto device_addr = shard.address; - host_addrs[shard_index] = host_addr; - device_addrs[shard_index] = reinterpret_cast(device_addr); - shard_index++; - } - auto it = shards.begin(); - return device->D2HBatchSync(host_addrs.data(), const_cast(device_addrs.data()), shards.size(), it->length); -} - -void DramTransQueue::Done(std::list& shards, const Device& device, const bool success) { - auto it = shards.begin(); - if (!success) { this->failureSet_->Insert(it->owner); } - for (auto& shard : shards) { - if (shard.done) { - if (device) { - if (device->Synchronized().Failure()) { this->failureSet_->Insert(shard.owner); } - } - shard.done(); - } - } -} - -} // namespace UC \ No newline at end of file diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.h b/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.h deleted file mode 100644 index 72350709..00000000 --- a/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_DRAM_TRANS_QUEUE_H -#define UNIFIEDCACHE_DRAM_TRANS_QUEUE_H - -#include "device/idevice.h" -#include "status/status.h" -#include "task_queue.h" -#include "task_set.h" -#include "thread/thread_pool.h" -#include "memory/memory_pool.h" - -namespace UC { - -class DramTransQueue : public TaskQueue { - using Device = std::unique_ptr; - int32_t deviceId_{-1}; - TaskSet* failureSet_{nullptr}; - const MemoryPool* memPool_{nullptr}; - ThreadPool, Device> backend_{}; - -public: - Status Setup(const int32_t deviceId, - TaskSet* failureSet, - const MemoryPool* memPool, - const size_t timeoutMs); - void Push(std::list& shards) noexcept override; - -private: - bool Init(Device& device); - void Exit(Device& device); - void Work(std::list& shards, const Device& device); - void Done(std::list& shards, const Device& device, const bool success); - Status H2D(std::list& shards, const Device& device); - Status D2H(std::list& shards, const Device& device); -}; - -} // namespace UC - -#endif diff --git a/ucm/store/dramstore/cpy/dramstore.py.cc b/ucm/store/dramstore/cpy/dramstore.py.cc deleted file mode 100644 index cb76d5d1..00000000 --- a/ucm/store/dramstore/cpy/dramstore.py.cc +++ /dev/null @@ -1,123 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#include "dramstore.h" -#include - -namespace py = pybind11; - -namespace UC { - -class DRAMStorePy : public DRAMStore { -public: - void* CCStoreImpl() { return this; } - py::list AllocBatch(const py::list& blocks) - { - py::list results; - for (auto& block : blocks) { results.append(this->Alloc(block.cast())); } - return results; - } - py::list LookupBatch(const py::list& blocks) - { - py::list founds; - for (auto& block : blocks) { founds.append(this->Lookup(block.cast())); } - return founds; - } - void CommitBatch(const py::list& blocks, const bool success) - { - for (auto& block : blocks) { this->Commit(block.cast(), success); } - } - py::tuple CheckPy(const size_t task) - { - auto finish = false; - auto ret = this->Check(task, finish); - return py::make_tuple(ret, finish); - } - size_t Load(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths) - { - return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::LOAD, - Task::Location::DEVICE, "DRAM::H2D"); - } - size_t Dump(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths) - { - return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::DUMP, - Task::Location::DEVICE, "DRAM::D2H"); - } - -private: - size_t SubmitPy(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths, Task::Type&& type, Task::Location&& location, - std::string&& brief) - { - Task task{std::move(type), std::move(location), std::move(brief)}; - auto blockId = blockIds.begin(); - auto offset = offsets.begin(); - auto address = addresses.begin(); - auto length = lengths.begin(); - while ((blockId != blockIds.end()) && (offset != offsets.end()) && - (address != addresses.end()) && (length != lengths.end())) { - task.Append(blockId->cast(), offset->cast(), - address->cast(), length->cast()); - blockId++; - offset++; - address++; - length++; - } - return this->Submit(std::move(task)); - } -}; - -} // namespace UC - -PYBIND11_MODULE(ucmdramstore, module) -{ - module.attr("project") = UCM_PROJECT_NAME; - module.attr("version") = UCM_PROJECT_VERSION; - module.attr("commit_id") = UCM_COMMIT_ID; - module.attr("build_type") = UCM_BUILD_TYPE; - auto store = py::class_(module, "DRAMStore"); - auto config = py::class_(store, "Config"); - config.def(py::init(), - py::arg("capacity"), py::arg("blockSize"), py::arg("deviceId"), py::arg("streamNumber"), py::arg("timeoutMs")); - config.def_readwrite("capacity", &UC::DRAMStorePy::Config::capacity); - config.def_readwrite("blockSize", &UC::DRAMStorePy::Config::blockSize); - config.def_readwrite("deviceId", &UC::DRAMStorePy::Config::deviceId); - config.def_readwrite("streamNumber", &UC::DRAMStorePy::Config::streamNumber); - config.def_readwrite("timeoutMs", &UC::DRAMStorePy::Config::timeoutMs); - store.def(py::init<>()); - store.def("CCStoreImpl", &UC::DRAMStorePy::CCStoreImpl); - store.def("Setup", &UC::DRAMStorePy::Setup); - store.def("Alloc", py::overload_cast(&UC::DRAMStorePy::Alloc)); - store.def("AllocBatch", &UC::DRAMStorePy::AllocBatch); - store.def("Lookup", py::overload_cast(&UC::DRAMStorePy::Lookup)); - store.def("LookupBatch", &UC::DRAMStorePy::LookupBatch); - store.def("Load", &UC::DRAMStorePy::Load); - store.def("Dump", &UC::DRAMStorePy::Dump); - store.def("Wait", &UC::DRAMStorePy::Wait); - store.def("Check", &UC::DRAMStorePy::Check); - store.def("Commit", - py::overload_cast(&UC::DRAMStorePy::Commit)); - store.def("CommitBatch", &UC::DRAMStorePy::CommitBatch); -} diff --git a/ucm/store/infra/CMakeLists.txt b/ucm/store/infra/CMakeLists.txt index f3e0ce72..6bc8dc4a 100644 --- a/ucm/store/infra/CMakeLists.txt +++ b/ucm/store/infra/CMakeLists.txt @@ -1,21 +1,11 @@ add_library(storeinfra STATIC) target_include_directories(storeinfra PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) file(GLOB_RECURSE UCMSTORE_COMMON_FILE_SOURCE_FILES "file/*.cc") -if(LOGGER_BACKEND STREQUAL "spdlog") - file(GLOB_RECURSE UCMSTORE_COMMON_LOGGER_SOURCE_FILES "logger/spdlog/*.cc") -endif() -if(LOGGER_BACKEND STREQUAL "flux") - file(GLOB_RECURSE UCMSTORE_COMMON_LOGGER_SOURCE_FILES "logger/flux/*.cc") -endif() -file(GLOB_RECURSE UCMSTORE_COMMON_STATUS_SOURCE_FILES "status/*.cc") -file(GLOB_RECURSE UCMSTORE_COMMON_TEMPLATE_SOURCE_FILES "template/*.cc") -file(GLOB_RECURSE UCMSTORE_COMMON_THREAD_SOURCE_FILES "thread/*.cc") target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_FILE_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_LOGGER_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_STATUS_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_TEMPLATE_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_THREAD_SOURCE_FILES}) -target_link_libraries(storeinfra PUBLIC fmt) -if(LOGGER_BACKEND STREQUAL "spdlog") - target_link_libraries(storeinfra PUBLIC spdlog) -endif() +target_link_libraries(storeinfra PUBLIC + infra_status + infra_logger + infra_template + infra_thread + infra_time +) diff --git a/ucm/store/infra/memory/memory_pool.h b/ucm/store/infra/memory/memory_pool.h deleted file mode 100644 index 200d1286..00000000 --- a/ucm/store/infra/memory/memory_pool.h +++ /dev/null @@ -1,174 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_MEMORY_POOL_H -#define UNIFIEDCACHE_MEMORY_POOL_H - -#include -#include -#include -#include -#include -#include -#include "status/status.h" -#include "device/idevice.h" -#include -#include -#include -#include "logger/logger.h" - -namespace UC { - -class MemoryPool { - - std::string DUMMY_SLOT_PREFIX{"__slot_"}; - using Device = std::unique_ptr; -public: - - Status Setup(int32_t deviceId, size_t capacity, size_t blockSize) { - capacity_ = capacity; - blockSize_ = blockSize; - device_ = DeviceFactory::Make(deviceId, blockSize, static_cast(capacity / blockSize)); - if (!device_) { - UC_ERROR("MemoryPool: failed to create device"); - return Status::Error(); - } - Status status = device_->Setup(); - if (!status.Success()) { - UC_ERROR("MemoryPool: failed to set up device"); - return Status::Error(); - } - pool_ = device_->GetBuffer(capacity_); - if (!pool_) { - UC_ERROR("MemoryPool: failed to get pool memory space"); - return Status::Error(); - } - - size_t slotNum = capacity_ / blockSize_; - for (size_t i = 0; i < slotNum; ++i) { - std::string dummy = DUMMY_SLOT_PREFIX + std::to_string(i); - size_t offset = i * blockSize_; - lruList_.push_front(dummy); - lruIndex_[dummy] = lruList_.begin(); - offsetMap_[dummy] = offset; - } - return Status::OK(); - - } - - Status NewBlock(const std::string& blockId) { - if (offsetMap_.count(blockId)) { - return Status::DuplicateKey(); - } - if (lruList_.empty()) { - // 所有空间里的块都正在写,那么就不能够分配 - return Status::Error(); - } - size_t offset = LRUEvictOne(); - offsetMap_[blockId] = offset; - return Status::OK(); - } - - bool LookupBlock(const std::string& blockId) const { - return availableBlocks_.count(blockId); - } - - bool GetOffset(const std::string& blockId, size_t* offset) const { - auto it = offsetMap_.find(blockId); - if (it == offsetMap_.end()) { - return false; - } - *offset = it->second; - return true; - } - - Status CommitBlock(const std::string& blockId, bool success) { - if (success) { - availableBlocks_.insert(blockId); - touchUnsafe(blockId); - } else { - resetSpaceOfBlock(blockId); - } - return Status::OK(); - } - - std::shared_ptr GetStartAddr() const { - return pool_; - } - -private: - std::shared_ptr pool_ = nullptr; - Device device_ = nullptr; - size_t capacity_; - size_t blockSize_; - - std::unordered_map offsetMap_; - std::set availableBlocks_; - - using ListType = std::list; - ListType lruList_; - std::unordered_map lruIndex_; - - void touchUnsafe(const std::string& blockId) { - auto it = lruIndex_.find(blockId); - if (it != lruIndex_.end()) { - lruList_.splice(lruList_.begin(), lruList_, it->second); - } - else { - lruList_.push_front(blockId); // 访问一次,该块就是最近使用了的,所以放到LRU队列的头部。这就是一般LRU的逻辑 - lruIndex_[blockId] = lruList_.begin(); - } - } - - size_t LRUEvictOne() { - const std::string& victim = lruList_.back(); - // 真实数据块,才从availableBlocks_中删掉 - if (victim.rfind(DUMMY_SLOT_PREFIX, 0) != 0) { - availableBlocks_.erase(victim); - } - size_t offset = offsetMap_[victim]; - offsetMap_.erase(victim); - lruIndex_.erase(victim); - lruList_.pop_back(); - return offset; - } - - void resetSpaceOfBlock(const std::string& blockId) { - auto it = offsetMap_.find(blockId); - size_t offset = it->second; - std::string dummy = DUMMY_SLOT_PREFIX + std::to_string(offset / blockSize_); - offsetMap_.erase(blockId); - - auto lit = lruIndex_.find(blockId); - if (lit != lruIndex_.end()) { - lruList_.erase(lit->second); - lruIndex_.erase(lit); - } - lruList_.push_back(dummy); // 将一个块commit false后,回收之前分配的内存,并且要将其放到LRU队列的尾部(下次可以写的时候,要马上就写。因为该块的优先级高于已经写了的块) - lruIndex_[dummy] = std::prev(lruList_.end()); - offsetMap_[dummy] = offset; - } -}; - -} // namespace UC -#endif \ No newline at end of file diff --git a/ucm/store/infra/status/status.h b/ucm/store/infra/status/status.h deleted file mode 100644 index 809d2459..00000000 --- a/ucm/store/infra/status/status.h +++ /dev/null @@ -1,134 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_STATUS_H -#define UNIFIEDCACHE_STATUS_H - -#include - -namespace UC { - -class Status { - enum class Code { -#define UC_MAKE_STATUS_CODE(i) (-50000 - (i)) - OK = 0, - ERROR = -1, - EPARAM = UC_MAKE_STATUS_CODE(0), - EOOM = UC_MAKE_STATUS_CODE(1), - EOSERROR = UC_MAKE_STATUS_CODE(2), - EDUPLICATE = UC_MAKE_STATUS_CODE(3), - ERETRY = UC_MAKE_STATUS_CODE(4), - ENOOBJ = UC_MAKE_STATUS_CODE(5), - ESERIALIZE = UC_MAKE_STATUS_CODE(6), - EDESERIALIZE = UC_MAKE_STATUS_CODE(7), - EUNSUPPORTED = UC_MAKE_STATUS_CODE(8), - ENOSPACE = UC_MAKE_STATUS_CODE(9), -#undef UC_MAKE_STATUS_CODE - }; - -public: - static Status& OK() - { - static Status s{Code::OK}; - return s; - } - static Status& Error() - { - static Status s{Code::ERROR}; - return s; - } - static Status& InvalidParam() - { - static Status s{Code::EPARAM}; - return s; - } - static Status& OutOfMemory() - { - static Status s{Code::EOOM}; - return s; - } - static Status& OsApiError() - { - static Status s{Code::EOSERROR}; - return s; - } - static Status& DuplicateKey() - { - static Status s{Code::EDUPLICATE}; - return s; - } - static Status& Retry() - { - static Status s{Code::ERETRY}; - return s; - } - static Status& NotFound() - { - static Status s{Code::ENOOBJ}; - return s; - } - static Status& SerializeFailed() - { - static Status s{Code::ESERIALIZE}; - return s; - } - static Status& DeserializeFailed() - { - static Status s{Code::EDESERIALIZE}; - return s; - } - static Status& Unsupported() - { - static Status s{Code::EUNSUPPORTED}; - return s; - } - static Status& NoSpace() - { - static Status s{Code::ENOSPACE}; - return s; - } -public: - Status(const Status& status) { this->code_ = status.code_; } - Status& operator=(const Status& status) - { - if (this != &status) { this->code_ = status.code_; } - return *this; - } - bool operator==(const Status& status) const { return this->code_ == status.code_; } - bool operator!=(const Status& status) const { return this->code_ != status.code_; } - int32_t Underlying() const { return static_cast(this->code_); } - bool Success() const { return this->code_ == Code::OK; } - bool Failure() const { return this->code_ != Code::OK; } - -private: - Status(const Code code) : code_{code} {} - -private: - Code code_; -}; - -inline int32_t format_as(const Status& status) { return status.Underlying(); } - -} // namespace UC - -#endif diff --git a/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h b/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h index add777ba..d549788e 100644 --- a/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h +++ b/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h @@ -26,23 +26,26 @@ #define UNIFIEDCACHE_HOTNESS_TIMER_H #include #include +#include "logger/logger.h" #include "template/timer.h" namespace UC { class HotnessTimer { public: - void SetInterval(const size_t interval) { this->interval_ = std::chrono::seconds(interval); } - Status Start(std::function callable) - { + void SetInterval(const size_t interval) { this->interval_ = std::chrono::seconds(interval); } + Status Start(std::function callable) + { try { - this->timer_ = std::make_unique>>(this->interval_, std::move(callable)); + this->timer_ = std::make_unique>>(this->interval_, + std::move(callable)); } catch (const std::exception& e) { UC_ERROR("Failed({}) to start hotness timer.", e.what()); return Status::OutOfMemory(); } - return this->timer_->Start(); - } + return this->timer_->Start() ? Status::OK() : Status::Error(); + } + private: std::chrono::seconds interval_; std::unique_ptr>> timer_; @@ -50,4 +53,4 @@ class HotnessTimer { } // namespace UC -#endif \ No newline at end of file +#endif diff --git a/ucm/store/pcstore/cc/domain/trans/trans_queue.cc b/ucm/store/pcstore/cc/domain/trans/trans_queue.cc index 8e2958fa..83b8ce45 100644 --- a/ucm/store/pcstore/cc/domain/trans/trans_queue.cc +++ b/ucm/store/pcstore/cc/domain/trans/trans_queue.cc @@ -39,7 +39,7 @@ void TransQueue::DeviceWorker(BlockTask&& task) auto done = task.done; auto devPtrs = (void**)task.shards.data(); auto hostPtr = task.buffer.get(); - auto s = Trans::Status::OK(); + auto s = Status::OK(); if (task.type == TransTask::Type::LOAD) { s = stream_->HostToDevice(hostPtr, devPtrs, size, number); } else { diff --git a/ucm/store/test/CMakeLists.txt b/ucm/store/test/CMakeLists.txt index 859c185c..0c4974ef 100644 --- a/ucm/store/test/CMakeLists.txt +++ b/ucm/store/test/CMakeLists.txt @@ -4,7 +4,7 @@ if(BUILD_UNIT_TESTS) add_executable(ucmstore.test ${UCMSTORE_TEST_SOURCE_FILES}) target_include_directories(ucmstore.test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/case) target_link_libraries(ucmstore.test PRIVATE - dramstore nfsstore localstore storeinfra storedevice + nfsstore localstore storeinfra storedevice gtest_main gtest mockcpp ) gtest_discover_tests(ucmstore.test) diff --git a/ucm/store/test/case/infra/mem_pool_test.cc b/ucm/store/test/case/infra/mem_pool_test.cc deleted file mode 100644 index f9ea0438..00000000 --- a/ucm/store/test/case/infra/mem_pool_test.cc +++ /dev/null @@ -1,169 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ - -#include "infra/memory/memory_pool.h" -#include - -class UCMemoryPoolTest : public ::testing::Test {}; - -TEST_F(UCMemoryPoolTest, NewBlockAllocateAndCommit) -{ - UC::MemoryPool memPool; // 初始化内存池 - ASSERT_EQ(memPool.Setup(-1, 10, 2), UC::Status::OK()); - const std::string block1 = "block1"; - size_t offset = 10; - ASSERT_FALSE(memPool.LookupBlock(block1)); - // ASSERT_EQ(memPool.GetOffset(block1), nullptr); - ASSERT_EQ(memPool.GetOffset(block1, &offset), false); - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); - ASSERT_FALSE(memPool.LookupBlock(block1)); - // ASSERT_NE(memPool.GetOffset(block1), nullptr); - ASSERT_EQ(memPool.GetOffset(block1, &offset), true); - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::DuplicateKey()); - ASSERT_EQ(memPool.CommitBlock(block1, true), UC::Status::OK()); - ASSERT_TRUE(memPool.LookupBlock(block1)); -} - -TEST_F(UCMemoryPoolTest, EvictOldBlock) -{ - UC::MemoryPool memPool; // 初始化内存池 - ASSERT_EQ(memPool.Setup(-1, 10, 5), UC::Status::OK()); - const std::string block1 = "block1"; - const std::string block2 = "block2"; - const std::string block3 = "block3"; - size_t offset = 10; - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block1), nullptr); - ASSERT_EQ(memPool.GetOffset(block1, &offset), true); - ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block2), nullptr); - ASSERT_EQ(memPool.GetOffset(block2, &offset), true); - memPool.CommitBlock(block1, true); - memPool.CommitBlock(block2, true); - ASSERT_EQ(memPool.NewBlock(block3), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block3), nullptr); - ASSERT_EQ(memPool.GetOffset(block3, &offset), true); - // ASSERT_EQ(memPool.GetOffset(block1), nullptr); - ASSERT_EQ(memPool.GetOffset(block1, &offset), false); - // ASSERT_NE(memPool.GetOffset(block2), nullptr); - ASSERT_EQ(memPool.GetOffset(block2, &offset), true); - ASSERT_FALSE(memPool.LookupBlock(block1)); - ASSERT_TRUE(memPool.LookupBlock(block2)); - ASSERT_FALSE(memPool.LookupBlock(block3)); -} - -TEST_F(UCMemoryPoolTest, OldBlockCommitFalse) -{ - UC::MemoryPool memPool; // 初始化内存池 - ASSERT_EQ(memPool.Setup(-1, 32, 8), UC::Status::OK()); - const std::string block1 = "block1"; - const std::string block2 = "block2"; - const std::string block3 = "block3"; - const std::string block4 = "block4"; - const std::string block5 = "block5"; - size_t offset = 32; - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block1), nullptr); - ASSERT_EQ(memPool.GetOffset(block1, &offset), true); - ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block2), nullptr); - ASSERT_EQ(memPool.GetOffset(block2, &offset), true); - ASSERT_EQ(memPool.NewBlock(block3), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block3), nullptr); - ASSERT_EQ(memPool.GetOffset(block3, &offset), true); - memPool.CommitBlock(block1, true); - memPool.CommitBlock(block2, false); - ASSERT_TRUE(memPool.LookupBlock(block1)); - ASSERT_FALSE(memPool.LookupBlock(block2)); - ASSERT_FALSE(memPool.LookupBlock(block3)); - ASSERT_EQ(memPool.NewBlock(block4), UC::Status::OK()); - // ASSERT_EQ(memPool.GetOffset(block4), 8); - ASSERT_EQ(memPool.GetOffset(block4, &offset), true); - ASSERT_EQ(offset, 8); - ASSERT_EQ(memPool.NewBlock(block5), UC::Status::OK()); - // ASSERT_EQ(memPool.GetOffset(block5), 24); - ASSERT_EQ(memPool.GetOffset(block5, &offset), true); - ASSERT_EQ(offset, 24); - memPool.CommitBlock(block3, true); - memPool.CommitBlock(block4, true); - memPool.CommitBlock(block5, true); - ASSERT_TRUE(memPool.LookupBlock(block1)); - ASSERT_FALSE(memPool.LookupBlock(block2)); - ASSERT_TRUE(memPool.LookupBlock(block3)); - ASSERT_TRUE(memPool.LookupBlock(block4)); - ASSERT_TRUE(memPool.LookupBlock(block5)); - - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::DuplicateKey()); - ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); - // ASSERT_EQ(memPool.GetOffset(block2), 0); - ASSERT_EQ(memPool.GetOffset(block2, &offset), true); - ASSERT_EQ(offset, 0); - ASSERT_FALSE(memPool.LookupBlock(block1)); - ASSERT_FALSE(memPool.LookupBlock(block2)); - memPool.CommitBlock(block2, true); - ASSERT_TRUE(memPool.LookupBlock(block2)); -} - -TEST_F(UCMemoryPoolTest, NoCommittedBlock) -{ - UC::MemoryPool memPool; // 初始化内存池 - ASSERT_EQ(memPool.Setup(-1, 32, 8), UC::Status::OK()); - const std::string block1 = "block1"; - const std::string block2 = "block2"; - const std::string block3 = "block3"; - const std::string block4 = "block4"; - const std::string block5 = "block5"; - const std::string block6 = "block6"; - size_t offset = 32; - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); - ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); - ASSERT_EQ(memPool.NewBlock(block3), UC::Status::OK()); - ASSERT_EQ(memPool.NewBlock(block4), UC::Status::OK()); - ASSERT_EQ(memPool.NewBlock(block5), UC::Status::Error()); - memPool.CommitBlock(block1, true); - ASSERT_TRUE(memPool.LookupBlock(block1)); - ASSERT_EQ(memPool.NewBlock(block5), UC::Status::OK()); - // ASSERT_EQ(memPool.GetOffset(block5), 0); - ASSERT_EQ(memPool.GetOffset(block5, &offset), true); - ASSERT_EQ(offset, 0); - ASSERT_FALSE(memPool.LookupBlock(block1)); - ASSERT_EQ(memPool.NewBlock(block6), UC::Status::Error()); - // ASSERT_EQ(memPool.GetOffset(block2), 8); - ASSERT_EQ(memPool.GetOffset(block2, &offset), true); - ASSERT_EQ(offset, 8); - memPool.CommitBlock(block2, false); - // ASSERT_EQ(memPool.GetOffset((block2)), nullptr); - ASSERT_EQ(memPool.GetOffset(block2, &offset), false); - ASSERT_FALSE(memPool.LookupBlock(block1)); - ASSERT_EQ(memPool.NewBlock(block6), UC::Status::OK()); - // ASSERT_EQ(memPool.GetOffset(block6), 8); - ASSERT_EQ(memPool.GetOffset(block6, &offset), true); - ASSERT_EQ(offset, 8); - ASSERT_FALSE(memPool.LookupBlock(block6)); - memPool.CommitBlock(block6, true); - ASSERT_TRUE(memPool.LookupBlock(block6)); - // ASSERT_EQ(memPool.GetOffset(block6), 8); - ASSERT_EQ(memPool.GetOffset(block6, &offset), true); - ASSERT_EQ(offset, 8); -} \ No newline at end of file