In [1]:
! python -m pip install lm-format-enforcer transformers==4.41.2 bitsandbytes accelerate

Collecting lm-format-enforcer
  Downloading lm_format_enforcer-0.10.9-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.9/43.9 KB[0m [31m503.1 kB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hCollecting transformers==4.41.2
  Downloading transformers-4.41.2-py3-none-any.whl (9.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.1/9.1 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting bitsandbytes
  Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl (122.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.4/122.4 MB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting tokenizers<0.20,>=0.19
  Downloading tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m24.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m
Collecting interegul

In [2]:
import json
from pydantic import BaseModel
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn
import torch

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
    BitsAndBytesConfig,
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_id = "microsoft/Phi-3-mini-128k-instruct"
#model_id = "microsoft/Phi-3-mini-4k-instruct"
device = "cuda"
#device = "cpu"

In [4]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    #bnb_4bit_block_size=64,
    bnb_4bit_use_double_quant=True,
    #llm_int8_enable_fp32_cpu_offload=True,
)

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    quantization_config=quantization_config, 
    trust_remote_code=True, 
    device_map="auto",
    attn_implementation="flash_attention_2",
)

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-128k-instruct:
- configuration_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-128k-instruct:
- modeling_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Downloading shards: 100%|██████████| 2/2 [02:33<00:00, 76.86s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.51s/it]


In [6]:
tokenizer = AutoTokenizer.from_pretrained(
    model_id, 
    trust_remote_code=True,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Manual enforcing

In [96]:
allowed_token_ids = torch.tensor([
    token_id for word, token_id in tokenizer.vocab.items() 
    if (word.upper() == word and ":" not in word)
    or token_id in tokenizer.all_special_tokens
]).to(device)

In [11]:
def phi_prompt(*, system: str, user: str) -> str: 
    return "\n".join([
        "<|system|>", 
        f"{system}<|end|>", 
        "<|user|>", 
        f"{user}<|end|>", 
        "<|assistant|>",
    ])

In [76]:
prompt = phi_prompt(
    system="You are a helpful history expert",
    user="Who was the first man on the moon?",
)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

In [93]:
def step(input_ids) -> bool:
    logits = model(input_ids).logits
    prediction_logits = logits[:, -1, :].squeeze(dim=0)
    # restrict to the allowed tokens!
    allowed_predictions = prediction_logits[allowed_token_ids]
    token_idx = torch.argmax(allowed_predictions).item()
    next_token_id = allowed_token_ids[token_idx]
    # append the next_token
    return torch.cat((
    input_ids,
        next_token_id.unsqueeze(0).unsqueeze(0),
    ), dim=-1)

In [97]:
seq = input_ids
# we'll arbitrarily predict the next 30 tokens
for i in range(30):
    seq = step(seq)
    print(
        tokenizer.decode(
            seq[0][input_ids.shape[-1]:].squeeze()
        )
    )
    pass

A
ASTR
ASTRON
ASTRONAUT
ASTRONAUTS
ASTRONAUTS NE
ASTRONAUTS NEIL
ASTRONAUTS NEIL AR
ASTRONAUTS NEIL ARM
ASTRONAUTS NEIL ARMSTR
ASTRONAUTS NEIL ARMSTRONG
ASTRONAUTS NEIL ARMSTRONG AND
ASTRONAUTS NEIL ARMSTRONG AND B
ASTRONAUTS NEIL ARMSTRONG AND BILL
ASTRONAUTS NEIL ARMSTRONG AND BILL D
ASTRONAUTS NEIL ARMSTRONG AND BILL DAV
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES W
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES WERE
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES WERE THE
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES WERE THE F
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES WERE THE FIR
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES WERE THE FIRST
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES WERE THE FIRST M
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES WERE THE FIRST MEN
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES WERE THE FIRST MEN TO
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES WERE THE FIRST MEN TO S
ASTRONAUTS NEIL ARMSTRONG AND BILL DAVIES WERE THE FIRST MEN TO STE
ASTRONA

## Using 'lm-format-enforcer'

In [98]:
class SummaryShape(BaseModel):
    name: str
    age: int
    birthplace: str
    profession: str
    nicknames: list[str]
    pass

In [99]:
MAX_PAGE_LENGTH = 40000
def get_plaintext_wikipedia_page(title):
    import requests
    url = f"https://en.wikipedia.org/w/api.php"
    params = {
        "action": "query",
        "prop": "extracts",
        "explaintext": True,
        "titles": title,
        "format": "json"
    }
    response = requests.get(url, params=params)
    data = response.json()
    page = next(iter(data['query']['pages'].values()))
    return page['extract'][:MAX_PAGE_LENGTH]

In [100]:
plaintext_contents = get_plaintext_wikipedia_page("Michael_Jordan")
print(plaintext_contents[:250], "...")

Michael Jeffrey Jordan (born February 17, 1963), also known by his initials MJ, is an American businessman and former professional basketball player. He played 15 seasons in the National Basketball Association (NBA) between 1984 and 2003, winning six ...


In [101]:
pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, trust_remote_code=True)

In [102]:
parser = JsonSchemaParser(SummaryShape.schema())

In [103]:
prefix_function = build_transformers_prefix_allowed_tokens_fn(pipe.tokenizer, parser)

In [104]:
def get_structured_output(prompt: str):
    with torch.no_grad():
        output_dict = pipe(
            prompt,
            max_new_tokens=250,
            prefix_allowed_tokens_fn=prefix_function
        )
        return json.loads(output_dict[0]['generated_text'][-1]["content"])

In [105]:
get_structured_output([
    {"role": "system", "content": "You are a helpful assistant that responds in JSON"},
    {"role": "user", "content": "\n".join([
        "Sumamrize (in JSON format) the following information",
        plaintext_contents,    
    ])}
])

{'name': 'Michael Jordan',
 'nicknames': ['Air Jordan', 'His Airness'],
 'profession': 'Basketball Player',
 'birthplace': 'Brooklyn, New York City, New York',
 'age': 61}