# Distributed Inference for 70B Llama Model

This notebook serves as a tutorial for prompting the 70B LLama 3.1 model on the FASRC cluster. It assumes that you already have set up or have access to a server for this model created using vLLM (see instructions [here](https://github.com/KempnerInstitute/distributed-inference-vllm)).

## Getting Set Up

First, let's load in a few libraries necessary to requesting from the server. In particular, we use the `requests` library for making HTTP requests to the vLLM server on our node.

The `MODEL_PATH` variable corresponds to the directory of the model weights - this is what vLLM uses to ID models.

In [None]:
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass
import time
import requests

MODEL_PATH = "/n/netscratch/kempner_dev/Everyone/models/Llama-3.1-70B"

Next, we need to set our sampling parameters. 

The vLLM server allows sampling parameters similar to the OpenAI API, including adjusting temperature and whether or not log probabilities are returned with each token. We'll cover these parameters in greater detail throughout this tutorial. You can also find more details on the available sampling parameters in the [vLLM repo docs](https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html#sampling-parameters).

Below, we build an `InferenceRequestParams` class which uses a subset of the available fields. These will be serialized through the `.dict()` method before being sent to the server.

In [None]:
@dataclass
class InferenceRequestParams:
    model: str # Should be set to MODEL_PATH
    prompt: str
    max_tokens: int
    min_tokens: int = 0
    temperature: float = 1.0
    frequency_penalty: float = 0.0
    top_k: int = -1
    logprobs: int | None = None
    prompt_logprobs: int | None = None

    def dict(self) -> dict:
        return asdict(self)
    
request_params = InferenceRequestParams(MODEL_PATH, "San Francisco is a ", 10, temperature=1.0, logprobs=0, prompt_logprobs=0)
request_params.dict()

## Exercise: Inference for a single prompt

Now we're ready to prompt our LLM model! 

**The details of how we do this:** we can send a request to the server via an HTTP POST request to the `/v1/completions` endpoint. This can be done via the `requests.post` function. We use the `json` argument to send a JSON payload with our sampling parameters to our server. The response will be returned as a JSON, which will contain the text completion as well as some metadata about the request.

1) **Find the IP of the node hosting the server.** You can find this as the `head node ip` detailed in the output logs of the slurm job. Replace the `FIXME` in the code below with a string containing the ip address you find in the logs. If you've already opened up VSCode remote development on the head compute node, you can use `localhost` for the IP address. 

2) **Run the code and take a look at the outputs.** Do you understand what each output represents? 

3) **Try changing the `max_tokens` and `min_tokens` value in the code below** and rerunning the code to see what outputs you get. 

4) **Try changing the `temperature` value in the code below**. Run the code a couple times for each temperature parameter you try. What do you notice about the output text when temperature is 0 vs higher? 

5) **Try changing the `frequency_penalty` value in the code below**. It can take any value between [-2, 2]. Can you figure out what this parameter does?

6) **Try changing the `top_k` value in the code below**. Use temperature=1. How do the outputs change?




In [None]:
IP_ADDRESS='10.31.147.3'
def send_request(request_params: InferenceRequestParams):
    response = requests.post(f'http://{IP_ADDRESS}:8000/v1/completions', json = request_params.dict())
    return response.json()

send_request(InferenceRequestParams(MODEL_PATH, "San Francisco is a ", max_tokens=10, min_tokens=0, temperature=0.0, frequency_penalty=0.0, top_k=-1))

## Exercise: Investigating log probabilities

In the `InferenceRequestParams` class, you can change the `logprob` and `prompt_logprob` fields in order to extract the log probabilities of the tokens in the completion and the prompt, respectively.

When set to `None` (the default), the log probabilities are not returned by the server.  When set to non-negative integer `k`, the server will return the top `k` highest log probabilities at each token generation step, along with the log probability of generated token if it is not in the top `k`. Note that if `k=0`, then the server will just return the log probabilities of the generated tokens.

In the code below, we return the output as the variable `response`, which is a dictionary. We've already returned just the relevant element of the dictionary using the key `choices`. 

Note that `response['choices']` is a list of length one. To access the logprobs values, you need to use `response['choices'][0]['logprobs']`. This formatting might seem odd — why is it a list with a single element? The reason is that certain sampling parameter choices can return multiple output sequences, making `response['choices']` a list with multiple elements. However, we won’t cover those parameters in this tutorial.

**Try re-running the following code with different values of `logprobs` and `prompt_logprobs` and inspect the outputs.** 


In [None]:
response = send_request(InferenceRequestParams(MODEL_PATH, "San Francisco is a ", max_tokens=10, temperature=1.0, logprobs=None, prompt_logprobs=None))

response['choices']

## Exercise: Inference for multiple prompts

What if we would like to run multiple prompts at once (in parallel), instead of one-at-a-time? 

To process multiple prompts, we can make use of Python's multithreading to send multiple requests to the server. We put the prompts on the queue and create `NUM_THREADS` thread workers to process the queue. Each worker will independently pull prompts from the queue and send the corresponding request to the server. 

You may be familiar with the idea of batching prompts together, which is what we do when training LLMs. Compared to batching the prompts and processing each batch one by one, this queue method achieves better performance. This is because when a prompt finishes quickly, the corresponding worker will automatically pull the next prompt from the queue. In comparison, in batch processing, you need to wait for the slowest prompt within a batch to finish before the next batch gets processed, even if most of the other prompts within that batch finished quickly.

1) Given the multithreading description, how would you expect the execution time with four prompts and 4 threads to compare with 3 prompts and 4 threads? How about compared to 5 prompts and 4 threads?
2) Try running the following code below. Add or remove prompts and change the value of `NUM_THREADS`. Look at the execution time. Does it match your expectations?
3) Try changing max_tokens and look at the effect on execution time.


In [None]:
NUM_THREADS = 4
prompts = ["San Francisco is a ", "Boston is a ", "Chicago is a ", "NYC is a"]
params = [InferenceRequestParams(MODEL_PATH, prompt, max_tokens=200, temperature=0.0) for prompt in prompts]

start_time = time.time()
with ThreadPoolExecutor(max_workers=NUM_THREADS) as pool:
    responses = pool.map(send_request, params)
print(f"Total time: {time.time() - start_time}")
[response['choices'][0]['text'] for response in responses]

## Exercise: Using the output logs to analyze performance

The output logs for the SLURM job running your server will show logs when requests are received, as well as statistics on the number of tokens being processed per second and KV cache usage. This can be helpful for debugging and performance analysis.

Run the following code and look at the output log file. Watch as the requests are received. You should see that the KV cache memory usage grows over time as more tokens are generated with each request. The usage should also drop as each request finishes.

In [None]:
NUM_THREADS = 4
params = InferenceRequestParams(MODEL_PATH, "Give me a list of 100 history questions related to American history. 1. How did ", 5000, temperature=0.0)

start_time = time.time()
with ThreadPoolExecutor(max_workers=NUM_THREADS) as pool:
    for i in range(NUM_THREADS):
        pool.submit(send_request, params)
        time.sleep(25)
f"Total time: {time.time() - start_time}"