In [1]:
import os
import psutil
import copy
import time
import signal
import requests
import warnings
import subprocess
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Generator, Union, Optional
from IPython.display import FileLink, display
from concurrent.futures import ThreadPoolExecutor

import transformers
from transformers.utils import logging
from openai import Stream
from nemo_skills.prompt.utils import get_prompt, CodeTags
from nemo_skills.code_execution.sandbox import get_sandbox
from nemo_skills.evaluation.math_grader import extract_answer
from nemo_skills.inference.model.code_execution import CodeExecutionWrapper
from nemo_skills.inference.model import get_code_execution_model, get_model

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
logging.set_verbosity_error()

!nvidia-smi -L | cut -d '(' -f 1

GPU 0: NVIDIA H100 80GB HBM3 
GPU 1: NVIDIA H100 80GB HBM3 


In [2]:
BASE_DIR = "./"
MODEL_DIR_HF = f"{BASE_DIR}/OpenMath-Nemotron-14B-kaggle"
MODEL_DIR_BF16 = f"{BASE_DIR}/OpenMath-Nemotron-14B-kaggle-bf16-trtllm"
MODEL_DIR_FP8 = f"{BASE_DIR}/OpenMath-Nemotron-14B-kaggle-fp8-trtllm"
MODEL_DIR_FP8_DRAFT = f"{BASE_DIR}/OpenMath-Nemotron-14B-kaggle-fp8-redrafter-trtllm"
benchmark = {} 

In [3]:
def wait_for_server(host, port, timeout=300, interval=1):
    url = f"http://{host}:{port}"
    start_time = time.time()
    while True:
        try:
            response = requests.put(url)
            if response.status_code != 403:
                return True
        except requests.RequestException:
            if time.time() - start_time > timeout:
                raise TimeoutError("Server did not respond within timeout period")
            time.sleep(interval)
                
def start_server(model_dir, port=5000):
    host = "127.0.0.1"
    cmd = (
        f'python -m tensorrt_llm.commands.serve serve {model_dir} '
        f'    --tokenizer {MODEL_DIR_HF}'
        f'    --backend trt '
        f'    --tp_size 2 '
        f'    --kv_cache_free_gpu_memory_fraction 0.92 '
        f'    --max_batch_size 12 '
        f'    --host {host} '
        f'    --port {port}'
    )
    print(f"Starting server from {model_dir} at {host}:{port}")
    model_name = model_dir.split("/")[-1]
    log_path = Path(f"{model_name}_server_logs.log").resolve()
    log_file = open(log_path, "w", buffering=1)
    proc = subprocess.Popen(cmd, shell=True, stdout=log_file, stderr=subprocess.STDOUT, preexec_fn=os.setsid)
    print("Waiting for server to be ready (might take a while) ...")
    wait_for_server(host, port)
    print("Server ready!")
    return proc

def kill_server(proc, port=5000):
    os.killpg(proc.pid, signal.SIGTERM)  
    time.sleep(10)

    for proc in psutil.process_iter(['pid', 'name']):
        for conn in proc.connections(kind='inet'):
            if conn.laddr.port == port:
                print(f"Killing process {proc.info['name']} (PID: {proc.info['pid']}) running on port {port}")
                os.kill(proc.info['pid'], 9)
                break

    time.sleep(10)
    print(f'Server closed.')

In [4]:
# we keep a reference to the server process, so we can stop it if necessary
server_process = start_server(MODEL_DIR_FP8_DRAFT)

Starting server from .//OpenMath-Nemotron-14B-kaggle-fp8-redrafter-trtllm at 127.0.0.1:5000
Waiting for server to be ready (might take a while) ...
Server ready!


In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_DIR_FP8)

def consume_stream(stream: Union[Stream, Generator], thread_id=None):
    """Process a single stream and return concatenated text with timing."""
    start_time = time.time()
    result = ""
    time_to_first_token = None
    try:
        for chunk in stream:
            if chunk['generation'] is not None:
                result += chunk['generation']
                if time_to_first_token is None:
                    time_to_first_token = time.time() - start_time
    except Exception as e:
        pass    
        
    end_time = time.time()
    total_time = end_time - start_time
    
    num_generated_tokens = len(tokenizer.encode(result))
    return {
        'result': result,
        'total_time': total_time,
        'thread_id': thread_id,
        'time_to_first_token': time_to_first_token,
        'num_tokens': num_generated_tokens,
        'throughput': num_generated_tokens / total_time,
    }

def stream_generate(
    code_exec_model: CodeExecutionWrapper,
    prompts: list[str | dict],
    code_begin: str | list[str],
    code_end: str | list[str],
    code_output_begin: str | list[str],
    code_output_end: str | list[str],
    code_output_format: str | list[str],
    tokens_to_generate: int | list[int] = 512,
    temperature: float | list[float] = 0.0,
    top_p: float | list[float] = 0.95,
    top_k: int | list[int] = 0,
    min_p: float | list[float] = 0.0,
    repetition_penalty: float | list[float] = 1.0,
    random_seed: int | list[int] = 0,
    stop_phrases: list[str] | list[list[str]] | None = None,
    remove_stop_phrases: bool = True,
    timeout: int | list[int] | None = None,
    max_code_executions: int | list[int] | None = None,
    stop_after_n_completed : Optional[int] = None, 
    stop_after_n_seconds : Optional[int] = None,
    stop_after_n_same_answer : Optional[int] = None,
    return_stats : bool = False,
    ) -> list[dict]:
    """Process multiple streams concurrently and return results with durations."""

    streams = code_exec_model.generate(
        prompts=prompts,
        code_begin=code_begin,
        code_end=code_end,
        code_output_begin=code_output_begin,
        code_output_end=code_output_end,
        code_output_format=code_output_format,
        tokens_to_generate=tokens_to_generate,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        min_p=min_p,
        repetition_penalty=repetition_penalty,
        random_seed=random_seed,
        stop_phrases=stop_phrases,
        remove_stop_phrases=remove_stop_phrases,
        timeout=timeout,
        max_code_executions=max_code_executions,
        stream=True,
        )

    with ThreadPoolExecutor() as executor:
        # Submit all streams to thread pool with thread IDs
        futures = [(thread_id, executor.submit(consume_stream, stream, thread_id)) for thread_id, stream in enumerate(streams)]
        
        if stop_after_n_completed is not None:
            stop_after_n_completed = min(stop_after_n_completed, len(streams))
        
        start_time = time.time()
        current_answers = [] # list of answers that have been completed
        completed_futures = []  # list of tuples (thread_id, result_data)
        
        while futures:
            time_elapsed  = time.time() - start_time
            if stop_after_n_completed is not None and len(completed_futures) >= stop_after_n_completed:
                print(f"\nStopping after {stop_after_n_completed} completed...")
                # This will asynchronously cancel all generations
                # We don't break here because we want to collect results up to now
                code_exec_model.model.cancel_all_generations()
                print(f"Canceling thread @ {time_elapsed:.2f} ({stop_after_n_completed} generations completed)")
                
            elif stop_after_n_seconds is not None and time_elapsed >= stop_after_n_seconds:
                print(f"\nStopping after {stop_after_n_seconds} seconds...")
                code_exec_model.model.cancel_all_generations()
                print(f"Canceling thread @ {time_elapsed:.2f} (maximum time reached)")
            
            elif stop_after_n_same_answer is not None:
                # Check whether at least n elements in current_answers are the same
                if len(completed_futures) - len(set(current_answers)) >= stop_after_n_same_answer-1:
                    print(f"\nStopping after {stop_after_n_same_answer} identical answers. {current_answers}")
                    code_exec_model.model.cancel_all_generations()
                    print(f"Canceling thread @ {time_elapsed:.2f} ({stop_after_n_same_answer} identical answers)")
            time.sleep(0.1)
            
            completed_in_this_iteration = []
            for idx, future in futures:
                if future.done():
                    result_data = future.result()
                    completed_futures.append((idx, result_data))
                    completed_in_this_iteration.append((idx, future))
                    current_answers.append(extract_answer(result_data['result']))

            
            for item in completed_in_this_iteration:
                futures.remove(item)
    
    # Sort by original index and return results with durations
    completed_futures.sort(key=lambda x: x[0])
    if return_stats:
        return [result_data for _, result_data in completed_futures]
    else:
        return [result_data['result'] for _, result_data in completed_futures]

In [6]:
sandbox = get_sandbox()  # localhost by default
llm = get_code_execution_model(server_type="trtllm-serve", sandbox=sandbox)

# Initialize the prompt template
prompt_template = get_prompt('generic/math', 'qwen-instruct')

# Set the code tags directly on the config's code_tags attribute
prompt_template.config.code_tags = CodeTags(
    code_begin="<tool_call>\n",
    code_end="</tool_call>\n",
)

In [7]:
sampling_params = {
    "tokens_to_generate": 12000,
    "temperature": 0.,
    "top_k": 20,
    "top_p": 0.8,
    "repetition_penalty": 1.0,
    "max_code_executions": 2
}

problem = r'''
The Fibonacci numbers are defined as follows: $F_0 = 0$, $F_1 = 1$, and $F_{n+1} = F_n + F_{n-1}$ for $n \geq 1$. 
There are $N$ positive integers $n$ strictly less than $10^{101}$ such that $n^2 + (n+1)^2$ is a multiple of 5 but $F_{n-1}^2 + F_n^2$ is not. 
How many prime factors does $N$ have, counted with multiplicity?
'''

request = copy.deepcopy(sampling_params)
list_of_texts =  [prompt_template.fill({'problem': problem})] * 12
request["prompts"] = list_of_texts

In [8]:
list_of_texts[0]

'<|im_start|>system\n<|im_end|>\n<|im_start|>user\nSolve the following math problem. Make sure to put the answer (and only answer) inside \\boxed{}.\n\n\nThe Fibonacci numbers are defined as follows: $F_0 = 0$, $F_1 = 1$, and $F_{n+1} = F_n + F_{n-1}$ for $n \\geq 1$. \nThere are $N$ positive integers $n$ strictly less than $10^{101}$ such that $n^2 + (n+1)^2$ is a multiple of 5 but $F_{n-1}^2 + F_n^2$ is not. \nHow many prime factors does $N$ have, counted with multiplicity?\n<|im_end|>\n<|im_start|>assistant\n'

In [9]:
def plot_stats_table(model_results):
    """ Plots a table of statistics for the given model results. """
    data = {'Metric': ['Total Generation Time', 'Batch Throughput (Tok/sec)',  "Avg Request Throughput (Tok/sec)"]}
    
    for model_name in model_results:
        granular_res = model_results[model_name]['results']
        total_time = model_results[model_name]['total_time']
        per_sample_throughputs = [r['throughput'] for r in granular_res]
        
        avg_per_sample_throughput = np.mean(per_sample_throughputs)
        std_per_sample_throughput = np.std(per_sample_throughputs)

        throughput = sum([r['num_tokens'] for r in granular_res]) / total_time
        
        data[model_name] = [
            f"{total_time:.1f}",
            f"{throughput:.1f}",
            f"{avg_per_sample_throughput:.1f} ± {std_per_sample_throughput:.2f}",
        ]
    
    df = pd.DataFrame(data)
    print(df)

In [10]:
start_time = time.time()
results_fp8_draft = stream_generate(
    llm,
    **request,
    **prompt_template.get_code_execution_args(),
    stop_after_n_completed=8,
    stop_after_n_seconds=200,
    return_stats=True,
    )
total_time = time.time() - start_time
benchmark['fp8_draft'] = {'results': results_fp8_draft, 'total_time': total_time}

In [11]:
kill_server(server_process)
server_process= start_server(MODEL_DIR_FP8)

Server closed.
Starting server from .//OpenMath-Nemotron-14B-kaggle-fp8-trtllm at 127.0.0.1:5000
Waiting for server to be ready (might take a while) ...
Server ready!


In [12]:
start_time = time.time()
results_fp8 = stream_generate(
    llm,
    **request,
    **prompt_template.get_code_execution_args(),
    stop_after_n_completed=8,
    stop_after_n_seconds=200,
    return_stats=True,
    )
total_time = time.time() - start_time
benchmark['fp8'] = {'results': results_fp8, 'total_time': total_time}

In [13]:
kill_server(server_process)
server_process= start_server(MODEL_DIR_BF16)

Server closed.
Starting server from .//OpenMath-Nemotron-14B-kaggle-bf16-trtllm at 127.0.0.1:5000
Waiting for server to be ready (might take a while) ...
Server ready!


In [14]:
start_time = time.time()
results_bf16 = stream_generate(
    llm,
    **request,
    **prompt_template.get_code_execution_args(),
    stop_after_n_completed=8,
    stop_after_n_seconds=200,
    return_stats=True,
    )
total_time = time.time() - start_time
benchmark['bf16'] = {'results': results_bf16, 'total_time': total_time}

In [15]:
plot_stats_table(benchmark)

                             Metric     fp8_draft          fp8         bf16
0             Total Generation Time          33.8         72.9        170.4
1        Batch Throughput (Tok/sec)        2036.0       1029.3        518.6
2  Avg Request Throughput (Tok/sec)  175.7 ± 0.94  89.9 ± 0.99  44.1 ± 0.26
