In [1]:
import asyncio
import json
import time
from typing import List

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from pydantic import BaseModel, constr

In [2]:
from lorax import AsyncClient, Client
from utils import endpoint_url, headers

client = Client(endpoint_url, headers=headers)

## Prefill vs Decode (KV Cache)


In [None]:
t0 = time.time()
resp = client.generate("What is deep learning?", max_new_tokens=32)
duration_s = time.time() - t0

print(resp.generated_text)
print("\n\n-------------")
print("Request duration (s):", duration_s)


In [3]:
duration_s = []
t0 = time.time()
for resp in client.generate_stream("What is deep learning?", max_new_tokens=32):
    duration_s.append(time.time() - t0)
    if not resp.token.special:
        print(resp.token.text, sep="", end="", flush=True)
    t0 = time.time()

print("\n\n-------------")
print("Time to first token (TTFT) (s):", duration_s[0])
print("Throughout (tok / s):", (len(duration_s) - 1) / sum(duration_s[1:]))

In [None]:
plt.plot(duration_s)
plt.show()

##  Continuous Batching

In [4]:
color_codes = [
    "31", # red
    "32", # green
    "33", # blue
]

def format_text(text, i):
    return f"\x1b[{color_codes[i]}m{text}\x1b[0m"

In [5]:
async_client = AsyncClient(endpoint_url, headers=headers)

duration_s = [[], [], []]

async def run(max_new_tokens, i):
    t0 = time.time()
    async for resp in async_client.generate_stream("What is deep learning?", max_new_tokens=max_new_tokens):
        duration_s[i].append(time.time() - t0)
        print(format_text(resp.token.text, i), sep="", end="", flush=True)
        t0 = time.time()


t0 = time.time()
all_max_new_tokens = [100, 10, 10]
await asyncio.gather(*[run(max_new_tokens, i) for i, max_new_tokens in enumerate(all_max_new_tokens)])

print("\n\n-------------")
print("Time to first token (TTFT) (s):", [s[0] for s in duration_s])
print("Throughout (tok / s):", [(len(s) - 1) / sum(s[1:]) for s in duration_s])
print("Total duration (s):", time.time() - t0)


NameError: name 'endpoint_url' is not defined

## Multi-LoRA

In [6]:
def run_with_adapter(prompt, adapter_id):
    duration_s = []

    t0 = time.time()
    for resp in client.generate_stream(
        prompt,
        adapter_id=adapter_id,
        adapter_source="hub",
        max_new_tokens=64,
    ):
        duration_s.append(time.time() - t0)
        if not resp.token.special:
            print(resp.token.text, sep="", end="", flush=True)
        t0 = time.time()

    print("\n\n-------------")
    print("Time to first token (TTFT) (s):", duration_s[0])
    print("Throughout (tok / s):", (len(duration_s) -1) / sum(duration_s[1:]))

In [None]:
pt_hellaswag_processed = \
    """You are provided with an incomplete passage below. Please read the passage and then finish it with an appropriate response. For example:

    ### Passage: My friend and I think alike. We

    ### Ending: often finish each other's sentences.

    Now please continue the following passage:

    ### Passage: {ctx}

    ### Ending: """

ctx = "Numerous people are watching others on a filed. Trainers are playing frisbee with their dogs. the dogs"

run_with_adapter(pt_hellaswag_processed.format(ctx=ctx), adapter_id="predibase/hellaswag_processed")


In [None]:
pt_cnn = \
"""You are given a news article below. Please summarize the article, including only its highlights.

### Article: {article}

### Summary: """

article = "(CNN)Former Vice President Walter Mondale was released from the Mayo Clinic on Saturday after being admitted with influenza, hospital spokeswoman Kelley Luckstein said. \"He's doing well. We treated him for flu and cold symptoms and he was released today,\" she said. Mondale, 87, was diagnosed after he went to the hospital for a routine checkup following a fever, former President Jimmy Carter said Friday. \"He is in the bed right this moment, but looking forward to come back home,\" Carter said during a speech at a Nobel Peace Prize Forum in Minneapolis. \"He said tell everybody he is doing well.\" Mondale underwent treatment at the Mayo Clinic in Rochester, Minnesota. The 42nd vice president served under Carter between 1977 and 1981, and later ran for President, but lost to Ronald Reagan. But not before he made history by naming a woman, U.S. Rep. Geraldine A. Ferraro of New York, as his running mate. Before that, the former lawyer was  a U.S. senator from Minnesota. His wife, Joan Mondale, died last year."


run_with_adapter(pt_cnn.format(article=article), adapter_id="predibase/cnn")

In [None]:
pt_conllpp = """
Your task is a Named Entity Recognition (NER) task. Predict the category of
each entity, then place the entity into the list associated with the
category in an output JSON payload. Below is an example:

Input: EU rejects German call to boycott British lamb . Output: {{"person":
[], "organization": ["EU"], "location": [], "miscellaneous": ["German",
"British"]}}

Now, complete the task.

Input: {inpt} Output:"""

inpt = "Only France and Britain backed Fischler 's proposal ."


run_with_adapter(pt_conllpp.format(inpt=inpt), adapter_id="predibase/conllpp")

In [None]:
duration_s = [[], [], []]

async def run(prompt, adapter_id, i):
    t0 = time.time()
    async for resp in async_client.generate_stream(
        prompt,
        adapter_id=adapter_id,
        adapter_source="hub",
        max_new_tokens=64,
    ):
        duration_s[i].append(time.time() - t0)
        if not resp.token.special:
            print(format_text(resp.token.text, i), sep="", end="", flush=True)
        t0 = time.time()

t0 = time.time()
prompts = [
    pt_hellaswag_processed.format(ctx=ctx),
    pt_cnn.format(article=article),
    pt_conllpp.format(inpt=inpt),
]
adapter_ids = [
    "predibase/hellaswag_processed",
    "predibase/cnn",
    "predibase/conllpp",
]
await asyncio.gather(*[run(prompt, adapter_id, i)
                       for i, (prompt, adapter_id) in enumerate(zip(prompts, adapter_ids))])
print("\n\n-------------")
print("Time to first token (TTFT) (s):", [s[0] for s in duration_s])
print("Throughout (tok / s):", [(len(s) - 1) / sum(s[1:]) for s in duration_s])
print("Total duration (s):", time.time() - t0)

## Bonus: Structured Generation

In [7]:
from pydantic import BaseModel, constr

class Person(BaseModel):
    name: constr(max_length=10)
    age: int


schema = Person.model_json_schema()
schema

{'properties': {'name': {'maxLength': 10, 'title': 'Name', 'type': 'string'},
  'age': {'title': 'Age', 'type': 'integer'}},
 'required': ['name', 'age'],
 'title': 'Person',
 'type': 'object'}

In [None]:
resp = client.generate(
    "Create a person description for me",
    response_format={"type": "json_object", "schema": schema},
)
json.loads(resp.generated_text)

In [None]:
prompt_template = """
Your task is a Named Entity Recognition (NER) task. Predict the category of each entity, then place the entity into the list associated with the category in an output JSON payload. Below is an example:

Input: EU rejects German call to boycott British lamb . Output: {{"person": [], "organization": ["EU"], "location": [], "miscellaneous": ["German", "British"]}}

Now, complete the task.

Input: {inpt} Output:"""

# Base Mistral-7B
resp = client.generate(
    prompt_template.format(input="Only France and Britain backed Fischler 's proposal ."),
    max_new_tokens=128,
)
resp.generated_text

In [8]:
from typing import List

class Output(BaseModel):
    person: List[str]
    organization: List[str]
    location: List[str]
    miscellaneous: List[str]

schema = Output.model_json_schema()
schema

{'properties': {'person': {'items': {'type': 'string'},
   'title': 'Person',
   'type': 'array'},
  'organization': {'items': {'type': 'string'},
   'title': 'Organization',
   'type': 'array'},
  'location': {'items': {'type': 'string'},
   'title': 'Location',
   'type': 'array'},
  'miscellaneous': {'items': {'type': 'string'},
   'title': 'Miscellaneous',
   'type': 'array'}},
 'required': ['person', 'organization', 'location', 'miscellaneous'],
 'title': 'Output',
 'type': 'object'}

In [None]:
resp = client.generate(
    prompt_template.format(inpt="Only France and Britain backed Fischler 's proposal ."),
    response_format={"type": "json_object", "schema": schema},
    max_new_tokens=128,
)
json.loads(resp.generated_text)

In [None]:
resp = client.generate(
    prompt_template.format(inpt="Only France and Britain backed Fischler 's proposal ."),
    adapter_id="predibase/conllpp",
    adapter_source="hub",
    response_format={"type": "json_object", "schema": schema},
    max_new_tokens=128,
)
json.loads(resp.generated_text)