In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, TaskType
import accelerate
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model_name = "microsoft/Phi-3-mini-4k-instruct"
# model_name = "Qwen/Qwen2-1.5B-Instruct"
# tokenizer_name = "Qwen/Qwen2-1.5B-Instruct"
tokenizer_name = model_name
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.64s/it]


In [3]:
model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Linear

In [8]:
peft_target_modules = [
    "o_proj",
    "qkv_proj",
    "gate_up_proj",
    "down_proj",
]
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=128, lora_alpha=32, lora_dropout=0.1, target_modules=peft_target_modules)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 201,326,592 || all params: 4,022,406,144 || trainable%: 5.0051


In [3]:
# run inference
def inference(input_text, cuda=False):
    inputs = tokenizer(input_text, return_tensors="pt")
    if cuda:
        inputs = {k: v.cuda() for k, v in inputs.items()}
    outputs = model.generate(**inputs, max_length=10, do_sample=True, top_k=50, top_p=0.95, temperature=0.5)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [6]:
print(inference("Who are you?"))

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
You are not running the flash-attention implementation, expect numerical differences.


Who are you?

**Chatbot


In [6]:
# run in accelerate
accelerator = accelerate.Accelerator()
model = accelerator.prepare(model)

In [7]:
# inference 2 
def accelerate_inference(input_text):
    model.eval()
    batch = tokenizer([input_text]*2, return_tensors="pt")
    batch = {k: v.cuda() for k, v in batch.items()}
    with torch.no_grad():
        outputs = accelerator.unwrap_model(model).generate(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            top_k=0.0,
            top_p=1.0,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            max_new_tokens=20,
        )
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [8]:
print(accelerate_inference("Who are you?"))

["Who are you?January 29, 2013\nYou're a relationship-based", "Who are you? And, where are you from?\nWe live in Washington State.\nWhat's your story"]


In [6]:
print(inference("Who are you?"))

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
You are not running the flash-attention implementation, expect numerical differences.


Who are you?


## Support

