Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions tests/integration/defs/accuracy/accuracy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,108 @@ class PassKeyRetrieval128k(AccuracyTask):
MAX_OUTPUT_LEN = 50


class LongBenchV2(AccuracyTask):
DATASET = "longbench_v2"
DATASET_DIR = f"{llm_models_root()}/zai-org/LongBench-v2"

ALPHA = 0.05
BETA = 0.2
SIGMA = 50.0
NUM_SAMPLES = 215

MAX_BATCH_SIZE = 32
MAX_INPUT_LEN = 1280000
MAX_OUTPUT_LEN = 32000

EVALUATOR_CLS = tensorrt_llm.evaluate.LongBenchV2
EVALUATOR_KWARGS = dict(
dataset_path=DATASET_DIR,
length="medium",
max_len=1280000,
apply_chat_template=True,
random_seed=0,
)

@staticmethod
def create_modified_model_dir(original_model_dir: str,
max_position_embeddings: int = 1280000,
model_max_length: int = 1280000) -> str:
"""
Create temporary directory with modified config files for long context evaluation.

This method creates a temporary directory with symlinks to all model files except
config files, which are copied and modified to support longer context lengths.
This is useful for evaluating models on long context tasks that exceed the
original model's max_position_embeddings.

Args:
original_model_dir: Path to the original model directory
max_position_embeddings: New value for max_position_embeddings in config.json
model_max_length: New value for model_max_length in tokenizer_config.json

Returns:
Path to the temporary modified model directory

Note:
The caller is responsible for cleaning up the temporary directory after use.
"""
import tempfile

# Create temporary model directory with symlinks
temp_dir = tempfile.mkdtemp(prefix="longbench_v2_modified_model_")
logger.info(f"Created temporary model directory: {temp_dir}")

# Create symlinks for all files except config files
for item in os.listdir(original_model_dir):
src = os.path.join(original_model_dir, item)
dst = os.path.join(temp_dir, item)

# Skip config files - will handle them separately
if item in ["config.json", "tokenizer_config.json"]:
continue

# Create symlink for other files/directories
os.symlink(src, dst)
logger.info(f" Symlinked: {item}")

# Modify and copy config.json
config_src = os.path.join(original_model_dir, "config.json")
config_dst = os.path.join(temp_dir, "config.json")
if os.path.exists(config_src):
with open(config_src, 'r', encoding='utf-8') as f:
config = json.load(f)

# Modify max_position_embeddings
original_max_pos = config.get('max_position_embeddings')
config['max_position_embeddings'] = max_position_embeddings
logger.info(
f" Modified config.json: max_position_embeddings {original_max_pos} -> {max_position_embeddings}"
)

with open(config_dst, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)

# Modify and copy tokenizer_config.json
tokenizer_config_src = os.path.join(original_model_dir,
"tokenizer_config.json")
tokenizer_config_dst = os.path.join(temp_dir, "tokenizer_config.json")
if os.path.exists(tokenizer_config_src):
with open(tokenizer_config_src, 'r', encoding='utf-8') as f:
tokenizer_config = json.load(f)

# Modify model_max_length
original_max_len = tokenizer_config.get('model_max_length')
tokenizer_config['model_max_length'] = model_max_length
logger.info(
f" Modified tokenizer_config.json: model_max_length {original_max_len} -> {model_max_length}"
)

with open(tokenizer_config_dst, 'w', encoding='utf-8') as f:
json.dump(tokenizer_config, f, indent=2, ensure_ascii=False)

return temp_dir


class CliFlowAccuracyTestHarness:
# Model
MODEL_NAME = None
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/defs/accuracy/references/longbench_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
DeepSeek-R1-0528:
- quant_algo: FP8_BLOCK_SCALES
kv_cache_quant_algo: FP8
spec_dec_algo: MTP
accuracy: 52.093
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
spec_dec_algo: MTP
accuracy: 52.093
110 changes: 109 additions & 1 deletion tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
skip_post_blackwell, skip_pre_ada, skip_pre_blackwell,
skip_pre_hopper, skip_ray)
from .accuracy_core import (GSM8K, MMLU, CnnDailymail, GPQADiamond,
JsonModeEval, LlmapiAccuracyTestHarness)
JsonModeEval, LlmapiAccuracyTestHarness,
LongBenchV2)


class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
Expand Down Expand Up @@ -4136,3 +4137,110 @@ def test_auto_dtype(self):
extra_evaluator_kwargs=dict(
apply_chat_template=True,
chat_template_kwargs=chat_template_kwargs))


@skip_pre_blackwell
@pytest.mark.skip_less_device_memory(183000)
@pytest.mark.timeout(28800)
class TestDeepSeekR1LongBenchV2(LlmapiAccuracyTestHarness):
MODEL_NAME = "DeepSeek-R1-0528"

@pytest.mark.skip_less_mpi_world_size(8)
def test_fp8_8gpus(self):
original_model_dir = f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-0528"
if not os.path.exists(original_model_dir):
pytest.skip(f"Model directory {original_model_dir} does not exist")

temp_dir = None
try:
# Create modified model directory using LongBenchV2 static method
# This is a WAR for the fact that the model config is not modified to support long context.
# TODO: remove this once the model config is modified to support long context.
temp_dir = LongBenchV2.create_modified_model_dir(original_model_dir)

# Configure model settings
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8,
enable_block_reuse=True,
enable_partial_reuse=False,
dtype="fp8")

cuda_graph_config = CudaGraphConfig(enable_padding=True,
max_batch_size=32)

mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3)

moe_config = MoeConfig(backend='DEEPGEMM', max_num_tokens=32000)

pytorch_config = dict(cuda_graph_config=cuda_graph_config,
kv_cache_config=kv_cache_config,
speculative_config=mtp_config,
moe_config=moe_config,
enable_chunked_prefill=True,
enable_autotuner=True)

# Create LLM instance and evaluate
with LLM(temp_dir,
tensor_parallel_size=8,
moe_expert_parallel_size=8,
max_num_tokens=32000,
max_batch_size=32,
**pytorch_config) as llm:

task = LongBenchV2(self.MODEL_NAME)

sampling_params = SamplingParams(max_tokens=32000)

task.evaluate(llm, sampling_params=sampling_params)

finally:
# Cleanup temporary files
if temp_dir and os.path.exists(temp_dir):
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)

@pytest.mark.skip_less_mpi_world_size(4)
def test_nvfp4_4gpus(self):
original_model_dir = f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-0528-FP4"
temp_dir = None
try:
# Create modified model directory using LongBenchV2 static method
temp_dir = LongBenchV2.create_modified_model_dir(original_model_dir)

# Configure model settings (no MOE config for FP4 version)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8,
enable_block_reuse=True,
enable_partial_reuse=False,
dtype="fp8")

cuda_graph_config = CudaGraphConfig(enable_padding=True,
max_batch_size=32)

mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3)

pytorch_config = dict(cuda_graph_config=cuda_graph_config,
kv_cache_config=kv_cache_config,
speculative_config=mtp_config,
enable_chunked_prefill=True,
enable_autotuner=True)

# Create LLM instance and evaluate
with LLM(temp_dir,
tensor_parallel_size=4,
moe_expert_parallel_size=4,
max_num_tokens=32000,
max_batch_size=32,
**pytorch_config) as llm:

assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4

task = LongBenchV2(self.MODEL_NAME)

sampling_params = SamplingParams(max_tokens=32000)

task.evaluate(llm, sampling_params=sampling_params)

finally:
# Cleanup temporary files
if temp_dir and os.path.exists(temp_dir):
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
2 changes: 2 additions & 0 deletions tests/integration/test_lists/qa/llm_function_stress.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_3600s_timeout_5400s-GUARANTEED_NO_EVICT-pytorch-stress-test-with-accuracy]
stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy]
stress_test/stress_test.py::test_run_stress_test[DeepSeek-R1_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1LongBenchV2::test_fp8_8gpus
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1LongBenchV2::test_nvfp4_4gpus