In [31]:
import torch

from peft import PeftConfig
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
device = 'cuda'
model_name = "lmsys/vicuna-13b-v1.5"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    use_cache=False,
    torch_dtype=torch.float16,
    load_in_8bit=False
)

In [26]:
model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 5120, padding_idx=0)
    (layers): ModuleList(
      (0-39): 40 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (k_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (v_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
 

# Inference
- [x] Test we can run inference on 13B
- [x] Test we load a lora, trained on 7B, onto 13B
- [ ] Build text-only QA dataset
- [ ] Fine-tune 13B on dataset

## Generic Inference

In [22]:
input_ids = tokenizer("Hi, who are you?", return_tensors='pt').to(device)
output_ids = model.generate(input_ids.input_ids)
tokenizer.batch_decode(output_ids)

["<s> Hi, who are you?\n\nI'm a language model called Vicuna, and I was trained by Large Model Systems Organization (LMSYS) researchers.</s>"]

## Lora

In [25]:
test_7b_lora = "Charlie911/vicuna-7b-v1.5-lora-sharegpt-without-timedial"

In [33]:
peft_config = PeftConfig.from_pretrained(test_7b_lora)
model = peft.get_peft_model(model, peft_config)

LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='lmsys/vicuna-7b-v1.5', revision=None, task_type='CAUSAL_LM', inference_mode=True, r=16, target_modules={'q_proj', 'v_proj'}, lora_alpha=32, lora_dropout=0.05, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={})

In [37]:
model.print_trainable_parameters()
model

trainable params: 0 || all params: 13,028,971,520 || trainable%: 0.0


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 5120, padding_idx=0)
        (layers): ModuleList(
          (0-39): 40 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=5120, out_features=5120, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=5120, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=5120, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear(in_features=5120, out_features=5120, bia

In [39]:
input_ids = tokenizer("Hi, how are you?", return_tensors='pt').to(device)
output_ids = model.generate(input_ids.input_ids)
tokenizer.batch_decode(output_ids)

["<s> Hi, how are you?\n\nI'm a new user here and I'm trying to learn more about the platform. I'm interested in creating a new post, but I'm not sure how to get started. Can you help me?\n\nThanks!</s>"]