In [7]:
import json
import sys
import time
import warnings
from pathlib import Path
from typing import Literal, Optional

import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy



from generate.base import generate
from lit_gpt import Tokenizer
from lit_gpt.lora import GPT, Block, Config, merge_lora_weights
from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, lazy_load, quantization
from scripts.prepare_sql import generate_prompt

In [8]:
lora_r = 4
lora_alpha = 16
lora_dropout = 0.05
lora_query = True
lora_key = True
lora_value = True
lora_projection = True
lora_mlp = True
lora_head = True

In [15]:
lora_path: Path = Path("out/lora/sql_llama_cosine_scheduler/iter-025599-ckpt.pth")
checkpoint_dir: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf")
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = "bnb.nf4"
max_new_tokens: int = 100
top_k: int = 200
temperature: float = 0.8
strategy: str = "auto"
devices: int = 1
precision: Optional[str] = 'bf16-true'

In [16]:
precision = precision or get_default_supported_precision(training=False)

if strategy == "fsdp":
    strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch()

check_valid_checkpoint_dir(checkpoint_dir)

with open(checkpoint_dir / "lit_config.json") as fp:
    config_params = dict(
        r=lora_r,
        alpha=lora_alpha,
        dropout=lora_dropout,
        to_query=lora_query,
        to_key=lora_key,
        to_value=lora_value,
        to_projection=lora_projection,
        to_mlp=lora_mlp,
        to_head=lora_head,
    )
    config_params.update(**json.load(fp))
    config = Config(**config_params)

You are using a CUDA device ('NVIDIA A10G') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


In [19]:
if quantize is not None and devices > 1:
    raise NotImplementedError
if quantize == "gptq.int4":
    model_file = "lit_model_gptq.4bit.pth"
    if not (checkpoint_dir / model_file).is_file():
        raise ValueError("Please run `python quantize/gptq.py` first")
else:
    model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True), quantization(quantize):
    model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

Loading model 'checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth' with {'org': 'meta-llama', 'name': 'Llama-2-7b-chat-hf', 'block_size': 4096, 'vocab_size': 32000, 'padding_multiple': 64, 'padded_vocab_size': 32000, 'n_layer': 32, 'n_head': 32, 'n_embd': 4096, 'rotary_percentage': 1.0, 'parallel_residual': False, 'bias': False, 'n_query_groups': 32, 'shared_attention_norm': False, '_norm_class': 'RMSNorm', 'norm_eps': 1e-05, '_mlp_class': 'LLaMAMLP', 'intermediate_size': 11008, 'condense_ratio': 1, 'r': 4, 'alpha': 16, 'dropout': 0.05, 'to_query': True, 'to_key': True, 'to_value': True, 'to_projection': True, 'to_mlp': True, 'to_head': True}
Time to instantiate model: 7.38 seconds.


In [20]:
t0 = time.perf_counter()
with lazy_load(checkpoint_path) as checkpoint, lazy_load(lora_path) as lora_checkpoint:
    checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
    model.load_state_dict(checkpoint, strict=quantize is None)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
merge_lora_weights(model)
model = fabric.setup(model)

Time to load the model weights: 6.89 seconds.
  rank_zero_warn(


In [90]:
prompt: str = "Select the "
input: str = "CREATE TABLE food(name VARCHAR, calories INTEGER)"
tokenizer = Tokenizer(checkpoint_dir)
sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)
# LLAMA encoding
encoded = tokenizer.encode(prompt, device=model.device, bos=True)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
t0 = time.perf_counter()
y = generate(
    model,
    encoded,
    max_returned_tokens,
    max_seq_length=max_returned_tokens,
    temperature=temperature,
    top_k=top_k,
    eos_id=tokenizer.eos_id,
)
t = time.perf_counter() - t0

model.reset_cache()
output = tokenizer.decode(y)
fabric.print(output)

[INST] <<SYS>>
You are an expert SQL programmer and system administrator. Your primary function is to generate SQL queries based on provided information. Your tone is professional. Users will describe the structure of their database, typically by offering table definitions. An example of this is:
CREATE TABLE A (b INTEGER).
Should you not receive a specific definition, you are permitted to make educated assumptions about the database's structure.
You can make any assumptions about the database's structure if some terminology is loose. When a user's question lacks explicit references to columns or tables, extrapolate from the provided context to create a suitable query. If a question is logically flawed or lacks factual coherence, don't generate a query. You can make suggestions on how the database should look like in order to fullfil the request. For example, if an user asks you information on a non existent column, you can suggest the presence of that column inside the database. If th

In [91]:
stringa = "\n[INST]  What is the meaning of life? You can only use Emojis.[/INST]\n"
second_input = torch.concat((tokenizer.encode(output, eos=True, bos=True,device=model.device), tokenizer.encode(stringa, bos=True,device=model.device)))

In [92]:
t0 = time.perf_counter()
y = generate(
    model,
    second_input,
    max_returned_tokens,
    max_seq_length=max_returned_tokens,
    temperature=temperature,
    top_k=top_k,
    eos_id=tokenizer.eos_id,
)
t = time.perf_counter() - t0

model.reset_cache()
output = tokenizer.decode(y)
fabric.print(output)

[INST] <<SYS>>
You are an expert SQL programmer and system administrator. Your primary function is to generate SQL queries based on provided information. Your tone is professional. Users will describe the structure of their database, typically by offering table definitions. An example of this is:
CREATE TABLE A (b INTEGER).
Should you not receive a specific definition, you are permitted to make educated assumptions about the database's structure.
You can make any assumptions about the database's structure if some terminology is loose. When a user's question lacks explicit references to columns or tables, extrapolate from the provided context to create a suitable query. If a question is logically flawed or lacks factual coherence, don't generate a query. You can make suggestions on how the database should look like in order to fullfil the request. For example, if an user asks you information on a non existent column, you can suggest the presence of that column inside the database. If th