Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code for chat inference server #1329

Closed
hobodrifterdavid opened this issue Jul 4, 2023 · 19 comments
Closed

Code for chat inference server #1329

hobodrifterdavid opened this issue Jul 4, 2023 · 19 comments

Comments

@hobodrifterdavid
Copy link

hobodrifterdavid commented Jul 4, 2023

I made a wrapper around CTranslate2, an API server. It batches requests, supports running these batches on multiple gpus (sort of round-robin), requests can have different priority levels, requests are immeadiatly rejected when more than 100 requests in front of it in the queue, and has an endpoint that returns the approximate load. It doesn't do streaming, just does continuation for up to 80 tokens.. this fits with how I am using CTranslate2. It avoids a websocket (which can create networking issues), from the front end you can do requests in a loop until you see the string </s> in the continuation string. It's responsive enough with the 13B model at least, the performance overhead could be tested. @guillaumekln If you want to use it as an example, or put it in a seperate repo, please feel free. Knowing myself, I know I probably am not able to find time to provide support, but it may be useful for someone. I also have a python wrapper for faster-whisper. If you prefer to just delete the issue, I won't have hurt feelings. :) If there is interest, I can document the code better.

What you do on the front-end etc.:

let continuation = 'The best things in life are';
let counter = 0;
while(!continuation.includes('</s>') && counter < 5) {
    // Request more tokens from the server
    let data = { prompt: continuation, priority: 'MED' }; // 'HIGH' | 'MED' | 'LOW' | 'VLOW'
    try {
        const response = await axios.post(`http://mydomain.com/dt_llama`, data, { timeout: 10_000});
        if(response.data.status === 'success') {
            continuation += response.data.data.text;
            counter += 1;
            // Here write the continuation to your UI, or observable state etc.
        } else {
            console.error(response.data.error);
            break;
        }
    } catch(err) {
        // non-2XX/3XX response
        console.error(err);
        break;
    }
}

app.py

# Import necessary libraries
import sys
from typing import List, Tuple
from fastapi import FastAPI, HTTPException, Response
from pydantic import BaseModel
import asyncio
from fastapi.responses import JSONResponse
import os
import ctranslate2
import sentencepiece as spm
import time
from datetime import datetime

model_dir = "/models/vicuna-13b-1.1_ct2/"

MAX_BATCH_SIZE = 6

class gpuProcessor:
    def __init__(self, device_number):

        self.busy = False
        self.load = 0
        self.device_number = device_number

        self.generator = ctranslate2.Generator(
            model_dir, device="cuda", device_index=[device_number], compute_type='int8_float16')  # device_index=[0, 1, 2, 3]
        self.sp = spm.SentencePieceProcessor(
            os.path.join(model_dir, "tokenizer.model"))

    def isBusy(self):
        return self.busy

    def getLoad(self):
        return self.load

    async def generate(self, prompts: List[str]):

        self.busy = True
        self.load = len(prompts)

        now = time.time()

        def sync_func() -> List[str]:

            # now = time.time()

            tokens: List[List[str]] = [
                ["<s>"] + self.sp.encode(prompt, out_type=str) for prompt in prompts]

            # print(f">> Tokenise time {time.time() - now}")
            # now = time.time()

            try:
                results = self.generator.generate_batch(
                    tokens,
                    beam_size=1,
                    sampling_temperature=0.8,
                    sampling_topk=10,
                    num_hypotheses=1,
                    max_length=80, # Hard coded to 80!
                    min_length=1,
                    include_prompt_in_result=False,
                    max_batch_size=MAX_BATCH_SIZE,
                    return_end_token=True,
                    # Edit: don't do this, value is too high and results in poor output:
                    # repetition_penalty=1.5
                )
            except Exception as e:
                print(f"An error occurred during generate_batch: {str(e)}")
                sys.exit(1)

            # print(f">> Generate time {time.time() - now}")
            # now = time.time()

            def decode(this_result) -> str:

                # print(this_result)

                if len(this_result.sequences) == 0:
                    # No sequences..
                    return ''

                if len(this_result.sequences[0]) == 0:
                    # Sequence is len 0
                    return ''

                startsWithSpace = this_result.sequences[0][0][0] == '▁'
                hasEOS = this_result.sequences[0][-1] == '</s>'

                string: str = self.sp.decode(this_result.sequences_ids[0])
                if startsWithSpace:
                    string = ' ' + string
                if hasEOS:
                    string += '</s>'

                return string

            output: List[str] = [decode(result) for result in results]

            # print(f"+++||| {len(prompts)} --> {len(output)}")

            # for i in range(len(prompts)):
            #     print(f"+++ {prompts[i][-5:]} --> {output[i][:16]}")

            # print(f">> Decode time {time.time() - now}")

            return output

        # Run sync_func asyncronously, so we don't block the event loop.
        # Allows other requests to be handled meanwhile.
        loop = asyncio.get_event_loop()

        # print(f"START run_in_executor {self.device_number}")
        output2 = await loop.run_in_executor(None, lambda: sync_func())
        # print(f"END run_in_executor {self.device_number}")

        print(f">> Batch of {len(prompts)} {time.time() - now}")

        self.busy = False
        self.load = 0

        return output2

    async def run(self):

        # print(f"Tick device: {self.device_number} busy: {self.busy}")

        global req_queue_HIGH
        global req_queue_MED
        global req_queue_LOW
        global req_queue_VLOW

        if self.busy:
            return

        queue_count = len(req_queue_HIGH) + len(req_queue_MED) + \
            len(req_queue_LOW) + len(req_queue_VLOW)

        if queue_count == 0:
            # print(f"{self.device_number} queue empty, run returning.")
            return

        space_remaining = MAX_BATCH_SIZE

        # Make a shallow copy of the queued items
        requestsProcessingNow = req_queue_HIGH[:space_remaining]
        # Remove these from the queue
        req_queue_HIGH = req_queue_HIGH[space_remaining:]

        space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_MED[:space_remaining]
            # Remove these from the queue
            req_queue_MED = req_queue_MED[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_LOW[:space_remaining]
            # Remove these from the queue
            req_queue_LOW = req_queue_LOW[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_VLOW[:space_remaining]
            # Remove these from the queue
            req_queue_VLOW = req_queue_VLOW[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if (len(requestsProcessingNow) == 0):
            return

        now = time.time()

        # print(f"{self.device_number} Running batch of {len(requestsProcessingNow)}, {len(req_queue_HIGH)} / {len(req_queue_MED)} / {len(req_queue_LOW)} / {len(req_queue_VLOW)} still queued.")

        results = await self.generate([prompt for prompt, _ in requestsProcessingNow])

        # print(f"{self.device_number} Run batch of {len(requestsProcessingNow)} in {time.time() - now}s")

        for i, (_, future) in enumerate(requestsProcessingNow):
            future.set_result(results[i])


processors = [
    gpuProcessor(0),
    gpuProcessor(1),
    gpuProcessor(2),
    gpuProcessor(3),
    gpuProcessor(4),
    gpuProcessor(5),
    gpuProcessor(6),
    gpuProcessor(7),
]

curGPUStartIndex = 0


def getProcessors():
    global curGPUStartIndex
    curGPUStartIndex += 1
    if (curGPUStartIndex >= len(processors)):
        curGPUStartIndex = 0

    return processors[curGPUStartIndex:] + processors[:curGPUStartIndex]


# High Priority Requests
req_queue_HIGH: List[Tuple[str, asyncio.Future]] = [
]

# Medium Priority Requests
req_queue_MED: List[Tuple[str, asyncio.Future]] = [
]

# Low Priority Requests
req_queue_LOW: List[Tuple[str, asyncio.Future]] = [
]

# Very Low Priority Requests
req_queue_VLOW: List[Tuple[str, asyncio.Future]] = [
]

app = FastAPI()


class InputData(BaseModel):
    prompt: str
    # max_len: int # Not used!
    priority: str  # 'HIGH', 'MED', 'LOW', 'VLOW'


@app.post("/dt_llama")
async def dt_llama(data: InputData):

    global req_queue_HIGH
    global req_queue_MED
    global req_queue_LOW
    global req_queue_VLOW

    # print(f"REQ: current queue {len(req_queue_HIGH)} / {len(req_queue_MED)} / {len(req_queue_LOW)} / {len(req_queue_VLOW)}")

    # Check for space..
    if data.priority == 'HIGH':
        if (len(req_queue_HIGH) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'MED':
        if ((len(req_queue_HIGH) + len(req_queue_MED)) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'LOW':
        if ((len(req_queue_HIGH) + len(req_queue_MED) + len(req_queue_LOW)) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'VLOW':
        if ((len(req_queue_HIGH) + len(req_queue_MED) + len(req_queue_LOW) + len(req_queue_VLOW)) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    else:
        # not HIGH, MED, LOW..
        return {
            'status': 'failure',
            'error': 'BAD_REQUEST'
        }

    try:
        # print(f"Generation request, len(data.prompt) {len(data.prompt)}.")

        future = asyncio.get_running_loop().create_future()

        if data.priority == 'HIGH':
            req_queue_HIGH.append((data.prompt, future))
        elif data.priority == 'MED':
            req_queue_MED.append((data.prompt, future))
        elif data.priority == 'LOW':
            req_queue_LOW.append((data.prompt, future))
        elif data.priority == 'VLOW':
            req_queue_VLOW.append((data.prompt, future))

        output = await future

        # print('Output: ' + output)

        # if data.priority == 'HIGH':
        #     print('MEOW: ' + output)
        # if data.priority == 'MED':
        #     print('MEOW: ' + output)

        return {
            'status': 'success',
            'data': {'text': output}
        }
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return Response(content='{"status": "failure", "error": "SERVER_ERROR"}', status_code=401)


class InputDataBatch(BaseModel):
    prompts: List[str]
    # max_len: int


async def generate_loop():

    start_time = 0

    while True:

        await asyncio.sleep(0.5)

        end_time = time.process_time()
        # print(f"Process time: {end_time - start_time}")
        start_time = time.process_time()

        processors = getProcessors()

        for processor in processors:
            # For each GPU..
            asyncio.create_task(processor.run())

        # print(f"-----")

stats = []


@app.get("/dt_stats")
async def dt_stats():
    return {
        'status': 'success',
        'data': {'stats': stats}
    }


async def gpu_status_loop():
    global stats
    while True:
        # Run every 1 seconds
        await asyncio.sleep(1)

        total = 0

        for processor in processors:
            total += processor.getLoad()

        # Number from 0 to 1..
        load = total/(MAX_BATCH_SIZE * len(processors))

        queue = len(req_queue_HIGH) + len(req_queue_MED) + \
            len(req_queue_LOW) + len(req_queue_VLOW)

        # last 5 mins
        stats = stats[-299:] + [[load, queue]]

        if (len(stats) > 1):
            # Stats from last 5s
            # print(f"LOAD: {stats[-1][0]}/{stats[-1][1]} {stats[-2][0]}/{stats[-1][1]} {stats[-3][0]}/{stats[-1][1]} {stats[-4][0]}/{stats[-1][1]} {stats[-5][0]}/{stats[-1][1]}")
            print(f"LOAD: {str(load).ljust(5)} QUEUE: H {str(len(req_queue_HIGH)).ljust(3)} / M {str(len(req_queue_MED)).ljust(3)} / L {str(len(req_queue_LOW)).ljust(3)} / VL {str(len(req_queue_VLOW)).ljust(3)}")


root_tasks = []


@ app.on_event("startup")
async def startup_event():
    global root_tasks
    root_tasks.append(asyncio.create_task(generate_loop()))
    root_tasks.append(asyncio.create_task(gpu_status_loop()))

Dockerfile

FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04

# Use Python 3.11 for better Python perf
# Update the package lists and install necessary dependencies
RUN apt-get update && apt-get install -y \
    software-properties-common \
    && add-apt-repository -y ppa:deadsnakes/ppa \
    && apt-get update \
    && apt-get install -y python3.11 python3.11-dev

# Set Python 3.11 as the default version (for python3)
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1

# Download get-pip.py script
RUN apt install curl -y
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py

# Install pip for Python 3.11
RUN python3 get-pip.py

# Verify Python and pip versions
RUN python3 --version && pip3.11 --version

# Set pip3.11 as the default pip command
RUN update-alternatives --install /usr/bin/pip3 pip3 /usr/local/lib/python3.11/dist-packages/pip 1

ENV PYTHONUNBUFFERED=1

# Install necessary dependencies
# RUN apt-get update && \
#     apt-get install -y python3-pip

# Set the working directory. /app is mounted to the container with -v, 
# but we want to have the right cwd for uvicorn command below
RUN mkdir /app
WORKDIR /app

# # Copy the app code and requirements filed
# COPY . /app
# COPY requirements.txt .

# Install the app dependencies
RUN pip3 install "ctranslate2==3.16" torch sentencepiece uvicorn fastapi

# Expose the FastAPI port
EXPOSE 14000

# Start the FastAPI app using Uvicorn web server
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "14000", "--limit-concurrency", "1000"]

docker-dioco-ct2-chat.service [systemd unit file]

[Unit]
Description=Docker Container for dioco-ct2-chat
Requires=docker.service
After=docker.service

[Service]
WorkingDirectory=/root/projects/dioco-ct2-chat
ExecStart=/usr/bin/docker run --gpus '"device=0,1,2,3,4,5,6,7"' --restart=unless-stopped -p 14000:14000 -v /root/projects/chatModels/:/models/ -v /root/projects/dioco-ct2-chat:/app chat

[Install]
WantedBy=default.target
@hobodrifterdavid
Copy link
Author

I editied the code above to remove the problematic 'repetition_penalty=1.5' parameter.

@silvacarl2
Copy link

this is isnaely cool and what we were also planning to do, checking it out!

@hobodrifterdavid
Copy link
Author

Okay, I'll add some tips later that may be helpful. :) It's currently running this feature: https://www.languagereactor.com/chatbot

@silvacarl2
Copy link

NICE!!!!!!!!!!!!!! which model(s) are you running?

@silvacarl2
Copy link

have you added any more to this?

@suhjohn
Copy link

suhjohn commented Aug 17, 2023

import argparse
import asyncio
import os
import queue
import random
import threading
import time
from typing import List, Tuple, Dict, Union, AsyncGenerator

import ctranslate2
import sentencepiece as spm
import uvicorn
from ctranslate2 import GenerationResult, GenerationStepResult
from fastapi import FastAPI, responses, status
from pydantic import BaseModel


class DynamicBuffer:
    def __init__(self):
        self.buffer = []  # Stores (timestamp, object) tuples

    def add(self, obj):
        timestamp = time.time()
        self.buffer.append((timestamp, obj))

    def flush(self, n: int):
        n = int(n)
        # Current time
        current_time = time.time()
        if len(self.buffer) == 0:
            return []
       # Flush 1 object if the buffer has 1 object and it's older than 1 second.
        elif len(self.buffer) == 1:
            obj = self.buffer[0]
            if current_time - obj[0] > 1:
                self.buffer = []
                return [obj[1]]
            return []
        # Flush n objects if the buffer has more than 2.
        to_flush = self.buffer[:n]
        self.buffer = self.buffer[n:]
        return [obj for _, obj in to_flush]


MAX_BATCH_SIZE = int(os.environ.get("MAX_BATCH_SIZE", 6))
MAX_BATCH_COUNT = int(os.environ.get("MAX_BATCH_COUNT", 1))
BUFFER = DynamicBuffer()
STATS = []
DAEMON_TASKS = []


class GPUProcessor:
    def __init__(
        self, device_number: int, model_dir: str, compute_type: str, tokenizer_dir: str
    ):
        self.load: int = 0  # number of prompts being handled by the GPU
        self.batch_count: int = 0
        self.device_number: int = device_number
        print(f"Loading model {model_dir} on GPU {device_number}")
        self.generator = ctranslate2.Generator(
            model_dir,
            device="cuda",
            device_index=[device_number],
            compute_type=compute_type,
        )
        print(f"Loading tokenizer {tokenizer_dir}")
        tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.model")
        self.sp = spm.SentencePieceProcessor(tokenizer_path)

    def get_load(self) -> int:
        return self.load

    async def generate(self, prompts: List[str]):
        # self.is_busy = True
        self.load = len(prompts)

        def _generate() -> Tuple[List[str], int]:
            tokens: List[List[str]] = [
                ["<s>"] + self.sp.encode(prompt, out_type=str) for prompt in prompts
            ]
            try:
                results = self.generator.generate_batch(
                    tokens,
                    beam_size=1,
                    sampling_temperature=0,
                    num_hypotheses=1,
                    max_length=256,
                    min_length=1,
                    include_prompt_in_result=False,
                    max_batch_size=MAX_BATCH_SIZE,
                    return_end_token=False,
                )
            except Exception as e:
                print(f"An error occurred during generate_batch: {str(e)}")
                raise e

            outputs: List[Tuple[str, int]] = [decode(result) for result in results]
            decoded_output: List[str] = [output[0] for output in outputs]
            num_tokens = sum([output[1] for output in outputs])
            return decoded_output, num_tokens

        def decode(result: GenerationResult) -> Tuple[str, int]:
            if len(result.sequences) == 0:
                return "", 0

            if len(result.sequences[0]) == 0:
                return "", 0
            starts_with_space = result.sequences[0][0][0] == "▁"
            has_eos = result.sequences[0][-1] == "</s>"

            string: str = self.sp.decode(result.sequences_ids[0])
            if starts_with_space:
                string = " " + string
            if has_eos:
                string += "</s>"
            return string, len(result.sequences_ids[0])

        loop = asyncio.get_event_loop()
        start = time.time()
        outputs, token_count = await loop.run_in_executor(None, _generate)
        end = time.time()
        print(
            f"[{self.device_number}] Time taken: {end - start}, batch size: {len(prompts)}, token count: {token_count}, time per token: {(end - start) * 1000 / token_count}ms"
        )
        # self.is_busy = False
        self.load = 0
        return outputs

    async def generate_stream(
        self, requests_to_process: List[Tuple[str, asyncio.Queue]]
    ):
        step_result_queue = asyncio.Queue()
        closed_events = [asyncio.Event() for _ in range(len(requests_to_process))]
        start = time.time()

        def _on_success(res: GenerationStepResult):
            target_closed_event = closed_events[res.batch_id]
            step_result_queue.put_nowait((res.batch_id, res))
            if res.is_last:
                step_result_queue.put_nowait((res.batch_id, None))
            return target_closed_event.is_set()

        def _on_error():
            for idx, result in enumerate(results):
                try:
                    result.result()
                except Exception as e:
                    step_result_queue.put_nowait((idx, e))

        token_count = 0
        self.load = len(requests_to_process)
        tokens = [
            ["<s>"] + self.sp.encode(q[0], out_type=str)
            for q in requests_to_process
        ]
        input_token_lengths = [len(token) for token in tokens]
        try:
            results = self.generator.generate_batch(
                tokens,
                beam_size=1,
                sampling_temperature=0,
                num_hypotheses=1,
                max_length=256,
                min_length=1,
                include_prompt_in_result=False,
                max_batch_size=MAX_BATCH_SIZE,
                return_end_token=False,
                asynchronous=True,
                callback=_on_success,
            )
            thread = threading.Thread(target=_on_error, daemon=True)
            thread.start()
            pending_idx = set(range(len(requests_to_process)))
            while pending_idx:
                idx, step_result = await step_result_queue.get()
                target_streaming_queue = requests_to_process[idx][1]
                if step_result is None or isinstance(step_result, Exception):
                    await target_streaming_queue.put(None)
                    pending_idx.remove(idx)
                    continue
                decoded_token = self.sp.decode([step_result.token_id])
                #  NOTE: ▁ != _. Special token for delimiter
                if step_result.token.startswith("▁"):
                    decoded_token = f" {decoded_token}"
                await target_streaming_queue.put(decoded_token)
                token_count += 1
                await asyncio.sleep(0)
            thread.join()
        except Exception as e:
            print(f"An error occurred during generate_stream: {str(e)}")
            for closed_event in closed_events:
                closed_event.set()
        finally:
            self.load = 0

        end = time.time()
        print(
            f"[GPU-{self.device_number}] total inference time: {end - start}, batch size: {len(requests_to_process)}, input_token_lengths: {input_token_lengths}, "
            f"total output token count: {token_count}, time per token: {(end - start) * 1000 / token_count}ms"
        )

    @property
    def is_busy(self) -> bool:
        return self.batch_count == MAX_BATCH_COUNT

    async def handle_stream_request(self):
        global BUFFER
        processing_queue = BUFFER.flush(MAX_BATCH_SIZE)
        if len(processing_queue) == 0:
            return
        try:
            self.batch_count += 1
            print(
                f"[GPU-{self.device_number}] handle_stream_request: processing {len(processing_queue)} requests"
            )
            await self.generate_stream(processing_queue)
        finally:
            self.batch_count -= 1


app = FastAPI()


class GenerateParams(BaseModel):
    prompt: str


async def stream_result(data: GenerateParams) -> AsyncGenerator[str, None]:
    global BUFFER
    new_queue = asyncio.Queue()
    BUFFER.add((data.prompt, new_queue))
    while True:
        try:
            result = await new_queue.get()
            if result is None:
                break
            yield result
            await asyncio.sleep(0)
        except queue.Empty:
            break


@app.post("/generate")
async def generate(data: GenerateParams):
    if len(BUFFER.buffer) > 100:
        return responses.JSONResponse(
            status_code=status.HTTP_429_TOO_MANY_REQUESTS,
            content={"error": "TOO_MANY_REQUESTS"},
        )
    try:
        output = "".join([value async for value in stream_result(data)])
        return responses.JSONResponse({"generated_text": output})
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return responses.JSONResponse(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            content={"status": "failure", "error": "INTERNAL_SERVER_ERROR"},
        )


@app.post("/generate_stream")
async def generate_stream(data: GenerateParams):
    if len(BUFFER.buffer) > 100:
        return responses.JSONResponse(
            status_code=status.HTTP_429_TOO_MANY_REQUESTS,
            content={"error": "TOO_MANY_REQUESTS"},
        )
    try:
        return responses.StreamingResponse(stream_result(data))
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return responses.JSONResponse(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            content={"status": "failure", "error": "INTERNAL_SERVER_ERROR"},
        )


async def run_stream_consumer():
    while True:
        await asyncio.sleep(0.25)
        random.shuffle(processors)
        for processor in processors:
            if processor.is_busy:
                continue
            asyncio.create_task(processor.handle_stream_request())


@app.get("/metrics")
async def metrics() -> Dict[str, List[List[Union[float, int]]]]:
    return {"stats": STATS}


async def update_stats():
    """
    Updates the stats every second
    :return:
    """
    global STATS
    while True:
        await asyncio.sleep(1)

        total = 0

        for processor in processors:
            total += processor.get_load()

        load = total / (MAX_BATCH_SIZE * len(processors))

        queue_size = len(BUFFER.buffer)

        STATS = STATS[-299:] + [[load, queue_size]]


@app.on_event("startup")
async def startup_event():
    global DAEMON_TASKS
    DAEMON_TASKS.append(asyncio.create_task(run_stream_consumer()))
    DAEMON_TASKS.append(asyncio.create_task(update_stats()))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="FastAPI application with ctranslate2")
    parser.add_argument(
        "--model_dir",
        type=str,
        default=os.environ.get(
            "QUANTIZED_MODEL_OUTPUT_DIR", "/workspace/chronos-hermes-13b-v2-int8"
        ),
        help="Directory containing the model",
    )
    parser.add_argument(
        "--compute_type",
        type=str,
        default=os.environ.get("QUANTIZATION", "int8"),
        help="Compute type of the model. Refer to ctranslate2 docs.",
    )
    parser.add_argument(
        "--tokenizer_dir",
        type=str,
        default=os.environ.get("TOKENIZER_DIR", "/workspace/chronos-hermes-13b-v2"),
        help="Directory containing the tokenizer model.",
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=os.environ.get("NUM_GPUS", 1),
        help="Number of GPU processors",
    )
    parser.add_argument("--host", type=str, default=os.environ.get("HOST", "0.0.0.0"))
    parser.add_argument("--port", type=int, default=os.environ.get("PORT", 8000))
    args = parser.parse_args()

    processors = [
        GPUProcessor(i, args.model_dir, args.compute_type, args.tokenizer_dir)
        for i in range(args.num_gpus)
    ]
    print(f"Number of GPUs: {len(processors)}")
    uvicorn.run(app, host=args.host, port=args.port)

Cleaned up the code a bit and made streaming possible, but the code is currently buggy with the wrong response coming back on parallel requests. Might be some un-threadsafe parts of the code that's causing this

@silvacarl2
Copy link

THIS IS TOTALLY AWESOME!!!!!!!!!!!!!!!!!!

WILL TEST IT OUT!!!!!!!!!!!!!!!!!!

@BBC-Esq
Copy link

BBC-Esq commented Sep 6, 2023

I editied the code above to remove the problematic 'repetition_penalty=1.5' parameter.

Is there any way I can get a copy of the python wrapper you referred to in your first message? I'm also curious if you have any other resources regarding Python as it relates to ctranslate2 specifically? I've been struggling to write my own scripts utilizing both technologies, although I still think ctranslate2 is awesome.

@guillaumekln
Copy link
Collaborator

@BBC-Esq For a chat usage, you can see the Llama 2 example:

https://github.com/OpenNMT/CTranslate2/tree/master/examples/llama2

@BBC-Esq
Copy link

BBC-Esq commented Sep 7, 2023

I just cloned your repo and am going to start studying it a little more to understand it better. I was already aware of that chat script and actually have something based off of it, but I'm still struggling. Will keep trying though because, if it lives up to what I've seen comparatively with Faster Whisper and WhisperX, it'll beat ggml/gguf/gptq and those guys. Appreciate the advice.

@silvacarl2
Copy link

this package is the best. also, you should check out fatser-whisper.

@Max-Rubino
Copy link

import argparse
import asyncio
import os
import queue
import random
import threading
import time
from typing import List, Tuple, Dict, Union, AsyncGenerator

import ctranslate2
import sentencepiece as spm
import uvicorn
from ctranslate2 import GenerationResult, GenerationStepResult
from fastapi import FastAPI, responses, status
from pydantic import BaseModel


class DynamicBuffer:
    def __init__(self):
        self.buffer = []  # Stores (timestamp, object) tuples

    def add(self, obj):
        timestamp = time.time()
        self.buffer.append((timestamp, obj))

    def flush(self, n: int):
        n = int(n)
        # Current time
        current_time = time.time()
        if len(self.buffer) == 0:
            return []
       # Flush 1 object if the buffer has 1 object and it's older than 1 second.
        elif len(self.buffer) == 1:
            obj = self.buffer[0]
            if current_time - obj[0] > 1:
                self.buffer = []
                return [obj[1]]
            return []
        # Flush n objects if the buffer has more than 2.
        to_flush = self.buffer[:n]
        self.buffer = self.buffer[n:]
        return [obj for _, obj in to_flush]


MAX_BATCH_SIZE = int(os.environ.get("MAX_BATCH_SIZE", 6))
MAX_BATCH_COUNT = int(os.environ.get("MAX_BATCH_COUNT", 1))
BUFFER = DynamicBuffer()
STATS = []
DAEMON_TASKS = []


class GPUProcessor:
    def __init__(
        self, device_number: int, model_dir: str, compute_type: str, tokenizer_dir: str
    ):
        self.load: int = 0  # number of prompts being handled by the GPU
        self.batch_count: int = 0
        self.device_number: int = device_number
        print(f"Loading model {model_dir} on GPU {device_number}")
        self.generator = ctranslate2.Generator(
            model_dir,
            device="cuda",
            device_index=[device_number],
            compute_type=compute_type,
        )
        print(f"Loading tokenizer {tokenizer_dir}")
        tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.model")
        self.sp = spm.SentencePieceProcessor(tokenizer_path)

    def get_load(self) -> int:
        return self.load

    async def generate(self, prompts: List[str]):
        # self.is_busy = True
        self.load = len(prompts)

        def _generate() -> Tuple[List[str], int]:
            tokens: List[List[str]] = [
                ["<s>"] + self.sp.encode(prompt, out_type=str) for prompt in prompts
            ]
            try:
                results = self.generator.generate_batch(
                    tokens,
                    beam_size=1,
                    sampling_temperature=0,
                    num_hypotheses=1,
                    max_length=256,
                    min_length=1,
                    include_prompt_in_result=False,
                    max_batch_size=MAX_BATCH_SIZE,
                    return_end_token=False,
                )
            except Exception as e:
                print(f"An error occurred during generate_batch: {str(e)}")
                raise e

            outputs: List[Tuple[str, int]] = [decode(result) for result in results]
            decoded_output: List[str] = [output[0] for output in outputs]
            num_tokens = sum([output[1] for output in outputs])
            return decoded_output, num_tokens

        def decode(result: GenerationResult) -> Tuple[str, int]:
            if len(result.sequences) == 0:
                return "", 0

            if len(result.sequences[0]) == 0:
                return "", 0
            starts_with_space = result.sequences[0][0][0] == "▁"
            has_eos = result.sequences[0][-1] == "</s>"

            string: str = self.sp.decode(result.sequences_ids[0])
            if starts_with_space:
                string = " " + string
            if has_eos:
                string += "</s>"
            return string, len(result.sequences_ids[0])

        loop = asyncio.get_event_loop()
        start = time.time()
        outputs, token_count = await loop.run_in_executor(None, _generate)
        end = time.time()
        print(
            f"[{self.device_number}] Time taken: {end - start}, batch size: {len(prompts)}, token count: {token_count}, time per token: {(end - start) * 1000 / token_count}ms"
        )
        # self.is_busy = False
        self.load = 0
        return outputs

    async def generate_stream(
        self, requests_to_process: List[Tuple[str, asyncio.Queue]]
    ):
        step_result_queue = asyncio.Queue()
        closed_events = [asyncio.Event() for _ in range(len(requests_to_process))]
        start = time.time()

        def _on_success(res: GenerationStepResult):
            target_closed_event = closed_events[res.batch_id]
            step_result_queue.put_nowait((res.batch_id, res))
            if res.is_last:
                step_result_queue.put_nowait((res.batch_id, None))
            return target_closed_event.is_set()

        def _on_error():
            for idx, result in enumerate(results):
                try:
                    result.result()
                except Exception as e:
                    step_result_queue.put_nowait((idx, e))

        token_count = 0
        self.load = len(requests_to_process)
        tokens = [
            ["<s>"] + self.sp.encode(q[0], out_type=str)
            for q in requests_to_process
        ]
        input_token_lengths = [len(token) for token in tokens]
        try:
            results = self.generator.generate_batch(
                tokens,
                beam_size=1,
                sampling_temperature=0,
                num_hypotheses=1,
                max_length=256,
                min_length=1,
                include_prompt_in_result=False,
                max_batch_size=MAX_BATCH_SIZE,
                return_end_token=False,
                asynchronous=True,
                callback=_on_success,
            )
            thread = threading.Thread(target=_on_error, daemon=True)
            thread.start()
            pending_idx = set(range(len(requests_to_process)))
            while pending_idx:
                idx, step_result = await step_result_queue.get()
                target_streaming_queue = requests_to_process[idx][1]
                if step_result is None or isinstance(step_result, Exception):
                    await target_streaming_queue.put(None)
                    pending_idx.remove(idx)
                    continue
                decoded_token = self.sp.decode([step_result.token_id])
                #  NOTE: ▁ != _. Special token for delimiter
                if step_result.token.startswith("▁"):
                    decoded_token = f" {decoded_token}"
                await target_streaming_queue.put(decoded_token)
                token_count += 1
                await asyncio.sleep(0)
            thread.join()
        except Exception as e:
            print(f"An error occurred during generate_stream: {str(e)}")
            for closed_event in closed_events:
                closed_event.set()
        finally:
            self.load = 0

        end = time.time()
        print(
            f"[GPU-{self.device_number}] total inference time: {end - start}, batch size: {len(requests_to_process)}, input_token_lengths: {input_token_lengths}, "
            f"total output token count: {token_count}, time per token: {(end - start) * 1000 / token_count}ms"
        )

    @property
    def is_busy(self) -> bool:
        return self.batch_count == MAX_BATCH_COUNT

    async def handle_stream_request(self):
        global BUFFER
        processing_queue = BUFFER.flush(MAX_BATCH_SIZE)
        if len(processing_queue) == 0:
            return
        try:
            self.batch_count += 1
            print(
                f"[GPU-{self.device_number}] handle_stream_request: processing {len(processing_queue)} requests"
            )
            await self.generate_stream(processing_queue)
        finally:
            self.batch_count -= 1


app = FastAPI()


class GenerateParams(BaseModel):
    prompt: str


async def stream_result(data: GenerateParams) -> AsyncGenerator[str, None]:
    global BUFFER
    new_queue = asyncio.Queue()
    BUFFER.add((data.prompt, new_queue))
    while True:
        try:
            result = await new_queue.get()
            if result is None:
                break
            yield result
            await asyncio.sleep(0)
        except queue.Empty:
            break


@app.post("/generate")
async def generate(data: GenerateParams):
    if len(BUFFER.buffer) > 100:
        return responses.JSONResponse(
            status_code=status.HTTP_429_TOO_MANY_REQUESTS,
            content={"error": "TOO_MANY_REQUESTS"},
        )
    try:
        output = "".join([value async for value in stream_result(data)])
        return responses.JSONResponse({"generated_text": output})
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return responses.JSONResponse(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            content={"status": "failure", "error": "INTERNAL_SERVER_ERROR"},
        )


@app.post("/generate_stream")
async def generate_stream(data: GenerateParams):
    if len(BUFFER.buffer) > 100:
        return responses.JSONResponse(
            status_code=status.HTTP_429_TOO_MANY_REQUESTS,
            content={"error": "TOO_MANY_REQUESTS"},
        )
    try:
        return responses.StreamingResponse(stream_result(data))
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return responses.JSONResponse(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            content={"status": "failure", "error": "INTERNAL_SERVER_ERROR"},
        )


async def run_stream_consumer():
    while True:
        await asyncio.sleep(0.25)
        random.shuffle(processors)
        for processor in processors:
            if processor.is_busy:
                continue
            asyncio.create_task(processor.handle_stream_request())


@app.get("/metrics")
async def metrics() -> Dict[str, List[List[Union[float, int]]]]:
    return {"stats": STATS}


async def update_stats():
    """
    Updates the stats every second
    :return:
    """
    global STATS
    while True:
        await asyncio.sleep(1)

        total = 0

        for processor in processors:
            total += processor.get_load()

        load = total / (MAX_BATCH_SIZE * len(processors))

        queue_size = len(BUFFER.buffer)

        STATS = STATS[-299:] + [[load, queue_size]]


@app.on_event("startup")
async def startup_event():
    global DAEMON_TASKS
    DAEMON_TASKS.append(asyncio.create_task(run_stream_consumer()))
    DAEMON_TASKS.append(asyncio.create_task(update_stats()))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="FastAPI application with ctranslate2")
    parser.add_argument(
        "--model_dir",
        type=str,
        default=os.environ.get(
            "QUANTIZED_MODEL_OUTPUT_DIR", "/workspace/chronos-hermes-13b-v2-int8"
        ),
        help="Directory containing the model",
    )
    parser.add_argument(
        "--compute_type",
        type=str,
        default=os.environ.get("QUANTIZATION", "int8"),
        help="Compute type of the model. Refer to ctranslate2 docs.",
    )
    parser.add_argument(
        "--tokenizer_dir",
        type=str,
        default=os.environ.get("TOKENIZER_DIR", "/workspace/chronos-hermes-13b-v2"),
        help="Directory containing the tokenizer model.",
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=os.environ.get("NUM_GPUS", 1),
        help="Number of GPU processors",
    )
    parser.add_argument("--host", type=str, default=os.environ.get("HOST", "0.0.0.0"))
    parser.add_argument("--port", type=int, default=os.environ.get("PORT", 8000))
    args = parser.parse_args()

    processors = [
        GPUProcessor(i, args.model_dir, args.compute_type, args.tokenizer_dir)
        for i in range(args.num_gpus)
    ]
    print(f"Number of GPUs: {len(processors)}")
    uvicorn.run(app, host=args.host, port=args.port)

Cleaned up the code a bit and made streaming possible, but the code is currently buggy with the wrong response coming back on parallel requests. Might be some un-threadsafe parts of the code that's causing this

I tried to solve the bug. My solution is not very elegant but it should solve the problem. So the solution for now is to process the request with more tokens at the beginning. You only need to modify the handle_stream_request.

    async def handle_stream_request(self):
        global BUFFER
        processing_queue = BUFFER.flush(MAX_BATCH_SIZE)
        # Sort the processing_queue based on token lengths
        processing_queue.sort(key=lambda item: len(self.sp.encode(item[0], out_type=str)),reverse=True)        
        if len(processing_queue) == 0:
            return
        try:
            self.batch_count += 1
            print(
                f"[GPU-{self.device_number}] handle_stream_request: processing {len(processing_queue)} requests"
            )
            # print(processing_queue)
            await self.generate_stream(processing_queue)
        finally:
            self.batch_count -= 1

@silvacarl2
Copy link

very cool, checking it out.

@BBC-Esq
Copy link

BBC-Esq commented Oct 2, 2023

I made a wrapper around CTranslate2, an API server. It batches requests, supports running these batches on multiple gpus (sort of round-robin), requests can have different priority levels, requests are immeadiatly rejected when more than 100 requests in front of it in the queue, and has an endpoint that returns the approximate load. It doesn't do streaming, just does continuation for up to 80 tokens.. this fits with how I am using CTranslate2. It avoids a websocket (which can create networking issues), from the front end you can do requests in a loop until you see the string </s> in the continuation string. It's responsive enough with the 13B model at least, the performance overhead could be tested. @guillaumekln If you want to use it as an example, or put it in a seperate repo, please feel free. Knowing myself, I know I probably am not able to find time to provide support, but it may be useful for someone. I also have a python wrapper for faster-whisper. If you prefer to just delete the issue, I won't have hurt feelings. :) If there is interest, I can document the code better.

What you do on the front-end etc.:

let continuation = 'The best things in life are';
let counter = 0;
while(!continuation.includes('</s>') && counter < 5) {
    // Request more tokens from the server
    let data = { prompt: continuation, priority: 'MED' }; // 'HIGH' | 'MED' | 'LOW' | 'VLOW'
    try {
        const response = await axios.post(`http://mydomain.com/dt_llama`, data, { timeout: 10_000});
        if(response.data.status === 'success') {
            continuation += response.data.data.text;
            counter += 1;
            // Here write the continuation to your UI, or observable state etc.
        } else {
            console.error(response.data.error);
            break;
        }
    } catch(err) {
        // non-2XX/3XX response
        console.error(err);
        break;
    }
}

app.py

# Import necessary libraries
import sys
from typing import List, Tuple
from fastapi import FastAPI, HTTPException, Response
from pydantic import BaseModel
import asyncio
from fastapi.responses import JSONResponse
import os
import ctranslate2
import sentencepiece as spm
import time
from datetime import datetime

model_dir = "/models/vicuna-13b-1.1_ct2/"

MAX_BATCH_SIZE = 6

class gpuProcessor:
    def __init__(self, device_number):

        self.busy = False
        self.load = 0
        self.device_number = device_number

        self.generator = ctranslate2.Generator(
            model_dir, device="cuda", device_index=[device_number], compute_type='int8_float16')  # device_index=[0, 1, 2, 3]
        self.sp = spm.SentencePieceProcessor(
            os.path.join(model_dir, "tokenizer.model"))

    def isBusy(self):
        return self.busy

    def getLoad(self):
        return self.load

    async def generate(self, prompts: List[str]):

        self.busy = True
        self.load = len(prompts)

        now = time.time()

        def sync_func() -> List[str]:

            # now = time.time()

            tokens: List[List[str]] = [
                ["<s>"] + self.sp.encode(prompt, out_type=str) for prompt in prompts]

            # print(f">> Tokenise time {time.time() - now}")
            # now = time.time()

            try:
                results = self.generator.generate_batch(
                    tokens,
                    beam_size=1,
                    sampling_temperature=0.8,
                    sampling_topk=10,
                    num_hypotheses=1,
                    max_length=80, # Hard coded to 80!
                    min_length=1,
                    include_prompt_in_result=False,
                    max_batch_size=MAX_BATCH_SIZE,
                    return_end_token=True,
                    # Edit: don't do this, value is too high and results in poor output:
                    # repetition_penalty=1.5
                )
            except Exception as e:
                print(f"An error occurred during generate_batch: {str(e)}")
                sys.exit(1)

            # print(f">> Generate time {time.time() - now}")
            # now = time.time()

            def decode(this_result) -> str:

                # print(this_result)

                if len(this_result.sequences) == 0:
                    # No sequences..
                    return ''

                if len(this_result.sequences[0]) == 0:
                    # Sequence is len 0
                    return ''

                startsWithSpace = this_result.sequences[0][0][0] == '▁'
                hasEOS = this_result.sequences[0][-1] == '</s>'

                string: str = self.sp.decode(this_result.sequences_ids[0])
                if startsWithSpace:
                    string = ' ' + string
                if hasEOS:
                    string += '</s>'

                return string

            output: List[str] = [decode(result) for result in results]

            # print(f"+++||| {len(prompts)} --> {len(output)}")

            # for i in range(len(prompts)):
            #     print(f"+++ {prompts[i][-5:]} --> {output[i][:16]}")

            # print(f">> Decode time {time.time() - now}")

            return output

        # Run sync_func asyncronously, so we don't block the event loop.
        # Allows other requests to be handled meanwhile.
        loop = asyncio.get_event_loop()

        # print(f"START run_in_executor {self.device_number}")
        output2 = await loop.run_in_executor(None, lambda: sync_func())
        # print(f"END run_in_executor {self.device_number}")

        print(f">> Batch of {len(prompts)} {time.time() - now}")

        self.busy = False
        self.load = 0

        return output2

    async def run(self):

        # print(f"Tick device: {self.device_number} busy: {self.busy}")

        global req_queue_HIGH
        global req_queue_MED
        global req_queue_LOW
        global req_queue_VLOW

        if self.busy:
            return

        queue_count = len(req_queue_HIGH) + len(req_queue_MED) + \
            len(req_queue_LOW) + len(req_queue_VLOW)

        if queue_count == 0:
            # print(f"{self.device_number} queue empty, run returning.")
            return

        space_remaining = MAX_BATCH_SIZE

        # Make a shallow copy of the queued items
        requestsProcessingNow = req_queue_HIGH[:space_remaining]
        # Remove these from the queue
        req_queue_HIGH = req_queue_HIGH[space_remaining:]

        space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_MED[:space_remaining]
            # Remove these from the queue
            req_queue_MED = req_queue_MED[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_LOW[:space_remaining]
            # Remove these from the queue
            req_queue_LOW = req_queue_LOW[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_VLOW[:space_remaining]
            # Remove these from the queue
            req_queue_VLOW = req_queue_VLOW[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if (len(requestsProcessingNow) == 0):
            return

        now = time.time()

        # print(f"{self.device_number} Running batch of {len(requestsProcessingNow)}, {len(req_queue_HIGH)} / {len(req_queue_MED)} / {len(req_queue_LOW)} / {len(req_queue_VLOW)} still queued.")

        results = await self.generate([prompt for prompt, _ in requestsProcessingNow])

        # print(f"{self.device_number} Run batch of {len(requestsProcessingNow)} in {time.time() - now}s")

        for i, (_, future) in enumerate(requestsProcessingNow):
            future.set_result(results[i])


processors = [
    gpuProcessor(0),
    gpuProcessor(1),
    gpuProcessor(2),
    gpuProcessor(3),
    gpuProcessor(4),
    gpuProcessor(5),
    gpuProcessor(6),
    gpuProcessor(7),
]

curGPUStartIndex = 0


def getProcessors():
    global curGPUStartIndex
    curGPUStartIndex += 1
    if (curGPUStartIndex >= len(processors)):
        curGPUStartIndex = 0

    return processors[curGPUStartIndex:] + processors[:curGPUStartIndex]


# High Priority Requests
req_queue_HIGH: List[Tuple[str, asyncio.Future]] = [
]

# Medium Priority Requests
req_queue_MED: List[Tuple[str, asyncio.Future]] = [
]

# Low Priority Requests
req_queue_LOW: List[Tuple[str, asyncio.Future]] = [
]

# Very Low Priority Requests
req_queue_VLOW: List[Tuple[str, asyncio.Future]] = [
]

app = FastAPI()


class InputData(BaseModel):
    prompt: str
    # max_len: int # Not used!
    priority: str  # 'HIGH', 'MED', 'LOW', 'VLOW'


@app.post("/dt_llama")
async def dt_llama(data: InputData):

    global req_queue_HIGH
    global req_queue_MED
    global req_queue_LOW
    global req_queue_VLOW

    # print(f"REQ: current queue {len(req_queue_HIGH)} / {len(req_queue_MED)} / {len(req_queue_LOW)} / {len(req_queue_VLOW)}")

    # Check for space..
    if data.priority == 'HIGH':
        if (len(req_queue_HIGH) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'MED':
        if ((len(req_queue_HIGH) + len(req_queue_MED)) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'LOW':
        if ((len(req_queue_HIGH) + len(req_queue_MED) + len(req_queue_LOW)) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'VLOW':
        if ((len(req_queue_HIGH) + len(req_queue_MED) + len(req_queue_LOW) + len(req_queue_VLOW)) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    else:
        # not HIGH, MED, LOW..
        return {
            'status': 'failure',
            'error': 'BAD_REQUEST'
        }

    try:
        # print(f"Generation request, len(data.prompt) {len(data.prompt)}.")

        future = asyncio.get_running_loop().create_future()

        if data.priority == 'HIGH':
            req_queue_HIGH.append((data.prompt, future))
        elif data.priority == 'MED':
            req_queue_MED.append((data.prompt, future))
        elif data.priority == 'LOW':
            req_queue_LOW.append((data.prompt, future))
        elif data.priority == 'VLOW':
            req_queue_VLOW.append((data.prompt, future))

        output = await future

        # print('Output: ' + output)

        # if data.priority == 'HIGH':
        #     print('MEOW: ' + output)
        # if data.priority == 'MED':
        #     print('MEOW: ' + output)

        return {
            'status': 'success',
            'data': {'text': output}
        }
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return Response(content='{"status": "failure", "error": "SERVER_ERROR"}', status_code=401)


class InputDataBatch(BaseModel):
    prompts: List[str]
    # max_len: int


async def generate_loop():

    start_time = 0

    while True:

        await asyncio.sleep(0.5)

        end_time = time.process_time()
        # print(f"Process time: {end_time - start_time}")
        start_time = time.process_time()

        processors = getProcessors()

        for processor in processors:
            # For each GPU..
            asyncio.create_task(processor.run())

        # print(f"-----")

stats = []


@app.get("/dt_stats")
async def dt_stats():
    return {
        'status': 'success',
        'data': {'stats': stats}
    }


async def gpu_status_loop():
    global stats
    while True:
        # Run every 1 seconds
        await asyncio.sleep(1)

        total = 0

        for processor in processors:
            total += processor.getLoad()

        # Number from 0 to 1..
        load = total/(MAX_BATCH_SIZE * len(processors))

        queue = len(req_queue_HIGH) + len(req_queue_MED) + \
            len(req_queue_LOW) + len(req_queue_VLOW)

        # last 5 mins
        stats = stats[-299:] + [[load, queue]]

        if (len(stats) > 1):
            # Stats from last 5s
            # print(f"LOAD: {stats[-1][0]}/{stats[-1][1]} {stats[-2][0]}/{stats[-1][1]} {stats[-3][0]}/{stats[-1][1]} {stats[-4][0]}/{stats[-1][1]} {stats[-5][0]}/{stats[-1][1]}")
            print(f"LOAD: {str(load).ljust(5)} QUEUE: H {str(len(req_queue_HIGH)).ljust(3)} / M {str(len(req_queue_MED)).ljust(3)} / L {str(len(req_queue_LOW)).ljust(3)} / VL {str(len(req_queue_VLOW)).ljust(3)}")


root_tasks = []


@ app.on_event("startup")
async def startup_event():
    global root_tasks
    root_tasks.append(asyncio.create_task(generate_loop()))
    root_tasks.append(asyncio.create_task(gpu_status_loop()))

Dockerfile

FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04

# Use Python 3.11 for better Python perf
# Update the package lists and install necessary dependencies
RUN apt-get update && apt-get install -y \
    software-properties-common \
    && add-apt-repository -y ppa:deadsnakes/ppa \
    && apt-get update \
    && apt-get install -y python3.11 python3.11-dev

# Set Python 3.11 as the default version (for python3)
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1

# Download get-pip.py script
RUN apt install curl -y
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py

# Install pip for Python 3.11
RUN python3 get-pip.py

# Verify Python and pip versions
RUN python3 --version && pip3.11 --version

# Set pip3.11 as the default pip command
RUN update-alternatives --install /usr/bin/pip3 pip3 /usr/local/lib/python3.11/dist-packages/pip 1

ENV PYTHONUNBUFFERED=1

# Install necessary dependencies
# RUN apt-get update && \
#     apt-get install -y python3-pip

# Set the working directory. /app is mounted to the container with -v, 
# but we want to have the right cwd for uvicorn command below
RUN mkdir /app
WORKDIR /app

# # Copy the app code and requirements filed
# COPY . /app
# COPY requirements.txt .

# Install the app dependencies
RUN pip3 install "ctranslate2==3.16" torch sentencepiece uvicorn fastapi

# Expose the FastAPI port
EXPOSE 14000

# Start the FastAPI app using Uvicorn web server
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "14000", "--limit-concurrency", "1000"]

docker-dioco-ct2-chat.service [systemd unit file]

[Unit]
Description=Docker Container for dioco-ct2-chat
Requires=docker.service
After=docker.service

[Service]
WorkingDirectory=/root/projects/dioco-ct2-chat
ExecStart=/usr/bin/docker run --gpus '"device=0,1,2,3,4,5,6,7"' --restart=unless-stopped -p 14000:14000 -v /root/projects/chatModels/:/models/ -v /root/projects/dioco-ct2-chat:/app chat

[Install]
WantedBy=default.target

Can you provide the code regarding faster whisper? I'd love to see it!

@StephennFernandes
Copy link

hey guys,
Ive been working on building a translation API. so i came across this code. and i refactored it a little bit to work with nllb 3.3B model for translations.

the following is my working code:

# Import necessary libraries
import sys
from typing import List, Tuple
from fastapi import FastAPI, HTTPException, Response
from pydantic import BaseModel
import asyncio
from fastapi.responses import JSONResponse
import os
import ctranslate2
import sentencepiece as spm
import time
from datetime import datetime
import transformers 
model_dir = "ct2-nllb-3.3Q"
from transformers import AutoTokenizer as AT
tokenizer_name = "facebook/nllb-200-3.3B"
MAX_BATCH_SIZE = 5

class gpuProcessor:
    def __init__(self, device_number):

        self.busy = False
        self.load = 0
        self.device_number = device_number

        self.translator = ctranslate2.Translator(
            model_dir, device="cuda", device_index=[device_number],compute_type='int8_bfloat16', inter_threads=8)  # device_index=[0, 1, 2, 3]
        self.sp = transformers.AutoTokenizer.from_pretrained("facebook/nllb-200-3.3B",src_lang="hin_Deva")

    def isBusy(self):
        return self.busy

    def getLoad(self):
        return self.load

    async def generate(self, input_translations: List[str], src_langs: List[str], trg_langs: List[str]):

        self.busy = True
        self.load = len(input_translations)
        print("src_langs: ",src_langs)
        print("trg_langs: ",trg_langs)
        print("src_langs TYPE :: ", type(src_langs))
        print("trg_langs TYPE :: ", type(trg_langs))
        print("input_translations: ", input_translations)
        print("input_translations TYPE :: ", type(input_translations))
        now = time.time()

        def sync_func() -> List[str]:

            # now = time.time()
            tokens: List[List[str]] = [AT.from_pretrained(tokenizer_name, src_lang=src_lang).convert_ids_to_tokens(AT.from_pretrained(tokenizer_name, src_lang=src_lang).encode(input_translation)) for input_translation,src_lang in zip(input_translations,src_langs)]
            #tokens: List[List[str]] = [self.sp.convert_ids_to_tokens(self.sp.encode(input_translation)) for input_translation in input_translations]
            # print(f">> Tokenise time {time.time() - now}")
            # now = time.time()
            print(tokens)
            try:
                results = self.translator.translate_batch(
                    tokens,
                    batch_type="tokens",
                    target_prefix=[[trg_lang]for trg_lang in trg_langs],
                    beam_size=1,
                    # Edit: don't do this, value is too high and results in poor output:
                    # repetition_penalty=1.5
                )
            except Exception as e:
                print(f"An error occurred during generate_batch: {str(e)}")
                sys.exit(1)

            def decode(this_result) -> str:

                """if len(this_result.sequences) == 0:
                    return ''
                if len(this_result.sequences[0]) == 0:
                    return ''"""
                print("this_result: TYPE:: ",type(this_result))
                print(this_result)
                string: str = self.sp.decode(self.sp.convert_tokens_to_ids(this_result.hypotheses[0][1:]))
                return string

            output: List[str] = [decode(result) for result in results]
            return output

        # Run sync_func asyncronously, so we don't block the event loop.
        # Allows other requests to be handled meanwhile.
        loop = asyncio.get_event_loop()

        # print(f"START run_in_executor {self.device_number}")
        output2 = await loop.run_in_executor(None, lambda: sync_func())
        # print(f"END run_in_executor {self.device_number}")

        print(f">> Batch of {len(input_translations)} {time.time() - now}")

        self.busy = False
        self.load = 0

        return output2

    async def run(self):

        # print(f"Tick device: {self.device_number} busy: {self.busy}")

        global req_queue_HIGH
        global req_queue_MED
        global req_queue_LOW
        global req_queue_VLOW

        if self.busy:
            return

        queue_count = len(req_queue_HIGH) + len(req_queue_MED) + \
            len(req_queue_LOW) + len(req_queue_VLOW)

        if queue_count == 0:
            # print(f"{self.device_number} queue empty, run returning.")
            return

        space_remaining = MAX_BATCH_SIZE

        # Make a shallow copy of the queued items
        requestsProcessingNow = req_queue_HIGH[:space_remaining]
        # Remove these from the queue
        req_queue_HIGH = req_queue_HIGH[space_remaining:]

        space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_MED[:space_remaining]
            # Remove these from the queue
            req_queue_MED = req_queue_MED[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_LOW[:space_remaining]
            # Remove these from the queue
            req_queue_LOW = req_queue_LOW[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_VLOW[:space_remaining]
            # Remove these from the queue
            req_queue_VLOW = req_queue_VLOW[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if (len(requestsProcessingNow) == 0):
            return

        now = time.time()

        # print(f"{self.device_number} Running batch of {len(requestsProcessingNow)}, {len(req_queue_HIGH)} / {len(req_queue_MED)} / {len(req_queue_LOW)} / {len(req_queue_VLOW)} still queued.")
        #input_translations, src_langs, trg_langs = zip(*[(input_translations, src_langs, trg_langs) if len(request) == 3 else (request[0], request[1], "") for request in requestsProcessingNow])
        print("requestProcessingNow: ::",requestsProcessingNow)
        print("requestProcessingNow: TYPE :: ", type(requestsProcessingNow))
        
        input_translations = [request[0] for request in requestsProcessingNow]
        src_langs = [request[1] for request in requestsProcessingNow]
        trg_langs = [request[2] for request in requestsProcessingNow]
        results = await self.generate(input_translations, src_langs, trg_langs)

        # print(f"{self.device_number} Run batch of {len(requestsProcessingNow)} in {time.time() - now}s")

        for i, (_,_,_, future) in enumerate(requestsProcessingNow):
            future.set_result(results[i])


processors = [
    gpuProcessor(0),
    gpuProcessor(1),
    gpuProcessor(2),
    gpuProcessor(3),
]

curGPUStartIndex = 0


def getProcessors():
    global curGPUStartIndex
    curGPUStartIndex += 1
    if (curGPUStartIndex >= len(processors)):
        curGPUStartIndex = 0

    return processors[curGPUStartIndex:] + processors[:curGPUStartIndex]


# High Priority Requests
req_queue_HIGH: List[Tuple[str, asyncio.Future]] = [
]

# Medium Priority Requests
req_queue_MED: List[Tuple[str, asyncio.Future]] = [
]

# Low Priority Requests
req_queue_LOW: List[Tuple[str, asyncio.Future]] = [
]

# Very Low Priority Requests
req_queue_VLOW: List[Tuple[str, asyncio.Future]] = [
]

app = FastAPI()


class InputData(BaseModel):
    input_translation: str
    # max_len: int # Not used!
    src_langs: str
    trg_langs: str
    priority: str  # 'HIGH', 'MED', 'LOW', 'VLOW'
    class Config:
        arbitrary_types_allowed = True

@app.post("/dt_llama")
async def dt_llama(data: InputData):

    global req_queue_HIGH
    global req_queue_MED
    global req_queue_LOW
    global req_queue_VLOW

    # print(f"REQ: current queue {len(req_queue_HIGH)} / {len(req_queue_MED)} / {len(req_queue_LOW)} / {len(req_queue_VLOW)}")

    # Check for space..
    if data.priority == 'HIGH':
        if (len(req_queue_HIGH) > 500):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'MED':
        if ((len(req_queue_HIGH) + len(req_queue_MED)) > 500):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'LOW':
        if ((len(req_queue_HIGH) + len(req_queue_MED) + len(req_queue_LOW)) > 500):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'VLOW':
        if ((len(req_queue_HIGH) + len(req_queue_MED) + len(req_queue_LOW) + len(req_queue_VLOW)) > 500):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    else:
        # not HIGH, MED, LOW..
        return {
            'status': 'failure',
            'error': 'BAD_REQUEST'
        }

    try:
        # print(f"Generation request, len(data.input_translation) {len(data.input_translation)}.")

        future = asyncio.get_running_loop().create_future()
        request_data = (data.input_translation, data.src_langs, data.trg_langs, future)
        if data.priority == 'HIGH':
            req_queue_HIGH.append(request_data)
        elif data.priority == 'MED':
            req_queue_MED.append(request_data)
        elif data.priority == 'LOW':
            req_queue_LOW.append(request_data)
        elif data.priority == 'VLOW':
            req_queue_VLOW.append(request_data)

        output = await future

        
        return {
            'status': 'success',
            'data': {'text': output}
        }
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return Response(content='{"status": "failure", "error": "SERVER_ERROR"}', status_code=401)


class InputDataBatch(BaseModel):
    input_translations: List[str]
    src_langs = List[str]
    trg_langs = List[str]
    # max_len: int
    class Config:
        arbitrary_types_allowed = True

async def generate_loop():

    start_time = 0

    while True:

        await asyncio.sleep(0.5)

        end_time = time.process_time()
        # print(f"Process time: {end_time - start_time}")
        start_time = time.process_time()

        processors = getProcessors()

        for processor in processors:
            # For each GPU..
            asyncio.create_task(processor.run())

        # print(f"-----")

stats = []


@app.get("/dt_stats")
async def dt_stats():
    return {
        'status': 'success',
        'data': {'stats': stats}
    }


async def gpu_status_loop():
    global stats
    while True:
        # Run every 1 seconds
        await asyncio.sleep(1)

        total = 0

        for processor in processors:
            total += processor.getLoad()

        # Number from 0 to 1..
        load = total/(MAX_BATCH_SIZE * len(processors))

        queue = len(req_queue_HIGH) + len(req_queue_MED) + \
            len(req_queue_LOW) + len(req_queue_VLOW)

        # last 5 mins
        stats = stats[-299:] + [[load, queue]]

        if (len(stats) > 1):
            # Stats from last 5s
            # print(f"LOAD: {stats[-1][0]}/{stats[-1][1]} {stats[-2][0]}/{stats[-1][1]} {stats[-3][0]}/{stats[-1][1]} {stats[-4][0]}/{stats[-1][1]} {stats[-5][0]}/{stats[-1][1]}")
            print(f"LOAD: {str(load).ljust(5)} QUEUE: H {str(len(req_queue_HIGH)).ljust(3)} / M {str(len(req_queue_MED)).ljust(3)} / L {str(len(req_queue_LOW)).ljust(3)} / VL {str(len(req_queue_VLOW)).ljust(3)}")


root_tasks = []


@ app.on_event("startup")
async def startup_event():
    global root_tasks
    root_tasks.append(asyncio.create_task(generate_loop()))
    root_tasks.append(asyncio.create_task(gpu_status_loop()))

I have 4 A6000 and out of each of them only around 5-6GB of VRAM is using during inference per card. while the rest of the VRAM is just sitting idle.

I was wondering if there was a way to have multiple replicas of the same model on each card ? to maximize resources and have a higher throughput/concurrency with bare minimal latency.

i tried to reduce batch size, added inter_thread args but there isnt much of difference in load testing. could someone please let me know on what to exactly do ?

@StephennFernandes
Copy link

I made a wrapper around CTranslate2, an API server. It batches requests, supports running these batches on multiple gpus (sort of round-robin), requests can have different priority levels, requests are immeadiatly rejected when more than 100 requests in front of it in the queue, and has an endpoint that returns the approximate load. It doesn't do streaming, just does continuation for up to 80 tokens.. this fits with how I am using CTranslate2. It avoids a websocket (which can create networking issues), from the front end you can do requests in a loop until you see the string </s> in the continuation string. It's responsive enough with the 13B model at least, the performance overhead could be tested. @guillaumekln If you want to use it as an example, or put it in a seperate repo, please feel free. Knowing myself, I know I probably am not able to find time to provide support, but it may be useful for someone. I also have a python wrapper for faster-whisper. If you prefer to just delete the issue, I won't have hurt feelings. :) If there is interest, I can document the code better.

What you do on the front-end etc.:

let continuation = 'The best things in life are';
let counter = 0;
while(!continuation.includes('</s>') && counter < 5) {
    // Request more tokens from the server
    let data = { prompt: continuation, priority: 'MED' }; // 'HIGH' | 'MED' | 'LOW' | 'VLOW'
    try {
        const response = await axios.post(`http://mydomain.com/dt_llama`, data, { timeout: 10_000});
        if(response.data.status === 'success') {
            continuation += response.data.data.text;
            counter += 1;
            // Here write the continuation to your UI, or observable state etc.
        } else {
            console.error(response.data.error);
            break;
        }
    } catch(err) {
        // non-2XX/3XX response
        console.error(err);
        break;
    }
}

app.py

# Import necessary libraries
import sys
from typing import List, Tuple
from fastapi import FastAPI, HTTPException, Response
from pydantic import BaseModel
import asyncio
from fastapi.responses import JSONResponse
import os
import ctranslate2
import sentencepiece as spm
import time
from datetime import datetime

model_dir = "/models/vicuna-13b-1.1_ct2/"

MAX_BATCH_SIZE = 6

class gpuProcessor:
    def __init__(self, device_number):

        self.busy = False
        self.load = 0
        self.device_number = device_number

        self.generator = ctranslate2.Generator(
            model_dir, device="cuda", device_index=[device_number], compute_type='int8_float16')  # device_index=[0, 1, 2, 3]
        self.sp = spm.SentencePieceProcessor(
            os.path.join(model_dir, "tokenizer.model"))

    def isBusy(self):
        return self.busy

    def getLoad(self):
        return self.load

    async def generate(self, prompts: List[str]):

        self.busy = True
        self.load = len(prompts)

        now = time.time()

        def sync_func() -> List[str]:

            # now = time.time()

            tokens: List[List[str]] = [
                ["<s>"] + self.sp.encode(prompt, out_type=str) for prompt in prompts]

            # print(f">> Tokenise time {time.time() - now}")
            # now = time.time()

            try:
                results = self.generator.generate_batch(
                    tokens,
                    beam_size=1,
                    sampling_temperature=0.8,
                    sampling_topk=10,
                    num_hypotheses=1,
                    max_length=80, # Hard coded to 80!
                    min_length=1,
                    include_prompt_in_result=False,
                    max_batch_size=MAX_BATCH_SIZE,
                    return_end_token=True,
                    # Edit: don't do this, value is too high and results in poor output:
                    # repetition_penalty=1.5
                )
            except Exception as e:
                print(f"An error occurred during generate_batch: {str(e)}")
                sys.exit(1)

            # print(f">> Generate time {time.time() - now}")
            # now = time.time()

            def decode(this_result) -> str:

                # print(this_result)

                if len(this_result.sequences) == 0:
                    # No sequences..
                    return ''

                if len(this_result.sequences[0]) == 0:
                    # Sequence is len 0
                    return ''

                startsWithSpace = this_result.sequences[0][0][0] == '▁'
                hasEOS = this_result.sequences[0][-1] == '</s>'

                string: str = self.sp.decode(this_result.sequences_ids[0])
                if startsWithSpace:
                    string = ' ' + string
                if hasEOS:
                    string += '</s>'

                return string

            output: List[str] = [decode(result) for result in results]

            # print(f"+++||| {len(prompts)} --> {len(output)}")

            # for i in range(len(prompts)):
            #     print(f"+++ {prompts[i][-5:]} --> {output[i][:16]}")

            # print(f">> Decode time {time.time() - now}")

            return output

        # Run sync_func asyncronously, so we don't block the event loop.
        # Allows other requests to be handled meanwhile.
        loop = asyncio.get_event_loop()

        # print(f"START run_in_executor {self.device_number}")
        output2 = await loop.run_in_executor(None, lambda: sync_func())
        # print(f"END run_in_executor {self.device_number}")

        print(f">> Batch of {len(prompts)} {time.time() - now}")

        self.busy = False
        self.load = 0

        return output2

    async def run(self):

        # print(f"Tick device: {self.device_number} busy: {self.busy}")

        global req_queue_HIGH
        global req_queue_MED
        global req_queue_LOW
        global req_queue_VLOW

        if self.busy:
            return

        queue_count = len(req_queue_HIGH) + len(req_queue_MED) + \
            len(req_queue_LOW) + len(req_queue_VLOW)

        if queue_count == 0:
            # print(f"{self.device_number} queue empty, run returning.")
            return

        space_remaining = MAX_BATCH_SIZE

        # Make a shallow copy of the queued items
        requestsProcessingNow = req_queue_HIGH[:space_remaining]
        # Remove these from the queue
        req_queue_HIGH = req_queue_HIGH[space_remaining:]

        space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_MED[:space_remaining]
            # Remove these from the queue
            req_queue_MED = req_queue_MED[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_LOW[:space_remaining]
            # Remove these from the queue
            req_queue_LOW = req_queue_LOW[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if space_remaining > 0:
            # Make a shallow copy of the queued items
            requestsProcessingNow = requestsProcessingNow + \
                req_queue_VLOW[:space_remaining]
            # Remove these from the queue
            req_queue_VLOW = req_queue_VLOW[space_remaining:]
            # Decrement counter
            space_remaining = MAX_BATCH_SIZE - len(requestsProcessingNow)

        if (len(requestsProcessingNow) == 0):
            return

        now = time.time()

        # print(f"{self.device_number} Running batch of {len(requestsProcessingNow)}, {len(req_queue_HIGH)} / {len(req_queue_MED)} / {len(req_queue_LOW)} / {len(req_queue_VLOW)} still queued.")

        results = await self.generate([prompt for prompt, _ in requestsProcessingNow])

        # print(f"{self.device_number} Run batch of {len(requestsProcessingNow)} in {time.time() - now}s")

        for i, (_, future) in enumerate(requestsProcessingNow):
            future.set_result(results[i])


processors = [
    gpuProcessor(0),
    gpuProcessor(1),
    gpuProcessor(2),
    gpuProcessor(3),
    gpuProcessor(4),
    gpuProcessor(5),
    gpuProcessor(6),
    gpuProcessor(7),
]

curGPUStartIndex = 0


def getProcessors():
    global curGPUStartIndex
    curGPUStartIndex += 1
    if (curGPUStartIndex >= len(processors)):
        curGPUStartIndex = 0

    return processors[curGPUStartIndex:] + processors[:curGPUStartIndex]


# High Priority Requests
req_queue_HIGH: List[Tuple[str, asyncio.Future]] = [
]

# Medium Priority Requests
req_queue_MED: List[Tuple[str, asyncio.Future]] = [
]

# Low Priority Requests
req_queue_LOW: List[Tuple[str, asyncio.Future]] = [
]

# Very Low Priority Requests
req_queue_VLOW: List[Tuple[str, asyncio.Future]] = [
]

app = FastAPI()


class InputData(BaseModel):
    prompt: str
    # max_len: int # Not used!
    priority: str  # 'HIGH', 'MED', 'LOW', 'VLOW'


@app.post("/dt_llama")
async def dt_llama(data: InputData):

    global req_queue_HIGH
    global req_queue_MED
    global req_queue_LOW
    global req_queue_VLOW

    # print(f"REQ: current queue {len(req_queue_HIGH)} / {len(req_queue_MED)} / {len(req_queue_LOW)} / {len(req_queue_VLOW)}")

    # Check for space..
    if data.priority == 'HIGH':
        if (len(req_queue_HIGH) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'MED':
        if ((len(req_queue_HIGH) + len(req_queue_MED)) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'LOW':
        if ((len(req_queue_HIGH) + len(req_queue_MED) + len(req_queue_LOW)) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    elif data.priority == 'VLOW':
        if ((len(req_queue_HIGH) + len(req_queue_MED) + len(req_queue_LOW) + len(req_queue_VLOW)) > 100):
            return {
                'status': 'failure',
                'error': 'QUEUE_FULL'
            }
    else:
        # not HIGH, MED, LOW..
        return {
            'status': 'failure',
            'error': 'BAD_REQUEST'
        }

    try:
        # print(f"Generation request, len(data.prompt) {len(data.prompt)}.")

        future = asyncio.get_running_loop().create_future()

        if data.priority == 'HIGH':
            req_queue_HIGH.append((data.prompt, future))
        elif data.priority == 'MED':
            req_queue_MED.append((data.prompt, future))
        elif data.priority == 'LOW':
            req_queue_LOW.append((data.prompt, future))
        elif data.priority == 'VLOW':
            req_queue_VLOW.append((data.prompt, future))

        output = await future

        # print('Output: ' + output)

        # if data.priority == 'HIGH':
        #     print('MEOW: ' + output)
        # if data.priority == 'MED':
        #     print('MEOW: ' + output)

        return {
            'status': 'success',
            'data': {'text': output}
        }
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return Response(content='{"status": "failure", "error": "SERVER_ERROR"}', status_code=401)


class InputDataBatch(BaseModel):
    prompts: List[str]
    # max_len: int


async def generate_loop():

    start_time = 0

    while True:

        await asyncio.sleep(0.5)

        end_time = time.process_time()
        # print(f"Process time: {end_time - start_time}")
        start_time = time.process_time()

        processors = getProcessors()

        for processor in processors:
            # For each GPU..
            asyncio.create_task(processor.run())

        # print(f"-----")

stats = []


@app.get("/dt_stats")
async def dt_stats():
    return {
        'status': 'success',
        'data': {'stats': stats}
    }


async def gpu_status_loop():
    global stats
    while True:
        # Run every 1 seconds
        await asyncio.sleep(1)

        total = 0

        for processor in processors:
            total += processor.getLoad()

        # Number from 0 to 1..
        load = total/(MAX_BATCH_SIZE * len(processors))

        queue = len(req_queue_HIGH) + len(req_queue_MED) + \
            len(req_queue_LOW) + len(req_queue_VLOW)

        # last 5 mins
        stats = stats[-299:] + [[load, queue]]

        if (len(stats) > 1):
            # Stats from last 5s
            # print(f"LOAD: {stats[-1][0]}/{stats[-1][1]} {stats[-2][0]}/{stats[-1][1]} {stats[-3][0]}/{stats[-1][1]} {stats[-4][0]}/{stats[-1][1]} {stats[-5][0]}/{stats[-1][1]}")
            print(f"LOAD: {str(load).ljust(5)} QUEUE: H {str(len(req_queue_HIGH)).ljust(3)} / M {str(len(req_queue_MED)).ljust(3)} / L {str(len(req_queue_LOW)).ljust(3)} / VL {str(len(req_queue_VLOW)).ljust(3)}")


root_tasks = []


@ app.on_event("startup")
async def startup_event():
    global root_tasks
    root_tasks.append(asyncio.create_task(generate_loop()))
    root_tasks.append(asyncio.create_task(gpu_status_loop()))

Dockerfile

FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04

# Use Python 3.11 for better Python perf
# Update the package lists and install necessary dependencies
RUN apt-get update && apt-get install -y \
    software-properties-common \
    && add-apt-repository -y ppa:deadsnakes/ppa \
    && apt-get update \
    && apt-get install -y python3.11 python3.11-dev

# Set Python 3.11 as the default version (for python3)
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1

# Download get-pip.py script
RUN apt install curl -y
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py

# Install pip for Python 3.11
RUN python3 get-pip.py

# Verify Python and pip versions
RUN python3 --version && pip3.11 --version

# Set pip3.11 as the default pip command
RUN update-alternatives --install /usr/bin/pip3 pip3 /usr/local/lib/python3.11/dist-packages/pip 1

ENV PYTHONUNBUFFERED=1

# Install necessary dependencies
# RUN apt-get update && \
#     apt-get install -y python3-pip

# Set the working directory. /app is mounted to the container with -v, 
# but we want to have the right cwd for uvicorn command below
RUN mkdir /app
WORKDIR /app

# # Copy the app code and requirements filed
# COPY . /app
# COPY requirements.txt .

# Install the app dependencies
RUN pip3 install "ctranslate2==3.16" torch sentencepiece uvicorn fastapi

# Expose the FastAPI port
EXPOSE 14000

# Start the FastAPI app using Uvicorn web server
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "14000", "--limit-concurrency", "1000"]

docker-dioco-ct2-chat.service [systemd unit file]

[Unit]
Description=Docker Container for dioco-ct2-chat
Requires=docker.service
After=docker.service

[Service]
WorkingDirectory=/root/projects/dioco-ct2-chat
ExecStart=/usr/bin/docker run --gpus '"device=0,1,2,3,4,5,6,7"' --restart=unless-stopped -p 14000:14000 -v /root/projects/chatModels/:/models/ -v /root/projects/dioco-ct2-chat:/app chat

[Install]
WantedBy=default.target

Hey would love to checkout the similar wrapper on faster-whisper

@silvacarl2
Copy link

THIS IS TOTALLY COOL!!!!!!!!!!!!!!!!!!

@rschuurmanbench
Copy link

This looks nice; Any plans to put this on docker hub or Github container registry? Or other integration plans?

@vince62s
Copy link
Member

as great as it is, first poster said "probably no time to support it" so we can't commit things that won't have support. but it satys here. closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants